# Summary

This notebook demonstrates how to replicate the results from Section 5 of the paper _"Learning Block-sparse Partial Correlation Graphs"_. To ensure manageable computational time, we apply the models to $3$ out of $50$ data sets, using the setting $\bar{N}_{\mathcal{Y}}=10$. However, the code can be easily adjusted to reproduce other settings described in the article, as long as the hyper-parameters are modified according to the information provided in the supplemental material. Lastly, we compare the obtained results with those reported in the paper.

# Libraries

In [1]:
import jax 
import jax.numpy as jnp

from copy import deepcopy

from jax.numpy import array, expand_dims, hanning, ones
from jax.numpy.fft import rfft
from jax.numpy.linalg import norm

from scipy.signal import fftconvolve

from src.models.convex import TSGLASSO
from src.models.nonconvex import CF_method, IA_method
from src.utils import partial_coherence, hpinv, load_obj
from src.metrics import count_accuracy, blocks_skeleton

from jax.config import config 
config.update("jax_enable_x64", True)

In [2]:
data_dir="data/"

# Load the ground truth.

In [3]:
N=6 
T=128 
T1=T//2+1

gt_est = load_obj('1.0_gt_estimation', data_dir) #ground truth
K=((jnp.array(gt_est['rs_time_intervals'])*(T1-1)*2)[:,1][:-1]).astype(jnp.int64)
K_=jnp.array([8,56],dtype=jnp.int64)
R_true=gt_est['Rl2_norm']
F_true=gt_est['F_best']



In [4]:
# number of edges per layer used by CF-fk
r_bar = 5.e-2
d,_,g = R_true.shape
B_true = jnp.where(R_true>r_bar, 1., 0.)
k_true=jnp.zeros(g, dtype=jnp.int16)
for bg in range(g):
    # index of nonzeros, only strictly lower triangular part
    cond = jnp.flatnonzero(jnp.tril(B_true[...,bg],k=-1))
    k_true=k_true.at[bg].set(len(cond))
k_true

Array([0, 5, 7, 7, 7, 6, 4, 2], dtype=int16)

# Test example

As an example, we apply the considered methods on $3$ out of $50$ data sets, for the setting $\bar{N}_{\mathcal{Y}}=10$.
To reproduce other settings, please change the hyper-parameters according to the table below.

In [5]:
hyperparams=load_obj("1.0_hyperparameters", data_dir)
display(hyperparams)

nsamples,model,lmbd,coeff
5,IA-bs,0.3,0.01
5,IA-gs,0.5,0.01
5,TSGLASSO$_{\alpha=0.0}$,1.0,
5,TSGLASSO$_{\alpha=0.5}$,0.7,
10,IA-bs,0.3,0.01
10,IA-gs,0.5,0.01
10,TSGLASSO$_{\alpha=0.0}$,1.0,
10,TSGLASSO$_{\alpha=0.5}$,0.5,
20,IA-bs,0.3,0.01
20,IA-gs,0.5,0.01


In [6]:
#utils to compute the raw periodogram
def raw_periodogram(ts):
    multiply_=lambda x: 1./(2*jnp.pi)*(x.reshape(-1,1)@x.reshape(-1,1).conj().T)
    Y=rfft(ts, norm="ortho", axis=0)
    periodogram=jax.vmap(multiply_, 0, 2)(Y)
    return periodogram

In [7]:
# szs=[5, 10, 20, 50, 100, 1000]
szs=[10] #consider only this setting
ndatasets=50
performances = dict()

tauF=1.e-3
tauP=1.e-3
kind=1 
kind1=-99 
c1=.9 
c2=.99 
c3=.99 
c4=.99 
max_iter=2000
step=500 #frequency for printing info
tol=5.e-4
tolp_abs=tol 
tolp_rel=tol 
told_abs=tol 
told_rel=tol
only_primal=False

alpha_TSGLASSO00=0.
alpha_TSGLASSO05=0.5
rho_TSGLASSO=2.
penalize_diag=False
varying_rho=False

#set these hyper-params 
# according to supplemental material of the article
coeff_IAbs=1.e-2
lmbd_IAbs=3.e-1

coeff_IAgs=1.e-2
lmbd_IAgs=5.e-1

lmbd_TSGLASSO00=1.
lmbd_TSGLASSO05=0.5

