In [1]:
from torch.distributions import Normal,Laplace,Uniform,Gamma, Beta, Cauchy
import os
import time
import torch
from copy import deepcopy

# Causal_cocycle imports
from causal_cocycle.model_factory import CocycleFactory
from causal_cocycle.model_factory import FlowFactory
from causal_cocycle.loss_factory import CocycleLossFactory
from causal_cocycle.loss import FlowLoss
from causal_cocycle.optimise_new import validate, optimise
from causal_cocycle.kernels import gaussian_kernel
from causal_cocycle.helper_functions import kolmogorov_distance
from lm_config import opt_config, model_config

"""
Configs
"""
# Experimental set up
seed = 3
N,D,P = 1000,1,1
sig_noise_ratio = 1

"""
Main
"""
# Object storage
names = ["L2","L1","HSIC","URR","CMMD-V","CMMD-U","True"]
Coeffs = torch.zeros((1,len(names),P))

# Data generation

# Drawing data
torch.manual_seed(seed)
X = Normal(1,1).sample((N,D))
X *= 1/(D)**0.5
B = torch.ones((D,1))*(torch.linspace(0,D-1,D)<P)[:,None]
F = X @ B
Utot = Normal(0,1).sample((2*N,1))/sig_noise_ratio**0.5
#Utot = torch.sign(Uniform(-1,1).sample((2*N,1)))/sig_noise_ratio**0.5
#Utot = Cauchy(0,1).sample((2*N,1))/sig_noise_ratio**0.5
#Utot = (1/Gamma(1,1).sample((2*N,1))-1)/sig_noise_ratio**0.5
U,Uint = Utot[:N], Utot[N:]
Y = F + U

In [2]:
# Cocycle model construction
factory = CocycleFactory(1, model_config)
models, hyper_args = factory.build_models()
print(f"Constructed {len(models)} candidate cocycle models.")

gauss_config,laplace_config = model_config.copy(),model_config.copy()
gauss_config['base_distribution_configs'], laplace_config['base_distribution_configs'] = ["Normal"]*4, ["Laplace"]*4
models_gauss, hyper_args  = FlowFactory(1, gauss_config).build_models()
models_laplace, hyper_args  = FlowFactory(1, laplace_config).build_models()

models_urr_gauss,models_urr_laplace = deepcopy(models_gauss),deepcopy(models_laplace)
for i in range(len(models_urr_gauss)):
    models_urr_gauss[i].transformer.logdet = False
    models_urr_laplace[i].transformer.logdet = False

Constructed 4 candidate cocycle models.


In [3]:
opt_config['learn_rate'] = 0.01
opt_config['epochs'] = 500
opt_config['scheduler'] = False

In [4]:
Xtild = torch.column_stack((torch.ones((N,1)),X))
torch.linalg.solve( Xtild.T @ Xtild, Xtild.T @ Y)

tensor([[-0.0036],
        [ 0.9737]])

In [5]:
# Training with L2
loss= FlowLoss()
final_model_l2, (best_index_l2, val_loss_l2) = validate(
        models_gauss,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs = hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_l2} with average validation loss {val_loss_l2:.4f}")

Epoch 1/500, Training Loss: 196.2502
Epoch 2/500, Training Loss: 172.4278
Epoch 3/500, Training Loss: 159.1290
Epoch 4/500, Training Loss: 96.6615
Epoch 5/500, Training Loss: 82.0629
Epoch 6/500, Training Loss: 121.5600
Epoch 7/500, Training Loss: 101.8954
Epoch 8/500, Training Loss: 168.0363
Epoch 9/500, Training Loss: 73.9710
Epoch 10/500, Training Loss: 19.7233
Epoch 11/500, Training Loss: 87.5817
Epoch 12/500, Training Loss: 141.6195
Epoch 13/500, Training Loss: 44.3104
Epoch 14/500, Training Loss: 66.3462
Epoch 15/500, Training Loss: 78.2021
Epoch 16/500, Training Loss: 97.8843
Epoch 17/500, Training Loss: 94.7571
Epoch 18/500, Training Loss: 99.0725
Epoch 19/500, Training Loss: 22.5608
Epoch 20/500, Training Loss: 107.7922
Epoch 21/500, Training Loss: 49.6228
Epoch 22/500, Training Loss: 83.8626
Epoch 23/500, Training Loss: 97.0127
Epoch 24/500, Training Loss: 58.2120
Epoch 25/500, Training Loss: 46.9381
Epoch 26/500, Training Loss: 45.7694
Epoch 27/500, Training Loss: 26.0870
Ep

