# Independent vs. Non-independent performance testing

In [1]:
import torch
from torch.distributions import Distribution
import torch.nn as nn
from torch.distributions import Normal, Laplace, Bernoulli, Gamma,Cauchy
from causalflows.flows import CausalFlow
from typing import Callable
from zuko.flows import UnconditionalDistribution
from sklearn.model_selection import KFold
import copy
from torch.utils.data import DataLoader, TensorDataset
from architectures import get_stock_transforms
from csuite import SCMS, SCM_DIMS, SCM_MASKS
from causal_cocycle.causalflow_helper import select_and_train_flow, sample_do, sample_cf
from causal_cocycle.helper_functions import ks_statistic, wasserstein1_repeat, rmse

In [2]:
scm = SCMS['chain5_linear']
dims = SCM_DIMS['chain5_linear']
# ----------------------------------------
# 1) Define sample size and SCM dimension
# ----------------------------------------
N = 1000  # number of samples
# For a 5‐chain, d = 5 implicitly in scm

# --------------------------------------------------
# 2) Build noise distributions and transforms
# --------------------------------------------------

# (a) U₁ ∼ N(0,1)
dist_u1 = Normal(0,1)

# (b) “Base” for (U₂,…,U₅) is also N(0,1), but we will sample shape (N,4)
dist_base = Normal(0,1)

# (c) tf_u1: identity map on raw1 (shape (N,) → (N,1))
tf_u1 = lambda raw1: raw1.view(-1, 1)

# (d) tf_rest: affine map on raw_rest ∈ ℝ⁴ to induce correlation among U₂…U₅
rho = 0.0
Σ = (1-rho) * torch.eye(4) + rho * torch.ones(4, 4)
L = torch.linalg.cholesky(Σ)  # L is lower‐triangular, shape (4,4)

# We want U_rest = raw_rest @ Lᵀ
tf_rest = lambda raw_rest: raw_rest @ L.t()

# Package them into lists of length 2
noise_dists      = [dist_u1, dist_base]
noise_transforms = [tf_u1,   tf_rest]

# --------------------------------------------------
# 3) Draw data from the 5‐chain linear SCM
# --------------------------------------------------
#    - No intervention (intervention_node=None)
#    - return_u=True so we also get back the U matrix of shape (N,5)

X_full, U_full = scm(
    N=N,
    seed=42,
    intervention_node=None,
    intervention_value=None,
    return_u=True,
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
)

# X has shape (N, 5), U_full has shape (N, 5)
# U_full[:,0] = U₁, U_full[:,1] = U₂, ..., U_full[:,4] = U₅

print("X.shape =", X_full.shape)
print("U_full.shape =", U_full.shape)

# --------------------------------------------------
# 4) Quick sanity check: empirical correlation among U₂ and U₃
# --------------------------------------------------
U2 = U_full[:, 1]
U3 = U_full[:, 2]
emp_corr_23 = torch.corrcoef(torch.stack([U2, U3]))[0,1].item()
print(f"Empirical corr(U₂, U₃) ≈ {emp_corr_23:.3f}  (target ρ = {rho})")


X.shape = torch.Size([1000, 5])
U_full.shape = torch.Size([1000, 5])
Empirical corr(U₂, U₃) ≈ -0.014  (target ρ = 0.0)


# Training flow model

In [3]:
num_epochs = 1000
k_folds = 2
batch_size = 64
lr = 1e-2
transforms = get_stock_transforms(x_dim=0, y_dim=dims)  # returns a list of 4 different MAF instances
Base = Normal
base = UnconditionalDistribution(
    Base, loc=torch.zeros(dims), scale=torch.ones(dims), buffer=True
)
flows = [CausalFlow(transform=maf, base=base)
         for maf in transforms
        ]
# CV + retrain across all transforms
best_flow, test_nll, best_idx, cv_scores = select_and_train_flow(
    flows, X_full, train_fraction=1.0, k_folds=k_folds,
    num_epochs=num_epochs, batch_size=batch_size, lr=lr,
    device=X_full.device
)

