#### Env: Minh

In [1]:
import pandas as pd
from utils.upgrade import *
from causallearn.utils.cit import CIT
import numpy as np
from pathlib import Path
import os
from tqdm import tqdm
from copy import deepcopy

# dataname = "munin1"
dataname = "erdos_renyi/d100_p0.02"
mi = 5      # The number of values a variable can take is ranged in [2, mi-1]
di = 1.0      # The dirichlet alpha that controls the data distribution
n = 10      # The number of data silos

silos = []

folderpath = f"./data/distributed/{dataname}/m{mi}_d{di}_n{n}"
groundtruth = np.loadtxt(f"./data/distributed/{dataname}/adj.txt")

if not Path(folderpath).exists():
    print("Folder", folderpath, "not exist!")
else:
    for file in sorted(os.listdir(folderpath)):
        filename = os.path.join(folderpath, file)
        silo_data = pd.read_csv(filename)
        silos.append(silo_data)
        print("Loaded file:", filename, end="\t")
        print(len(silo_data), " Instances", len(silo_data.columns), "Variables")

merged_df = pd.concat(silos, axis=0)
merged_df = merged_df.reindex(sorted(merged_df.columns, key=lambda item: int(item[1:])), axis=1)
all_vars = list(merged_df.columns)

print("Num edges:", np.sum(groundtruth))

Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-0.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-1.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-2.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-3.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-4.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-5.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-6.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-7.csv	2500  Instances 100 Variables
Loaded file: ./data/distributed/erdos_renyi/d100_p0.02/m5_d1.0_n10/silo-8.csv	2500  Instances 100 Variables
Loaded file: ./data/distribu

In [2]:
from causallearn.search.ConstraintBased.CDNOD import cdnod
from baselines.FL_FedCDH.mycausallearn.utils.data_utils import get_cpdag_from_cdnod, get_dag_from_pdag
from causallearn.utils.cit import fisherz

c_indx = np.repeat(range(1, n + 1), len(silos[0])).reshape(-1, 1).astype(float)
cg = cdnod(merged_df.to_numpy(), c_indx, 0.05, fisherz)

est_graph = cg.G.graph[0:len(all_vars), 0:len(all_vars)]
est_cpdag = get_cpdag_from_cdnod(est_graph) # est_graph[i,j]=-1 & est_graph[j,i]=1  ->  est_graph_cpdag[i,j]=1
est_dag_from_pdag = get_dag_from_pdag(est_cpdag) # return a DAG from a PDAG in causaldag
adj_mtx = get_dag_from_pdag(est_cpdag) # return a DAG from a PDAG in causaldag

  0%|          | 0/101 [00:00<?, ?it/s]

In [None]:
def true_markov_blanket(adj_matrix, var_idx):
    parents = np.where(adj_matrix[:, var_idx])[0].tolist()
    children = np.where(adj_matrix[var_idx])[0].tolist()
    
    spouses = set()
    for c in children:
        for sp in np.where(adj_matrix[:, c])[0]:
            spouses.add(sp)
    
    pa_sp = list(set(parents)&spouses - set(parents))
    ch_sp = list(set(children)&spouses - set(children))
    spouses = list(spouses - set(pa_sp) - set(ch_sp))
    
    return parents, pa_sp, spouses, ch_sp, children


def to_list(all_vars, mb_idx_list):
    return [all_vars[i] for i in mb_idx_list]

In [None]:
confidence = 0.05
connectivity = {var: [] for var in all_vars}
chisq_obj = CIT(merged_df, "chisq")

for X in connectivity.keys():
    other_vars = list(set(all_vars) - set(connectivity[X]) - set([X]))
    for Y in other_vars:
        pval = chisq_obj(all_vars.index(X), all_vars.index(Y), []) # type: ignore
        if pval <= confidence: # type: ignore
            connectivity[X] = list(set(connectivity[X]) | set([Y]))
            connectivity[Y] = list(set(connectivity[Y]) | set([X]))

In [None]:
# connectivity['X2']

In [None]:
basis = []
ordering = sorted(all_vars, key=lambda item: len(connectivity[item]), reverse=False)

while len(ordering):
    x = ordering.pop(0)
    discard_vars = connectivity[x]
    ordering = sorted(list(set(ordering) - set(discard_vars)), 
                    key=lambda item: len(list(set(connectivity[item]) - set(discard_vars))), reverse=False)
    basis.append(x)
    
# basis

In [None]:
from copy import deepcopy

