In [1]:
import networkx as nx
import warnings
import numpy as np
import os
from pandas import DataFrame
import csv
import os.path
from causalnex.structure import StructureModel
from causalnex.plots import plot_structure
from numpy import array, save, load
from networkx import to_numpy_matrix
from cdt.causality.graph import CAM
import cdt
from pandas import DataFrame
from numpy import float32
from os import path
warnings.filterwarnings("ignore")  # silence warnings

# cdt.SETTINGS.rpath = os.getenv("RSCRIPT_PATH")  # path to your r executable
cdt.SETTINGS.rpath = 'C:\Program Files\R\R-4.2.1\\bin\Rscript' # path to your r executable



Detecting 1 CUDA device(s).


In [2]:
DATA_DIR = "CLeaR_2023_Dataset"

In [3]:
method_dirs = os.listdir(path=DATA_DIR)
method_dirs

['linear_mechanism', 'mix_mechanism']

# Plot graphs

In [4]:
def generate_plot(dag_path, plot_path, csv_path):
    dag = np.load(dag_path)

    graph = nx.from_numpy_array(dag, create_using=nx.DiGraph)

    e = list(graph.edges())
    causal_nex_graph = StructureModel(e)
    viz = plot_structure(causal_nex_graph)  # Default CausalNex visualisation
    viz.draw(plot_path, format="jpg")

    # check if the file exists
    file_exists = os.path.isfile(csv_path)

    # open the file in append mode
    with open(csv_path, "a", newline="") as csvfile:
        # create a CSV writer object
        writer = csv.writer(csvfile)

        # if the file doesn't exist, write the header row
        if not file_exists:
            writer.writerow(["nodes", "edges"])

        # write the value to the CSV file
        writer.writerow([graph.number_of_nodes(), graph.number_of_edges()])

In [5]:
loops = 0
for meth_dir in method_dirs:
    method_path = os.path.join(DATA_DIR, meth_dir)
    mechanism_dirs = os.listdir(path=method_path)

    for dataset_dir in mechanism_dirs:
        dataset_path = os.path.join(method_path, dataset_dir)
        dag_plot_path = os.path.join(dataset_path, "plot")
        if not os.path.isdir(dag_plot_path):
            os.mkdir(dag_plot_path)

        dag_mech_path = os.path.join(dataset_path, "causal_mechanisms.json")
        dag_confounder_path = os.path.join(dataset_path, "confounder_DAG1.npy")
        dag_path = os.path.join(dataset_path, "DAG1.npy")
        dag_confounder_data_path = os.path.join(dataset_path, "confounder_data1.npy")
        dag_data_path = os.path.join(dataset_path, "data1.npy")
        dag_dataframe_path = os.path.join(dataset_path, "data1.parquet")

        # Plot Paths
        dag_confounder_plot_path = os.path.join(dag_plot_path, "confounder_plot.jpg")
        dag_plot_path = os.path.join(dag_plot_path, "plot.jpg")
        csv_path = os.path.join(dataset_path, "details.csv")
        if path.exists(dag_confounder_path):
        # Plot DAGs
            # print(pd.read_parquet(dag_dataframe_path))
            generate_plot(dag_path, dag_plot_path, csv_path)
            generate_plot(dag_confounder_path, dag_confounder_plot_path, csv_path)


# Run CAM on data

In [6]:
def run_cam(data: array, output_path: str):
    print("=================")
    print("Running CAM: ", output_path)
    print("=================")
    cam_result_dir = os.path.join(output_path, "cam")
    cam_result_path = os.path.join(cam_result_dir, "result.npy")
    if not os.path.isdir(cam_result_dir):
        os.mkdir(cam_result_dir)
    obj = CAM()
    df = DataFrame(data).astype(float32)
    output = obj.predict(df)
    pred = to_numpy_matrix(output)
    save(cam_result_path, pred)

In [7]:
data_mix_mechanism_path = os.path.join(DATA_DIR, "mix_mechanism")
dag_path = os.path.join(data_mix_mechanism_path, "small_mixed_all_issues_1")
dag_data_path = os.path.join(dag_path, "data1.npy")
data = load(dag_data_path)
data

array([[ 0.13430775,  0.64757251, -0.33689519, ..., -1.6386685 ,
         0.50566282,  0.13430775],
       [-0.22075541, -0.31490199,  1.71283693, ..., -0.27932448,
        -1.16027789, -0.22075541],
       [ 0.38953108, -0.74322304,  1.52297946, ...,  0.80036491,
         1.43703873,  0.38953108],
       ...,
       [ 0.93526312, -1.7625656 ,  0.34059972, ...,  1.00118443,
         0.56828119,  0.93526312],
       [ 0.35476308,  0.8611057 , -0.70068245, ..., -0.40309276,
         1.69394289,  0.35476308],
       [ 0.97727662, -0.58569472, -0.07302278, ...,  0.89481651,
         2.80531204,  0.97727662]])

In [8]:
run_cam(data, "")

Running CAM:  
