In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax
import os
import pandas as pd
import math

from scipy.special import logsumexp
from pathlib import Path
from tqdm.auto import tqdm, trange

In [2]:
def get_transitive_closure(adjacency):
    # Warshall's algorithm
    def scan_fun(closure, i):
        outer_product = jnp.outer(closure[:, i], closure[i])
        return (jnp.logical_or(closure, outer_product), None)
    
    adjacency = adjacency.astype(jnp.bool_)
    arange = jnp.arange(adjacency.shape[0])
    closure, _ = jax.lax.scan(scan_fun, adjacency, arange)

    return closure

In [3]:
def path_log_marginal(dags_compressed, log_joint, num_variables, batch_size=1, verbose=True):
    @jax.jit
    def marginalize(log_probs, adjacencies_compressed):
        adjacencies = jnp.unpackbits(adjacencies_compressed, axis=1, count=num_variables ** 2)
        adjacencies = adjacencies.reshape(-1, num_variables, num_variables)
        closures = jax.vmap(get_transitive_closure)(adjacencies)
        log_probs = jnp.where(closures, log_probs[:, None, None], -jnp.inf)
        return jax.nn.logsumexp(log_probs, axis=0)

    num_dags = dags_compressed.shape[0]
    log_marginal = []
    for i in trange(0, num_dags, batch_size, disable=(not verbose)):
        # Get a batch of data
        batch_compressed = dags_compressed[i:i + batch_size]
        log_probs = log_joint[i:i + batch_size]

        log_marginal.append(marginalize(log_probs, batch_compressed))

    log_marginal = np.stack(log_marginal, axis=0)
    log_marginal = logsumexp(log_marginal, axis=0)
    return log_marginal.reshape(num_variables, num_variables)

In [4]:
root = Path(os.getenv('SLURM_TMPDIR'))

In [5]:
dags_compressed = np.load(root / 'dags_7_final.npy')

In [6]:
dags_compressed.shape

(1138779265, 7)

In [7]:
datasets = ['ell', 'len', 'spr']
dfs, observations = {}, {}

for dataset in datasets:
    dfs[dataset] = pd.read_csv(f'data/causal_BH_{dataset}.csv')
    dfs[dataset] = dfs[dataset].apply(lambda col: col - col.mean(), axis=0)  # Center data
    observations[dataset] = np.asarray(dfs[dataset])

In [None]:
log_path_marginals = {}
for dataset in datasets:
    log_joint = np.load(root / 'causalbh' / f'log_joint_{dataset}.npy')
    log_path_marginals[dataset] = path_log_marginal(dags_compressed,
        log_joint, observations[dataset].shape[1], batch_size=1024)

    with open(root / 'causalbh' / f'log_path_marginal_{dataset}.npy', 'wb') as f:
        np.save(f, log_path_marginals[dataset])

  0%|          | 0/1112090 [00:00<?, ?it/s]