def GSMB(indexes, confidence=0.01):
    data = merged_df.iloc[indexes].reset_index().drop(columns=['index'])
    chisq_obj = CIT(data, "chisq") # construct a CIT instance with data and method name
    all_var_idx = [i for i in range(len(data.columns))]
    markov_blankets_idx = {i: [] for i in range(len(data.columns))}

    for X in all_var_idx:
        S = []
        prev_length = 0
        count = 0
        while True:
            count += 1
            # print("==============New cycle==================")
            for Y in list(set(all_var_idx) - set(S) - set([X])):
                if Y != X:
                    pval = chisq_obj(X, Y, S) # type:ignore
                    if pval <= confidence: # type:ignore
                        S.append(Y)
            
            for Y in deepcopy(S):
                pval = chisq_obj(X, Y, list(set(S) - set([Y]))) # type:ignore
                if pval > confidence: # type:ignore
                    S.remove(Y)
            
            if (len(S) - prev_length == 0) or (count > 2):
                break
            else:
                prev_length = len(S)
        markov_blankets_idx[X] = list(set(markov_blankets_idx[X])|set(S))
    
    markov_blankets = {var: [] for var in all_vars}
    for idx, mb_idxes in markov_blankets_idx.items():
        var = all_vars[idx]
        markov_blankets[var] = [all_vars[i] for i in mb_idxes]
    
    return markov_blankets

In [None]:
TMB_activated = 1
markov_blankets = {var: [] for var in all_vars}

if TMB_activated:
    for var in markov_blankets.keys():
        pa, pa_sp, sp, ch_sp, ch = true_markov_blanket(groundtruth, int(var[1:]) - 1)
        markov_blankets[var] = list(set(to_list(all_vars, pa + pa_sp + sp + ch_sp + ch)) - set([var]))
else:
    markov_blankets = GSMB([i for i in range(len(merged_df))])

In [None]:
def generate_uniform_distributions(P0: np.ndarray, num_gen=100, gamma2=0.8):
    Ulist = list(np.eye(P0.shape[0]))
    # Compute the boundary points
    boundaries = []
    for i in range(len(Ulist)):
        if P0[i]/gamma2 < 1:
            alpha_i = 1/(1 - P0[i]) * (1 - P0[i]/(gamma2 + 0.001))
            boundary_i = alpha_i * P0 + (1 - alpha_i) * Ulist[i]
        else:
            boundary_i = Ulist[i]
        boundaries.append(boundary_i)
    
    boundaries = np.stack(boundaries)
    w = np.concatenate([np.random.dirichlet([alpha/2] * len(Ulist), size=num_gen) for alpha in range(1, 10)])
    
    kmeans = KMeans(n_clusters=num_gen, n_init="auto")
    kmeans.fit(w @ boundaries)
    res = kmeans.cluster_centers_
    
    return res

def multivariate_sampling(data: pd.DataFrame, variables: list, sample_dis: dict, instance_index):
    remains = deepcopy(variables)
    while len(remains):
        sampling_var = remains.pop(0)
        distribution = sample_dis[sampling_var][instance_index]
        _, all_index = univariate_sampling(data, sampling_var, {i: distribution[i] for i in range(distribution.shape[0])})
    return all_index

In [None]:
def unnested(input: list):
    if len(input) == 1:
        if isinstance(input[0], list):
            return unnested(input[0])
        else:
            return input
    else:
        nested_loc = [i for i in range(len(input)) if isinstance(input[i], list)]
        while len(nested_loc):
            i = nested_loc.pop(0)
            input += [*input[i]]
            input.pop(i)
            nested_loc = [i for i in range(len(input)) if isinstance(input[i], list)]
        return list(set(input))

# test_input = ['X66']
# unnested(test_input)

In [None]:
buffers = {}
visited = []
def recursive_conn(neighbors):   
    output = []
    if len(neighbors) <= 1:
        output = [neighbors]
    else:
        for i in neighbors:
            key = sorted(list(set(neighbors)&set(markov_blankets[i])))
            if tuple(key) in buffers.keys():
                # print("Here in", i)
                res_i = [i] + buffers[tuple(key)]
            else:
                # print("Here recur", i)
                val = recursive_conn(key)
                buffers[tuple(key)] = val
                res_i = [i] + val

            visit_key = tuple(sorted(unnested(deepcopy(res_i))))
            if visit_key not in visited:
                # print(visit_key)
                output.append(res_i)
                visited.append(visit_key)
    return output


def unfold(input):
    """
    Arguments:
      input: [var, var, ..., [var, ...], [var, ...]]

    that has a number of non-list element and a number of list element
    """
    cut_index = 0
    while cut_index < len(input):
      cut_index += 1
      if isinstance(input[cut_index], list):
        break

    out = []
    for i in range(cut_index, len(input)):
      out.append([*input[:cut_index], *input[i]])
    
    return out