In [4]:
# Training with L1
loss= FlowLoss()
final_model_l1, (best_index_l1, val_loss_l1) = validate(
        models_laplace,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_l1} with average validation loss {val_loss_l1:.4f}")

Epoch 1/100, Training Loss: 1.5519
Epoch 2/100, Training Loss: 1.4694
Epoch 3/100, Training Loss: 1.5214
Epoch 4/100, Training Loss: 1.4979
Epoch 5/100, Training Loss: 1.4953
Epoch 6/100, Training Loss: 1.5075
Epoch 7/100, Training Loss: 1.4582
Epoch 8/100, Training Loss: 1.3975
Epoch 9/100, Training Loss: 1.5174
Epoch 10/100, Training Loss: 1.5019
Epoch 11/100, Training Loss: 1.4288
Epoch 12/100, Training Loss: 1.5300
Epoch 13/100, Training Loss: 1.5052
Epoch 14/100, Training Loss: 1.5069
Epoch 15/100, Training Loss: 1.5267
Epoch 16/100, Training Loss: 1.4294
Epoch 17/100, Training Loss: 1.4520
Epoch 18/100, Training Loss: 1.4511
Epoch 19/100, Training Loss: 1.5124
Epoch 20/100, Training Loss: 1.4165
Epoch 21/100, Training Loss: 1.5160
Epoch 22/100, Training Loss: 1.4815
Epoch 23/100, Training Loss: 1.4918
Epoch 24/100, Training Loss: 1.5093
Epoch 25/100, Training Loss: 1.4453
Epoch 26/100, Training Loss: 1.4344
Epoch 27/100, Training Loss: 1.4524
Epoch 28/100, Training Loss: 1.4364
E

In [5]:
# Training with cocycles
kernel = [gaussian_kernel()] * 2
loss_factory = CocycleLossFactory(kernel)
loss= loss_factory.build_loss("CMMD_V", X, Y, subsamples=10**4)
final_model_cmmdv, (best_index_cmmdv, val_loss_cmmdv) = validate(
        models,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_cmmdv} with average validation loss {val_loss_cmmdv:.4f}")

Epoch 1/100, Training Loss: -0.5418
Epoch 2/100, Training Loss: -0.5472
Epoch 3/100, Training Loss: -0.5720
Epoch 4/100, Training Loss: -0.5525
Epoch 5/100, Training Loss: -0.5360
Epoch 6/100, Training Loss: -0.5446
Epoch 7/100, Training Loss: -0.5683
Epoch 8/100, Training Loss: -0.5473
Epoch 9/100, Training Loss: -0.5424
Epoch 10/100, Training Loss: -0.5376
Epoch 11/100, Training Loss: -0.5637
Epoch 12/100, Training Loss: -0.5566
Epoch 13/100, Training Loss: -0.5452
Epoch 14/100, Training Loss: -0.5445
Epoch 15/100, Training Loss: -0.5534
Epoch 16/100, Training Loss: -0.5652
Epoch 17/100, Training Loss: -0.5590
Epoch 18/100, Training Loss: -0.5445
Epoch 19/100, Training Loss: -0.5489
Epoch 20/100, Training Loss: -0.5600
Epoch 21/100, Training Loss: -0.5519
Epoch 22/100, Training Loss: -0.5523
Epoch 23/100, Training Loss: -0.5604
Epoch 24/100, Training Loss: -0.5493
Epoch 25/100, Training Loss: -0.5648
Epoch 26/100, Training Loss: -0.5391
Epoch 27/100, Training Loss: -0.5404
Epoch 28/1

