In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import os
import pandas as pd
import math
import seaborn as sns

from scipy.special import logsumexp
from pathlib import Path
from tqdm.auto import tqdm, trange

from bge_score_jax import BGe

In [None]:
# load all possible DAGs for n nodes, generated by "generate_all_dags.py"
dags_compressed = np.load('your_folder/dags_7.npy')
# verify the length a bit, should match to 1, 3, 25, 543, 29281, 3781503, 1138779265, 783702329343 for n nodes
dags_compressed.shape

In [None]:
# read in your data
dfs = pd.read_csv('your_data.csv')
dfs = dfs.apply(lambda col: ( col - col.mean() ) / col.std(), axis=0)  # Standardize data
observations = np.asarray(dfs)

In [None]:
# define a few functions
def compute_exact_posterior(dags_compressed, observations, batch_size=1, verbose=True):
    num_variables = observations.shape[1]
    model = BGe(num_variables=num_variables)

    @jax.jit
    def log_prob(observations, adjacencies_compressed):
        adjacencies = jnp.unpackbits(adjacencies_compressed, axis=1, count=num_variables ** 2)
        adjacencies = adjacencies.reshape(-1, num_variables, num_variables)

        v_log_prob = jax.vmap(model.log_prob, in_axes=(None, 0))
        log_probs = v_log_prob(observations, adjacencies)
        return jnp.sum(log_probs, axis=1)

    num_dags = dags_compressed.shape[0]
    log_probs = np.zeros((num_dags,), dtype=np.float32)
    for i in trange(0, num_dags, batch_size, disable=(not verbose)):
        # Get a batch of (compressed) DAGs
        batch_compressed = dags_compressed[i:i + batch_size]

        # Compute the BGe scores
        log_probs[i:i + batch_size] = log_prob(observations, batch_compressed)

    # Normalize the log-marginal probabilities
    log_probs = log_probs - logsumexp(log_probs)
    return log_probs

def edge_log_marginal(dags_compressed, log_joint, num_variables, batch_size=1, verbose=True):
    @jax.jit
    def marginalize(log_probs, adjacencies_compressed):
        adjacencies = jnp.unpackbits(adjacencies_compressed, axis=1, count=num_variables ** 2)
        log_probs = jnp.where(adjacencies == 1, log_probs[:, None], -jnp.inf)
        return jax.nn.logsumexp(log_probs, axis=0)

    num_dags = dags_compressed.shape[0]
    log_marginal = []
    for i in trange(0, num_dags, batch_size, disable=(not verbose)):
        # Get a batch of data
        batch_compressed = dags_compressed[i:i + batch_size]
        log_probs = log_joint[i:i + batch_size]

        log_marginal.append(marginalize(log_probs, batch_compressed))

    log_marginal = np.stack(log_marginal, axis=0)
    log_marginal = logsumexp(log_marginal, axis=0)
    return log_marginal.reshape(num_variables, num_variables)

def get_transitive_closure(adjacency):
    # Warshall's algorithm
    def scan_fun(closure, i):
        outer_product = jnp.outer(closure[:, i], closure[i])
        return (jnp.logical_or(closure, outer_product), None)
    
    adjacency = adjacency.astype(jnp.bool_)
    arange = jnp.arange(adjacency.shape[0])
    closure, _ = jax.lax.scan(scan_fun, adjacency, arange)

    return closure

def path_log_marginal(dags_compressed, log_joint, num_variables, batch_size=1, verbose=True):
    @jax.jit
    def marginalize(log_probs, adjacencies_compressed):
        adjacencies = jnp.unpackbits(adjacencies_compressed, axis=1, count=num_variables ** 2)
        adjacencies = adjacencies.reshape(-1, num_variables, num_variables)
        closures = jax.vmap(get_transitive_closure)(adjacencies)
        log_probs = jnp.where(closures, log_probs[:, None, None], -jnp.inf)
        return jax.nn.logsumexp(log_probs, axis=0)

    num_dags = dags_compressed.shape[0]
    log_marginal = []
    for i in trange(0, num_dags, batch_size, disable=(not verbose)):
        # Get a batch of data
        batch_compressed = dags_compressed[i:i + batch_size]
        log_probs = log_joint[i:i + batch_size]

        log_marginal.append(marginalize(log_probs, batch_compressed))

    log_marginal = np.stack(log_marginal, axis=0)
    log_marginal = logsumexp(log_marginal, axis=0)
    return log_marginal.reshape(num_variables, num_variables)

In [None]:
root = Path('your_output_folder/exact_posteriors')

# compute exact posteriors
log_joint = compute_exact_posterior(dags_compressed, observations, batch_size=2048)

# compute edge marginals
log_edge_marginals = edge_log_marginal(dags_compressed,
    log_joint, observations.shape[1], batch_size=4096)

with open(root / f'log_joint.npy', 'wb') as f:
    np.save(f, log_joint)

with open(root / f'log_edge_marginal.npy', 'wb') as f:
    np.save(f, log_edge_marginals)

In [None]:
# plot edge marginals
plt.figure(figsize=(6, 5))

log_edge_marginal = np.load(root / f'log_edge_marginal.npy')
edge_marginal = pd.DataFrame(np.exp(log_edge_marginal),
    index=dfs.columns, columns=dfs.columns)

sns.heatmap(edge_marginal, cmap='gray',
            annot=edge_marginal, fmt='.2f', cbar=False)
plt.show()

In [None]:
# compute path marginals
log_joint = np.load(root / f'log_joint.npy')
log_path_marginals = path_log_marginal(dags_compressed,
    log_joint, observations.shape[1], batch_size=1024)

with open(root / f'log_path_marginal.npy', 'wb') as f:
    np.save(f, log_path_marginals)

In [None]:
# plot path marginals
plt.figure(figsize=(6, 5))

log_path_marginal = np.load(root / f'log_path_marginal.npy')
path_marginal = pd.DataFrame(np.exp(log_path_marginal),
    index=dfs.columns, columns=dfs.columns)

sns.heatmap(path_marginal, cmap='gray',
            annot=path_marginal, fmt='.2f', cbar=False)
plt.show()