In [None]:
recursive_outputs = {}

for anchor_var in all_vars:
    buffers.clear()
    visited.clear()
    recursive_outputs[anchor_var] = recursive_conn(deepcopy(markov_blankets[anchor_var]))

In [None]:
def removes_irrelevant(df, var, plausible_set, confidence=0.01):
    subdata = df[[var, *plausible_set]]
    all_var = list(subdata.columns)
    all_var_idx = [i for i in range(len(all_var))]
    chisq_obj = CIT(subdata, 'chisq')
    
    X = all_var.index(var)
    S = []
    prev_length = 0
    count = 0
    while True:
        count += 1
        for Y in deepcopy(S):
            pval = chisq_obj(X, Y, list(set(S) - set([Y]))) # type:ignore
            if pval > confidence: # type:ignore
                S.remove(Y)
                
        for Y in list(set(all_var_idx) - set(S) - set([X])):
            if Y != X:
                pval = chisq_obj(X, Y, S) # type:ignore
                if pval <= confidence: # type:ignore
                    S.append(Y)
                    
        if (len(S) - prev_length == 0) or (count > 10):
            break
        else:
            prev_length = len(S)
        
    return [all_var[i] for i in S]

In [None]:
potential_parents = {}
for anchor_var in markov_blankets.keys():
    recursive_output = recursive_outputs[anchor_var]
    final_output = set()
    for i in range(len(recursive_output)):
        test_case = deepcopy(recursive_output[i])
        unique_elements = set()
        if len(test_case) <= 1:
            unique_elements.add(tuple(test_case))
        else:
            first_element = test_case.pop(0)
            while len(test_case):
                examine_group = test_case.pop(0)
                if len(examine_group) and not isinstance(examine_group[0], list) and isinstance(examine_group[-1], list):
                    test_case += [*unfold(examine_group)]
                else:
                    unique_elements.add(tuple(sorted(examine_group + [first_element])))
                
        final_output = final_output|unique_elements
    potential_parents[anchor_var] = [removes_irrelevant(merged_df, anchor_var, j, 0.05) for j in final_output]

In [None]:
groundtruth_dict = {}
for var in all_vars:
    var_id = all_vars.index(var)
    pa, pa_sp, sp, ch_sp, ch = true_markov_blanket(groundtruth, var_id)
    groundtruth_dict[var] = [all_vars[i] for i in pa + pa_sp]

In [None]:
for var in all_vars:
    checked = []
    for po in potential_parents[var]:
        checked.append(set(groundtruth_dict[var]) & set(po) == set(groundtruth_dict[var]))
    
    if True in checked:
        # print(var, "True")
        pass
    else:
        print(var, "False", groundtruth_dict[var])
        
        # break

In [None]:
def compute_variance_viaindexesv2(indexes: list, variable: str, parents: list):
    conditional_probs_record = merged_df[parents + [variable]].groupby(parents + [variable]).count().reset_index()
    mll_list = []
    env = 0
    for index in indexes:
        vertical_sampled_data = merged_df.iloc[index].reset_index()
        vertical_sampled_data = vertical_sampled_data.drop(columns=['index'])
        vertical_sampled_data.insert(0, 'count', [1] * len(vertical_sampled_data))
        
        summary_with_ch = vertical_sampled_data.groupby(parents + [variable])['count'].sum().reset_index()
        mll, output = compute_mll(summary_with_ch, parents, env)
        conditional_probs_record = conditional_probs_record.merge(output, on=parents + [variable], how='left')
        mll_list.append(mll)
        env += 1
    
    mean_mll = np.mean(mll_list)
    var_avg = conditional_probs_record.iloc[:, len(parents) + 1:].var(axis=1, skipna=True).mean()
    return var_avg, mean_mll, conditional_probs_record


def compute_weighted_variance_viaindexesv2(indexes: list, variable: str, parents: list):
    variance, _, df = compute_variance_viaindexesv2(indexes, variable, parents)
    if len(parents):
        joint_mat = np.array([df[f'joint_{i}'] for i in range(len(indexes))]).T
        probs_mat = np.array([df[f'probs_{i}'] for i in range(len(indexes))]).T
        probs_mean = []
        for i in range(probs_mat.shape[0]):
            if len(probs_mat[i][~np.isnan(probs_mat[i])]):
                probs_mean.append(np.mean(probs_mat[i][~np.isnan(probs_mat[i])]).item())
            else:
                probs_mean.append(0)
                
        probs_mean = np.expand_dims(np.array(probs_mean), 1)
        # joint_mat = joint_mat.shape[1] * joint_mat/joint_mat.sum(axis=1, keepdims=True)
        prod = joint_mat * (probs_mat - probs_mean)**2
        return np.power(np.mean(prod[~np.isnan(prod)]), 0.5), parents
    else:
        return variance, parents