for sz in szs:
      
    performances[sz]=dict()
    
    periodograms=load_obj("1.0_dataset_{}_{}".format(sz,ndatasets),data_dir)
    
    # ntrials=periodograms.shape[0]
    ntrials=3 #restrict the number of data sets for testing
    
    for trial in range(ntrials):
        print("########## N SAMPLES {} DATA SET {} ##########".format(sz,trial+1))
        
        periodogram = deepcopy(periodograms[trial])
        performances[sz][trial]=dict()
        
        N,_,T1=periodogram.shape

        p=3
        window_box=lambda x: ones(N,N,2*x+1)/(2*x+1)
        window_hann=lambda x: ones((N,N,1))*expand_dims(hanning(2*x+1)/(hanning(2*x+1)).sum(), (0,1))

        smoothing=lambda x,win: fftconvolve(x, win, mode='same', axes=-1)
        win = window_hann(p)
        smoothed_periodogram=smoothing(periodogram,win)

        F_hat=deepcopy(smoothed_periodogram)
        
        eta_IAbs_=coeff_IAbs*norm(F_hat)
        eta_IAbs=eta_IAbs_

        eta_IAgs_=coeff_IAgs*norm(F_hat)
        eta_IAgs=eta_IAgs_

        print("\nNaive baseline\n\n")
        P_naive=jax.vmap(hpinv,2,2)(smoothed_periodogram)
        R_naive=partial_coherence(P_naive)
        Rl2_naive=blocks_skeleton(array(R_naive),K)
        Rl2_naive/=Rl2_naive.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['Naive']=count_accuracy(R_true, Rl2_naive, tau=5.e-2, already_blocks=True)
        performances[sz][trial]['Naive']['iterations']=jnp.nan
        performances[sz][trial]['Naive']['objective value']=jnp.nan 

        print("\n'TSGLASSO alpha={}'\n\n".format(alpha_TSGLASSO00))
        results_TSGLASSO0=TSGLASSO(F_hat, lmbd_TSGLASSO00, alpha_TSGLASSO00, rho_TSGLASSO, tolp_abs, tolp_rel, told_abs, told_rel, max_iter, step, penalize_diag, varying_rho)
        R_TSGLASSO0=array(partial_coherence(results_TSGLASSO0['P']))
        Rl2_TSGLASSO0=blocks_skeleton(R_TSGLASSO0,K)
        Rl2_TSGLASSO0/=Rl2_TSGLASSO0.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO00)]=count_accuracy(R_true, Rl2_TSGLASSO0, tau=r_bar, already_blocks=True)
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO00)]['iterations']=results_TSGLASSO0['iterations']
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO00)]['objective value']=results_TSGLASSO0['objective value']
        
        print("\n'TSGLASSO alpha={}'\n\n".format(alpha_TSGLASSO05))
        results_TSGLASSO05=TSGLASSO(F_hat, lmbd_TSGLASSO05, alpha_TSGLASSO05, rho_TSGLASSO, tolp_abs, tolp_rel, told_abs, told_rel, max_iter, step, penalize_diag, varying_rho)
        R_TSGLASSO05=array(partial_coherence(results_TSGLASSO05['P']))
        Rl2_TSGLASSO05=blocks_skeleton(R_TSGLASSO05,K)
        Rl2_TSGLASSO05/=Rl2_TSGLASSO05.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO05)]=count_accuracy(R_true, Rl2_TSGLASSO05, tau=r_bar, already_blocks=True)
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO05)]['iterations']=results_TSGLASSO05['iterations']
        performances[sz][trial]['TSGLASSO alpha={}'.format(alpha_TSGLASSO05)]['objective value']=results_TSGLASSO05['objective value']

        print("\n CF max num. of nonzero entries\n\n")
        Psk=CF_method(array(F_hat),k=array([7], dtype=jnp.int16), K=K)
        R_convex =  array(partial_coherence(Psk))
        Rl2_convex=blocks_skeleton(R_convex,K)
        Rl2_convex/=Rl2_convex.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['CF-nz']=count_accuracy(R_true, Rl2_convex, tau=r_bar, already_blocks=True)
        performances[sz][trial]['CF-nz']['iterations']=jnp.nan
        performances[sz][trial]['CF-nz']['objective value']=jnp.nan

        print("\n CF method full knowledge\n\n")
        Psk=CF_method(array(F_hat),k=k_true, K=K)
        R_convex =  array(partial_coherence(Psk))
        Rl2_convex=blocks_skeleton(R_convex,K)
        Rl2_convex/=Rl2_convex.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['CF-fk']=count_accuracy(R_true, Rl2_convex, tau=r_bar, already_blocks=True)
        performances[sz][trial]['CF-fk']['iterations']=jnp.nan
        performances[sz][trial]['CF-fk']['objective value']=jnp.nan
        
        print("\nIA method block-sparse\n\n")
        algo=IA_method(K=K_ ,F_hat=F_hat, P_init="identity")
        F,P,vF,vP,vU,vX,vV,alpha,beta,mu,phi = algo.initialization(check_init=True)
        results=algo.solve_vmap(F,P,vF,vP,vU,vX,vV,alpha,beta,mu,phi, lmbd_IAbs, tauF, tauP, eta_IAbs, kind, kind1, c1, c2, c3, c4, tolp_abs=tolp_abs, tolp_rel=tolp_rel, told_abs=told_abs, told_rel=told_rel, max_iter=max_iter, only_primal=only_primal, step=step)
        R_ours=array(partial_coherence(results['P']))
        Rl2_ours=blocks_skeleton(R_ours,K)
        Rl2_ours/=Rl2_ours.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['IA-bs']=count_accuracy(R_true, Rl2_ours, tau=r_bar, already_blocks=True)
        performances[sz][trial]['IA-bs']['iterations']=results['iterations']
        performances[sz][trial]['IA-bs']['objective value']=results['objective value']

        print("\nIA method group-sparse\n\n")
        algo1=IA_method(K=None ,F_hat=F_hat, P_init="identity")
        F1,P1,vF1,vP1,vU1,vX1,vV1,alpha1,beta1,mu1,phi1 = algo1.initialization(check_init=True)
        results1=algo1.solve_vmap(F1,P1,vF1,vP1,vU1,vX1,vV1,alpha1,beta1,mu1,phi1, lmbd_IAgs, tauF, tauP, eta_IAgs, kind, kind1, c1, c2, c3, c4, tolp_abs=tolp_abs, tolp_rel=tolp_rel, told_abs=told_abs, told_rel=told_rel, max_iter=max_iter, only_primal=only_primal, step=step)
        R_ours1=array(partial_coherence(results1['P']))
        Rl2_ours1=blocks_skeleton(R_ours1,K)
        Rl2_ours1/=Rl2_ours1.max(axis=(0,1)).reshape((1,1,-1))
        performances[sz][trial]['IA-gs']=count_accuracy(R_true, Rl2_ours1, tau=r_bar, already_blocks=True)
        performances[sz][trial]['IA-gs']['iterations']=results1['iterations']
        performances[sz][trial]['IA-gs']['objective value']=results1['objective value']

