## Stable Differentiable Causal Discovery (SDCD) Tutorial

In this brief tutorial, we will go over the main API for the `sdcd` package. SDCD is a differentiable causal discovery method designed to scale reliably to hundreds and thousands of variables. The algorithm is implemented in vanilla PyTorch and also uses the `networkx` package. 

(Note, under the default pip installation, only the packages necessary to run the `SDCD` model are installed. If you would also like to run other causal discovery methods implemented in this package, you must also install the `benchmark` extra via the `pip install sdcd[benchmark]` command.)

In [1]:
# Imports
from sdcd.models import SDCD
from sdcd.utils import create_intervention_dataset

We will start by simulating some data for the tutorial. If you have your own data, you can ignore this step.

In [2]:
# Simulate Data
from sdcd.simulated_data import random_model_gaussian_global_variance # For demonstration

n = 200
n_per_intervention = 50
d = 20
n_edges = 20

true_causal_model = random_model_gaussian_global_variance(
    d,
    n_edges,
    dag_type="ER",
    scale=0.5,
    hard=True,
)
X_df = true_causal_model.generate_dataframe_from_all_distributions(
    n_samples_control=n,
    n_samples_per_intervention=n_per_intervention,
)
X_df.iloc[:, :-1] = (X_df.iloc[:, :-1] - X_df.iloc[:, :-1].mean()) / X_df.iloc[
    :, :-1
].std() # Normalize the data

The input data should be formatted as a Pandas Dataframe where each row corresponds to one observation and each column corresponds to one variable. There should be an additional column reports which variable(s) was intervened on for the given observation. Here it is labeled as `perturbation_label`. For rows that do not have any interventions, the value should be set to `"obs"`.

In [3]:
X_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,perturbation_label
0,0.871384,-1.298243,-0.541056,-0.839735,0.385079,-0.172131,0.417135,0.439012,1.084618,-0.942723,...,-1.29324,-0.20941,-1.180888,-0.450958,2.148383,3.076456,-1.035478,-1.569835,0.8642,obs
1,1.21953,-1.217384,-0.719318,0.067425,0.049899,1.223961,-0.048282,-0.314315,-0.728916,-0.923227,...,-1.319596,-0.001921,-1.191708,0.904951,1.143647,-0.79736,-1.061556,0.214472,-0.110571,obs
2,-1.060095,-0.105745,0.132579,-0.561208,-0.673707,1.175806,1.349005,1.170219,-0.030465,-0.292935,...,-0.647847,-0.284626,-0.109206,-0.166495,-0.533553,-1.087932,-0.174844,-0.665772,-0.233534,obs
3,0.386371,-0.478593,-0.286767,-0.131907,-0.096226,0.747864,-0.326589,-0.49389,0.80401,-0.317977,...,-0.308801,-0.133661,-0.283226,-0.407182,0.38518,-0.711048,-0.355081,0.509291,0.335167,obs
4,-0.459966,1.277083,1.056286,-0.432633,0.393007,-1.427799,-0.608757,-0.218208,1.930627,0.540725,...,1.256948,-0.097415,0.927841,-1.633421,1.263645,0.780632,0.906074,-0.496829,1.064864,obs


Now, to construct a torch dataset with the appropriate tensors, we use the utility function `sdcd.utils.create_intervention_dataset`.
 
(Note, we also have a function `sdcd.utils.create_dataset_anndata` for users working with the AnnData data format.)

In [4]:
X_dataset = create_intervention_dataset(X_df, perturbation_colname="perturbation_label")
X_dataset

<torch.utils.data.dataset.TensorDataset at 0x28fe46bb0>

Now, we are ready to train the SDCD model. Here, we will set `finetune=True` which runs the algorithm until the final adjacency matrix is estimated and binarized, then proceeds with a final stage of training the model until it converges for the fixed, discretized adjacency matrix. If you only care about the predicted graph, you may set this to `False` to shorten the runtime.

In [5]:
model = SDCD()
model.train(X_dataset, finetune=True)

Epoch 0: loss=27.59, gamma=0.00
Epoch 100: loss=7.35, gamma=0.00
Epoch 200: loss=4.89, gamma=0.00
Epoch 300: loss=3.45, gamma=0.00
Epoch 400: loss=2.71, gamma=0.00
Epoch 500: loss=2.15, gamma=0.00
Epoch 600: loss=1.88, gamma=0.00
Epoch 700: loss=1.62, gamma=0.00
Epoch 800: loss=1.48, gamma=0.00
Epoch 900: loss=1.28, gamma=0.00
Epoch 1000: loss=1.16, gamma=0.00
Epoch 1100: loss=1.03, gamma=0.00
Epoch 1200: loss=0.93, gamma=0.00
Epoch 1300: loss=0.83, gamma=0.00
Epoch 1400: loss=0.72, gamma=0.00
Epoch 1500: loss=0.66, gamma=0.00
Epoch 1600: loss=0.57, gamma=0.00
Epoch 1700: loss=0.49, gamma=0.00
Epoch 1800: loss=0.44, gamma=0.00
Epoch 1900: loss=0.36, gamma=0.00
Fraction of possible edges in mask: 0.1775
Epoch 0: loss=28.14, gamma=0.00
Epoch 100: loss=14.32, gamma=0.50
Epoch 200: loss=10.18, gamma=1.00
Epoch 300: loss=10.33, gamma=1.50
Epoch 400: loss=10.56, gamma=2.00
Epoch 500: loss=10.73, gamma=2.50
Epoch 600: loss=11.13, gamma=3.00
Epoch 700: loss=11.63, gamma=3.50
Epoch 800: loss=12

Lastly, we can recover the predicted adjacency matrix, both before thresholding and after thresholding, and compute likelihoods respect to a given observation.

In [6]:
adj_matrix = model.get_adjacency_matrix(threshold=True)
print(adj_matrix)

[[0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [1 0 1 0 1 1 0 0 0 1 0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1]
 [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0]]


In [7]:
adj_matrix = model.get_adjacency_matrix(threshold=False)
print(adj_matrix)

[[0.00000000e+00 0.00000000e+00 0.00000000e+00 1.82328773e+00
  0.00000000e+00 0.00000000e+00 2.19770358e-03 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 2.68300413e-04 0.00000000e+00 4.15186951e-04
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 6.29337847e-01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 2.77099991e+00
  0.00000000e+00 2.14771581e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 2.14851931e-01 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 3.13774276e+00
  0.00000000e+00 0.00000000e+00 3.80197715e-04 0.00000000e+00
  0.00000000e+00 2.43971852e-04 1.91857851e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 3.68017762e-04
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  3.6

In [8]:
model.compute_nll(X_dataset) # Reports average negative log-likelihood

7.453802591959636

This concludes the tutorial. Please raise an issue on the Github repo if you come across any issues, and we hope you find this package useful to your research!