#### Version 2 -- Given the leaves

In [None]:
from multiprocessing import Pool
from typing import List, Tuple


def individual_causal_search(var, silos_index):
    record = {}
    for mb_var in markov_blankets[var]:
        variance, _ = compute_weighted_variance_viaindexesv2(silos_index, var, [mb_var])
        record[tuple([mb_var])] = variance
    return {var: record}


# Function to execute F in parallel
def execute_in_parallel(args_list: List[Tuple]):
    with Pool() as pool:
        # Map the function F to the arguments in parallel
        results = pool.starmap(individual_causal_search, args_list)
    return results

In [None]:
leaves = ['X8']

In [None]:
num_env = 10
gamma2 = 0.5

sample_dis = {x: generate_uniform_distributions(P0=marginal_prob(merged_df, [x]),
                                                num_gen=num_env, 
                                                gamma2=np.power(gamma2, 1./len(leaves))) for x in leaves}
silos_index = [multivariate_sampling(merged_df, leaves, sample_dis, i) for i in range(num_env)]

inputs = [(var, silos_index) for var in markov_blankets.keys()]
outputs = execute_in_parallel(inputs)

results = tuple()
for out_dict in outputs:
    results += tuple(out_dict.items())

results = dict(results)

In [None]:
weighted_mtx = np.ones([len(all_vars), len(all_vars)])

for var in results.keys(): #type:ignore
    if len(results[var].items()):
        var_id = all_vars.index(var)
        best_comb, best_variance = min(results[var].items(), key=lambda item: item[1])
        # print(var, best_comb, best_variance)
        for parent in best_comb:
            pa_id = all_vars.index(parent)
            if best_variance < weighted_mtx[var_id][pa_id]:
                weighted_mtx[pa_id][var_id] = best_variance
                weighted_mtx[var_id][pa_id] = 1

In [None]:
weighted_mtx[weighted_mtx == 1] = 0
adj_mtx = (weighted_mtx > 0) * 1
adj_mtx = adj_mtx.T
# adj_mtx

In [None]:
from plot_utils import true_edge, spur_edge, fals_edge, miss_edge, swap_pos

etrue = true_edge(groundtruth, adj_mtx)
espur = spur_edge(groundtruth, adj_mtx)
efals = fals_edge(groundtruth, adj_mtx)
emiss = miss_edge(groundtruth, adj_mtx)

print(len(etrue), len(espur), len(emiss), len(efals))

In [None]:
sources_idx = np.array([i for i in range(len(all_vars)) if np.sum(adj_mtx[:, i]) == 0])
sources = np.array(all_vars)[sources_idx].tolist()
sources = list(set(sources) - set(basis))
sources

#### Version 1 -- Given the sources

In [None]:
# sources_idx = [i for i in range(groundtruth.shape[0]) if np.sum(groundtruth[:,i]) == 0]
# sources = np.array(all_vars)[sources_idx].tolist()
# sources

In [None]:
from multiprocessing import Pool
from typing import List, Tuple


def individual_causal_searchv2(var, silos_index):
    buffers = {}
    print(var, len(potential_parents[var]))
    for group in potential_parents[var]:
        conn_group = list(set(connectivity[var])&set(group))
        # print("Applied Connectivity:", group, "-->", conn_group)
        cleaned_group = removes_irrelevant(merged_df, var, conn_group)
        # print("Group:", conn_group, "-->", cleaned_group)
        if len(cleaned_group):
            # variance, _, _ = compute_variance_viaindexesv2(silos_index, var, cleaned_group)
            variance, _ = compute_weighted_variance_viaindexesv2(silos_index, var, cleaned_group)
            buffers[tuple(cleaned_group)] = variance
    return {var: buffers}


# Function to execute F in parallel
def execute_in_parallel(args_list: List[Tuple]):
    with Pool() as pool:
        # Map the function F to the arguments in parallel
        results = pool.starmap(individual_causal_searchv2, args_list)
    return results

In [None]:
num_env = 10
gamma2 = 0.5

sample_dis = {x: generate_uniform_distributions(P0=marginal_prob(merged_df, [x]),
                                                num_gen=num_env, 
                                                gamma2=np.power(gamma2, 1./len(basis))) for x in basis}