In [6]:
# Training with cocycles
kernel = [gaussian_kernel()] * 2
loss_factory = CocycleLossFactory(kernel)
loss= loss_factory.build_loss("CMMD_U", X, Y, subsamples=10**4)
final_model_cmmdu, (best_index_cmmdu, val_loss_cmmdu) = validate(
        models,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_cmmdu} with average validation loss {val_loss_cmmdu:.4f}")

Epoch 1/100, Training Loss: -0.5368
Epoch 2/100, Training Loss: -0.5318
Epoch 3/100, Training Loss: -0.5337
Epoch 4/100, Training Loss: -0.5448
Epoch 5/100, Training Loss: -0.5687
Epoch 6/100, Training Loss: -0.5579
Epoch 7/100, Training Loss: -0.5561
Epoch 8/100, Training Loss: -0.5567
Epoch 9/100, Training Loss: -0.5317
Epoch 10/100, Training Loss: -0.5418
Epoch 11/100, Training Loss: -0.5531
Epoch 12/100, Training Loss: -0.5414
Epoch 13/100, Training Loss: -0.5447
Epoch 14/100, Training Loss: -0.5619
Epoch 15/100, Training Loss: -0.5526
Epoch 16/100, Training Loss: -0.5497
Epoch 17/100, Training Loss: -0.5620
Epoch 18/100, Training Loss: -0.5393
Epoch 19/100, Training Loss: -0.5401
Epoch 20/100, Training Loss: -0.5449
Epoch 21/100, Training Loss: -0.5438
Epoch 22/100, Training Loss: -0.5591
Epoch 23/100, Training Loss: -0.5573
Epoch 24/100, Training Loss: -0.5537
Epoch 25/100, Training Loss: -0.5253
Epoch 26/100, Training Loss: -0.5593
Epoch 27/100, Training Loss: -0.5531
Epoch 28/1

In [34]:
loss.kernel[1].lengthscale

tensor(0.7824)

In [23]:
# Training with hsic
kernel = [gaussian_kernel()] * 2
loss_factory = CocycleLossFactory(kernel)
Uhat = final_model_l2.inverse_transformation(X,Y)[0].detach()
loss= loss_factory.build_loss("HSIC", X, Uhat, subsamples=10**4)
final_model_hsic, (best_index_hsic, val_loss_hsic) = validate(
        models,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_hsic} with average validation loss {val_loss_hsic:.4f}")

Epoch 1/500, Training Loss: 0.0191
Epoch 2/500, Training Loss: 0.0185
Epoch 3/500, Training Loss: 0.0221
Epoch 4/500, Training Loss: 0.0207
Epoch 5/500, Training Loss: 0.0210
Epoch 6/500, Training Loss: 0.0170
Epoch 7/500, Training Loss: 0.0197
Epoch 8/500, Training Loss: 0.0115
Epoch 9/500, Training Loss: 0.0191
Epoch 10/500, Training Loss: 0.0191
Epoch 11/500, Training Loss: 0.0173
Epoch 12/500, Training Loss: 0.0232
Epoch 13/500, Training Loss: 0.0194
Epoch 14/500, Training Loss: 0.0175
Epoch 15/500, Training Loss: 0.0191
Epoch 16/500, Training Loss: 0.0178
Epoch 17/500, Training Loss: 0.0179
Epoch 18/500, Training Loss: 0.0180
Epoch 19/500, Training Loss: 0.0154
Epoch 20/500, Training Loss: 0.0141
Epoch 21/500, Training Loss: 0.0135
Epoch 22/500, Training Loss: 0.0164
Epoch 23/500, Training Loss: 0.0151
Epoch 24/500, Training Loss: 0.0148
Epoch 25/500, Training Loss: 0.0131
Epoch 26/500, Training Loss: 0.0145
Epoch 27/500, Training Loss: 0.0133
Epoch 28/500, Training Loss: 0.0146
E

KeyboardInterrupt: 

In [5]:
# Training with urr
kernel = [gaussian_kernel()] * 2
loss_factory = CocycleLossFactory(kernel)
loss= loss_factory.build_loss("URR", X, Y, subsamples=10**4)
loss_val= loss_factory.build_loss("URR_N", X, Y, subsamples=10**4)
final_model_urr, (best_index_urr, val_loss_urr) = validate(
        models_urr_gauss,
        loss,
        X,
        Y,
        loss_val=loss_val,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_urr} with average validation loss {val_loss_urr:.4f}")

