In [None]:
import os
import argparse
import json
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from typing import *
from dibs.utils import visualize_ground_truth
from dibs.models import ErdosReniDAGDistribution, ScaleFreeDAGDistribution, BGe
from dibs.inference import JointDiBS, MarginalDiBS
from dibs.graph_utils import elwise_acyclic_constr_nograd
from jax.scipy.special import logsumexp
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".9"

import myUtils
from myUtils import *



In [None]:
model_name = "4"

ling_cols = json.load(
    open(
    )
)
original_df = pd.read_csv(
)

cm_cols = [
    "black",
    "syntaxError_rate",
    "sta_codeBleu",
    "sta_Bleu",
    "sim_codeBleu",
    "sim_Bleu",
    "pass_rate",
    "error_rate",
    "timeout_rate",
]

selected_cols = ling_cols + cm_cols
selected_df = original_df[selected_cols]

In [None]:

def matrix_to_dgraph(matrix: np.ndarray, columns: List[str], threshold: float = 1.0) -> List[str]:
    dgraph = []
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] >= threshold:
                dgraph.append(f"{columns[i]} -> {columns[j]}")
    return dgraph


def compute_expected_graph(*, dist):
    """
    Computes expected graph 

    Args:
        dist (:class:`dibs.metrics.ParticleDistribution`): particle distribution
    Returns: 
        expected Graph 
    """
    n_vars = dist.g.shape[1]

    # select acyclic graphs
    is_dag = elwise_acyclic_constr_nograd(dist.g, n_vars) == 0
    assert is_dag.sum() > 0,  "No acyclic graphs found"

    particles = dist.g[is_dag, :, :]
    log_weights = dist.logp[is_dag] - logsumexp(dist.logp[is_dag])

    # compute expected graph
    expected_g = jnp.zeros_like(particles[0])
    for i in range(particles.shape[0]):
        expected_g += jnp.exp(log_weights[i]) * particles[i, :, :]

    return expected_g



In [None]:
rand_key = jax.random.PRNGKey(0)

In [None]:
collected_df = selected_df.copy()
collected_df = collected_df.loc[:, collected_df.var() > 1e-5]
collected_df.replace([np.inf, -np.inf], np.nan, inplace=True)
collected_df.dropna(inplace=True)
collected_df = collected_df.sample(frac=1).reset_index(drop=True)
print(f"Collected data shape: {collected_df.shape}")
print(f"Collected data columns: {collected_df.columns}")

scaler = StandardScaler()
collected_data = scaler.fit_transform(collected_df)

model_graph = ScaleFreeDAGDistribution(collected_data.shape[1], n_edges_per_node=5)
# model_graph = ErdosReniDAGDistribution(collected_data.shape[1], n_edges_per_node=5)
model = BGe(graph_dist=model_graph)
dibs = MarginalDiBS(x=collected_data, interv_mask=None, inference_model=model)


In [None]:
rand_key, subk = jax.random.split(rand_key)
# steps = 13000
# gs = dibs.sample(key=subk, n_particles=50, steps=13000, callback_every=1000, callback=dibs.visualize_callback())
gs = dibs.sample(key=subk, n_particles=10, steps=4000, callback_every=500, callback=dibs.visualize_callback())

In [None]:
dibs_output = dibs.get_mixture(gs)
# dibs_output = dibs.get_empirical(gs)
expected_g = compute_expected_graph(dist=dibs_output)


In [None]:
dibs_output.logp

In [None]:
visualize_ground_truth(jnp.array(expected_g), )

In [None]:
dgraph = matrix_to_dgraph(expected_g, collected_df.columns, threshold=0.1)
print(len(dgraph))
for line in dgraph:
    print(line)