silos_index = [multivariate_sampling(merged_df, basis, sample_dis, i) for i in range(num_env)]

In [None]:
inputs = [(var, silos_index) for var in markov_blankets.keys()]
outputs = execute_in_parallel(inputs)

results = tuple()
for out_dict in outputs:
    results += tuple(out_dict.items())

results = dict(results)

In [None]:
weighted_mtx = np.ones([len(all_vars), len(all_vars)])

for var in results.keys(): #type:ignore
    var_id = all_vars.index(var)
    if len(results[var].items()):
        best_comb, best_variance = min(results[var].items(), key=lambda item: item[1])
        # print(var, best_comb, best_variance)
        
        for parent in best_comb:
            pa_id = all_vars.index(parent)
            if best_variance < weighted_mtx[var_id][pa_id]:
                weighted_mtx[pa_id][var_id] = best_variance
                weighted_mtx[var_id][pa_id] = 1

In [None]:
hardcap_invariance = 1e-3
weighted_mtx[weighted_mtx > hardcap_invariance] = 0
adj_mtx = (weighted_mtx > 0) * 1
# adj_mtx

In [3]:
from utils.plot_utils import true_edge, spur_edge, fals_edge, miss_edge, swap_pos

etrue = true_edge(groundtruth, adj_mtx)
espur = spur_edge(groundtruth, adj_mtx)
efals = fals_edge(groundtruth, adj_mtx)
emiss = miss_edge(groundtruth, adj_mtx)

print(len(etrue), len(espur), len(emiss), len(efals))

63 139 70 61


#### Left-out code

In [None]:
inv_var = 'X2'
pa, pa_sp, sp, ch_sp, ch = true_markov_blanket(groundtruth, int(inv_var[1:]) - 1)
print("Pa:", to_list(all_vars, pa))
print("Pa-Sp:", to_list(all_vars, pa_sp))
print("Sp:", to_list(all_vars, sp))
print("Ch-Sp:", to_list(all_vars, ch_sp))
print("Ch:", to_list(all_vars, ch))

In [None]:
for var in results.keys(): #type:ignore
    best_comb, best_variance = min(results[var].items(), key=lambda item: item[1])
    print(var, "\t", groundtruth_dict[var], "\t", best_comb, "\t", best_variance)

#### Plot -- Graphs

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

G = nx.DiGraph()

fin_adjmtx = adj_mtx

for i in range(fin_adjmtx.shape[0]):
    for j in range(fin_adjmtx.shape[1]):
        if fin_adjmtx[i][j] > 0:
            G.add_edge(f"X{i+1}", f"X{j+1}", weight=np.round(1/fin_adjmtx[i][j],2))
            # print("Here add edge", f"X{i+1}", f"X{j+1}")
    G.add_node(f"X{i+1}")
    

etrue = true_edge(groundtruth, fin_adjmtx)
espur = spur_edge(groundtruth, fin_adjmtx)
efals = fals_edge(groundtruth, fin_adjmtx)
emiss = miss_edge(groundtruth, fin_adjmtx)

# print(etrue)
print(len(etrue), len(espur), len(emiss), len(efals))

pos = nx.shell_layout(G)
pos = swap_pos(pos, 'X4', 'X3')
# pos = swap_pos(pos, 'X3', 'X5')

# nodes
nx.draw_networkx_nodes(G, pos, node_size=400, node_color="#1f78b4")

# edges
nx.draw_networkx_edges(G, pos, edgelist=espur, width=2, arrowstyle='->', arrowsize=20, edge_color="orange", label="Spurious Edges")
nx.draw_networkx_edges(G, pos, edgelist=emiss, width=2, arrowstyle='->', arrowsize=20, edge_color="purple", label="Missing Edges")
nx.draw_networkx_edges(G, pos, edgelist=efals, width=2, arrowstyle='->', arrowsize=20, edge_color="red", label="Anti-Causal Edges")
nx.draw_networkx_edges(G, pos, edgelist=etrue, width=2, arrowstyle='->', arrowsize=20, edge_color="green", label="Causal Edges")

# node labels
nx.draw_networkx_labels(G, pos, font_size=12, font_family="sans-serif", font_color='white')

# edge weight labels
# edge_labels = nx.get_edge_attributes(G, "weight")
# nx.draw_networkx_edge_labels(G, pos, edge_labels)

ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
plt.tight_layout()
# plt.box()
plt.title(dataname.upper())
# plt.legend()
plt.show()

# plt.savefig("res/asia.plot.svg", format="svg")