In [8]:
import torch
from torch.distributions import (
    Distribution, Normal, Laplace, Cauchy, Gamma, Uniform
)
from csuite import SCMS, SCM_DIMS, SCM_MASKS
from architectures import get_stock_transforms
from zuko.flows import UnconditionalDistribution
from causalflows.flows import CausalFlow
from causal_cocycle.causalflow_helper import select_and_train_flow, sample_do, sample_cf
from causal_cocycle.helper_functions import ks_statistic, wasserstein1_repeat, rmse

from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.flows import UnconditionalDistribution
from zuko.transforms import MonotonicRQSTransform

In [4]:
from torch.distributions.transforms import Transform
class ShiftTransform(Transform):
    bijective = True
    sign = +1

    def __init__(self, b):
        # b will be a tensor of shape (batch,1)
        super().__init__(cache_size=0)
        self.b = b

    def _call(self, x):
        # x: (B,), b: (B,1) ⇒ squeeze to (B,)
        shift = self.b.squeeze(-1)
        return x + shift

    def _inverse(self, y):
        shift = self.b.squeeze(-1)
        return y - shift

    def log_abs_det_jacobian(self, x, y):
        # d(x+shift)/dx = 1  ⇒ log|1| = 0
        return torch.zeros_like(x)

In [81]:

def evaluate_models(
    models_dict: dict,
    index_dict: dict,
    X: torch.Tensor,
    Y: torch.Tensor,
    noisedist: Distribution,
    noisetransform: callable,
    sig_noise_ratio: float,
    seed: int = None
) -> dict:
    """
    Adapted to work with a joint causal flow in models_dict.
    Expects models_dict = {'Flow': (flow, 'flow')},
             index_dict  = {'Flow': (idx, 'flow')}.
    Returns the same keys: KS_int, CF_RMSE, index under 'Flow'.
    """
    if seed is not None:
        torch.manual_seed(seed)

    device = X.device
    N, D = X.shape
    _, P = Y.shape
    assert D == 1 and P == 1  # 2D joint (X,Y)

    # “true” counterfactual shift ΔY = +1
    Z = torch.cat([X, Y], dim=1).to(device)             # shape (N,2)
    X_cf = X*0 + 1.0
    ΔY_true = Y - X + 1.0 

    # “true” interventional Y distribution:
    # Y_true = X_cf + U    with U = noisetransform(noisedist.sample)
    m = 10**5
    # sample X* from Normal(1,1) then +1, and add noise:
    Y_true = (
        Normal(1, 1).sample((m, D)).to(device)
        + 1.0
        + noisetransform(noisedist.sample((m, 1)).to(device))
    )  # shape (m,1)

    results = {}
    for name, (flow, _) in models_dict.items():

        # ---- interventional estimate via Alg 1 ----
        # sample m draws from the *joint* under do(X = 1)
        Y_do = sample_do(
            flow.to(device),
            index=0,
            intervention_fn=lambda old: old + 1.0,
            sample_shape=torch.Size([m])
        )  # shape (m, 2)
        Y_int = Y_do[:, 1].unsqueeze(-1)  # shape (m,1)
        KS_int = ks_statistic(Y_int[:,0].cpu(), Y_true[:,0].cpu())

        # ---- counterfactual via Alg 2 ----
        Z_cf = sample_cf(
            flow.to(device),
            x_obs=Z,
            index=0,
            intervention_fn=lambda old: old*0 + 1.0
        )  # shape (N,2)
        ΔY_model = (Z_cf[:,1] - Z[:,1]).unsqueeze(-1)  # (N,1)
        CF_RMSE = rmse(Z_cf[:,1].cpu(), ΔY_true[:,0].cpu())

        results[name] = {
            'KS_int':  KS_int,
            'CF_RMSE': CF_RMSE,
            'index': index_dict[name][0]
        }

    # add noise info if you like (mirroring your old script)
    results['noise_distribution'] = noisedist.__class__.__name__
    return results

In [6]:
seed = 0
N = 1000
noise_dist = "normal"

In [10]:
"""
Configs
"""
# Experimental set up
D,P = 1,1
sig_noise_ratio = 1

# Model setup
width = 32
bins = 8

