In [1]:
import os

# Ensure the Julia binary directory is in PATH
os.environ['PATH'] = '/home/mila/m/mehrab.hamidi/.local/bin:' + os.environ['PATH']

# Initialize Julia with compiled_modules=False
from julia.api import Julia
jl = Julia(compiled_modules=False)



In [None]:
def perform_knockoff_filtering(X, Y, q=0.1):
    """
    Perform knockoff filtering using functions from knockoffspy's ko.
    Steps:
    1. Generate model-X Gaussian knockoffs for X.
    2. Compute MK statistics W using original and knockoff features plus Y.
    3. Apply mk_threshold to find the cutoff.
    4. Return selected variables.
    """
    # Generate Model-X Gaussian Knockoffs
    # Check if additional arguments are needed (e.g., fitting a model for Y~X)
    # The usage might look like this (adjust as needed):
    X_knock = ko.modelX_gaussian_knockoffs(X)
    
    # Compute MK statistics. This may require specifying a model or penalty:
    # For demonstration, let's assume default usage:
    W = ko.MK_statistics(X, X_knock, Y)
    
    # Compute threshold:
    T = ko.mk_threshold(W, q)
    
    selected = np.where(W >= T)[0]
    return selected

def visualize_comparison(G_true, W_est, save_path, title):
    plt.figure(figsize=(8,8))
    TP = np.where((G_true == 1) & (W_est == 1))
    FP = np.where((G_true == 0) & (W_est == 1))
    FN = np.where((G_true == 1) & (W_est == 0))

    G = nx.DiGraph()
    rows_true, cols_true = np.where(G_true > 0)
    G.add_edges_from(zip(rows_true.tolist(), cols_true.tolist()))
    
    pos = nx.spring_layout(G, seed=42)
    nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=500)
    nx.draw_networkx_labels(G, pos)

    edges_tp = list(zip(TP[0], TP[1]))
    nx.draw_networkx_edges(G, pos, edgelist=edges_tp, edge_color='green', arrows=True, arrowsize=20, width=2)

    edges_fp = list(zip(FP[0], FP[1]))
    for e in edges_fp:
        G.add_edge(e[0], e[1])
    nx.draw_networkx_edges(G, pos, edgelist=edges_fp, edge_color='red', arrows=True, arrowsize=20, style='dashed')

    edges_fn = list(zip(FN[0], FN[1]))
    nx.draw_networkx_edges(G, pos, edgelist=edges_fn, edge_color='orange', arrows=True, arrowsize=20, width=3)

    plt.title(title)
    plt.savefig(save_path)
    plt.close()

# Example parameters
d = 30
n = 200
prob = 0.2
graph_type = 'er'
noise_types = 'gauss'
K = 3

save_path = 'experiment_knockoffspy_ko_example'
os.makedirs(save_path, exist_ok=True)

# Create datasets
dags = []
X_list = []
Y_list = []
for k in range(K):
    g_dag, adj_dag = random_dag_generation(d, prob, graph_type)
    X = generate_single_dataset(g_dag, n, noise_types, 1)
    # Define Y as a linear combination of first two variables for demonstration
    Y = X[:,0] + 0.5*X[:,1] + np.random.randn(n)*0.1

    dags.append(adj_dag)
    X_list.append(X)
    Y_list.append(Y)
    np.savetxt(f'{save_path}/W_true_{k}.csv', adj_dag, delimiter=',')
    np.savetxt(f'{save_path}/X_{k}.csv', X, delimiter=',')
    np.savetxt(f'{save_path}/Y_{k}.csv', Y, delimiter=',')

# Knockoff + NO TEARS union approach
S_list = []
for k in range(K):
    S_k = perform_knockoff_filtering(X_list[k], Y_list[k], q=0.1)
    S_list.append(set(S_k))

S_union = set.union(*S_list)
S_union = sorted(list(S_union))

X_union_list = [X[:, S_union] for X in X_list]
X_combined_selected = np.vstack(X_union_list)

start_time_knockoff = time.time()
W_est_restricted = notears_linear(X_combined_selected, lambda1=0.1, loss_type='l2')
end_time_knockoff = time.time()
notears_time_knockoff = end_time_knockoff - start_time_knockoff

W_est_full_selected = np.zeros((d, d))
for i, vi in enumerate(S_union):
    for j, vj in enumerate(S_union):
        W_est_full_selected[vi, vj] = W_est_restricted[i, j]
np.savetxt(f'{save_path}/W_est_union_selected.csv', W_est_full_selected, delimiter=',')

# NO TEARS on full data
X_combined_full = np.vstack(X_list)
start_time_full = time.time()
W_est_full = notears_linear(X_combined_full, lambda1=0.1, loss_type='l2')
end_time_full = time.time()
notears_time_full = end_time_full - start_time_full

np.savetxt(f'{save_path}/W_est_full_no_selection.csv', W_est_full, delimiter=',')

# Evaluate
W_true_union = np.zeros((d, d))
for k in range(K):
    W_true_union = np.maximum(W_true_union, dags[k])

visualize_comparison(W_true_union, W_est_full_selected, os.path.join(save_path, 'compare_union_selected.png'),
                     "Knockoff+NO TEARS (Union Selected) vs True Union")
visualize_comparison(W_true_union, W_est_full, os.path.join(save_path, 'compare_full_no_selection.png'),
                     "NO TEARS (No Selection) vs True Union")

results_selected = {}
for k in range(K):
    acc = count_accuracy(dags[k], W_est_full_selected != 0)
    results_selected[k] = acc

results_full = {}
for k in range(K):
    acc = count_accuracy(dags[k], W_est_full != 0)
    results_full[k] = acc

comparison = {
    "knockoff_selected_results": results_selected,
    "full_no_selection_results": results_full,
    "notears_time_knockoff_selected": notears_time_knockoff,
    "notears_time_full_no_selection": notears_time_full
}

with open(f'{save_path}/comparison.txt', "w") as f:
    f.write(json.dumps(comparison, indent=2))

print("Comparison Results:")
print(json.dumps(comparison, indent=2))