# Summary

This notebook shows how to apply the LinSEPAL-ADMM, LinSEPAL-PG, and CLinSEPAL algorithms onto synthetic data for learning linear causal abstractions (CAs), with full and partial prior knowledge.

In [1]:
import autograd.numpy as anp
import numpy as np
import pandas as pd

from sklearn.metrics import confusion_matrix
from src.CLinSEPAL_fullprior import CLinSEPAL_fp
from src.CLinSEPAL_partialprior import CLinSEPAL_pp
from src.LinSEPAL_ADMM import LinSEPAL_ADMM
from src.LinSEPAL_PG import LinSEPAL_PG
from src.utils import constructiveness, masked_stiefel_matrix, gen_covariances, data_dir, save_obj_parquet, load_obj_parquet, stiefel_arc_length, frobenious_abs_distance

## Data generating

Here we generate data used in both settings.

In [2]:
generatingseed = 0 # used for data generation
l = 12 # low-level dimension
h = 6 # high-level dimension

dim_l = l*l
dim_h = h*h
dim_lh = l*h

max_iterations=1000

# groud truth support
S_gt = anp.array([[0, 1, 0, 0, 0, 0], 
                [1, 0, 0, 0, 0, 0], 
                [0, 0, 1, 0, 0, 0], 
                [0, 0, 0, 0, 1, 0], 
                [0, 0, 0, 1, 0, 0],
                [0, 0, 0, 0, 0, 1],
                [0, 0, 1, 0, 0, 0],
                [0, 0, 0, 0, 0, 1],
                [1, 0, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 0],
                [0, 0, 0, 1, 0, 0],
                [0, 1, 0, 0, 0, 0]], dtype=float) 

V=masked_stiefel_matrix(S_gt, generatingseed)
covlow, covhigh = gen_covariances(V, generatingseed)

S_matrix = pd.DataFrame(S_gt.flatten(order='F').reshape((1,dim_lh)), index=['S_gt'], dtype=float)
V_matrices = pd.DataFrame(V.flatten(order='F').reshape((1, dim_lh)), index=['V_gt'], dtype=float)
covlow_matrices = pd.DataFrame(covlow.flatten(order='F').reshape((1,dim_l)), index=['covlow_gt'], dtype=float)
covhigh_matrices = pd.DataFrame(covhigh.flatten(order='F').reshape((1,dim_h)), index=['covhigh_gt'], dtype=float)


columns = ["(l,h)", "seed gen", "seed algo", "method", "iterations", "D_KL", "constructiveness", "stiefel distance", "frobenious distance", "tn", "fp", "fn", "tp", "fpr", "tpr", "fdr", "f1"]
metrics_df = pd.DataFrame(columns=columns)
obj_val_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)
primal_res_series_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)
dual_res_series_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)


## Full-prior

In this setting the structure of the linear CA - i.e., the above array `S_gt` - is available as prior knowledge.

In [3]:
# provide the ground truth support as a prior
B = S_gt.copy()

###############################
# This is a useful snippet to 
# detect if there are certain
# abstractions 
# (columns in S_gt with only a 1).
# Indeed, for these nodes we do not 
# need to learn the corresponding 
# abstraction. This occurs when h>l//2
# (cf. the paper for more details.)

all_row_indices = anp.arange(l)
all_col_indices = anp.arange(h)

uncertain_cols=anp.where(B.sum(axis=0)>1)[0]
uncertain_rows = anp.where(B[:, uncertain_cols].sum(axis=1) > 0)[0]

excluded_rows = anp.setdiff1d(all_row_indices, uncertain_rows)
excluded_cols = anp.setdiff1d(all_col_indices, uncertain_cols)

B_restr=B[anp.ix_(uncertain_rows,uncertain_cols)].copy()
y_true = B_restr.flatten(order='F')

# restrict to the relevant part for learning
covlow_restr=covlow[anp.ix_(uncertain_rows, uncertain_rows)]
covhigh_restr=covhigh[anp.ix_(uncertain_cols, uncertain_cols)]

# In this example we do not consider certain CAs.
# Hence, the restricted matrices coincides with the generated.
anp.allclose(B,B_restr), anp.allclose(covlow, covlow_restr), anp.allclose(covhigh,covhigh_restr)

(True, True, True)

In [4]:
# The learning problem is nonconvex and initialization matters.
# Hence, we run the algorithms with "ntrials" random initializations. 
ntrials = 10
seeds_algos = anp.arange(generatingseed+1, generatingseed+ntrials+1) #to ensure no overlap with generatingseed

In [5]:
# Algorithms hyper-params

# shared by
max_iterations=1000 #all
initialization='structural' #LinSEPAL-ADMM and LinSEPAL-PG
rho = 1. #LinSEPAL-ADMM and CLinSEPAL
lambda_reg=1. #LinSEPAL-ADMM and LinSEPAL-PG
tau_abs=1.e-4 #LinSEPAL-ADMM and CLinSEPAL
tau_rel=1.e-4 #LinSEPAL-ADMM and CLinSEPAL
adaptive_stepsize=True #LinSEPAL-ADMM and CLinSEPAL

# LinSEPAL-ADMM
verbose_LinSEPAL_ADMM=0

# LinSEPAL-PG
L = 2*(anp.linalg.norm(covlow_restr, ord='fro')**2)
gamma_line = .5
tau_line = 1.01
how = "exactly"
verbose_LinSEPAL_PG = False
tol_Dkl=1.e-4

# CLinSEPAL
verbose_CLinSEPAL=0
tau=1.e-3
sca_exp_ord=np.log10(tau).item()
sca_iter=1000
sca_tol=1.e-3
epsilon=.1

In [None]:
# if True, overwrite existing results
replace=False

In [7]:
# Testing

