## Short example how to use pretrained models to perform Causal Discovery

In [25]:
import pandas as pd
import torch
import torch.multiprocessing
from helpers.tools import load_river_data
from helpers.tools import lagged_batch_corr
from model.model_wrapper import Architecture_PL
import torchmetrics

In [60]:
# Example data (Rivers)
# Download it according to readme.md
data = load_river_data()


# Alternatively you can use random data to just check the input/output relation
#data = (torch.rand((1,600,3)), torch.zeros((1,3,3)))

In [61]:
# Pretrained weights that we selected from experiments with synthetic data.
path = "pretrained_weights"
mlp = path + "/mlp.ckpt"
uni = path + "/unidirectional.ckpt"
bi = path + "/bidirectional.ckpt"
conv =path + "/convMixer.ckpt"
trf = path + "/transformer.ckpt"
best = [mlp, uni, bi, conv, trf]

In [62]:
# Threshold free metric
auroc = torchmetrics.classification.BinaryAUROC()

In [None]:

# As the networks are trained with 5 variables as input we pad samples with less than 5 variables.
X = torch.concat(
    [data[0][0, :, :], torch.normal(0, 0.1, (len(data[0][0]), 2))], axis=1)
X = X.unsqueeze(0)

# For some models, lagged batch correlation is required (as an input. Generate it here.)
corr = lagged_batch_corr(X, 3)

# Networks only take a maximum of 600. Here we simply take the first 600 steps. 
# Alternatively one could also attempt to weigh multiple windows of the same time series.
X = X[:,:600,:]

In [64]:
preds = []
for x in best:
    # Load specific model from checkpoint
    model = Architecture_PL.load_from_checkpoint(x)
    M = model.model
    M = M.to("cpu")
    M = M.eval()
    # Run model and transform to proba:
    Y = M((X, corr))
    Y = torch.sigmoid(Y)
    # Remove the predictions for padded time series. 
    Y = Y[0, :3, :3]
    # As the labels in this case do not specify a lag, we can also reduce the lag dimension.
    Y = Y[:, :,-1]
    preds.append(Y)

In [66]:
# Now we can take the output as probabilities for a certrain link to exist. 
print(preds[1] > 0.05)

# Or Calculate AUROC
auroc(Y, data[1])

tensor([[ True,  True, False],
        [ True,  True,  True],
        [False,  True,  True]])


tensor(1.0000)