In [None]:
# (Output cleared due to dataset privacy limitations)

In [None]:
## Import Packages

In [None]:
from iscan import est_node_shifts, est_struct_shifts
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib as mlp
import matplotlib.pyplot as plt
from dagma.linear import DagmaLinear
from dagma.nonlinear import DagmaMLP, DagmaNonlinear
import dagma.utils
import math

In [None]:
## Load and process dataset

In [None]:
df = pd.read_csv("data/df_new_imputed.csv", header=0, index_col=0)
n = len(df)
full_df = pd.read_csv("data/CarpeDiem_dataset.csv", header=0)
covid_status = pd.read_csv("data/covid_status.csv", index_col=0)["COVID_status"]
black_latino = (full_df["Race"] == "Black or African American") | (full_df["Ethnicity"] == "Hispanic or Latino")
df = df.drop("Urine_output", axis=1)

In [None]:
## Split dataset into two groups

In [None]:
vars = list(df.columns)
group_a = df[black_latino & covid_status].to_numpy()[:2000,]
group_b = df[~black_latino & covid_status].to_numpy()[:2000,]

In [None]:
## Call ISCAN to discover nodes that are part of shifted causal structures

In [None]:
predict_shifted_nodes, order, ratio_dict = est_node_shifts(group_a, group_b, eta_G=0.001, eta_H=0.001, )

In [None]:
## Use FOCI to find shifted structures

In [None]:
est_ddag = est_struct_shifts(group_a, group_b, predict_shifted_nodes, order)

In [None]:
## Render output

In [None]:
empty = (est_ddag.sum(axis=0) + est_ddag.sum(axis=1)) == 0
dag = est_ddag[~empty, :][:,~empty]
trim = [i for i in range(len(vars)) if not empty[i]]
trim_vars = [vars[i] for i in trim]
trim
mlp.rcParams['figure.dpi'] = 600


G = nx.from_numpy_array(dag, create_using=nx.DiGraph)
pos = nx.spring_layout(G)  # positions for all nodes

plt.figure(figsize=(8,8)) 

nx.draw_networkx_nodes(G, pos, node_size=800, node_color=[math.log(ratio_dict[i]) if i in ratio_dict else 1 for i in trim ], cmap=plt.cm.spring)
nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=10, node_size=800)
labels = {i : trim_vars[i] for i in range(len(trim_vars))}
nx.draw_networkx_labels(G, pos, labels, font_size=6, font_color="black")

plt.show()

In [None]:
## Find full causal structure

In [None]:
print(len(vars))
eq_model = DagmaMLP(dims=[len(vars), 50, 20, 1])
eq_model.to('cuda')
model = DagmaNonlinear(eq_model)
W_covid = model.fit(torch.tensor(group_a, device='cuda'), T=6, w_threshold=0.2)
dagma.utils.is_dag(W_covid)


In [None]:
## Render

In [None]:
group_a

empty = (W_covid.sum(axis=0) + W_covid.sum(axis=1)) == 0
dag = W_covid[~empty, :][:,~empty]
trim_vars = [vars[i] for i in range(len(vars)) if not empty[i]]
dag

labels = {i : trim_vars[i] for i in range(len(trim_vars))}

G = nx.from_numpy_array(dag, create_using=nx.DiGraph)
pos = nx.spring_layout(G)  # positions for all nodes

plt.figure(figsize=(12,12)) 

nx.draw_networkx_nodes(G, pos)
nx.draw_networkx_edges(G, pos, arrowstyle="->", arrowsize=10)
labels = {i : trim_vars[i] for i in range(len(trim_vars))}
nx.draw_networkx_labels(G, pos, labels, font_size=6, font_color="black")

plt.show()