In [30]:
import pandas as pd
from upgrade import *
import numpy as np
from pathlib import Path
import os
from tqdm import tqdm

dataname = "erdos_renyi/d40_p0.1"
mi = 5      # The number of values a variable can take is ranged in [2, mi-1]
di = 1.0      # The dirichlet alpha that controls the data distribution
n = 10      # The number of data silos

silos = []

# folderpath = f"./data/distributed/erdos_renyi/d20_p0.2/m3_d1_n10"
# groundtruth = np.loadtxt(f"./data/distributed/erdos_renyi/d20_p0.2/adj.txt")

folderpath = f"./data/distributed/{dataname}/m{mi}_d{di}_n{n}"
groundtruth = np.loadtxt(f"./data/distributed/{dataname}/adj.txt")

if not Path(folderpath).exists():
    print("Folder", folderpath, "not exist!")
else:
    for file in sorted(os.listdir(folderpath)):
        filename = os.path.join(folderpath, file)
        silo_data = pd.read_csv(filename)
        silos.append(silo_data)
        print("Loaded file:", filename, end="\t")
        all_vars = silos[0].columns
        print(len(silo_data), " Instances\t", len(all_vars), "Variables")

all_vars = list(all_vars)
merged_df = pd.concat(silos, axis=0)
# merged_df['count'] = [1] * len(merged_df)

Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-0.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-1.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-2.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-3.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-4.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-5.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-6.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-7.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d40_p0.1/m5_d1.0_n10/silo-8.csv	2500  Instances	 40 Variables
Loaded file: ./data/distributed/erdos_renyi/d4

In [31]:
confidence = 0.01
connectivity = {var: [] for var in all_vars}
chisq_obj = CIT(merged_df, "chisq")

for X in connectivity.keys():
    other_vars = list(set(all_vars) - set(connectivity[X]) - set([X]))
    for Y in other_vars:
        pval = chisq_obj(all_vars.index(X), all_vars.index(Y), []) # type: ignore
        if pval <= confidence: # type: ignore
            connectivity[X] = list(set(connectivity[X]) | set([Y]))
            connectivity[Y] = list(set(connectivity[Y]) | set([X]))

# for X, club_X in connectivity.items():
#     print(X, club_X, len(club_X))

In [32]:
basis = []
ordering = sorted(list(connectivity.keys()), key=lambda item: len(connectivity[item]), reverse=False)
while len(ordering):
    var = ordering.pop(0)
    ordering = list(set(ordering) - set(connectivity[var]))
    basis.append(var)

print(basis)

['X39', 'X30', 'X18', 'X34', 'X21', 'X5', 'X27']


In [33]:
def generate_uniform_distributions(P0: np.ndarray, num_gen=100, gamma2=0.8):
    Ulist = list(np.eye(P0.shape[0]))
    # Compute the boundary points
    boundaries = []
    for i in range(len(Ulist)):
        if P0[i]/gamma2 < 1:
            alpha_i = 1/(1 - P0[i]) * (1 - P0[i]/(gamma2 + 0.001))
            boundary_i = alpha_i * P0 + (1 - alpha_i) * Ulist[i]
        else:
            boundary_i = Ulist[i]
        boundaries.append(boundary_i)
    
    boundaries = np.stack(boundaries)
    w = np.random.uniform(0, 1, (num_gen * 30, len(Ulist)))
    w = w/w.sum(axis=1, keepdims=True)
    
    kmeans = KMeans(n_clusters=num_gen, n_init="auto")
    kmeans.fit(w @ boundaries)
    res = kmeans.cluster_centers_
    
    return res
    # return None

In [34]:
markov_blankets = {var: [] for var in all_vars}
confidence = 0.01

num_env = 50
gamma2 = 0.4

sample_dis = {var: generate_uniform_distributions(P0=marginal_prob(merged_df, [var]),
                                                num_gen=num_env, 
                                                gamma2=np.power(gamma2, 1./len(basis))) for var in basis}