In [4]:

# ───────────────────────────────────────────────────────────────
# 1. Re-generate observational data (X_full = [X1, Y1, ..., Y4])
# ───────────────────────────────────────────────────────────────
N = 10000
X_obs_full, _ = scm(
    N=N,
    seed=0,
    intervention_node=None,
    return_u=True,
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
)

X_obs = X_obs_full[:, :1]     # (N,1)
Y_obs = X_obs_full[:, 1:]     # (N,4)

# ───────────────────────────────────────────────────────────────
# 2. Generate ground-truth interventional data under do(X1 = c)
# ───────────────────────────────────────────────────────────────
c = 0.0
X_int_full = scm(
    N=N,
    seed=0,
    intervention_node=1,
    intervention_value=c,
    return_u=False,
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
)
Y_true_int = X_int_full[:, 1:]  # (N,4)

# ───────────────────────────────────────────────────────────────
# 3. Generate model interventional samples using sample_do
# ───────────────────────────────────────────────────────────────
device = X_obs.device
flow_model = best_flow.to(device)

with torch.no_grad():
    X_model_int = sample_do(
        flow=best_flow,
        index=0,
        intervention_fn=lambda old: torch.full_like(old, c),
        sample_shape=torch.Size([N])
    )
Y_model_int = X_model_int[:, 1:]

# ───────────────────────────────────────────────────────────────
# 4. Compute marginal KS statistics
# ───────────────────────────────────────────────────────────────
ks_vals = []
for j in range(4):
    ks_j = ks_statistic(Y_model_int[:, j].cpu(), Y_true_int[:, j].cpu())
    ks_vals.append(ks_j)
print("Interventional KS per dimension:", ks_vals)

# ───────────────────────────────────────────────────────────────
# 5. Generate counterfactual predictions
# ───────────────────────────────────────────────────────────────
Z_obs = X_obs_full.to(device)
with torch.no_grad():
    Z_cf = sample_cf(
        flow=best_flow,
        x_obs=Z_obs,
        index=0,
        intervention_fn=lambda old: torch.full_like(old, c)
    )
Y_cf_pred = Z_cf[:, 1:]

# Get true counterfactuals for these same samples
Z_cf_true = scm(
    N=N,
    seed=0,  # same seed as earlier for consistency
    intervention_node=1,
    intervention_value=c,
    return_u=False,
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
).to(device)
Y_cf_true = Z_cf_true[:, 1:]

# ───────────────────────────────────────────────────────────────
# 6. Compute RMSE per dimension
# ───────────────────────────────────────────────────────────────
rmse_vals = []
for j in range(4):
    rmse_j = torch.sqrt(((Y_cf_pred[:, j] - Y_cf_true[:, j]) ** 2).nanmean()).item()
    rmse_vals.append(rmse_j)
print("Counterfactual RMSE per dimension:", rmse_vals)


Interventional KS per dimension: [0.018599987030029297, 0.1550999879837036, 0.11640000343322754, 0.0974000096321106]
Counterfactual RMSE per dimension: [0.007982463575899601, 0.003497605910524726, 0.051112715154886246, 0.1640557199716568]


In [5]:
cv_scores

[8.52140998840332, 9.243063926696777, 12.75081491470337, 11.118023872375488]

## Training Cocycle Models

In [6]:
# Split into context X_ctx (shape (N,1)) and Y (shape (N,4)):
X_ctx = X_full[:, 0:1]   # 1‐dimensional context
Y_obs = X_full[:, 1:]    # 4‐dimensional output

# 2) Import the stock autoregressive transforms for a 1→4 conditional flow

# Each MaskedAutoregressiveTransform expects:
#   • features = y_dim = 4
#   • context  = x_dim = 1
transforms = get_stock_transforms(x_dim=1, y_dim=4)  # returns a list of 4 different MAF instances

