In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.transforms import (
    MonotonicAffineTransform,
    MonotonicRQSTransform,
)
from zuko.flows import UnconditionalDistribution
from torch.distributions import Cauchy, Normal, Laplace, Bernoulli, Uniform
from causalflows.flows import CausalFlow
from causal_cocycle.causalflow_helper import select_and_train_flow
import copy
import numpy as np
from scipy.stats import betaprime, norm
from architectures import get_nsf_transforms, get_maf_transforms

In [2]:
# ── Data ─────────────────────────────────────────────────────
seed = 0 
torch.manual_seed(seed)
N_train = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

B = torch.cat([torch.ones(N_train // 2), torch.zeros(N_train // 2)])
Y1 = 2*B
Y0 = B

# ── Build base and transforms ───────────────────────────────
d = 1   # dim(Y)
# two bases (here both Gaussian; you could swap in Laplace etc.)
base0 = UnconditionalDistribution(Laplace, loc=torch.zeros(d), scale=torch.ones(d), buffer=True)
base1 = UnconditionalDistribution(Laplace, loc=torch.zeros(d), scale=torch.ones(d), buffer=True)

# transform
nsf0,nsf1 = get_nsf_transforms()
maf0,maf1 = get_maf_transforms()

# instantiate two flows
flows0 = [CausalFlow(transform=maf0, base=base0).to(device),
          CausalFlow(transform=nsf0, base=base0).to(device)]
flows1 = [copy.deepcopy(CausalFlow(transform=maf1, base=base1).to(device)),
          copy.deepcopy(CausalFlow(transform=nsf1, base=base1).to(device))]

In [3]:

# CV + retrain across all transforms
flow0, test_nll0, best_idx0, cv_scores0 = select_and_train_flow(
    flows0, Y0.unsqueeze(-1), train_fraction=0.5, k_folds=2,
    num_epochs=1000, batch_size=128, lr=1e-2, 
)

# CV + retrain across all transforms
flow1, test_nll1, best_idx1, cv_scores1 = select_and_train_flow(
    flows1, Y1.unsqueeze(-1), train_fraction=0.5, k_folds=2,
    num_epochs=1000, batch_size=128, lr=1e-2, 
)

In [4]:
best_architecture = np.argmin(np.array(cv_scores1) + np.array(cv_scores0))
print(best_architecture)

1