In [35]:
def multivariate_sampling(data: pd.DataFrame, variables: list, sample_dis: dict, instance_index):
    remains = deepcopy(variables)
    while len(remains):
        sampling_var = remains.pop(0)
        distribution = sample_dis[sampling_var][instance_index]
        _, all_index = univariate_sampling(data, sampling_var, {i: distribution[i] for i in range(distribution.shape[0])})
    return all_index

In [36]:
def GSMB(indexes, confidence=0.01):
    data = merged_df.iloc[indexes].reset_index().drop(columns=['index'])
    chisq_obj = CIT(data, "chisq") # construct a CIT instance with data and method name
    all_var_idx = [i for i in range(len(data.columns))]
    markov_blankets_idx = {i: [] for i in range(len(data.columns))}

    for X in all_var_idx:
        S = markov_blankets_idx[X]
        # X = 6
        prev_length = 0
        count = 0
        while True:
            count += 1
            # print("==============New cycle==================")
            for Y in list(set(all_var_idx) - set(S) - set([X])):
                if Y != X:
                    pval = chisq_obj(X, Y, S) # type:ignore
                    if pval <= confidence: # type:ignore
                        S.append(Y)
            
            for Y in deepcopy(S):
                pval = chisq_obj(X, Y, list(set(S) - set([Y]))) # type:ignore
                if pval > confidence: # type:ignore
                    S.remove(Y)
            
            if (len(S) - prev_length == 0) or (count > 2):
                break
            else:
                prev_length = len(S)
        
        # markov_blankets[data.columns[X]] = [data.columns[i] for i in S]
        markov_blankets_idx[X] = list(set(markov_blankets_idx[X])|set(S))
        for i in S:
            if X not in markov_blankets_idx[i]:
                markov_blankets_idx[i].append(X)
    
    markov_blankets = {var: [] for var in all_vars}
    for idx, mb_idxes in markov_blankets_idx.items():
        var = all_vars[idx]
        markov_blankets[var] = [all_vars[i] for i in mb_idxes]
    
    return markov_blankets

In [37]:
### Parallel

import concurrent.futures
from concurrent.futures import ProcessPoolExecutor

# Number of parallel executions
num_parallel_executions = num_env
silos_index = [multivariate_sampling(merged_df, basis, sample_dis, i) for i in range(num_parallel_executions)]

In [38]:
results = []

def run_in_parallel(func, args_list, max_parallel_executions):
    with ProcessPoolExecutor(max_workers=max_parallel_executions) as executor:
        future_to_arg = {executor.submit(func, arg): arg for arg in args_list}
        for future in concurrent.futures.as_completed(future_to_arg):
            arg = future_to_arg[future]
            try:
                result = future.result()
                results.append(result)
            except Exception as exc:
                print(f'{arg} generated an exception: {exc}')
    return results

# Running F in parallel and storing results
results = run_in_parallel(GSMB, silos_index, max_parallel_executions=num_env)

In [39]:
for res in results:
    for var, blanket in res.items(): #type:ignore
        markov_blankets[var] += blanket

In [41]:
max_size = 8

mk_with_freq = {var: [] for var in all_vars}
for var in markov_blankets.keys():
    mk_with_freq[var] = []
    vals, freqs = np.unique(markov_blankets[var], return_counts=True)
    # print("Variable", var)
    for val, freq in zip(vals, freqs):
        # print(f"\t[{val}, {freq:>2}]", end="")
        if freq >= int(0.8*num_env):
            if val not in mk_with_freq[var]:
                mk_with_freq[var].append((val, freq))
            # if var not in mk_with_freq[val]:
            #     mk_with_freq[val].append((var, freq))

for var in markov_blankets.keys():
    mk_with_freq[var] = [var for var, freq in sorted(mk_with_freq[var], key=lambda item: item[1], reverse=True)[:max_size]]
mk_with_freq