# 3) Wrap each AR transform into a 1‐layer ZukoCocycleModel
from causal_cocycle.model_new import ZukoCocycleModel

# If you want to cross‐validate over all 4 architectures, build a list:
cocycle_models = [
    ZukoCocycleModel(nn.ModuleList(maf))  
    for maf in transforms
]
# Alternatively, to pick just the first architecture:
# cocycle_model = ZukoCocycleModel(nn.ModuleList([transforms[0]]))
model0 = cocycle_models[2]

from torch.utils.data import random_split

# 4) Build a CMMD-V loss via CocycleLossFactory
from causal_cocycle.loss_factory import CocycleLossFactory
from causal_cocycle.kernels import gaussian_kernel

# We supply two Gaussian kernels (one on X, one on U):
kernel = [gaussian_kernel(), gaussian_kernel()]
loss_factory = CocycleLossFactory(kernel)

# We need to “tune” the kernels’ lengthscales on training data.
# Split X_ctx / Y_obs into train / val
X_train = X_ctx
Y_train = Y_obs

# Build the CMMD-V loss, which will internally set lengthscales via median heuristic on (X_train,Y_train)
cmmdv_loss = loss_factory.build_loss("CMMD_V", X_train, Y_train, subsamples=10000)

# 5) Call optimise(...)
from causal_cocycle.optimise_new import optimise

# Example hyperparameters—feel free to adjust as desired:
learn_rate       = 1e-3
epochs           = 1000
weight_decay     = 0.0
batch_size       = 64
val_batch_size   = 256
scheduler        = False
schedule_milestone = 10
lr_mult          = 0.90
print_           = True
plot             = False
likelihood_param_opt = False
likelihood_param_lr  = 0.01

# Finally, run training. We supply both train and val sets so optimise can compute validation loss each epoch:
best_model, history = optimise(
    model            = model0,
    loss_tr          = cmmdv_loss,
    inputs           = X_train,
    outputs          = Y_train,
    inputs_val       = None,
    outputs_val      = None,
    learn_rate       = learn_rate,
    epochs           = epochs,
    weight_decay     = weight_decay,
    batch_size       = batch_size,
    val_batch_size   = val_batch_size,
    scheduler        = scheduler,
    schedule_milestone = schedule_milestone,
    lr_mult          = lr_mult,
    print_           = print_,
    plot             = plot,
    likelihood_param_opt = likelihood_param_opt,
    likelihood_param_lr  = likelihood_param_lr,
    loss_val         = None  # no separate “loss_val” function needed, since CMMD-V is symmetric
)

Epoch 1/1000, Training Loss: -0.4279
Epoch 2/1000, Training Loss: -0.4169
Epoch 3/1000, Training Loss: -0.4347
Epoch 4/1000, Training Loss: -0.4793
Epoch 5/1000, Training Loss: -0.4791
Epoch 6/1000, Training Loss: -0.5147
Epoch 7/1000, Training Loss: -0.5290
Epoch 8/1000, Training Loss: -0.5600
Epoch 9/1000, Training Loss: -0.5818
Epoch 10/1000, Training Loss: -0.6013
Epoch 11/1000, Training Loss: -0.6205
Epoch 12/1000, Training Loss: -0.6151
Epoch 13/1000, Training Loss: -0.6289
Epoch 14/1000, Training Loss: -0.6443
Epoch 15/1000, Training Loss: -0.6544
Epoch 16/1000, Training Loss: -0.6714
Epoch 17/1000, Training Loss: -0.7002
Epoch 18/1000, Training Loss: -0.7623
Epoch 19/1000, Training Loss: -0.8149
Epoch 20/1000, Training Loss: -0.8136
Epoch 21/1000, Training Loss: -0.8186
Epoch 22/1000, Training Loss: -0.8136
Epoch 23/1000, Training Loss: -0.8179
Epoch 24/1000, Training Loss: -0.8289
Epoch 25/1000, Training Loss: -0.8194
Epoch 26/1000, Training Loss: -0.8145
Epoch 27/1000, Traini