Epoch 1/500, Training Loss: 0.0719
Epoch 2/500, Training Loss: -0.0094
Epoch 3/500, Training Loss: 0.0100
Epoch 4/500, Training Loss: -0.0226
Epoch 5/500, Training Loss: -0.0157
Epoch 6/500, Training Loss: -0.0751
Epoch 7/500, Training Loss: -0.0490
Epoch 8/500, Training Loss: -0.0620
Epoch 9/500, Training Loss: -0.1245
Epoch 10/500, Training Loss: -0.0488
Epoch 11/500, Training Loss: -0.1238
Epoch 12/500, Training Loss: -0.1351
Epoch 13/500, Training Loss: -0.1149
Epoch 14/500, Training Loss: -0.1577
Epoch 15/500, Training Loss: -0.1238
Epoch 16/500, Training Loss: -0.1746
Epoch 17/500, Training Loss: -0.2348
Epoch 18/500, Training Loss: -0.2857
Epoch 19/500, Training Loss: -0.2371
Epoch 20/500, Training Loss: -0.2513
Epoch 21/500, Training Loss: -0.2302
Epoch 22/500, Training Loss: -0.3016
Epoch 23/500, Training Loss: -0.3102
Epoch 24/500, Training Loss: -0.2946
Epoch 25/500, Training Loss: -0.3664
Epoch 26/500, Training Loss: -0.4091
Epoch 27/500, Training Loss: -0.3874
Epoch 28/500

In [None]:
# Training with urr
kernel = [gaussian_kernel()] * 2
loss_factory = CocycleLossFactory(kernel)
loss= loss_factory.build_loss("URR", X, Y, subsamples=10**4)
final_model_urr_l, (best_index_urr_l, val_loss_urr_l) = validate(
        models_urr_laplace,
        loss,
        X,
        Y,
        loss_val=loss,
        method="CV",
        train_val_split=0.5,
        opt_kwargs=opt_config,
        hyper_kwargs=hyper_args,
        choose_best_model="overall",
        retrain=True,
    )
print(f"Best overall model index: {best_index_urr} with average validation loss {val_loss_urr:.4f}")

Epoch 1/500, Training Loss: -0.1647
Epoch 2/500, Training Loss: -0.2069
Epoch 3/500, Training Loss: -0.2138
Epoch 4/500, Training Loss: -0.3055
Epoch 5/500, Training Loss: -0.2902
Epoch 6/500, Training Loss: -0.3436
Epoch 7/500, Training Loss: -0.2951
Epoch 8/500, Training Loss: -0.3222
Epoch 9/500, Training Loss: -0.3448
Epoch 10/500, Training Loss: -0.3523
Epoch 11/500, Training Loss: -0.3962
Epoch 12/500, Training Loss: -0.4142
Epoch 13/500, Training Loss: -0.4054
Epoch 14/500, Training Loss: -0.4282
Epoch 15/500, Training Loss: -0.4452
Epoch 16/500, Training Loss: -0.4808
Epoch 17/500, Training Loss: -0.4172
Epoch 18/500, Training Loss: -0.5139
Epoch 19/500, Training Loss: -0.4656
Epoch 20/500, Training Loss: -0.4897
Epoch 21/500, Training Loss: -0.5277
Epoch 22/500, Training Loss: -0.5335
Epoch 23/500, Training Loss: -0.5784
Epoch 24/500, Training Loss: -0.5389
Epoch 25/500, Training Loss: -0.4829
Epoch 26/500, Training Loss: -0.5273
Epoch 27/500, Training Loss: -0.5680
Epoch 28/5

In [25]:
import torch