if replace:

    failed_combos = []

    for n,trial in enumerate(seeds_algos):
        try:
            print("\nDoing {}/{}.\n".format(n+1, ntrials))

            #### LinSEPAL-ADMM ###
            V_LinSEPAL_ADMM, Y_LinSEPAL_ADMM, iter_LinSEPAL_ADMM, primal_res_series_LinSEPAL_ADMM, dual_res_series_LinSEPAL_ADMM, obj_val_series_LinSEPAL_ADMM = LinSEPAL_ADMM(covlow_restr, covhigh_restr, B_restr, lambda_reg, rho, initialization, adaptive_stepsize=adaptive_stepsize, max_iter=max_iterations, seed=trial.item(), verbosity=verbose_LinSEPAL_ADMM)
            V_LinSEPAL_ADMM = anp.where(anp.abs(V_LinSEPAL_ADMM)>1.e-3, V_LinSEPAL_ADMM, 0.)
            y_pred_LinSEPAL_ADMM = (anp.abs(V_LinSEPAL_ADMM)>1.e-3).flatten(order='F')

            V_reconstructed_LinSEPAL_ADMM = anp.zeros_like(B)
            V_reconstructed_LinSEPAL_ADMM[anp.ix_(uncertain_rows, uncertain_cols)] = V_LinSEPAL_ADMM
            V_reconstructed_LinSEPAL_ADMM[:,excluded_cols] = B[:,excluded_cols].copy()

            V_LinSEPAL_ADMM_df = pd.DataFrame(V_reconstructed_LinSEPAL_ADMM.flatten(order='F').reshape((1, dim_lh)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_series_LinSEPAL_ADMM_df = pd.DataFrame(primal_res_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_series_LinSEPAL_ADMM_df = pd.DataFrame(dual_res_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_LinSEPAL_ADMM_df = pd.DataFrame(obj_val_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            
            tn_LinSEPAL_ADMM, fp_LinSEPAL_ADMM, fn_LinSEPAL_ADMM, tp_LinSEPAL_ADMM = confusion_matrix(y_true, y_pred_LinSEPAL_ADMM).ravel().data
            fpr_LinSEPAL_ADMM = fp_LinSEPAL_ADMM/(fp_LinSEPAL_ADMM+tn_LinSEPAL_ADMM)
            tpr_LinSEPAL_ADMM = tp_LinSEPAL_ADMM/(tp_LinSEPAL_ADMM+fn_LinSEPAL_ADMM)
            fdr_LinSEPAL_ADMM = fp_LinSEPAL_ADMM/(tp_LinSEPAL_ADMM+fp_LinSEPAL_ADMM)
            f1_LinSEPAL_ADMM = 2*tp_LinSEPAL_ADMM/(2*tp_LinSEPAL_ADMM+fp_LinSEPAL_ADMM+fn_LinSEPAL_ADMM)  

            LinSEPAL_ADMM_df = pd.DataFrame([[str((l,h)), generatingseed, trial.item(), "LinSEPAL-ADMM", iter_LinSEPAL_ADMM, 
                                            obj_val_series_LinSEPAL_ADMM[iter_LinSEPAL_ADMM].item(),
                                            constructiveness(V_reconstructed_LinSEPAL_ADMM),
                                            stiefel_arc_length(V_reconstructed_LinSEPAL_ADMM,V).item(),
                                            frobenious_abs_distance(V_reconstructed_LinSEPAL_ADMM,V).item(),
                                            tn_LinSEPAL_ADMM, fp_LinSEPAL_ADMM, fn_LinSEPAL_ADMM, tp_LinSEPAL_ADMM, fpr_LinSEPAL_ADMM, tpr_LinSEPAL_ADMM, fdr_LinSEPAL_ADMM, f1_LinSEPAL_ADMM]], 
                                            columns=columns)     

            #### LinSEPAL-PG ###
            V_LinSEPAL_PG, iter_LinSEPAL_PG, obj_val_series_LinSEPAL_PG = LinSEPAL_PG(covlow_restr, covhigh_restr, lambda_reg, B_restr, how, L, gamma_line, tau_line, max_iter=max_iterations, tol=tol_Dkl, initialization=initialization, V_init=None, seed=trial.item(), verbose=verbose_LinSEPAL_PG)
            V_LinSEPAL_PG = anp.where(anp.abs(V_LinSEPAL_PG)>1.e-3, V_LinSEPAL_PG, 0.)
            y_pred_LinSEPAL_PG = (anp.abs(V_LinSEPAL_PG)>1.e-3).flatten(order='F')

            V_reconstructed_LinSEPAL_PG = anp.zeros_like(B)
            V_reconstructed_LinSEPAL_PG[anp.ix_(uncertain_rows, uncertain_cols)] = V_LinSEPAL_PG
            V_reconstructed_LinSEPAL_PG[:,excluded_cols] = B[:,excluded_cols].copy()
            
            V_LinSEPAL_PG_df = pd.DataFrame(V_reconstructed_LinSEPAL_PG.flatten(order='F').reshape((1, dim_lh)), index=['LinSEPAL_PG_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_LinSEPAL_PG_df = pd.DataFrame(obj_val_series_LinSEPAL_PG.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL_PG_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            
            tn_LinSEPAL_PG, fp_LinSEPAL_PG, fn_LinSEPAL_PG, tp_LinSEPAL_PG = confusion_matrix(y_true, y_pred_LinSEPAL_PG).ravel().data
            fpr_LinSEPAL_PG = fp_LinSEPAL_PG/(fp_LinSEPAL_PG+tn_LinSEPAL_PG)
            tpr_LinSEPAL_PG = tp_LinSEPAL_PG/(tp_LinSEPAL_PG+fn_LinSEPAL_PG)
            fdr_LinSEPAL_PG = fp_LinSEPAL_PG/(tp_LinSEPAL_PG+fp_LinSEPAL_PG)
            f1_LinSEPAL_PG = 2*tp_LinSEPAL_PG/(2*tp_LinSEPAL_PG+fp_LinSEPAL_PG+fn_LinSEPAL_PG)  

            LinSEPAL_PG_df = pd.DataFrame([[str((l,h)), generatingseed, trial.item(), "LinSEPAL_PG", iter_LinSEPAL_PG, 
                                        obj_val_series_LinSEPAL_PG[iter_LinSEPAL_PG].item(),
                                        constructiveness(V_reconstructed_LinSEPAL_PG),
                                        stiefel_arc_length(V_reconstructed_LinSEPAL_PG,V).item(),
                                        frobenious_abs_distance(V_reconstructed_LinSEPAL_PG,V).item(),
                                        tn_LinSEPAL_PG, fp_LinSEPAL_PG, fn_LinSEPAL_PG, tp_LinSEPAL_PG, fpr_LinSEPAL_PG, tpr_LinSEPAL_PG, fdr_LinSEPAL_PG, f1_LinSEPAL_PG]], 
                                        columns=columns)

            ### CLinSEPAL ###
            V_CLinSEPAL, Y_CLinSEPAL, iter_CLinSEPAL, primal_res_seriesY_CLinSEPAL, dual_res_seriesY_CLinSEPAL, obj_val_series_CLinSEPAL= CLinSEPAL_fp(covlow_restr, covhigh_restr, B_restr, rho, epsilon=epsilon, adaptive_stepsize=adaptive_stepsize, tau=tau, tau_abs=tau_abs, tau_rel=tau_rel, max_iter=max_iterations, sca_iter=sca_iter, sca_tol=sca_tol, seed=trial.item(), verbosity=verbose_CLinSEPAL)
            BV = B_restr*V_CLinSEPAL
            BV = anp.where(anp.abs(BV)>1.e-3, BV, 0.)
            y_pred_CLinSEPAL = (anp.abs(BV)>1.e-3).flatten(order='F')

            V_reconstructed_CLinSEPAL = anp.zeros_like(B)
            V_reconstructed_CLinSEPAL[anp.ix_(uncertain_rows, uncertain_cols)] = BV
            V_reconstructed_CLinSEPAL[:,excluded_cols] = B[:,excluded_cols].copy()
            
            V_CLinSEPAL_df = pd.DataFrame(V_reconstructed_CLinSEPAL.flatten(order='F').reshape((1, dim_lh)), index=['CLinSEPAL_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_seriesY_CLinSEPAL_df = pd.DataFrame(primal_res_seriesY_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_seriesY_CLinSEPAL_df = pd.DataFrame(dual_res_seriesY_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_CLinSEPAL_df = pd.DataFrame(obj_val_series_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)

            tn_CLinSEPAL, fp_CLinSEPAL, fn_CLinSEPAL, tp_CLinSEPAL = confusion_matrix(y_true, y_pred_CLinSEPAL).ravel().data
            fpr_CLinSEPAL = fp_CLinSEPAL/(fp_CLinSEPAL+tn_CLinSEPAL)
            tpr_CLinSEPAL = tp_CLinSEPAL/(tp_CLinSEPAL+fn_CLinSEPAL)
            fdr_CLinSEPAL = fp_CLinSEPAL/(tp_CLinSEPAL+fp_CLinSEPAL)
            f1_CLinSEPAL = 2*tp_CLinSEPAL/(2*tp_CLinSEPAL+fp_CLinSEPAL+fn_CLinSEPAL)  

            CLinSEPAL_df = pd.DataFrame([[str((l,h)), generatingseed, trial.item(), "CLinSEPAL", iter_CLinSEPAL, 
                                        obj_val_series_CLinSEPAL[iter_CLinSEPAL].item(), 
                                        constructiveness(V_reconstructed_CLinSEPAL),
                                        stiefel_arc_length(V_reconstructed_CLinSEPAL,V).item(), 
                                        frobenious_abs_distance(V_reconstructed_CLinSEPAL,V).item(), 
                                        tn_CLinSEPAL, fp_CLinSEPAL, fn_CLinSEPAL, tp_CLinSEPAL, fpr_CLinSEPAL, tpr_CLinSEPAL, fdr_CLinSEPAL, f1_CLinSEPAL]], 
                                        columns=columns)  

            ### only in case all methods succeed, concat the results

            #LinSEPAL-ADMM
            V_matrices = pd.concat((V_matrices, V_LinSEPAL_ADMM_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_series_LinSEPAL_ADMM_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_series_LinSEPAL_ADMM_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_LinSEPAL_ADMM_df))         
            
            metrics_df = pd.concat((metrics_df, LinSEPAL_ADMM_df))

            #LinSEPAL_PG
            V_matrices = pd.concat((V_matrices, V_LinSEPAL_PG_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_LinSEPAL_PG_df))

            metrics_df = pd.concat((metrics_df, LinSEPAL_PG_df))

            #CLinSEPAL
            V_matrices = pd.concat((V_matrices, V_CLinSEPAL_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_seriesY_CLinSEPAL_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_seriesY_CLinSEPAL_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_CLinSEPAL_df))

            metrics_df = pd.concat((metrics_df, CLinSEPAL_df))

            to_save=True

            if to_save:
                save_obj_parquet(V_matrices, "full_prior_{}_{}_{}_V_matrices".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(metrics_df, "full_prior_{}_{}_{}_metrics_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(obj_val_df, "full_prior_{}_{}_{}_obj_val_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(primal_res_series_df, "full_prior_{}_{}_{}_primal_res_series_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(dual_res_series_df, "full_prior_{}_{}_{}_dual_res_series_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)

        except:
            failed_combos.append((generatingseed, trial))
            print("Failed at ({},{},{})".format(str((l,h)), generatingseed, trial))
            save_obj_parquet(failed_combos, "full_prior_{}_{}_{}_failed_combos".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
            to_save=False
            continue

else:
    V_matrices = load_obj_parquet("full_prior_{}_{}_{}_V_matrices".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
    metrics_df = load_obj_parquet("full_prior_{}_{}_{}_metrics_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
    obj_val_df = load_obj_parquet("full_prior_{}_{}_{}_obj_val_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
    primal_res_series_df = load_obj_parquet("full_prior_{}_{}_{}_primal_res_series_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
    dual_res_series_df = load_obj_parquet("full_prior_{}_{}_{}_dual_res_series_df".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)

    try:
        failed_combos = load_obj_parquet("full_prior_{}_{}_{}_failed_combos".format(str((l,h)), generatingseed, ntrials), data_dir=data_dir)
    except:
        failed_combos = []



Doing 1/10.

Residuals convergence at iteration 1: (objective, primal, dual)=(1.0587086762825493e-12,1.7659996829664377e-06,0.0)
Objective convergence at iteration 650: objective=8.459234121982462e-05
Residuals convergence at iteration 57: (objective, primal, dual)=(9.607042438908309e-06,0.00045308735384731307,0.0007384705002361652)


  metrics_df = pd.concat((metrics_df, LinSEPAL_ADMM_df))



Doing 2/10.

Residuals convergence at iteration 34: (objective, primal, dual)=(2.8035469679110747e-07,0.00017791896801804888,0.0)
Max number of iterations reached: objective 0.4246044788275869
Residuals convergence at iteration 45: (objective, primal, dual)=(0.08908313835312676,0.0004435348679851518,0.0008034401075825103)

Doing 3/10.

Residuals convergence at iteration 223: (objective, primal, dual)=(9.198695138934454e-09,8.10910327006872e-05,0.0)
Max number of iterations reached: objective 0.4493839127936141
Max number of iterations reached: (objective, primal, dual)=(1.1566908383109364e-05,0.0025124206348526856,0.008508138805780881)

Doing 4/10.

Residuals convergence at iteration 1: (objective, primal, dual)=(5.577760475716786e-13,1.1473388502181266e-06,0.0)
Max number of iterations reached: objective 0.08905487351672381
Residuals convergence at iteration 48: (objective, primal, dual)=(0.0891498239537869,0.001063460695100466,0.00021550030445814612)

Doing 5/10.

Residuals converge

  return f_raw(*args, **kwargs)



Doing 8/10.

Residuals convergence at iteration 79: (objective, primal, dual)=(1.6524562793662767e-06,0.0005812749691647562,0.0)
Objective convergence at iteration 685: objective=7.239508505119119e-05
Residuals convergence at iteration 48: (objective, primal, dual)=(8.361505918763612e-06,0.00016027351267430527,0.0007833614995102041)

Doing 9/10.

Residuals convergence at iteration 49: (objective, primal, dual)=(9.845888726545127e-08,0.0001154048025661998,0.0)
Max number of iterations reached: objective 0.08905487426605152
Residuals convergence at iteration 46: (objective, primal, dual)=(0.08905267821876084,0.0010235614144088088,0.0003037274558819702)

Doing 10/10.

Residuals convergence at iteration 255: (objective, primal, dual)=(9.517307653084117e-08,0.000504438803720083,0.0)
Max number of iterations reached: objective 0.6480992023181216
Residuals convergence at iteration 50: (objective, primal, dual)=(1.6149805758480795e-05,0.0010928293988197672,0.0006743200900433807)


  return f_raw(*args, **kwargs)


In [9]:
# Display all results
metrics_df.reset_index(inplace=True, drop=True)
metrics_df

Unnamed: 0,"(l,h)",seed gen,seed algo,method,iterations,D_KL,constructiveness,stiefel distance,frobenious distance,tn,fp,fn,tp,fpr,tpr,fdr,f1
0,"(12, 6)",0,1,LinSEPAL-ADMM,1,1.058709e-12,1.0,0.0,5e-06,60,0,0,12,0.0,1.0,0.0,1.0
1,"(12, 6)",0,1,LinSEPAL_PG,649,0.0001217378,1.0,0.0,0.011793,60,0,0,12,0.0,1.0,0.0,1.0
2,"(12, 6)",0,1,CLinSEPAL,57,9.607042e-06,1.0,0.0,0.002254,60,0,0,12,0.0,1.0,0.0,1.0
3,"(12, 6)",0,2,LinSEPAL-ADMM,34,2.803547e-07,1.0,0.0,0.000392,60,0,0,12,0.0,1.0,0.0,1.0
4,"(12, 6)",0,2,LinSEPAL_PG,1000,0.4246045,0.541667,1.41842,0.357862,36,24,0,12,0.4,1.0,0.666667,0.5
5,"(12, 6)",0,2,CLinSEPAL,45,0.08908314,1.0,0.319779,0.130601,60,0,0,12,0.0,1.0,0.0,1.0
6,"(12, 6)",0,3,LinSEPAL-ADMM,223,9.198695e-09,1.0,0.0,0.000192,60,0,0,12,0.0,1.0,0.0,1.0
7,"(12, 6)",0,3,LinSEPAL_PG,1000,0.4493839,0.875,1.270691,0.358542,54,6,0,12,0.1,1.0,0.333333,0.8
8,"(12, 6)",0,3,CLinSEPAL,1000,1.156691e-05,1.0,0.100012,0.001374,60,0,0,12,0.0,1.0,0.0,1.0
9,"(12, 6)",0,4,LinSEPAL-ADMM,1,5.57776e-13,1.0,0.0,3e-06,60,0,0,12,0.0,1.0,0.0,1.0


In [10]:
# Keep only constructive CAs 
# and select the best one for each method
# according to the KL divergence (alignment metric)  
metrics_constructiveCA=metrics_df[metrics_df["constructiveness"]==1].copy()
metrics_constructiveCA=metrics_constructiveCA.loc[metrics_constructiveCA.groupby(["(l,h)", "seed gen", "method"])['D_KL'].idxmin()].copy()
metrics_constructiveCA

Unnamed: 0,"(l,h)",seed gen,seed algo,method,iterations,D_KL,constructiveness,stiefel distance,frobenious distance,tn,fp,fn,tp,fpr,tpr,fdr,f1
23,"(12, 6)",0,8,CLinSEPAL,48,8.361506e-06,1.0,0.0,0.002467,60,0,0,12,0.0,1.0,0.0,1.0
9,"(12, 6)",0,4,LinSEPAL-ADMM,1,5.57776e-13,1.0,0.0,3e-06,60,0,0,12,0.0,1.0,0.0,1.0
22,"(12, 6)",0,8,LinSEPAL_PG,684,0.0001056358,1.0,0.0,0.013609,60,0,0,12,0.0,1.0,0.0,1.0


In [11]:
# Display the CAs
print("Ground truth")
display(V.round(3))

for idx in metrics_constructiveCA.index:
    row = metrics_constructiveCA.loc[idx]
    print("{}".format(row['method']))
    display(V_matrices.loc['{}_{}_{}_{}'.format(row['method'], row['(l,h)'], row['seed gen'], row['seed algo'])].values.reshape((l,h), order='F').round(3))

Ground truth


array([[ 0.   ,  0.471,  0.   ,  0.   , -0.   ,  0.   ],
       [-0.32 ,  0.   ,  0.   , -0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.695, -0.   , -0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.753,  0.   ],
       [-0.   ,  0.   , -0.   ,  0.776,  0.   , -0.   ],
       [-0.   ,  0.   , -0.   ,  0.   , -0.   ,  0.302],
       [ 0.   ,  0.   ,  0.719,  0.   , -0.   , -0.   ],
       [ 0.   , -0.   ,  0.   ,  0.   , -0.   , -0.953],
       [-0.947, -0.   ,  0.   , -0.   ,  0.   , -0.   ],
       [-0.   , -0.   ,  0.   , -0.   , -0.658, -0.   ],
       [-0.   , -0.   ,  0.   , -0.631, -0.   , -0.   ],
       [ 0.   , -0.882,  0.   , -0.   ,  0.   , -0.   ]])

CLinSEPAL


array([[ 0.   ,  0.472,  0.   ,  0.   ,  0.   ,  0.   ],
       [-0.317,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.694,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.752,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.772,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.302],
       [ 0.   ,  0.   ,  0.72 ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.953],
       [-0.948,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.659,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.635,  0.   ,  0.   ],
       [ 0.   , -0.882,  0.   ,  0.   ,  0.   ,  0.   ]])

LinSEPAL-ADMM


array([[ 0.   , -0.471,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.32 ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.695,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.753,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.776,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.302],
       [ 0.   ,  0.   , -0.719,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.953],
       [ 0.947,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.658,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.631,  0.   ,  0.   ],
       [ 0.   ,  0.882,  0.   ,  0.   ,  0.   ,  0.   ]])

LinSEPAL_PG


array([[ 0.   ,  0.468,  0.   ,  0.   ,  0.   ,  0.   ],
       [-0.321,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.699,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.754,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.786,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.276],
       [ 0.   ,  0.   ,  0.716,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.961],
       [-0.947,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.657,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.618,  0.   ,  0.   ],
       [ 0.   , -0.884,  0.   ,  0.   ,  0.   ,  0.   ]])

# Partial-prior

In this setting we only have partial prior (pp) structural knowledge of the CA of some nodes. 
As an example, we consider pp knowledge for $3$ nodes out of $12$.
For each of these $3$ nodes, we specify $2$ nodes of the high-level model as possible abstractions. 
The setting can be changed easily by specifying different values for `nnodes` $\leq \ell$ and `uncertain_nodes` $<h$.

In [12]:
import random

In [13]:
# set the seed for reproducibility
random.seed(generatingseed)

# partial prior for 3/12 nodes
nnodes=3
pp_nodes = random.sample(range(l), nnodes)

# we add for each row in pp_nodes 
# a 1 at another randomly selected node
B_pp = B.copy()
uncertain_nodes = 1

for pp_node in pp_nodes:
    # exclude the gt high-level node from the sampling 
    possible_nodes = anp.delete(anp.arange(h),B_pp[pp_node,:].argmax())
    nodes_to_add = random.sample(sorted(possible_nodes), uncertain_nodes)
    B_pp[pp_node, nodes_to_add]+=1 

print("Indices of nodes with partial prior: {}".format(pp_nodes))
print("Partial prior matrix B:\n {}".format(B_pp))

Indices of nodes with partial prior: [6, 11, 0]
Partial prior matrix B:
 [[0. 1. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1.]
 [0. 0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 1.]]


In [14]:
uncertain_cols=anp.where(B_pp.sum(axis=0)>1)[0]
uncertain_rows = anp.where(B_pp[:, uncertain_cols].sum(axis=1) > 0)[0]

excluded_rows = anp.setdiff1d(all_row_indices, uncertain_rows)
excluded_cols = anp.setdiff1d(all_col_indices, uncertain_cols)

B_restr=B_pp[anp.ix_(uncertain_rows,uncertain_cols)].copy()
y_true = B[anp.ix_(uncertain_rows,uncertain_cols)].flatten(order='F')

# restrict to the relevant part for learning
covlow_restr=covlow[anp.ix_(uncertain_rows, uncertain_rows)]
covhigh_restr=covhigh[anp.ix_(uncertain_cols, uncertain_cols)]

# In this example we do not consider certain CAs.
# Hence, the restricted matrices coincides with the generated.
anp.allclose(B_pp,B_restr), anp.allclose(covlow, covlow_restr), anp.allclose(covhigh,covhigh_restr)

(True, True, True)

In [15]:
# The learning problem is nonconvex and initialization matters.
# Hence, we run the algorithms with "ntrials" random initializations. 
ntrials = 10
seeds_algos = anp.arange(generatingseed+1, generatingseed+ntrials+1) #to ensure no overlap with generatingseed

In [16]:
# hyper-params 
solver_CLinSEPAL = "OSQP" #solver fot the quadratic programming subproblem
a=-1 #lower_bound CA coeff useful to balance the product S*V
b=1 #upper_bound CA coeff useful to balance the product S*V
initialization="notstructural"

In [17]:
columns = ["(l,h)", "idx pp nodes", "n uncertain nodes for pp nodes", "seed gen", "seed algo", "method", "iterations", "constructiveness", "D_KL", "stiefel distance", "frobenious distance", "nnz", "tn", "fp", "fn", "tp", "fpr", "tpr", "fdr", "f1"]
    
S_matrices = pd.DataFrame(columns=range(dim_lh))
V_matrices = pd.DataFrame(columns=range(dim_lh))
metrics_df = pd.DataFrame(columns=columns)
obj_val_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)
primal_res_series_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)
dual_res_series_df = pd.DataFrame(columns=range(max_iterations+1), dtype=float)

failed_combos=[]

In [None]:
# if True, overwrite existing results
replace = False

In [19]:
# Testing

if replace:

    failed_combos = []

    for n,trial in enumerate(seeds_algos):
        try:
            print("\nDoing {}/{}.\n".format(n+1, ntrials))

            #### LinSEPAL-ADMM ###
            V_LinSEPAL_ADMM, Y_LinSEPAL_ADMM, iter_LinSEPAL_ADMM, primal_res_series_LinSEPAL_ADMM, dual_res_series_LinSEPAL_ADMM, obj_val_series_LinSEPAL_ADMM = LinSEPAL_ADMM(covlow_restr, covhigh_restr, B_restr, lambda_reg, rho, initialization, adaptive_stepsize=adaptive_stepsize, max_iter=max_iterations, seed=trial.item(), verbosity=verbose_LinSEPAL_ADMM)
            V_LinSEPAL_ADMM = anp.where(anp.abs(V_LinSEPAL_ADMM)>1.e-2, V_LinSEPAL_ADMM, 0.)
            y_pred_LinSEPAL_ADMM = (anp.abs(V_LinSEPAL_ADMM)>1.e-2).flatten(order='F')

            V_reconstructed_LinSEPAL_ADMM = anp.zeros_like(B_pp)
            V_reconstructed_LinSEPAL_ADMM[anp.ix_(uncertain_rows, uncertain_cols)] = V_LinSEPAL_ADMM
            V_reconstructed_LinSEPAL_ADMM[:,excluded_cols] = B_pp[:,excluded_cols].copy()

            V_LinSEPAL_ADMM_df = pd.DataFrame(V_reconstructed_LinSEPAL_ADMM.flatten(order='F').reshape((1, dim_lh)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_series_LinSEPAL_ADMM_df = pd.DataFrame(primal_res_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_series_LinSEPAL_ADMM_df = pd.DataFrame(dual_res_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_LinSEPAL_ADMM_df = pd.DataFrame(obj_val_series_LinSEPAL_ADMM.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL-ADMM_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            
            nnz_LinSEPAL_ADMM = y_pred_LinSEPAL_ADMM.sum()
            tn_LinSEPAL_ADMM, fp_LinSEPAL_ADMM, fn_LinSEPAL_ADMM, tp_LinSEPAL_ADMM = confusion_matrix(y_true, y_pred_LinSEPAL_ADMM).ravel().data
            fpr_LinSEPAL_ADMM = fp_LinSEPAL_ADMM/(fp_LinSEPAL_ADMM+tn_LinSEPAL_ADMM)
            tpr_LinSEPAL_ADMM = tp_LinSEPAL_ADMM/(tp_LinSEPAL_ADMM+fn_LinSEPAL_ADMM)
            fdr_LinSEPAL_ADMM = fp_LinSEPAL_ADMM/(tp_LinSEPAL_ADMM+fp_LinSEPAL_ADMM)
            f1_LinSEPAL_ADMM = 2*tp_LinSEPAL_ADMM/(2*tp_LinSEPAL_ADMM+fp_LinSEPAL_ADMM+fn_LinSEPAL_ADMM)  

            LinSEPAL_ADMM_df = pd.DataFrame([[str((l,h)), str(pp_nodes), uncertain_nodes+1, generatingseed, trial.item(), "LinSEPAL-ADMM", iter_LinSEPAL_ADMM,
                                            constructiveness(V_reconstructed_LinSEPAL_ADMM), 
                                            obj_val_series_LinSEPAL_ADMM[iter_LinSEPAL_ADMM].item(),
                                            stiefel_arc_length(V_reconstructed_LinSEPAL_ADMM,V).item(),
                                            frobenious_abs_distance(V_reconstructed_LinSEPAL_ADMM,V).item(),
                                            nnz_LinSEPAL_ADMM, tn_LinSEPAL_ADMM, fp_LinSEPAL_ADMM, fn_LinSEPAL_ADMM, tp_LinSEPAL_ADMM, fpr_LinSEPAL_ADMM, tpr_LinSEPAL_ADMM, fdr_LinSEPAL_ADMM, f1_LinSEPAL_ADMM]], 
                                            columns=columns)     

            #### LinSEPAL-PG ###
            V_LinSEPAL_PG, iter_LinSEPAL_PG, obj_val_series_LinSEPAL_PG = LinSEPAL_PG(covlow_restr, covhigh_restr, lambda_reg, B_restr, how, L, gamma_line, tau_line, max_iter=max_iterations, tol=tol_Dkl, initialization=initialization, V_init=None, seed=trial.item(), verbose=verbose_LinSEPAL_PG)
            V_LinSEPAL_PG = anp.where(anp.abs(V_LinSEPAL_PG)>1.e-2, V_LinSEPAL_PG, 0.)
            y_pred_LinSEPAL_PG = (anp.abs(V_LinSEPAL_PG)>1.e-2).flatten(order='F')

            V_reconstructed_LinSEPAL_PG = anp.zeros_like(B_pp)
            V_reconstructed_LinSEPAL_PG[anp.ix_(uncertain_rows, uncertain_cols)] = V_LinSEPAL_PG
            V_reconstructed_LinSEPAL_PG[:,excluded_cols] = B_pp[:,excluded_cols].copy()
            
            V_LinSEPAL_PG_df = pd.DataFrame(V_reconstructed_LinSEPAL_PG.flatten(order='F').reshape((1, dim_lh)), index=['LinSEPAL_PG_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_LinSEPAL_PG_df = pd.DataFrame(obj_val_series_LinSEPAL_PG.flatten(order='F').reshape((1, max_iterations+1)), index=['LinSEPAL_PG_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            
            nnz_LinSEPAL_PG = y_pred_LinSEPAL_PG.sum()
            tn_LinSEPAL_PG, fp_LinSEPAL_PG, fn_LinSEPAL_PG, tp_LinSEPAL_PG = confusion_matrix(y_true, y_pred_LinSEPAL_PG).ravel().data
            fpr_LinSEPAL_PG = fp_LinSEPAL_PG/(fp_LinSEPAL_PG+tn_LinSEPAL_PG)
            tpr_LinSEPAL_PG = tp_LinSEPAL_PG/(tp_LinSEPAL_PG+fn_LinSEPAL_PG)
            fdr_LinSEPAL_PG = fp_LinSEPAL_PG/(tp_LinSEPAL_PG+fp_LinSEPAL_PG)
            f1_LinSEPAL_PG = 2*tp_LinSEPAL_PG/(2*tp_LinSEPAL_PG+fp_LinSEPAL_PG+fn_LinSEPAL_PG)  

            LinSEPAL_PG_df = pd.DataFrame([[str((l,h)), str(pp_nodes), uncertain_nodes+1, generatingseed, trial.item(), "LinSEPAL_PG", iter_LinSEPAL_PG, 
                                        constructiveness(V_reconstructed_LinSEPAL_PG),
                                        obj_val_series_LinSEPAL_PG[iter_LinSEPAL_PG].item(),
                                        stiefel_arc_length(V_reconstructed_LinSEPAL_PG,V).item(),
                                        frobenious_abs_distance(V_reconstructed_LinSEPAL_PG,V).item(),
                                        nnz_LinSEPAL_PG, tn_LinSEPAL_PG, fp_LinSEPAL_PG, fn_LinSEPAL_PG, tp_LinSEPAL_PG, fpr_LinSEPAL_PG, tpr_LinSEPAL_PG, fdr_LinSEPAL_PG, f1_LinSEPAL_PG]], 
                                        columns=columns)

            ### CLinSEPAL ###
            V_CLinSEPAL, S_CLinSEPAL, Y1_CLinSEPAL, Y2_CLinSEPAL, X_CLinSEPAL, iter_CLinSEPAL, primal_res_seriesY1_CLinSEPAL, primal_res_seriesY2_CLinSEPAL, primal_res_seriesX_CLinSEPAL, dual_res_seriesY1_CLinSEPAL, dual_res_seriesY2_CLinSEPAL, dual_res_seriesX_CLinSEPAL, obj_val_series_CLinSEPAL= CLinSEPAL_pp(covlow_restr, covhigh_restr, B_restr, a, b, rho, epsilon=epsilon, adaptive_stepsize=adaptive_stepsize, tau=tau, tau_abs=tau_abs, tau_rel=tau_rel, max_iter=max_iterations, sca_iter=sca_iter, sca_tol=sca_tol, seed=trial.item(), solver=solver_CLinSEPAL, verbosity=verbose_CLinSEPAL)
            BSV = B_restr*S_CLinSEPAL*V_CLinSEPAL
            BSV = anp.where(anp.abs(BSV)>1.e-2, BSV, 0.)
            y_pred_CLinSEPAL = (anp.abs(BSV)>1.e-2).flatten(order='F')

            V_reconstructed_CLinSEPAL = anp.zeros_like(B_pp)
            V_reconstructed_CLinSEPAL[anp.ix_(uncertain_rows, uncertain_cols)] = BSV
            V_reconstructed_CLinSEPAL[:,excluded_cols] = B_pp[:,excluded_cols].copy()

            S_reconstructed_CLinSEPAL = anp.zeros_like(B_pp)
            S_reconstructed_CLinSEPAL[anp.ix_(uncertain_rows, uncertain_cols)] = S_CLinSEPAL
            S_reconstructed_CLinSEPAL[:,excluded_cols] = B_pp[:,excluded_cols].copy()
            
            V_CLinSEPAL_df = pd.DataFrame(V_reconstructed_CLinSEPAL.flatten(order='F').reshape((1, dim_lh)), index=['CLinSEPAL_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            S_CLinSEPAL_df = pd.DataFrame(S_reconstructed_CLinSEPAL.flatten(order='F').reshape((1, dim_lh)), index=['CLinSEPAL_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_seriesY1_CLinSEPAL_df = pd.DataFrame(primal_res_seriesY1_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y1_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_seriesY2_CLinSEPAL_df = pd.DataFrame(primal_res_seriesY2_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y2_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            primal_res_seriesX_CLinSEPAL_df = pd.DataFrame(primal_res_seriesX_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_X_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_seriesY1_CLinSEPAL_df = pd.DataFrame(dual_res_seriesY1_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y1_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_seriesY2_CLinSEPAL_df = pd.DataFrame(dual_res_seriesY2_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_Y2_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            dual_res_seriesX_CLinSEPAL_df = pd.DataFrame(dual_res_seriesX_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_X_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)
            obj_val_CLinSEPAL_df = pd.DataFrame(obj_val_series_CLinSEPAL.flatten(order='F').reshape((1, max_iterations+1)), index=['CLinSEPAL_{}_{}_{}'.format(str((l,h)), generatingseed, trial.item())], dtype=float)

            nnz_CLinSEPAL = y_pred_CLinSEPAL.sum()
            tn_CLinSEPAL, fp_CLinSEPAL, fn_CLinSEPAL, tp_CLinSEPAL = confusion_matrix(y_true, y_pred_CLinSEPAL).ravel().data
            fpr_CLinSEPAL = fp_CLinSEPAL/(fp_CLinSEPAL+tn_CLinSEPAL)
            tpr_CLinSEPAL = tp_CLinSEPAL/(tp_CLinSEPAL+fn_CLinSEPAL)
            fdr_CLinSEPAL = fp_CLinSEPAL/(tp_CLinSEPAL+fp_CLinSEPAL)
            f1_CLinSEPAL = 2*tp_CLinSEPAL/(2*tp_CLinSEPAL+fp_CLinSEPAL+fn_CLinSEPAL)  

            CLinSEPAL_df = pd.DataFrame([[str((l,h)), str(pp_nodes), uncertain_nodes+1, generatingseed, trial.item(), "CLinSEPAL", iter_CLinSEPAL, 
                                        constructiveness(V_reconstructed_CLinSEPAL),
                                        obj_val_series_CLinSEPAL[iter_CLinSEPAL].item(), 
                                        stiefel_arc_length(V_reconstructed_CLinSEPAL,V).item(), 
                                        frobenious_abs_distance(V_reconstructed_CLinSEPAL,V).item(), 
                                        nnz_CLinSEPAL, tn_CLinSEPAL, fp_CLinSEPAL, fn_CLinSEPAL, tp_CLinSEPAL, fpr_CLinSEPAL, tpr_CLinSEPAL, fdr_CLinSEPAL, f1_CLinSEPAL]], 
                                        columns=columns)  

            ### only in case all methods succeed, concat the results

            #LinSEPAL-ADMM
            V_matrices = pd.concat((V_matrices, V_LinSEPAL_ADMM_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_series_LinSEPAL_ADMM_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_series_LinSEPAL_ADMM_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_LinSEPAL_ADMM_df))         
            
            metrics_df = pd.concat((metrics_df, LinSEPAL_ADMM_df))

            #LinSEPAL_PG
            V_matrices = pd.concat((V_matrices, V_LinSEPAL_PG_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_LinSEPAL_PG_df))

            metrics_df = pd.concat((metrics_df, LinSEPAL_PG_df))

            #CLinSEPAL
            V_matrices = pd.concat((V_matrices, V_CLinSEPAL_df))
            S_matrices = pd.concat((S_matrices, S_CLinSEPAL_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_seriesY1_CLinSEPAL_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_seriesY2_CLinSEPAL_df))
            primal_res_series_df = pd.concat((primal_res_series_df, primal_res_seriesX_CLinSEPAL_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_seriesY2_CLinSEPAL_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_seriesY1_CLinSEPAL_df))
            dual_res_series_df = pd.concat((dual_res_series_df, dual_res_seriesX_CLinSEPAL_df))
            obj_val_df = pd.concat((obj_val_df, obj_val_CLinSEPAL_df))

            metrics_df = pd.concat((metrics_df, CLinSEPAL_df))

            to_save=True

            if to_save:
                save_obj_parquet(V_matrices, "partial_prior_{}_{}_{}_{}_V_matrices".format(str((l,h)), str(pp_nodes), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(S_matrices, "partial_prior_{}_{}_{}_{}_S_matrices".format(str((l,h)), str(pp_nodes), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(metrics_df, "partial_prior_{}_{}_{}_{}_metrics_df".format(str((l,h)), str(pp_nodes), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(obj_val_df, "partial_prior_{}_{}_{}_{}_obj_val_df".format(str((l,h)), str(pp_nodes), generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(primal_res_series_df, "partial_prior_{}_{}_{}_{}_primal_res_series_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
                save_obj_parquet(dual_res_series_df, "partial_prior_{}_{}_{}_{}_dual_res_series_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)

        except:
            failed_combos.append((generatingseed, trial))
            print("Failed at ({},{},{})".format(str((l,h)), generatingseed, trial))
            save_obj_parquet(failed_combos, "partial_prior_{}_{}_{}_{}_failed_combos".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
            to_save=False
            continue

else:
    V_matrices = load_obj_parquet("partial_prior_{}_{}_{}_{}_V_matrices".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    S_matrices = load_obj_parquet("partial_prior_{}_{}_{}_{}_S_matrices".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    metrics_df = load_obj_parquet("partial_prior_{}_{}_{}_{}_metrics_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    obj_val_df = load_obj_parquet("partial_prior_{}_{}_{}_{}_obj_val_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    primal_res_series_df = load_obj_parquet("partial_prior_{}_{}_{}_{}_primal_res_series_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    dual_res_series_df = load_obj_parquet("partial_prior_{}_{}_{}_{}_dual_res_series_df".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)

    try:
        failed_combos = load_obj_parquet("partial_prior_{}_{}_{}_{}_failed_combos".format(str((l,h)), str(pp_nodes),  generatingseed, ntrials), data_dir=data_dir)
    except:
        failed_combos = []



Doing 1/10.

Residuals convergence at iteration 28: (objective, primal, dual)=(4.3601426744643845e-06,0.0007225052287403336,0.0)
Max number of iterations reached: objective 0.08908329139345561
Max number of iterations reached: (objective, primal Y1 St, primal Y2 St, primal  X Sp, dual Y1 St, dual Y1 St, dual X Sp)=(0.12061169620445966,0.0004508163723223487,0.0066757829810875444,0.010439957026288542,4.352002460993237e-05,0.018657933682304964,0.0)


  V_matrices = pd.concat((V_matrices, V_LinSEPAL_ADMM_df))
  metrics_df = pd.concat((metrics_df, LinSEPAL_ADMM_df))
  S_matrices = pd.concat((S_matrices, S_CLinSEPAL_df))



Doing 2/10.

Residuals convergence at iteration 503: (objective, primal, dual)=(1.0891592951978168e-06,0.0003900650059426947,0.0)
Max number of iterations reached: objective 0.40522322961731394
Max number of iterations reached: (objective, primal Y1 St, primal Y2 St, primal  X Sp, dual Y1 St, dual Y1 St, dual X Sp)=(0.2807877252252178,0.0001713680398060651,0.0004066201843235896,0.010949042122560988,0.0001272052029043697,0.0060889401795859634,0.0)

Doing 3/10.

Residuals convergence at iteration 704: (objective, primal, dual)=(2.237875128763278e-06,0.0005817163173440098,0.0)
Max number of iterations reached: objective 0.5547217827951716
Max number of iterations reached: (objective, primal Y1 St, primal Y2 St, primal  X Sp, dual Y1 St, dual Y1 St, dual X Sp)=(0.5748621531309732,0.00038571986093591004,0.013534545290358304,0.010940204210445439,0.0003292681689295755,0.004489180571700411,0.0)

Doing 4/10.

Residuals convergence at iteration 804: (objective, primal, dual)=(1.0589244707759349

  return f_raw(*args, **kwargs)



Doing 7/10.

Residuals convergence at iteration 57: (objective, primal, dual)=(2.790344186820448e-07,0.00017344058071683323,0.0)
Max number of iterations reached: objective 0.35784314626489255
Max number of iterations reached: (objective, primal Y1 St, primal Y2 St, primal  X Sp, dual Y1 St, dual Y1 St, dual X Sp)=(0.4987470375501113,0.00711195842989032,0.010072579827974742,0.010869906467472748,0.0054234844235554166,0.0055993007663576785,0.0)

Doing 8/10.

Residuals convergence at iteration 124: (objective, primal, dual)=(2.11776591818591e-06,0.0007657343057365502,0.0)
Max number of iterations reached: objective 0.0887622900606111
Max number of iterations reached: (objective, primal Y1 St, primal Y2 St, primal  X Sp, dual Y1 St, dual Y1 St, dual X Sp)=(0.4346476979097096,0.0017581007804134605,0.012854298182336568,0.01120464452570237,0.0011926749837925395,0.01205297708573134,0.0)

Doing 9/10.

Residuals convergence at iteration 617: (objective, primal, dual)=(0.0897174472145732,0.00075

In [20]:
# Display all results
metrics_df.reset_index(inplace=True, drop=True)
metrics_df

Unnamed: 0,"(l,h)",idx pp nodes,n uncertain nodes for pp nodes,seed gen,seed algo,method,iterations,constructiveness,D_KL,stiefel distance,frobenious distance,nnz,tn,fp,fn,tp,fpr,tpr,fdr,f1
0,"(12, 6)","[6, 11, 0]",2,0,1,LinSEPAL-ADMM,28,1.0,4.360143e-06,0.0,0.000881,12,60,0,0,12,0.0,1.0,0.0,1.0
1,"(12, 6)","[6, 11, 0]",2,0,1,LinSEPAL_PG,1000,1.0,0.08908329,0.322911,0.132487,12,60,0,0,12,0.0,1.0,0.0,1.0
2,"(12, 6)","[6, 11, 0]",2,0,1,CLinSEPAL,1000,1.0,0.1206117,0.564912,0.232162,12,59,1,1,11,0.016667,0.916667,0.083333,0.916667
3,"(12, 6)","[6, 11, 0]",2,0,2,LinSEPAL-ADMM,503,1.0,1.089159e-06,0.0,0.000433,12,60,0,0,12,0.0,1.0,0.0,1.0
4,"(12, 6)","[6, 11, 0]",2,0,2,LinSEPAL_PG,1000,0.791667,0.4052232,1.66252,0.4718,18,54,6,0,12,0.1,1.0,0.333333,0.8
5,"(12, 6)","[6, 11, 0]",2,0,2,CLinSEPAL,1000,1.0,0.2807877,0.877641,0.351287,12,59,1,1,11,0.016667,0.916667,0.083333,0.916667
6,"(12, 6)","[6, 11, 0]",2,0,3,LinSEPAL-ADMM,704,1.0,2.237875e-06,0.0,0.000636,12,60,0,0,12,0.0,1.0,0.0,1.0
7,"(12, 6)","[6, 11, 0]",2,0,3,LinSEPAL_PG,1000,0.916667,0.5547218,1.277001,0.519032,14,58,2,0,12,0.033333,1.0,0.142857,0.923077
8,"(12, 6)","[6, 11, 0]",2,0,3,CLinSEPAL,1000,1.0,0.5748622,1.251356,0.514705,12,59,1,1,11,0.016667,0.916667,0.083333,0.916667
9,"(12, 6)","[6, 11, 0]",2,0,4,LinSEPAL-ADMM,804,1.0,1.058924e-06,0.0,0.000665,12,60,0,0,12,0.0,1.0,0.0,1.0


In [21]:
# Keep only constructive CAs 
# and select the best one for each method
# according to the KL divergence (alignment metric)  
metrics_constructiveCA=metrics_df[metrics_df["constructiveness"]==1].copy()
metrics_constructiveCA=metrics_constructiveCA.loc[metrics_constructiveCA.groupby(["(l,h)", "seed gen", "method"])['D_KL'].idxmin()].copy()
metrics_constructiveCA

Unnamed: 0,"(l,h)",idx pp nodes,n uncertain nodes for pp nodes,seed gen,seed algo,method,iterations,constructiveness,D_KL,stiefel distance,frobenious distance,nnz,tn,fp,fn,tp,fpr,tpr,fdr,f1
17,"(12, 6)","[6, 11, 0]",2,0,6,CLinSEPAL,1000,1.0,0.0001851643,0.126576,0.004581,12,60,0,0,12,0.0,1.0,0.0,1.0
27,"(12, 6)","[6, 11, 0]",2,0,10,LinSEPAL-ADMM,255,1.0,1.494169e-07,0.0,0.000526,12,60,0,0,12,0.0,1.0,0.0,1.0
10,"(12, 6)","[6, 11, 0]",2,0,4,LinSEPAL_PG,807,1.0,0.000171218,0.0,0.012666,12,60,0,0,12,0.0,1.0,0.0,1.0


In [22]:
# Display the CAs
print("Ground truth")
display(V.round(3))

for idx in metrics_constructiveCA.index:
    row = metrics_constructiveCA.loc[idx]
    print("{}".format(row['method']))
    display(V_matrices.loc['{}_{}_{}_{}'.format(row['method'], row['(l,h)'], row['seed gen'], row['seed algo'])].values.reshape((l,h), order='F').round(3))

Ground truth


array([[ 0.   ,  0.471,  0.   ,  0.   , -0.   ,  0.   ],
       [-0.32 ,  0.   ,  0.   , -0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.695, -0.   , -0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.753,  0.   ],
       [-0.   ,  0.   , -0.   ,  0.776,  0.   , -0.   ],
       [-0.   ,  0.   , -0.   ,  0.   , -0.   ,  0.302],
       [ 0.   ,  0.   ,  0.719,  0.   , -0.   , -0.   ],
       [ 0.   , -0.   ,  0.   ,  0.   , -0.   , -0.953],
       [-0.947, -0.   ,  0.   , -0.   ,  0.   , -0.   ],
       [-0.   , -0.   ,  0.   , -0.   , -0.658, -0.   ],
       [-0.   , -0.   ,  0.   , -0.631, -0.   , -0.   ],
       [ 0.   , -0.882,  0.   , -0.   ,  0.   , -0.   ]])

CLinSEPAL


array([[ 0.   ,  0.472,  0.   ,  0.   ,  0.   ,  0.   ],
       [-0.322,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.695,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.753,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.773,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.302],
       [ 0.   ,  0.   ,  0.709,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.953],
       [-0.947,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.658,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.634,  0.   ,  0.   ],
       [ 0.   , -0.885,  0.   ,  0.   ,  0.   ,  0.   ]])

LinSEPAL-ADMM


array([[ 0.   , -0.471,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.32 ,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.696,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.753,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.776,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.301],
       [ 0.   ,  0.   , -0.718,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.953],
       [ 0.947,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.658,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.631,  0.   ,  0.   ],
       [ 0.   ,  0.882,  0.   ,  0.   ,  0.   ,  0.   ]])

LinSEPAL_PG


array([[ 0.   ,  0.469,  0.   ,  0.   ,  0.   ,  0.   ],
       [-0.326,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   , -0.695,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.752,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.781,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   ,  0.274],
       [ 0.   ,  0.   ,  0.719,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   ,  0.   , -0.962],
       [-0.946,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ],
       [ 0.   ,  0.   ,  0.   ,  0.   , -0.659,  0.   ],
       [ 0.   ,  0.   ,  0.   , -0.625,  0.   ,  0.   ],
       [ 0.   , -0.883,  0.   ,  0.   ,  0.   ,  0.   ]])