In [7]:
import torch
from torch.distributions import Normal

# ──────────────────────────────────────────────────────────────────────────────
# 1) Re‐generate “observational” data (no intervention)
# ──────────────────────────────────────────────────────────────────────────────
N = 10000

# draw N observations (no do‐intervention)
X_obs_full,U_obs = scm(
    N=N,
    seed=0,
    intervention_node=None,
    return_u=True,            # we do not need U here
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
)
X_obs = X_obs_full[:, :1]   # (N,1)
Y_obs = X_obs_full[:, 1:]   # (N,4)


# ──────────────────────────────────────────────────────────────────────────────
# 2) Draw “true interventional” data by calling the SCM with intervention_node=1
# ──────────────────────────────────────────────────────────────────────────────
c = 0.0
X_int_full, U_int = scm(
    N=N,
    seed=0,                     # fresh seed (independent of the observational draw)
    intervention_node=1,        # hard‐set X₁ = c
    intervention_value=c,
    return_u=True,
    noise_dists=noise_dists,
    noise_transforms=noise_transforms
)
Y_true_int = X_int_full[:, 1:]  # shape (N,4)


# ──────────────────────────────────────────────────────────────────────────────
# 3) Use the trained cocycle model to predict counterfactuals under do(X₁ = c)
# ──────────────────────────────────────────────────────────────────────────────
#    Y_pred_int[i] = f( x_do = c,  f^{-1}( x_obs[i],  Y_obs[i] ) )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = model0.to(device)       # your trained ZukoCocycleModel
x_obs      = X_obs.to(device)        # (N,1)
y_obs      = Y_obs.to(device)        # (N,4)
x_do       = torch.full_like(x_obs, c)  # (N,1), all entries = c

with torch.no_grad():
    Y_pred_int = best_model.cocycle(x_do, x_obs, y_obs)  # (N,4)

Y_pred_int_cpu = Y_pred_int.cpu()
Y_true_int_cpu = Y_true_int.cpu()


# ──────────────────────────────────────────────────────────────────────────────
# 4) Compute marginal KS‐distance for each output dimension j = 0..3
# ──────────────────────────────────────────────────────────────────────────────

ks_vals = []
for j in range(Y_true_int_cpu.size(1)):  # 4 dims
    ks_j = ks_statistic(Y_pred_int_cpu[:, j], Y_true_int_cpu[:, j])
    ks_vals.append(ks_j)

print("KS distances (hard intervention) per dimension:", ks_vals)


# ──────────────────────────────────────────────────────────────────────────────
# 5) Compute counterfactual RMSE (per output dimension)
# ──────────────────────────────────────────────────────────────────────────────
rmse_vals = []
for j in range(Y_true_int_cpu.size(1)):
    diff   = Y_pred_int_cpu[:, j] - Y_true_int_cpu[:, j]
    rmse_j = torch.sqrt((diff ** 2).mean()).item()
    rmse_vals.append(rmse_j)

print("Counterfactual RMSE (hard intervention) per dimension:", rmse_vals)

KS distances (hard intervention) per dimension: [0.0340999960899353, 0.0317000150680542, 0.045399993658065796, 0.035800039768218994]
Counterfactual RMSE (hard intervention) per dimension: [0.17734958231449127, 0.3007647693157196, 0.5304532647132874, 0.5358552932739258]


In [8]:
# 1) Grab the first MAF in your flow:
maf = model0.transforms[0]   # this is a MaskedAutoregressiveTransform

# 2) The very first layer inside that MAF’s conditioner is actually a MaskedLinear.
#    In most Zuko versions, it lives under `maf.hyper[0]`, so:
masked_linear = maf.hyper[0]     # type: zuko.flows.autoregressive.MaskedLinear