"""
Data gen
"""
torch.manual_seed(seed)
X = Normal(1,1).sample((N,D))
X *= 1/(D)**0.5
B = torch.ones((D,1))*(torch.linspace(0,D-1,D)<P)[:,None]
F = X @ B
if noise_dist == "normal":
    noisedist = Normal(0,1)
    noisetransform = lambda x : x
elif noise_dist == "rademacher": 
    noisedist = Uniform(-1,1)
    noisetransform = lambda x : torch.sign(x)
elif noise_dist == "cauchy":
    noisedist = Cauchy(0,1)
    noisetransform = lambda x : x
elif noise_dist == "gamma":
    noisedist = Gamma(1,1)
    noisetransform = lambda x : x
elif noise_dist == "inversegamma":
    noisedist = Gamma(1,1)
    noisetransform = lambda x : 1/x
U = noisetransform(noisedist.sample((N,1)))/sig_noise_ratio**0.5
Y = F + U
XY = torch.cat([X,Y],dim = 1)

"""
Defining list of flows to CV over
"""

# shared base distribution
baseg = UnconditionalDistribution(
    Normal,
    loc=torch.zeros(2),
    scale=torch.ones(2),
    buffer=True,
)

basel = UnconditionalDistribution(
    Laplace,
    loc=torch.zeros(2),
    scale=torch.ones(2),
    buffer=True,
)


# 1) MAF only, shift linear conditioner
maf_shift_only = MaskedAutoregressiveTransform(
    features=2,
    context=0,
    hidden_features=(),   # you can also use () for purely linear shift
    univariate=ShiftTransform,  # use our shift‐only bijector
    shapes=([1],),              # one shift‐parameter per dimension
)
flow_shift_only_g = CausalFlow(
    transform=[maf_shift_only],
    base=baseg,
)
flow_shift_only_l = CausalFlow(
    transform=[maf_shift_only],
    base=basel,
)

# 3) MAF only, shift neural‑net conditioner
maf_nn_shift_only = MaskedAutoregressiveTransform(
    features=2,
    context=0,
    hidden_features=(width, width),   # you can also use () for purely linear shift
    univariate=ShiftTransform,  # use our shift‐only bijector
    shapes=([1],),              # one shift‐parameter per dimension
)
flow_nn_shift_only_g = CausalFlow(
    transform=[maf_nn_shift_only],
    base=baseg,
)
flow_nn_shift_only_l = CausalFlow(
    transform=[maf_nn_shift_only],
    base=basel,
)

# 3) MAF only, neural‑net conditioner
maf_nn = MaskedAutoregressiveTransform(
    features=2,
    context=0,
    hidden_features=(width, width),  # two hidden layers of size 32
)
flow_maf_nn_g = CausalFlow(transform=[maf_nn], base=baseg)
flow_maf_nn_l = CausalFlow(transform=[maf_nn], base=basel)

# 4) MAF → RQS, both neural‑net conditioners
maf_nn = MaskedAutoregressiveTransform(
    features=2,
    context=0,
    hidden_features=(width, width),
)
rqs_nn = MaskedAutoregressiveTransform(
    features=2,
    context=0,
    hidden_features=(width, width),
    univariate=MonotonicRQSTransform,
    shapes=([bins], [bins], [bins + 1]),
)
flow_maf_rqs_nn_g = CausalFlow(transform=[maf_nn, rqs_nn], base=baseg)
flow_maf_rqs_nn_l = CausalFlow(transform=[maf_nn, rqs_nn], base=basel)

# Gaussian + Laplace base flows
flows_g = [
flow_shift_only_g,
flow_nn_shift_only_g,
flow_maf_nn_g,
flow_maf_rqs_nn_g
    ]
flows_l = [
flow_shift_only_l,
flow_nn_shift_only_l,
flow_maf_nn_l,
flow_maf_rqs_nn_l
    ]


"""
Training
"""
# Gaussian‐base
best_g, test_nll_g, idx_g, cv_scores_g = select_and_train_flow(
    flows_g,
    XY,
    train_fraction=1.0,
    k_folds=2,
    num_epochs=500,
    batch_size=128,
    lr=1e-2,
    device=XY.device
)