# -------------------------------------------------------------------
# 1) KS and RMSE helpers (pure Torch)
# -------------------------------------------------------------------
def ks_statistic(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.flatten(); b = b.flatten()
    a_s, _ = torch.sort(a); b_s, _ = torch.sort(b)
    all_vs = torch.cat([a_s, b_s]).unique()
    cdf_a = torch.bucketize(all_vs, a_s, right=True).float() / a_s.numel()
    cdf_b = torch.bucketize(all_vs, b_s, right=True).float() / b_s.numel()
    return (torch.abs(cdf_a - cdf_b).max()).item()

def rmse(a: torch.Tensor, b: torch.Tensor) -> float:
    return torch.sqrt(((a - b)**2).mean()).item()

# -------------------------------------------------------------------
# 2) Evaluation loop
# -------------------------------------------------------------------
def evaluate_models(
    models_dict: dict,
    index_dict: dict,
    X: torch.Tensor,
    Y: torch.Tensor,
    B: torch.Tensor,
    U: torch.Tensor,
    sig_noise_ratio: float,
    seed: int = None
) -> dict:
    """
    models_dict: { name: (model, model_type) }
      model_type in {'cocycle','hsic','cmmd_v','cmmd_u','l2','l1','urr'}

    Returns:
      {
        name: {
          'KS_int':   KS between Y_model(X+1) and Y_true(X+1),
          'KS_cf':    KS between (Y_cf-Y) and (Y_true-Y),
          'CF_RMSE':  RMSE between (Y_cf-Y) and (Y_true-Y)
        },
        ...
      }
    """
    N, D = X.shape
    _, P = Y.shape
    device = X.device

    # “true” interventional outcome
    Xp      = X + 1.0  
    Y_true  = Xp + U            # (N,P)
    ΔY_true = torch.ones((N,P))          # (N,P)

    results = {}
    for name, (model, mtype) in models_dict.items():
        # ---- interventional estimate ----
        if mtype in ('hsic','cmmdv','cmmdu'):
            Y_int = model.cocycle(Xp, X, Y)             # (N,P)
        else:  # l2, l1, urr
            Uhat = model.base_distribution.sample((N,))
            out = model.transformation(Xp, Uhat)           # either y or (y,logdet)
            Y_int = out[0] if isinstance(out, tuple) else out

        KS_int = ks_statistic(Y_int[:,0],   Y_true[:,0])

        # ---- counterfactual via cocycle ----
        model.transformer.logdet = False
        Y_cf    = model.cocycle(Xp, X, Y)     # (N,P)
        ΔY_model= Y_cf - Y                    # (N,P)

        RMSE_cf = rmse(    ΔY_model[:,0],   ΔY_true[:,0])

        results[name] = {
            'KS_int':  KS_int,
            'CF_RMSE': RMSE_cf,
            'index': index_dict[name][0]
        }

    return results

# -------------------------------------------------------------------
# 3) Example usage
# -------------------------------------------------------------------
my_models = {
    'L2' : (final_model_l2, 'l2'),
    'L1' : (final_model_l1, 'l1'),
    'URR L2': (final_model_urr,'urr'),
    'URR L1': (final_model_urr_l,'urr'),
    'CMMD_V': (final_model_cmmdv,'cmmdv'),
    'CMMD_U': (final_model_cmmdu,'cmmdu'),
}

my_indexes = {
    'L2' : (best_index_l2, 'l2'),
    'L1' : (best_index_l1, 'l1'),
    'URR L2': (best_index_urr,'urr'),
    'URR L1': (best_index_urr_l,'urr'),
    'CMMD_V': (best_index_cmmdv,'cmmdv'),
    'CMMD_U': (best_index_cmmdu,'cmmdu'),
}

final_model_l2.inverse_transformation(X,Y).detach()

metrics = evaluate_models(
    my_models,
    my_indexes,
    X, Y,
    B=torch.ones((1,1)), 
    U = Uint,
    sig_noise_ratio=sig_noise_ratio,
    seed=2025
)

# Print nicely
for name, m in metrics.items():
     print(f"{name:6s}  KS_int={m['KS_int']:.3f}  CF_RMSE={m['CF_RMSE']:.3f}")


torch.Size([1000, 1]) torch.Size([1000, 1])
L2      KS_int=0.114  CF_RMSE=0.221
L1      KS_int=0.054  CF_RMSE=0.047
URR L2  KS_int=0.109  CF_RMSE=0.177
URR L1  KS_int=0.054  CF_RMSE=0.095
CMMD_V  KS_int=0.045  CF_RMSE=0.022
CMMD_U  KS_int=0.045  CF_RMSE=0.020