# 3) The raw weight parameter is:
raw_W = masked_linear.weight     # shape = (4, 5)
#    and the associated mask is:
mask  = masked_linear.mask       # same shape = (4, 5)

print("raw_W (un‐masked parameter):")
print(raw_W)

print("\nmask (0/1 binary tensor):")
print(mask)

# 4) If you multiply them, you get the actual weight used during forward:
effective_W = raw_W * mask
print("\neffective_W = raw_W * mask (i.e. the masked‐out values):")
print(effective_W)


raw_W (un‐masked parameter):
Parameter containing:
tensor([[ 8.0902e-02, -4.2480e-01,  1.0662e-01, -2.3866e-01,  2.2698e-01],
        [ 5.1630e-02, -3.2764e-01, -2.2654e-01,  3.4222e-01, -4.6735e-01],
        [ 3.6205e-01, -3.1706e-01,  3.0608e-01,  1.5165e-01, -3.6487e-01],
        [ 2.8724e-01,  1.2983e-01, -2.0013e-01, -2.0294e-01, -6.0159e-02],
        [-4.4647e-01,  4.1890e-02, -1.3083e-01,  3.2318e-01, -4.9815e-01],
        [ 2.5704e-01, -1.0225e-01, -3.1755e-01, -4.1455e-01, -9.4470e-02],
        [ 2.6518e-01, -6.9268e-02,  2.5378e-01,  1.7207e-01, -5.4861e-01],
        [ 8.2895e-02, -8.6746e-02, -1.1993e-01, -2.7021e-01, -2.0238e-01],
        [ 2.7290e-01, -1.9106e-01,  1.2508e-01,  1.5714e-01, -3.9317e-01],
        [-1.2485e-01, -3.8506e-01,  3.3927e-01, -2.6481e-01,  2.6963e-01],
        [ 1.8183e-01, -3.7577e-01,  4.1329e-01,  3.4634e-01, -5.5728e-01],
        [ 1.4495e-01, -6.3691e-01,  2.4880e-01, -4.1497e-01, -5.3599e-01],
        [ 3.6794e-01, -1.6546e-01,  3.3083e-01, -

In [9]:
# Multi-cocycle Training

In [10]:
import torch
import torch.nn as nn
from torch.utils.data import random_split, TensorDataset
from causal_cocycle.model_new import ZukoCocycleModel
from causal_cocycle.loss_factory import CocycleLossFactory
from causal_cocycle.kernels import gaussian_kernel
from causal_cocycle.optimise_new import optimise
from architectures import get_stock_transforms

# -----------------------------
# Step 1: Setup from full data
# -----------------------------
# Full data matrix (N, 5), e.g. from chain5
N, d = X_full.shape
assert d == 5
X_full = X_full.to("cuda" if torch.cuda.is_available() else "cpu")

# We will train 4 cocycle models: each V_j | V_{<j}, for j=1..4
models = []
losses = []
factory = CocycleLossFactory([gaussian_kernel(), gaussian_kernel()])
subsamples = 1000

# ------------------------------------------
# Step 2: Construct model/loss per coordinate
# ------------------------------------------
for j in range(1, d):
    x_ctx = X_full[:, :j]         # V_{<j}
    y_out = X_full[:, j:j+1]      # V_j

    # split train/val
    dataset = TensorDataset(x_ctx, y_out)
    n_train = int(1 * N)
    n_val = N - n_train
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
    X_train = torch.stack([train_ds[i][0] for i in range(n_train)])
    Y_train = torch.stack([train_ds[i][1] for i in range(n_train)])

    # Make transform and model
    maf = get_stock_transforms(x_dim=j, y_dim=1)[2]
    model = ZukoCocycleModel(nn.ModuleList(maf)).to(X_full.device)
    models.append(model)

    # Tune kernel and get CMMD-V loss
    loss_j = factory.build_loss("CMMD_V", X_train, Y_train, subsamples=subsamples)
    losses.append(loss_j)

# ------------------------------------------
# Step 3: Define aggregate loss over models
# ------------------------------------------
class CMMDMultiLoss(nn.Module):
    def __init__(self, models, losses):
        super().__init__()
        self.models = models
        self.losses = losses

    def forward(self, _, inputs, outputs):
        loss_total = 0.0
        for j, (model, loss) in enumerate(zip(self.models, self.losses)):
            x_j = inputs[:, :j+1]       # context: V_{<j+1}
            y_j = outputs[:, j:j+1]     # target:  V_{j+1}
            loss_total += loss(model, x_j, y_j)
        return loss_total

# ------------------------------------------
# Step 4: Training
# ------------------------------------------
multi_loss = CMMDMultiLoss(models, losses)
composite_model = nn.ModuleList(models)

best_model, history = optimise(
    model              = composite_model,
    loss_tr            = multi_loss,
    inputs             = X_full[:, :d-1],  # V_{<d}
    outputs            = X_full[:, 1:],    # V_1 to V_4
    inputs_val         = None,
    outputs_val        = None,
    learn_rate         = 1e-2,
    epochs             = 500,
    weight_decay       = 0.0,
    batch_size         = 64,
    val_batch_size     = 256,
    scheduler          = False,
    schedule_milestone = 10,
    lr_mult            = 0.90,
    print_             = True,
    plot               = False,
    likelihood_param_opt = False,
    likelihood_param_lr  = 0.01,
    loss_val           = None
)

Epoch 1/500, Training Loss: -2.0410
Epoch 2/500, Training Loss: -2.6992
Epoch 3/500, Training Loss: -3.0456
Epoch 4/500, Training Loss: -3.0979
Epoch 5/500, Training Loss: -3.1252
Epoch 6/500, Training Loss: -3.1520
Epoch 7/500, Training Loss: -3.1277
Epoch 8/500, Training Loss: -3.1515
Epoch 9/500, Training Loss: -3.1833
Epoch 10/500, Training Loss: -3.1683
Epoch 11/500, Training Loss: -3.1481
Epoch 12/500, Training Loss: -3.1753
Epoch 13/500, Training Loss: -3.1549
Epoch 14/500, Training Loss: -3.1967
Epoch 15/500, Training Loss: -3.1921
Epoch 16/500, Training Loss: -3.1882
Epoch 17/500, Training Loss: -3.1739
Epoch 18/500, Training Loss: -3.1663
Epoch 19/500, Training Loss: -3.1849
Epoch 20/500, Training Loss: -3.1972
Epoch 21/500, Training Loss: -3.1926
Epoch 22/500, Training Loss: -3.2005
Epoch 23/500, Training Loss: -3.1932
Epoch 24/500, Training Loss: -3.1966
Epoch 25/500, Training Loss: -3.1852
Epoch 26/500, Training Loss: -3.1805
Epoch 27/500, Training Loss: -3.2081
Epoch 28/5

In [11]:
import torch
from torch.distributions import Normal

# 1) Re-generate observational and interventional data
N = 5000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# observational data (X₁,...,X₅)
X_obs_full, U = scm(
    N=N, seed=0, intervention_node=None, return_u=True,
    noise_dists=noise_dists, noise_transforms=noise_transforms
)
X_obs_full = X_obs_full.to(device)

# interventional (true) data under do(X₁ = c)
c = 0.0
X_true_cf_full, U_cf = scm(
    N=N, seed=0, intervention_node=1, intervention_value=c,
    return_u=True, noise_dists=noise_dists, noise_transforms=noise_transforms
)
Y_true_cf = X_true_cf_full[:, 1:]  # (N, 4)

# 2) Predict counterfactuals from models
models = [f.to(device) for f in models]  # trained cocycle models per coordinate

# Intervention: set X₁ = c
X_obs_cf = X_obs_full.clone()
X_obs_cf[:, 0] = c   # hard intervention on V₀

# Sequentially reconstruct (V₁,...,V₄) using the trained cocycles
predicted_cf = torch.zeros((N, 4), device=device)

for j, model in enumerate(models):
    # Build context using previously predicted counterfactuals
    if j == 0:
        x_new = X_obs_cf[:, :1]                     # just the intervened x₁
        x_old = X_obs_full[:, :1]
    else:
        x_new = torch.cat([X_obs_cf[:, :1], predicted_cf[:, :j]], dim=1)  # updated context
        x_old = torch.cat([X_obs_full[:, :1], X_obs_full[:, 1:1+j]], dim=1)

    y_old = X_obs_full[:, j+1:j+2]  # observed y_j

    with torch.no_grad():
        y_pred = model.cocycle(x_new, x_old, y_old)

    predicted_cf[:, j:j+1] = y_pred  # store for use in future iterations

# 3) Compute KS distances
def ks_statistic(a: torch.Tensor, b: torch.Tensor) -> float:
    a, b = a.flatten(), b.flatten()
    a_s, _ = torch.sort(a)
    b_s, _ = torch.sort(b)
    all_vals = torch.cat([a_s, b_s]).unique()
    cdf_a = torch.bucketize(all_vals, a_s, right=True).float() / a_s.numel()
    cdf_b = torch.bucketize(all_vals, b_s, right=True).float() / b_s.numel()
    return (torch.abs(cdf_a - cdf_b).max()).item()

ks_vals = [
    ks_statistic(predicted_cf[:, j].cpu(), Y_true_cf[:, j].cpu())
    for j in range(4)
]
print("KS distances (hard intervention) per dimension:", ks_vals)

# 4) Compute RMSE
rmse_vals = [
    torch.sqrt(((predicted_cf[:, j] - Y_true_cf[:, j])**2).mean()).item()
    for j in range(4)
]
print("Counterfactual RMSE (hard intervention) per dimension:", rmse_vals)

KS distances (hard intervention) per dimension: [0.0690000057220459, 0.12020003795623779, 0.1340000033378601, 0.13679999113082886]
Counterfactual RMSE (hard intervention) per dimension: [0.26215505599975586, 1.0929704904556274, 1.2617816925048828, 1.7701361179351807]


In [12]:
for i in range(4):
    print(models[i].transforms[0].hyper[0].weight)
    print(torch.linalg.solve(X_full[:,:i+1].T @ X_full[:,:i+1], X_full[:,:i+1].T @ X_full[:,i+1:i+2]))

Parameter containing:
tensor([[ 3.4044e-01],
        [-5.4944e-01],
        [-3.7757e-01],
        [ 1.4378e-01],
        [-1.2887e-01],
        [ 4.0049e-01],
        [ 1.3131e-03],
        [ 2.6580e-01],
        [-1.2055e-01],
        [ 2.3000e-03],
        [ 4.8790e-02],
        [-7.7819e-01],
        [-1.3907e-01],
        [ 6.3595e-01],
        [-2.2286e-02],
        [ 1.8039e-01],
        [ 1.9533e-01],
        [-1.5045e-01],
        [ 4.8664e-01],
        [-9.8210e-02],
        [ 4.8804e-01],
        [ 3.3773e-03],
        [-2.4974e-01],
        [-1.9627e-01],
        [ 6.7053e-01],
        [ 8.0588e-01],
        [-4.0051e-01],
        [-9.2777e-03],
        [-2.4872e-01],
        [ 2.4762e-01],
        [-8.1879e-02],
        [-8.3429e-04],
        [ 5.4720e-01],
        [ 4.5696e-01],
        [-3.0364e-05],
        [ 2.1999e-02],
        [ 2.5275e-02],
        [-3.2750e-01],
        [-5.8912e-05],
        [ 3.4219e-01],
        [-5.9961e-03],
        [-4.5564e-01],
        [ 3.