# Laplace‐base
best_l, test_nll_l, idx_l, cv_scores_l = select_and_train_flow(
    flows_l,
    XY,
    train_fraction=1.0,
    k_folds=2,
    num_epochs=500,
    batch_size=128,
    lr=1e-2,
    device=XY.device
)

In [82]:
"""
Evaluating
"""
# 6) Wrap into the same my_models / my_indexes format
# --------------------------------------------------------------------------------
my_models = {
    'GaussianFlow': (best_g, 'flow'),
    'LaplaceFlow':  (best_l, 'flow'),
}
my_indexes = {
    'GaussianFlow': (idx_g, 'flow'),
    'LaplaceFlow':  (idx_l, 'flow'),
}

# 7) Evaluate with your original evaluate_models
# --------------------------------------------------------------------------------
metrics = evaluate_models(
    my_models,
    my_indexes,
    X, Y,
    noisedist,
    noisetransform,
    sig_noise_ratio,
    seed=seed
)
print(metrics)

{'GaussianFlow': {'KS_int': 0.012010008096694946, 'CF_RMSE': 0.024647604674100876, 'index': 0}, 'LaplaceFlow': {'KS_int': 0.11251002550125122, 'CF_RMSE': 0.2849721610546112, 'index': 2}, 'noise_distribution': 'Normal'}


In [63]:
from architectures import get_stock_transforms
from csuite import generate_2var_linear, _get_noises
def generate_2var_linear(
    N: int,
    seed: int | None = None,
    intervention_node: int | None = None,
    intervention_fn: callable = None,
    intervention_value: float | None = None,
    return_u: bool = False,
    noise_dists: list[Distribution] = None,
    noise_transforms: list[callable] = None,
):
    """
    2-VAR linear SCM:
      u1; x1 = u1
      u2; x2 = x1 + u1
    noise_dists: list of 2 Distribution objects for u1,u2.
    noise_transform: applied to each sampled noise.
    """
    if noise_dists is None or len(noise_dists) != 2:
        raise ValueError("noise_dists must be a list of length 2")
    u1, u2 = _get_noises(N, 2, noise_dists, noise_transforms, seed)
    x1 = u1.clone() + 1
    if intervention_node == 1:
        x1 = intervention_fn(x1, intervention_value)
    x2 = x1 + u2
    if intervention_node == 2:
        x2 = intervention_fn(x2, intervention_value)
    X = torch.cat([x1, x2], dim=1)
    if return_u:
        return X, torch.cat([u1, u2], dim=1)
    return X