########## N SAMPLES 10 DATA SET 1 ##########

Naive baseline



'TSGLASSO alpha=0.0'




Iteration: 500.0
objective: (435.27602172910895+0j)
Residual for primal feasibility: 1.291 vs 0.044
Residual for dual feasibility: 0.608 vs 0.025
Primal feasibility: False, dual feasibility: False
ADMM stepsize value: 2.0





Iteration: 1000.0
objective: (435.5915357526938+0j)
Residual for primal feasibility: 1.314 vs 0.044
Residual for dual feasibility: 0.607 vs 0.025
Primal feasibility: False, dual feasibility: False
ADMM stepsize value: 2.0





Iteration: 1500.0
objective: (435.6305117582795+0j)
Residual for primal feasibility: 1.317 vs 0.044
Residual for dual feasibility: 0.607 vs 0.025
Primal feasibility: False, dual feasibility: False
ADMM stepsize value: 2.0





Iteration: 2000.0
objective: (435.6356623442488+0j)
Residual for primal feasibility: 1.318 vs 0.044
Residual for dual feasibility: 0.607 vs 0.025
Primal feasibility: False, dual feasibility: False
ADMM stepsize value: 2.0




'TS

## Check if the values obtained coincide with those in the article.

In [8]:
b_performances = load_obj("1.1_performances_synth_experiments", data_dir)

for trial in range(ntrials):
    print("\n############ RESULTS ON DATA SET N.{} ############\n".format(trial+1))
    for method in list(performances[sz][trial].keys()):
        here=performances[sz][trial][method]['hamming_kpcg']
        if 'TSGLASSO' in method:
            get_alpha = method.split()[1]
            paper=int(b_performances[(b_performances['nsamples']==str(sz)) & (b_performances['model']=='TSGLASSO$_{\\'+get_alpha+'}$')].iloc[trial]['hamming'])
        else:
            paper=int(b_performances[(b_performances['nsamples']==str(sz)) & (b_performances['model']==method)].iloc[trial]['hamming'])
        print("Method: {}".format(method))
        print("Hamming obtained here: {}".format(here))
        print("Hamming in paper results: {}\n".format(paper))


############ RESULTS ON DATA SET N.1 ############

Method: Naive
Hamming obtained here: 81
Hamming in paper results: 81

Method: TSGLASSO alpha=0.0
Hamming obtained here: 35
Hamming in paper results: 35

Method: TSGLASSO alpha=0.5
Hamming obtained here: 40
Hamming in paper results: 40

Method: CF-nz
Hamming obtained here: 38
Hamming in paper results: 38

Method: CF-fk
Hamming obtained here: 24
Hamming in paper results: 24

Method: IA-bs
Hamming obtained here: 25
Hamming in paper results: 25

Method: IA-gs
Hamming obtained here: 28
Hamming in paper results: 28


############ RESULTS ON DATA SET N.2 ############

Method: Naive
Hamming obtained here: 82
Hamming in paper results: 82

Method: TSGLASSO alpha=0.0
Hamming obtained here: 38
Hamming in paper results: 38

Method: TSGLASSO alpha=0.5
Hamming obtained here: 41
Hamming in paper results: 41

Method: CF-nz
Hamming obtained here: 46
Hamming in paper results: 46

Method: CF-fk
Hamming obtained here: 32
Hamming in paper results: 32

Meth