### Sachs data

[Sachs Protein Data](https://perso.univ-rennes1.fr/valerie.monbet/GM/Sachs.html)

In [1]:
#pip install ../. 
#pip install cdt
from cdt.data import load_dataset
import networkx as nx
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import pandas as pd

print(f"JAX backend: {jax.default_backend()}")
key=random.PRNGKey(343)

def permute_random_rows(key, X, p=0.5):
    key, subk = random.split(key)
    permutation = random.permutation(subk, jnp.arange(X.shape[1]))
    
    key, subk = random.split(key)
    permute = random.bernoulli(subk, p=jnp.float32(p), shape=(X.shape[0],))

    [X := X.at[i,:].set(X[i,permutation]) for i in range(X.shape[0]) if permute[i]]
    
    return X, permute, permutation

def permute_directed_graph(graph, permutation):
    g, perm, perm_g = graph, permutation, jnp.zeros_like(graph)
    for i in range(g.shape[0]):
        for j in range(g.shape[1]):
                perm_g = perm_g.at[i,j].set(g[perm[i],perm[j]])
                
    return perm_g

No GPU automatically detected. Setting SETTINGS.GPU to 0, and SETTINGS.NJOBS to cpu_count.
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
I0000 00:00:1737034938.771853    1799 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


JAX backend: cpu


In [2]:
# Load the graph and data
sachs_data, sachs_graph = load_dataset("sachs")

adjacency = jnp.array(nx.to_numpy_array(sachs_graph))
print('Adjacency matrix: \n', adjacency)

data = jnp.array(sachs_data.values)[:,:]
print('Data shape: \n', data.shape)

Adjacency matrix: 
 [[0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0.]
 [1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Data shape: 
 (7466, 11)


In [3]:
# Process data
key, subk = random.split(key)
X, indicator, permutation = permute_random_rows(subk, data, p=0.5)
#X = jax.nn.standardize(X, axis=0)
print('Processed data shape: \n', X.shape)
print('Permutation: \n', permutation)

ground_truth_graphs = [adjacency, permute_directed_graph(adjacency, permutation)]
print('Ground truth graphs: \n', ground_truth_graphs)

Processed data shape: 
 (7466, 11)
Permutation: 
 [ 2  4  9  6  7  8 10  5  3  0  1]
Ground truth graphs: 
 [Array([[0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 0., 1., 1., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32), Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.

In [4]:
# Unpermuted data set
pd.DataFrame(X[~indicator,:]).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
count,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0,3706.0
mean,126.226578,149.241562,51.623028,145.906372,27.600828,26.573629,79.875435,636.472717,30.360804,132.743011,72.06279
std,257.605286,391.166779,149.812042,266.308319,43.727592,50.692448,139.785034,672.737854,99.237152,511.181976,221.737915
min,1.0,1.0,1.0,1.0,1.0,1.0,1.01,1.0,1.0,1.0,1.0
25%,30.799999,16.4,9.39,17.9,9.47,8.51,23.299999,274.0,4.18,19.299999,7.84
50%,53.799999,26.4,16.4,51.900002,17.799999,17.200001,37.200001,449.0,12.6,30.200001,18.1
75%,103.0,65.949999,26.9,172.0,33.700001,32.5,72.300003,757.0,23.45,50.0,50.0
max,3820.0,5829.0,2267.0,3619.0,1084.0,2571.0,3555.0,8896.0,1611.0,7499.0,4740.0


In [5]:
# Permuted data set
pd.DataFrame(X[indicator,:]).describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10
count,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0,3760.0
mean,58.037865,26.477222,137.253372,82.440437,615.198242,30.322784,74.45491,26.687933,156.260208,121.948219,141.575821
std,194.643814,42.366215,478.097717,135.753296,615.223389,86.14949,209.520645,40.469086,328.633026,237.192398,362.621826
min,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
25%,9.47,9.65,19.299999,23.299999,279.0,4.73,8.2,8.51,18.799999,31.1,16.700001
50%,16.700001,17.6,30.799999,37.200001,449.0,13.05,18.799999,17.299999,54.200001,53.799999,26.9
75%,27.4,31.9,49.599998,72.300003,750.0,23.5,54.700001,31.9,170.0,103.0,62.599998
max,6208.0,1275.0,6378.0,1715.0,8354.0,1155.0,3398.0,1715.0,9058.0,4614.0,7105.0


### Create VaMSL model

In [6]:
from vamsl.models.graph import ErdosReniDAGDistribution, ScaleFreeDAGDistribution, UniformDAGDistributionRejection
from vamsl.models import MixtureLinearGaussian, MixtureDenseNonlinearGaussian
from vamsl.models import LinearGaussian, DenseNonlinearGaussian
from vamsl.target import make_graph_model

# BN settings
n_vars = X.shape[1] # number of variables in each component BN
struct_eq_type = 'linear' # BN function class: 'linear' or 'nonlinear'
graph_type = 'er' # Random graph structure: 'sf' (scale-free) or 'er' (Erdos-Renyi)

# Derived variables
n_components = len(jnp.unique(indicator))
linear = True if struct_eq_type == 'linear' else False

# Model specification 
graph_model = make_graph_model(n_vars=n_vars, graph_prior_str=graph_type, edges_per_node=2)
lik = MixtureLinearGaussian(n_vars=n_vars, obs_noise=0.1)  
component_lik = LinearGaussian(n_vars=n_vars, obs_noise=0.1)#, hidden_layers=(5,))

In [7]:
from vamsl.inference import VaMSL

# Create VaMSL and initialize posteriors
vamsl = VaMSL(x=X, graph_model=graph_model, mixture_likelihood_model=lik, component_likelihood_model=component_lik)
key, subk = random.split(key)
vamsl.initialize_posteriors(key=subk, n_components=n_components, n_particles=7, linear=linear)

print('Posterior shapes:')
posts = vamsl.get_posteriors()
print('q_z:     ' + str(posts[0].shape)) # [n_components, n_particls, d, l, 2]
print('q_theta: ' + str(posts[1].shape)) if linear else print('q_theta: ' + str(len(posts[1]))) # leading dim of n_components
print('log_q_c: ' + str(posts[2].shape)) # [n_observations, n_components]
print('q_pi:    ' + str(posts[3].shape)) # [n_components,]

Posterior shapes:
q_z:     (2, 7, 11, 11, 2)
q_theta: (2, 7, 11, 11)
log_q_c: (7466, 2)
q_pi:    (2,)


In [13]:
from sklearn.metrics import classification_report, confusion_matrix
# CAVI and SVGD vars
n_cavi_updates, steps = 0, 100

# CAVI-loop
for cavi_update in range(n_cavi_updates):
    key, subk = random.split(key)
    # Optimize q(Z, \Theta)
    vamsl.update_particle_posteriors(key=subk, steps=steps, callback_every=steps,
                                     callback=vamsl.visualize_callback(), linear=linear)

    # Update to optimal q(c) and q(\pi)
    vamsl.update_responsibilities_and_weights()
    print(f'CAVI update number {cavi_update+1}/{n_cavi_updates}')
    
    # Print current clustering
    order = vamsl.identify_MAP_classification_ordering(ground_truth_indicator=indicator)
    y_pred = [order[k] for k in [jnp.argmax(c_i) for c_i in vamsl.get_posteriors()[2]]]
    print('MAP clustering: \n', confusion_matrix(indicator, y_pred))
    
# Final CAVI update with more SVGD steps to ensure annealing unto acyclic graphs
key, subk = random.split(key)
#vamsl.update_particle_posteriors(key=subk, steps=1000, callback_every=200, callback=vamsl.visualize_callback(), linear=linear)
vamsl.update_responsibilities_and_weights()

In [14]:
# Compute optimal ordering with respect to MAP classification accuracy
order = vamsl.identify_MAP_classification_ordering(ground_truth_indicator=indicator)
print('Optimal order:')
print(order)

Optimal order:
[0 1]


In [15]:
from sklearn.metrics import classification_report, confusion_matrix

print('Sums of responsibilities:')
print(jnp.sum(jnp.exp(vamsl.get_posteriors()[2]), axis=0))
print('Sum of entropy of responsibilities:')
print(jnp.sum(jnp.exp(vamsl.get_posteriors()[2])*vamsl.get_posteriors()[2]))
y_true, order = indicator, vamsl.identify_MAP_classification_ordering(ground_truth_indicator=indicator)
y_pred = [order[k] for k in [jnp.argmax(c_i) for c_i in vamsl.get_posteriors()[2]]]
if n_components==2:
    print('Classification report:')
    print(classification_report(y_true=y_true, y_pred=y_pred, target_names=['Component 1', 'Component 2']))
confusion_matrix(y_true, y_pred)

Sums of responsibilities:
[3733. 3733.]
Sum of entropy of responsibilities:
0.0
Classification report:
              precision    recall  f1-score   support

 Component 1       0.99      1.00      0.99      3706
 Component 2       1.00      0.99      0.99      3760

    accuracy                           0.99      7466
   macro avg       0.99      0.99      0.99      7466
weighted avg       0.99      0.99      0.99      7466



array([[3700,    6],
       [  33, 3727]])

In [10]:
from sklearn.metrics import classification_report, confusion_matrix

print('Sums of responsibilities:')
print(jnp.sum(jnp.exp(vamsl.get_posteriors()[2]), axis=0))
print('Sum of entropy of responsibilities:')
print(jnp.sum(jnp.exp(vamsl.get_posteriors()[2])*vamsl.get_posteriors()[2]))
y_true, order = indicator, vamsl.identify_MAP_classification_ordering(ground_truth_indicator=indicator)
y_pred = [order[k] for k in [jnp.argmax(c_i) for c_i in vamsl.get_posteriors()[2]]]
if n_components==2:
    print('Classification report:')
    print(classification_report(y_true=y_true, y_pred=y_pred, target_names=['Component 1', 'Component 2']))
confusion_matrix(y_true, y_pred)

Sums of responsibilities:
[522. 478.]
Sum of entropy of responsibilities:
0.0
Classification report:
              precision    recall  f1-score   support

 Component 1       0.98      0.96      0.97       489
 Component 2       0.96      0.98      0.97       511

    accuracy                           0.97      1000
   macro avg       0.97      0.97      0.97      1000
weighted avg       0.97      0.97      0.97      1000



array([[470,  19],
       [  8, 503]])

In [16]:
from vamsl.metrics import expected_shd, threshold_metrics, neg_ave_log_likelihood

# Get component datasets
datas = [X[(indicator==k).flatten(),:] for k in range(n_components)]

# Loop over components and calculate metrics
for k, data, q_z_k, q_theta_k in zip(range(n_components), datas, vamsl.get_posteriors()[0], vamsl.get_posteriors()[1]):
    # Get particle distribution for component
    q_g_k = vamsl.particle_to_g_lim(q_z_k, vamsl.get_E()[order[k]])
    dist = vamsl.get_empirical(q_g_k, q_theta_k)
    
    # Calculate metrics
    eshd = expected_shd(dist=dist, g=ground_truth_graphs[order[k]])       
    auroc = threshold_metrics(dist=dist, g=ground_truth_graphs[order[k]])['roc_auc']
    #negll = neg_ave_log_likelihood(dist=dist, eltwise_log_likelihood=vamsl.eltwise_component_log_likelihood_observ, x=data.x_ho)
    
    print(f' Component {k+1:4d} |  E-SHD: {eshd:4.1f}    AUROC: {auroc:5.2f}')#    neg. LL {negll:5.2f}')

 Component    1 |  E-SHD: 29.0    AUROC:  0.55
 Component    2 |  E-SHD: 31.5    AUROC:  0.48


### Performance of DiBS

In [10]:
from vamsl.metrics import expected_shd, threshold_metrics, neg_ave_log_likelihood

# Get component datasets
datas = [X[(indicator==k).flatten(),:] for k in range(n_components)]

# Loop over components and calculate metrics
for k, data, q_z_k, q_theta_k in zip(range(n_components), datas, vamsl.get_posteriors()[0], vamsl.get_posteriors()[1]):
    # Get particle distribution for component
    q_g_k = vamsl.particle_to_g_lim(q_z_k, vamsl.get_E()[order[k]])
    dist = vamsl.get_empirical(q_g_k, q_theta_k)
    
    # Calculate metrics
    eshd = expected_shd(dist=dist, g=ground_truth_graphs[order[k]])       
    auroc = threshold_metrics(dist=dist, g=ground_truth_graphs[order[k]])['roc_auc']
    #negll = neg_ave_log_likelihood(dist=dist, eltwise_log_likelihood=vamsl.eltwise_component_log_likelihood_observ, x=data.x_ho)
    
    print(f' Component {k+1:4d} |  E-SHD: {eshd:4.1f}    AUROC: {auroc:5.2f}')#    neg. LL {negll:5.2f}')

 Component    1 |  E-SHD: 21.0    AUROC:  0.54