{'X10': ['X15', 'X16', 'X18', 'X2', 'X20', 'X24', 'X29', 'X3'],
 'X12': ['X20', 'X38', 'X35'],
 'X16': ['X10', 'X30', 'X2', 'X19', 'X21', 'X32', 'X38'],
 'X23': ['X17', 'X18', 'X20', 'X24', 'X30', 'X37', 'X5', 'X32'],
 'X33': ['X11', 'X15', 'X18', 'X19', 'X2', 'X20', 'X26', 'X29'],
 'X4': ['X17', 'X37', 'X33', 'X18', 'X20', 'X30', 'X32', 'X2'],
 'X8': ['X35', 'X17', 'X5', 'X10', 'X37', 'X38', 'X31', 'X22'],
 'X11': ['X15', 'X18', 'X2', 'X20', 'X24', 'X3', 'X30', 'X34'],
 'X28': ['X18', 'X2', 'X20', 'X30', 'X31', 'X32', 'X34', 'X37'],
 'X38': ['X17', 'X18', 'X20', 'X30', 'X32', 'X35', 'X37', 'X40'],
 'X25': ['X22', 'X37', 'X17', 'X31', 'X20', 'X24', 'X10'],
 'X35': ['X13', 'X15', 'X17', 'X18', 'X2', 'X20', 'X30', 'X31'],
 'X13': ['X18', 'X20', 'X30', 'X31', 'X32', 'X34', 'X35', 'X37'],
 'X22': ['X17', 'X18', 'X2', 'X20', 'X30', 'X31', 'X32', 'X37'],
 'X24': ['X17', 'X18', 'X2', 'X27', 'X30', 'X31', 'X37', 'X40'],
 'X6': ['X24', 'X17', 'X20', 'X2', 'X3', 'X37', 'X5'],
 'X7': ['X17', 'X34

In [None]:
def recursive_conn(anchor_var, start_var, track):
    candidates = list(set(mk_with_freq[anchor_var]) & set(mk_with_freq[start_var]) - set(track))
    if len(candidates):
        return [start_var] + [recursive_conn(anchor_var, can, track + [start_var]) for can in candidates]
    else:
        return start_var
    

anchor_var = 'X6'
potential_clusters = []
print(recursive_conn('X6', 'X8', potential_clusters))

In [12]:
def compute_variance_viaindexesv2(indexes: list, variable: str, parents: list):
    conditional_probs_record = merged_df[parents + [variable]].groupby(parents + [variable]).count().reset_index()
    mll_list = []
    env = 0
    for index in indexes:
        vertical_sampled_data = merged_df.iloc[index].reset_index()
        vertical_sampled_data = vertical_sampled_data.drop(columns=['index'])
        vertical_sampled_data.insert(0, 'count', [1] * len(vertical_sampled_data))
        
        summary_with_ch = vertical_sampled_data.groupby(parents + [variable])['count'].sum().reset_index()
        mll, output = compute_mll(summary_with_ch, parents, env)
        conditional_probs_record = conditional_probs_record.merge(output, on=parents + [variable], how='left')
        mll_list.append(mll)
        env += 1
    
    mean_mll = np.mean(mll_list)
    var_avg = conditional_probs_record.iloc[:, len(parents) + 1:].var(axis=1, skipna=True).mean()
    return var_avg, mean_mll, conditional_probs_record


def compute_weighted_variance_viaindexesv2(indexes: list, variable: str, parents: list):
    variance, _, df = compute_variance_viaindexesv2(indexes, variable, parents)
    if len(parents):
        joint_mat = np.array([df[f'joint_{i}'] for i in range(len(indexes))]).T
        probs_mat = np.array([df[f'probs_{i}'] for i in range(len(indexes))]).T
        probs_mean = []
        for i in range(probs_mat.shape[0]):
            if len(probs_mat[i][~np.isnan(probs_mat[i])]):
                probs_mean.append(np.mean(probs_mat[i][~np.isnan(probs_mat[i])]).item())
            else:
                probs_mean.append(0)
                
        probs_mean = np.expand_dims(np.array(probs_mean), 1)
        prod = joint_mat * (probs_mat - probs_mean)**2
        return np.mean(prod[~np.isnan(prod)]), parents
    else:
        return variance, parents
    
def F_wrapper(args):
    return compute_weighted_variance_viaindexesv2(*args)

# Function to execute F in parallel with limited concurrency
def run_in_parallel2(func, args_list, max_parallel_executions):
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel_executions) as executor:
        future_to_arg = {executor.submit(func, args): args for args in args_list}
        for future in concurrent.futures.as_completed(future_to_arg):
            arg = future_to_arg[future]
            try:
                result = future.result()
                results.append(result)
            except Exception as exc:
                print(f'{arg} generated an exception: {exc}')
    return results

In [14]:
np.mean([len(mb) for var, mb in mk_with_freq.items()])

7.425

In [None]:
def recursive_conn(neighbors):
  if len(neighbors) <= 1:
    return [neighbors]
  else:
    return [[i] + recursive_conn(list(set(neighbors)&set(mk_with_freq[i]))) for i in neighbors]


def unfold(input):
  """
  Arguments:
    input: [var, var, ..., [var, ...], [var, ...]]
  
  that has a number of non-list element and a number of list element
  """
  cut_index = 0
  while cut_index < len(input):
    cut_index += 1
    if isinstance(input[cut_index], list):
      break
  
  out = []
  for i in range(cut_index, len(input)):
    out.append([*input[:cut_index], *input[i]])
  return out

In [None]:
anchor_var = 'X3'
recursive_output = recursive_conn(deepcopy(mk_with_freq[anchor_var]))
unique_elements = set()

while len(recursive_output):
  examine_group = recursive_output.pop(0)
  if not isinstance(examine_group[0], list) and isinstance(examine_group[-1], list):
    unfolded = unfold(examine_group)
    recursive_output += [*unfold(examine_group)]
  else:
    unique_elements.add(tuple(sorted(examine_group)))

result = [list(item) for item in unique_elements]
result

In [31]:
from itertools import combinations
import numpy as np


potential_parents = {var: [] for var in all_vars}
children = {var: [] for var in all_vars}
invariance_hardcap = 0.001
max_size = 8

repeat = 1
adj_record = []


for _ in range(repeat):
    for anchor_var in mk_with_freq.keys():
        markov_variables = list(set(mk_with_freq[anchor_var]) - set(children[anchor_var]))
        if len(markov_variables) < 1:
            continue
        
        lowest_variance = 1e2
        best_comb = []
        
        for l in range(1, len(markov_variables) + 1):
            results = []
            inputs = [[silos_index, anchor_var, list(comb)] for comb in list(combinations(markov_variables, l))]
            results = run_in_parallel2(F_wrapper, inputs, max_parallel_executions=64)
            lowest_variance, best_comb = sorted(results, key=lambda item: item[0])[0]

        potential_parents[anchor_var] = (best_comb, lowest_variance) # type:ignore
        
    adj_mtx = np.zeros([len(all_vars), len(all_vars)])
    for var in potential_parents.keys():
        if len(potential_parents[var]):
            parents, invariance = potential_parents[var]
            var_id = int(var[1:]) - 1
            for pa in parents:
                pa_id = int(pa[1:]) - 1
                if adj_mtx[var_id][pa_id] == 0:
                    adj_mtx[pa_id][var_id] = invariance
                elif adj_mtx[var_id][pa_id] > adj_mtx[pa_id][var_id]:
                    adj_mtx[pa_id][var_id] = invariance
                    adj_mtx[var_id][pa_id] = 0


    for i in range(len(all_vars)):
        children[f'X{i+1}'] = []
        for j in range(len(all_vars)):
            if adj_mtx[i][j] > 0:
                children[f'X{i+1}'].append(f'X{j+1}')
                
    adj_record.append(adj_mtx)
    if len(adj_record) >= 2:
        if np.sum(adj_record[-1] - adj_record[-2]) == 0:
            break

In [41]:
from plot_utils import true_edge, spur_edge, fals_edge, miss_edge, swap_pos

for fin_adjmtx in adj_record:
    # fin_adjmtx = adj_record[-1]

    etrue = true_edge(groundtruth, fin_adjmtx)
    espur = spur_edge(groundtruth, fin_adjmtx)
    efals = fals_edge(groundtruth, fin_adjmtx)
    emiss = miss_edge(groundtruth, fin_adjmtx)

    # print(etrue)
    print(len(etrue), len(espur), len(emiss), len(efals))

8 3 0 0
8 3 0 0
8 3 0 0


In [None]:
import matplotlib.pyplot as plt
import networkx as nx

G = nx.DiGraph()

fin_adjmtx = adj_record[0]

for i in range(fin_adjmtx.shape[0]):
    for j in range(fin_adjmtx.shape[1]):
        if fin_adjmtx[i][j] > 0:
            G.add_edge(f"X{i+1}", f"X{j+1}", weight=np.round(1/fin_adjmtx[i][j],2))
            # print("Here add edge", f"X{i+1}", f"X{j+1}")
    G.add_node(f"X{i+1}")
    

etrue = true_edge(groundtruth, fin_adjmtx)
espur = spur_edge(groundtruth, fin_adjmtx)
efals = fals_edge(groundtruth, fin_adjmtx)
emiss = miss_edge(groundtruth, fin_adjmtx)

# print(etrue)
print(len(etrue), len(espur), len(emiss), len(efals))

pos = nx.shell_layout(G)
pos = swap_pos(pos, 'X4', 'X3')
# pos = swap_pos(pos, 'X3', 'X5')
# pos = swap_pos(pos, 'X10', 'X5')
# pos = swap_pos(pos, 'X10', 'X8')
# pos = swap_pos(pos, 'X8', 'X3')
# pos = swap_pos(pos, 'X6', 'X5')
# pos = swap_pos(pos, 'X1', 'X2')
# pos = swap_pos(pos, 'X11', 'X8')
# pos = swap_pos(pos, 'X7', 'X19')
# pos = swap_pos(pos, 'X20', 'X7')
# pos = swap_pos(pos, 'X18', 'X16')


# nodes
nx.draw_networkx_nodes(G, pos, node_size=400, node_color="#1f78b4")

# edges
nx.draw_networkx_edges(G, pos, edgelist=espur, width=2, arrowstyle='->', arrowsize=20, edge_color="orange", label="Spurious Edges")
nx.draw_networkx_edges(G, pos, edgelist=emiss, width=2, arrowstyle='->', arrowsize=20, edge_color="purple", label="Missing Edges")
nx.draw_networkx_edges(G, pos, edgelist=efals, width=2, arrowstyle='->', arrowsize=20, edge_color="red", label="Anti-Causal Edges")
nx.draw_networkx_edges(G, pos, edgelist=etrue, width=2, arrowstyle='->', arrowsize=20, edge_color="green", label="Causal Edges")

# node labels
nx.draw_networkx_labels(G, pos, font_size=12, font_family="sans-serif", font_color='white')

# edge weight labels
# edge_labels = nx.get_edge_attributes(G, "weight")
# nx.draw_networkx_edge_labels(G, pos, edge_labels)

ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
plt.tight_layout()
# plt.box()
plt.title(dataname.upper())
# plt.legend()
plt.show()

# plt.savefig("res/asia.plot.svg", format="svg")

In [None]:
# for anchor_var in mk_with_freq.keys():
anchor_var = 'X6'
print("anchor_var:", anchor_var, end="")
print(mk_with_freq[anchor_var])
markov_variables = mk_with_freq[anchor_var]
if len(markov_variables) < 1:
    pass
else:
    lowest_variance = 1e2
    best_comb = []

    print("anchor var:", anchor_var, "Len(markov) = ", len(markov_variables))
    for l in range(1, len(markov_variables) + 1):
        for comb in list(combinations(markov_variables, l)):
            comb_variance = compute_weighted_variance_viasilos(silos, anchor_var, list(comb)) # type:ignore
            print("\tdoing comb", comb, "variance:", comb_variance)
            if comb_variance < lowest_variance and comb_variance < invariance_hardcap:
                lowest_variance = comb_variance
                best_comb = list(comb)

    print("\tParents:", best_comb, "\tVariance:", lowest_variance)
    potential_parents[anchor_var] = (best_comb, lowest_variance) # type:ignore

In [None]:
# Parents: ['X4'] 	Variance: 1.1418060575318796e-06