def evaluate_models_2(
    models_dict: dict,
    index_dict: dict,
    sc_fun: callable,
    X: torch.Tensor,
    Y: torch.Tensor,
    noisedist: Distribution,
    noisetransform: callable,
    seed: int = None,
    N_true: int = 10**5,
    intervention_node: int = 1,
    intervention_value: float = 1.0,
) -> dict:
    """
    Evaluates each model by comparing its interventional, counterfactual,
    and paired-difference estimates to the ground-truth SCM (`sc_fun`).
    """
    if seed is not None:
        torch.manual_seed(seed)
    device = X.device
    N, d = X.shape
    N, p = Y.shape

    # --- True paired (Y, Y_cf) via SCM generator ---
    torch.manual_seed(seed)
    X_obs_true, U = sc_fun(
        N_true,
        seed=seed,
        intervention_node=None,
        return_u=True,
        noise_dists=[Normal(0,1)] + [noisedist],
        noise_transforms = [lambda x : x] + [noisetransform]
    )
    X_obs_true = X_obs_true.to(device)
    torch.manual_seed(seed)
    X_cf_true, Ucf = sc_fun(
        N_true,
        seed=seed,
        intervention_node=intervention_node,
        intervention_fn = lambda x,a : x+a,
        intervention_value=intervention_value,
        return_u=True,
        noise_dists=[Normal(0,1)] + [noisedist],
        noise_transforms = [lambda x : x] + [noisetransform]
    )
    X_cf_true = X_cf_true.to(device)
    # check same noise
    assert((U-Ucf).sum()==0)
    
    # Extract paired Y variables (columns 1 onward)
    Y_true = X_obs_true[:, 1:]
    Y_cf_true = X_cf_true[:, 1:]
    Y_dim = Y_true.shape[1]

    # paired-difference
    diff_true = Y_cf_true - Y_true

    results = {}
    with torch.no_grad():
        for name, (flow, _) in models_dict.items():
            # --- Model paired (Y, Y_cf) via sample_do ---
            torch.manual_seed(seed)
            X_model = sample_do(
                flow.to(device),
                index=intervention_node-1,
                intervention_fn=lambda old: old,
                sample_shape=torch.Size([N_true])
            )
            torch.manual_seed(seed)
            X_cf_model = sample_do(
                flow.to(device),
                index=intervention_node-1,
                intervention_fn=lambda old: old + intervention_value,
                sample_shape=torch.Size([N_true])
            )
            Y_model = X_model[:, 1:]
            Y_cf_model = X_cf_model[:, 1:]
            diff_model = Y_cf_model - Y_model
    
            # --- Marginal KS values for each Y dimension ---
            w1_vals = [
                wasserstein1_repeat(diff_model[:, j].cpu(), diff_true[:, j].cpu())
                for j in range(Y_dim)
            ]
    
            # --- Marginal KS values for each Y dimension ---
            ks_vals = [
                ks_statistic(diff_model[:, j].cpu(), diff_true[:, j].cpu())
                for j in range(Y_dim)
            ]
    
            # --- Interventional marginal KS ---
            # Reuse Y_cf_model for marginal KS
            w1_int_vals = [
                wasserstein1_repeat(Y_cf_model[:, j].cpu(), Y_cf_true[:, j].cpu())
                for j in range(Y_dim)
            ]
            ks_int_vals = [
                ks_statistic(Y_cf_model[:, j].cpu(), Y_cf_true[:, j].cpu())
                for j in range(Y_dim)
            ]
    
            # --- Counterfactual RMSE ---
            torch.manual_seed(seed)
            X_obs, U = sc_fun(
                N,
                seed=seed,
                intervention_node=None,
                return_u=True,
                noise_dists=[Normal(0,1)] + [noisedist],
                noise_transforms = [lambda x : x] + [noisetransform]
            )
            X_obs = X_obs.to(device)
            Y_obs = X_obs[:, 1:].to(device)
            torch.manual_seed(seed)
            X_cf, Ucf = sc_fun(
                N,
                seed=seed,
                intervention_node=intervention_node,
                intervention_fn = lambda x,a : x+a,
                intervention_value=intervention_value,
                return_u=True,
                noise_dists=[Normal(0,1)] + [noisedist],
                noise_transforms = [lambda x : x] + [noisetransform]
            )
            X_cf = X_cf.to(device)
            Y_cf = X_cf[:, 1:]
            
            Z_cf = sample_cf(
                flow.to(device),
                x_obs=X_obs,
                index=intervention_node - 1,
                intervention_fn=lambda old: old + intervention_value
            )
            rmse_cf_vals = [
                rmse(Z_cf[:, 1:][:,j].cpu(), Y_cf[:, j].cpu())
                for j in range(Y_dim)
            ]
            results[name] = {
                'KS_CF': ks_vals,
                'KS_int': ks_int_vals,
                'W1_CF': w1_vals,
                'W1_int': w1_int_vals,
                'RMSE_CF': rmse_cf_vals,
                'index': index_dict[name][0]
            }
    return results

In [83]:
metrics = evaluate_models_2(
    my_models,
    my_indexes,
    generate_2var_linear,
    X, Y,
    noisedist,
    noisetransform,
    seed = 0,
    N_true = 10**5,
    intervention_node = 1,
    intervention_value = 1.0,
)

print(metrics)

{'GaussianFlow': {'KS_CF': [1.0], 'KS_int': [0.01184999942779541], 'W1_CF': [tensor(0.0240)], 'W1_int': [tensor(0.0328)], 'RMSE_CF': [0.02396257221698761], 'index': 0}, 'LaplaceFlow': {'KS_CF': [0.7042700052261353], 'KS_int': [0.11078000068664551], 'W1_CF': [tensor(0.8956)], 'W1_int': [tensor(0.5245)], 'RMSE_CF': [0.35234159231185913], 'index': 2}}
