In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 1

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path         = csv_path
        self.batch_size       = batch_size
        self.sequence_length  = sequence_length
        self.treatment_lag    = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # --- Control treatment T0 ---
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # --- Treated treatment T1: independent additive component ---
        # Fix: T1 is NOT 1.5*T0 (that gives correlation=1.0)
        # Adding 2*hidden + noise breaks the perfect correlation
        np.random.seed(REPRODUCIBILITY_SEED)
        # --- Regime-Dependent Treated treatment T1 ---
        v0 = np.mean(vel) # Using the raw velocity baseline
        # High velocity results in a sigmoid approaching 1.0, scaling T0 by 2.0x
        # Low velocity results in a sigmoid approaching 0.0, scaling T0 by 1.5x
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1) # Range expanded

        # T1 is now a non-linear scaling of T0 based on current ice velocity
        #T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)
        #T1_np = T0_np + 2.0 * hidden + np.random.normal(0, 0.5, T0_np.shape)

        # --- Treatment lag (Fix: was duplicate T0 before) ---
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        # Scale and reshape
        scaler        = StandardScaler()
        X_scaled      = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # --- Counterfactual outcomes ---
        # Fix: delta now depends on (T1 - T0) so treatment has a real causal effect
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)

        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}  (target: < 1.0)")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices          = np.arange(len(self.xall))
        tr_idx, te_idx   = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include T1, Y0, Y1 for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )

        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        # Preserves T0/T1 values unchanged
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        # Concatenate treatment with covariates before encoding
        x_and_t = torch.cat([X, t], dim=-1)         # (B, T, dim_x+1)
        h, _    = self.encoder_rnn(x_and_t)          # (B, T, HIDDEN*2)

        mu     = self.fc_mu(h)                        # (B, T, LATENT)
        logvar = self.fc_logvar(h)                    # (B, T, LATENT)

        std = torch.exp(0.5 * logvar)
        z   = mu + torch.randn_like(mu) * std if mode == 'train' else mu

        # Project treatment and concatenate with latent
        t_emb  = self.t_proj(t)                       # (B, T, 16)
        z_and_t = torch.cat([z, t_emb], dim=-1)       # (B, T, LATENT+16)
        y_pred  = self.outcome_head(z_and_t)           # (B, T, 1)

        return y_pred, mu, logvar, z

# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )

    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)   # KL warmup
        mmd_scale  = min(1.0, epoch / 50.0)   # MMD warmup

        for batch_idx, (x_in, t_fact, y_fact) in enumerate(train_loader):
            x_in, t_fact, y_fact = (
                x_in.to(DEVICE), t_fact.to(DEVICE), y_fact.to(DEVICE)
            )
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t_fact, mode='train')

            # Reconstruction: factual outcome only
            loss_recon = F.mse_loss(y_pred, y_fact)

            # KL with warmup and clamp to prevent explosion
            loss_kl = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            # MMD: balance treated vs control latent representations
            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t_fact.view(-1) > t_median).float()
                z0     = z_flat[t_flat == 0]
                z1     = z_flat[t_flat == 1]

                # Print group sizes once per epoch for diagnosis
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0.size(0)}, z1={z1.size(0)}")

                loss_mmd = compute_mmd_stable(z0, z1) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)

        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")

# ============================================================
# 5. EVALUATION
# ============================================================

def evaluate(model, test_loader):
    model.eval()
    pehe_sum = 0.0
    count    = 0

    with torch.no_grad():
        for batch_idx, (x_in, t_f, t_c, y0_gt, y1_gt) in enumerate(test_loader):
            x_in = x_in.to(DEVICE)
            t_f  = t_f.to(DEVICE)   # T0: control treatment
            t_c  = t_c.to(DEVICE)   # T1: treated treatment

            # Predict outcome under T0 and T1
            p0, _, _, _ = model(x_in, t_f, mode='eval')
            p1, _, _, _ = model(x_in, t_c, mode='eval')

            ite_pred = p1 - p0
            ite_true = (y1_gt - y0_gt).to(DEVICE)

            # ITE diagnostic on first batch
            if batch_idx == 0:
                print(f"  p0 mean={p0.mean().item():.4f}  "
                      f"p1 mean={p1.mean().item():.4f}")
                print(f"  Pred ITE mean={ite_pred.mean().item():.4f}  "
                      f"True ITE mean={ite_true.mean().item():.4f}")

            pehe_sum += torch.sum((ite_pred - ite_true) ** 2).item()
            count    += x_in.size(0) * SEQUENCE_LENGTH

    return math.sqrt(pehe_sum / count)

# ============================================================
# 6. ABLATION BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    # Full ablation: MMD on/off
    configs = [True, False]
    results = []

    for use_mmd in configs:
        label = "MMD=True " if use_mmd else "MMD=False"
        print(f"\n{'='*55}")
        print(f"Training: {label}")
        print(f"{'='*55}")

        model = DCMVAE(use_mmd=use_mmd).to(DEVICE)
        train_model(model, train_loader, t_median)

        print(f"\nEvaluating {label}...")
        pehe = evaluate(model, test_loader)
        print(f"PEHE: {pehe:.4f}")
        results.append((label, pehe))

    print(f"\n{'='*55}")
    print("FINAL ABLATION RESULTS")
    print(f"{'='*55}")
    print(f"{'Config':<12} | {'Test PEHE'}")
    print("-" * 35)
    for label, pehe in sorted(results, key=lambda x: x[1]):
        marker = " ← best" if pehe == min(r[1] for r in results) else ""
        print(f"{label:<12} | {pehe:.4f}{marker}")

# ============================================================

if __name__ == "__main__":
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443  (target: < 1.0)

Training: MMD=True 
  [MMD] z0=418, z1=542
Epoch 000 | loss=10.1360 | recon=1.0490 kl=0.0055 mmd=0.000000
  [MMD] z0=439, z1=521
Epoch 030 | loss=9.7046 | recon=0.9419 kl=0.0377 mmd=0.000001
  [MMD] z0=453, z1=507
Epoch 060 | loss=9.8188 | recon=0.9927 kl=0.1189 mmd=0.000001
  [MMD] z0=501, z1=459
Epoch 090 | loss=10.2375 | recon=1.1260 kl=0.1378 mmd=0.000002
  [MMD] z0=482, z1=478
Epoch 120 | loss=9.9621 | recon=1.0302 kl=0.1528 mmd=0.000002
  [MMD] z0=509, z1=451
Epoch 149 | loss=9.8733 | recon=1.0181 kl=0.1580 mmd=0.000001

Evaluating MMD=True ...
  p0 mean=0.0071  p1 mean=0.0075
  Pred ITE mean=0.0004  True ITE mean=0.6843
PEHE: 3.0883

Training: MMD=False
Epoch 000 | loss=10.0251 | recon=1.0152 kl=0.0056 mmd=0.000000
Epoch 030 | loss=10.0002 | recon=1.0277 kl=0.0150 mmd=0.00

In [None]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 3

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path         = csv_path
        self.batch_size       = batch_size
        self.sequence_length  = sequence_length
        self.treatment_lag    = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # --- Control treatment T0 ---
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # --- Treated treatment T1: independent additive component ---
        # Fix: T1 is NOT 1.5*T0 (that gives correlation=1.0)
        # Adding 2*hidden + noise breaks the perfect correlation
        np.random.seed(REPRODUCIBILITY_SEED)
        # --- Regime-Dependent Treated treatment T1 ---
        v0 = np.mean(vel) # Using the raw velocity baseline
        # High velocity results in a sigmoid approaching 1.0, scaling T0 by 2.0x
        # Low velocity results in a sigmoid approaching 0.0, scaling T0 by 1.5x
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1) # Range expanded

        # T1 is now a non-linear scaling of T0 based on current ice velocity
        #T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)
        #T1_np = T0_np + 2.0 * hidden + np.random.normal(0, 0.5, T0_np.shape)

        # --- Treatment lag (Fix: was duplicate T0 before) ---
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        # Scale and reshape
        scaler        = StandardScaler()
        X_scaled      = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # --- Counterfactual outcomes ---
        # Fix: delta now depends on (T1 - T0) so treatment has a real causal effect
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)

        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}  (target: < 1.0)")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices          = np.arange(len(self.xall))
        tr_idx, te_idx   = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include T1, Y0, Y1 for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )

        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        # Preserves T0/T1 values unchanged
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        # Concatenate treatment with covariates before encoding
        x_and_t = torch.cat([X, t], dim=-1)         # (B, T, dim_x+1)
        h, _    = self.encoder_rnn(x_and_t)          # (B, T, HIDDEN*2)

        mu     = self.fc_mu(h)                        # (B, T, LATENT)
        logvar = self.fc_logvar(h)                    # (B, T, LATENT)

        std = torch.exp(0.5 * logvar)
        z   = mu + torch.randn_like(mu) * std if mode == 'train' else mu

        # Project treatment and concatenate with latent
        t_emb  = self.t_proj(t)                       # (B, T, 16)
        z_and_t = torch.cat([z, t_emb], dim=-1)       # (B, T, LATENT+16)
        y_pred  = self.outcome_head(z_and_t)           # (B, T, 1)

        return y_pred, mu, logvar, z

# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )

    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)   # KL warmup
        mmd_scale  = min(1.0, epoch / 50.0)   # MMD warmup

        for batch_idx, (x_in, t_fact, y_fact) in enumerate(train_loader):
            x_in, t_fact, y_fact = (
                x_in.to(DEVICE), t_fact.to(DEVICE), y_fact.to(DEVICE)
            )
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t_fact, mode='train')

            # Reconstruction: factual outcome only
            loss_recon = F.mse_loss(y_pred, y_fact)

            # KL with warmup and clamp to prevent explosion
            loss_kl = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            # MMD: balance treated vs control latent representations
            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t_fact.view(-1) > t_median).float()
                z0     = z_flat[t_flat == 0]
                z1     = z_flat[t_flat == 1]

                # Print group sizes once per epoch for diagnosis
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0.size(0)}, z1={z1.size(0)}")

                loss_mmd = compute_mmd_stable(z0, z1) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)

        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")

# ============================================================
# 5. EVALUATION
# ============================================================

def evaluate(model, test_loader):
    model.eval()
    pehe_sum = 0.0
    count    = 0

    with torch.no_grad():
        for batch_idx, (x_in, t_f, t_c, y0_gt, y1_gt) in enumerate(test_loader):
            x_in = x_in.to(DEVICE)
            t_f  = t_f.to(DEVICE)   # T0: control treatment
            t_c  = t_c.to(DEVICE)   # T1: treated treatment

            # Predict outcome under T0 and T1
            p0, _, _, _ = model(x_in, t_f, mode='eval')
            p1, _, _, _ = model(x_in, t_c, mode='eval')

            ite_pred = p1 - p0
            ite_true = (y1_gt - y0_gt).to(DEVICE)

            # ITE diagnostic on first batch
            if batch_idx == 0:
                print(f"  p0 mean={p0.mean().item():.4f}  "
                      f"p1 mean={p1.mean().item():.4f}")
                print(f"  Pred ITE mean={ite_pred.mean().item():.4f}  "
                      f"True ITE mean={ite_true.mean().item():.4f}")

            pehe_sum += torch.sum((ite_pred - ite_true) ** 2).item()
            count    += x_in.size(0) * SEQUENCE_LENGTH

    return math.sqrt(pehe_sum / count)

# ============================================================
# 6. ABLATION BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    # Full ablation: MMD on/off
    configs = [True, False]
    results = []

    for use_mmd in configs:
        label = "MMD=True " if use_mmd else "MMD=False"
        print(f"\n{'='*55}")
        print(f"Training: {label}")
        print(f"{'='*55}")

        model = DCMVAE(use_mmd=use_mmd).to(DEVICE)
        train_model(model, train_loader, t_median)

        print(f"\nEvaluating {label}...")
        pehe = evaluate(model, test_loader)
        print(f"PEHE: {pehe:.4f}")
        results.append((label, pehe))

    print(f"\n{'='*55}")
    print("FINAL ABLATION RESULTS")
    print(f"{'='*55}")
    print(f"{'Config':<12} | {'Test PEHE'}")
    print("-" * 35)
    for label, pehe in sorted(results, key=lambda x: x[1]):
        marker = " ← best" if pehe == min(r[1] for r in results) else ""
        print(f"{label:<12} | {pehe:.4f}{marker}")

# ============================================================

if __name__ == "__main__":
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443  (target: < 1.0)

Training: MMD=True 
  [MMD] z0=418, z1=542
Epoch 000 | loss=10.1363 | recon=1.0491 kl=0.0055 mmd=0.000000
  [MMD] z0=439, z1=521
Epoch 030 | loss=9.7016 | recon=0.9416 kl=0.0388 mmd=0.000001
  [MMD] z0=453, z1=507
Epoch 060 | loss=9.8095 | recon=0.9924 kl=0.1246 mmd=0.000001
  [MMD] z0=501, z1=459
Epoch 090 | loss=10.2231 | recon=1.1245 kl=0.1293 mmd=0.000002
  [MMD] z0=482, z1=478
Epoch 120 | loss=9.9424 | recon=1.0275 kl=0.1523 mmd=0.000001
  [MMD] z0=509, z1=451
Epoch 149 | loss=9.8597 | recon=1.0165 kl=0.1527 mmd=0.000001

Evaluating MMD=True ...
  p0 mean=0.0081  p1 mean=0.0080
  Pred ITE mean=-0.0001  True ITE mean=0.6843
PEHE: 3.0810

Training: MMD=False
Epoch 000 | loss=10.0249 | recon=1.0152 kl=0.0055 mmd=0.000000
Epoch 030 | loss=9.9999 | recon=1.0276 kl=0.0159 mmd=0.00

In [None]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 6

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path         = csv_path
        self.batch_size       = batch_size
        self.sequence_length  = sequence_length
        self.treatment_lag    = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # --- Control treatment T0 ---
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # --- Treated treatment T1: independent additive component ---
        # Fix: T1 is NOT 1.5*T0 (that gives correlation=1.0)
        # Adding 2*hidden + noise breaks the perfect correlation
        np.random.seed(REPRODUCIBILITY_SEED)
        # --- Regime-Dependent Treated treatment T1 ---
        v0 = np.mean(vel) # Using the raw velocity baseline
        # High velocity results in a sigmoid approaching 1.0, scaling T0 by 2.0x
        # Low velocity results in a sigmoid approaching 0.0, scaling T0 by 1.5x
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1) # Range expanded

        # T1 is now a non-linear scaling of T0 based on current ice velocity
        #T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)
        #T1_np = T0_np + 2.0 * hidden + np.random.normal(0, 0.5, T0_np.shape)

        # --- Treatment lag (Fix: was duplicate T0 before) ---
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        # Scale and reshape
        scaler        = StandardScaler()
        X_scaled      = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # --- Counterfactual outcomes ---
        # Fix: delta now depends on (T1 - T0) so treatment has a real causal effect
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)

        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}  (target: < 1.0)")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices          = np.arange(len(self.xall))
        tr_idx, te_idx   = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include T1, Y0, Y1 for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )

        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        # Preserves T0/T1 values unchanged
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        # Concatenate treatment with covariates before encoding
        x_and_t = torch.cat([X, t], dim=-1)         # (B, T, dim_x+1)
        h, _    = self.encoder_rnn(x_and_t)          # (B, T, HIDDEN*2)

        mu     = self.fc_mu(h)                        # (B, T, LATENT)
        logvar = self.fc_logvar(h)                    # (B, T, LATENT)

        std = torch.exp(0.5 * logvar)
        z   = mu + torch.randn_like(mu) * std if mode == 'train' else mu

        # Project treatment and concatenate with latent
        t_emb  = self.t_proj(t)                       # (B, T, 16)
        z_and_t = torch.cat([z, t_emb], dim=-1)       # (B, T, LATENT+16)
        y_pred  = self.outcome_head(z_and_t)           # (B, T, 1)

        return y_pred, mu, logvar, z

# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )

    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)   # KL warmup
        mmd_scale  = min(1.0, epoch / 50.0)   # MMD warmup

        for batch_idx, (x_in, t_fact, y_fact) in enumerate(train_loader):
            x_in, t_fact, y_fact = (
                x_in.to(DEVICE), t_fact.to(DEVICE), y_fact.to(DEVICE)
            )
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t_fact, mode='train')

            # Reconstruction: factual outcome only
            loss_recon = F.mse_loss(y_pred, y_fact)

            # KL with warmup and clamp to prevent explosion
            loss_kl = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            # MMD: balance treated vs control latent representations
            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t_fact.view(-1) > t_median).float()
                z0     = z_flat[t_flat == 0]
                z1     = z_flat[t_flat == 1]

                # Print group sizes once per epoch for diagnosis
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0.size(0)}, z1={z1.size(0)}")

                loss_mmd = compute_mmd_stable(z0, z1) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)

        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")

# ============================================================
# 5. EVALUATION
# ============================================================

def evaluate(model, test_loader):
    model.eval()
    pehe_sum = 0.0
    count    = 0

    with torch.no_grad():
        for batch_idx, (x_in, t_f, t_c, y0_gt, y1_gt) in enumerate(test_loader):
            x_in = x_in.to(DEVICE)
            t_f  = t_f.to(DEVICE)   # T0: control treatment
            t_c  = t_c.to(DEVICE)   # T1: treated treatment

            # Predict outcome under T0 and T1
            p0, _, _, _ = model(x_in, t_f, mode='eval')
            p1, _, _, _ = model(x_in, t_c, mode='eval')

            ite_pred = p1 - p0
            ite_true = (y1_gt - y0_gt).to(DEVICE)

            # ITE diagnostic on first batch
            if batch_idx == 0:
                print(f"  p0 mean={p0.mean().item():.4f}  "
                      f"p1 mean={p1.mean().item():.4f}")
                print(f"  Pred ITE mean={ite_pred.mean().item():.4f}  "
                      f"True ITE mean={ite_true.mean().item():.4f}")

            pehe_sum += torch.sum((ite_pred - ite_true) ** 2).item()
            count    += x_in.size(0) * SEQUENCE_LENGTH

    return math.sqrt(pehe_sum / count)

# ============================================================
# 6. ABLATION BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    # Full ablation: MMD on/off
    configs = [True, False]
    results = []

    for use_mmd in configs:
        label = "MMD=True " if use_mmd else "MMD=False"
        print(f"\n{'='*55}")
        print(f"Training: {label}")
        print(f"{'='*55}")

        model = DCMVAE(use_mmd=use_mmd).to(DEVICE)
        train_model(model, train_loader, t_median)

        print(f"\nEvaluating {label}...")
        pehe = evaluate(model, test_loader)
        print(f"PEHE: {pehe:.4f}")
        results.append((label, pehe))

    print(f"\n{'='*55}")
    print("FINAL ABLATION RESULTS")
    print(f"{'='*55}")
    print(f"{'Config':<12} | {'Test PEHE'}")
    print("-" * 35)
    for label, pehe in sorted(results, key=lambda x: x[1]):
        marker = " ← best" if pehe == min(r[1] for r in results) else ""
        print(f"{label:<12} | {pehe:.4f}{marker}")

# ============================================================

if __name__ == "__main__":
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443  (target: < 1.0)

Training: MMD=True 
  [MMD] z0=418, z1=542
Epoch 000 | loss=10.1364 | recon=1.0492 kl=0.0055 mmd=0.000000
  [MMD] z0=439, z1=521
Epoch 030 | loss=9.6995 | recon=0.9412 kl=0.0403 mmd=0.000001
  [MMD] z0=453, z1=507
Epoch 060 | loss=9.8161 | recon=0.9927 kl=0.1363 mmd=0.000001
  [MMD] z0=501, z1=459
Epoch 090 | loss=10.2237 | recon=1.1241 kl=0.1346 mmd=0.000001
  [MMD] z0=482, z1=478
Epoch 120 | loss=9.9390 | recon=1.0272 kl=0.1626 mmd=0.000001
  [MMD] z0=509, z1=451
Epoch 149 | loss=9.8553 | recon=1.0153 kl=0.1617 mmd=0.000001

Evaluating MMD=True ...
  p0 mean=0.0009  p1 mean=-0.0010
  Pred ITE mean=-0.0019  True ITE mean=0.6843
PEHE: 3.0819

Training: MMD=False
Epoch 000 | loss=10.0247 | recon=1.0151 kl=0.0055 mmd=0.000000
Epoch 030 | loss=9.9966 | recon=1.0274 kl=0.0157 mmd=0.0

In [None]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 9

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path         = csv_path
        self.batch_size       = batch_size
        self.sequence_length  = sequence_length
        self.treatment_lag    = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # --- Control treatment T0 ---
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # --- Treated treatment T1: independent additive component ---
        # Fix: T1 is NOT 1.5*T0 (that gives correlation=1.0)
        # Adding 2*hidden + noise breaks the perfect correlation
        np.random.seed(REPRODUCIBILITY_SEED)
        # --- Regime-Dependent Treated treatment T1 ---
        v0 = np.mean(vel) # Using the raw velocity baseline
        # High velocity results in a sigmoid approaching 1.0, scaling T0 by 2.0x
        # Low velocity results in a sigmoid approaching 0.0, scaling T0 by 1.5x
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1) # Range expanded

        # T1 is now a non-linear scaling of T0 based on current ice velocity
        #T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)
        #T1_np = T0_np + 2.0 * hidden + np.random.normal(0, 0.5, T0_np.shape)

        # --- Treatment lag (Fix: was duplicate T0 before) ---
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        # Scale and reshape
        scaler        = StandardScaler()
        X_scaled      = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # --- Counterfactual outcomes ---
        # Fix: delta now depends on (T1 - T0) so treatment has a real causal effect
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)

        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}  (target: < 1.0)")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices          = np.arange(len(self.xall))
        tr_idx, te_idx   = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include T1, Y0, Y1 for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )

        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        # Preserves T0/T1 values unchanged
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        # Concatenate treatment with covariates before encoding
        x_and_t = torch.cat([X, t], dim=-1)         # (B, T, dim_x+1)
        h, _    = self.encoder_rnn(x_and_t)          # (B, T, HIDDEN*2)

        mu     = self.fc_mu(h)                        # (B, T, LATENT)
        logvar = self.fc_logvar(h)                    # (B, T, LATENT)

        std = torch.exp(0.5 * logvar)
        z   = mu + torch.randn_like(mu) * std if mode == 'train' else mu

        # Project treatment and concatenate with latent
        t_emb  = self.t_proj(t)                       # (B, T, 16)
        z_and_t = torch.cat([z, t_emb], dim=-1)       # (B, T, LATENT+16)
        y_pred  = self.outcome_head(z_and_t)           # (B, T, 1)

        return y_pred, mu, logvar, z

# ============================================================
# 4. TRAINING
# ============================================================

def train_model(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )

    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)   # KL warmup
        mmd_scale  = min(1.0, epoch / 50.0)   # MMD warmup

        for batch_idx, (x_in, t_fact, y_fact) in enumerate(train_loader):
            x_in, t_fact, y_fact = (
                x_in.to(DEVICE), t_fact.to(DEVICE), y_fact.to(DEVICE)
            )
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t_fact, mode='train')

            # Reconstruction: factual outcome only
            loss_recon = F.mse_loss(y_pred, y_fact)

            # KL with warmup and clamp to prevent explosion
            loss_kl = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            # MMD: balance treated vs control latent representations
            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t_fact.view(-1) > t_median).float()
                z0     = z_flat[t_flat == 0]
                z1     = z_flat[t_flat == 1]

                # Print group sizes once per epoch for diagnosis
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0.size(0)}, z1={z1.size(0)}")

                loss_mmd = compute_mmd_stable(z0, z1) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)

        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")

# ============================================================
# 5. EVALUATION
# ============================================================

def evaluate(model, test_loader):
    model.eval()
    pehe_sum = 0.0
    count    = 0

    with torch.no_grad():
        for batch_idx, (x_in, t_f, t_c, y0_gt, y1_gt) in enumerate(test_loader):
            x_in = x_in.to(DEVICE)
            t_f  = t_f.to(DEVICE)   # T0: control treatment
            t_c  = t_c.to(DEVICE)   # T1: treated treatment

            # Predict outcome under T0 and T1
            p0, _, _, _ = model(x_in, t_f, mode='eval')
            p1, _, _, _ = model(x_in, t_c, mode='eval')

            ite_pred = p1 - p0
            ite_true = (y1_gt - y0_gt).to(DEVICE)

            # ITE diagnostic on first batch
            if batch_idx == 0:
                print(f"  p0 mean={p0.mean().item():.4f}  "
                      f"p1 mean={p1.mean().item():.4f}")
                print(f"  Pred ITE mean={ite_pred.mean().item():.4f}  "
                      f"True ITE mean={ite_true.mean().item():.4f}")

            pehe_sum += torch.sum((ite_pred - ite_true) ** 2).item()
            count    += x_in.size(0) * SEQUENCE_LENGTH

    return math.sqrt(pehe_sum / count)

# ============================================================
# 6. ABLATION BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    # Full ablation: MMD on/off
    configs = [True, False]
    results = []

    for use_mmd in configs:
        label = "MMD=True " if use_mmd else "MMD=False"
        print(f"\n{'='*55}")
        print(f"Training: {label}")
        print(f"{'='*55}")

        model = DCMVAE(use_mmd=use_mmd).to(DEVICE)
        train_model(model, train_loader, t_median)

        print(f"\nEvaluating {label}...")
        pehe = evaluate(model, test_loader)
        print(f"PEHE: {pehe:.4f}")
        results.append((label, pehe))

    print(f"\n{'='*55}")
    print("FINAL ABLATION RESULTS")
    print(f"{'='*55}")
    print(f"{'Config':<12} | {'Test PEHE'}")
    print("-" * 35)
    for label, pehe in sorted(results, key=lambda x: x[1]):
        marker = " ← best" if pehe == min(r[1] for r in results) else ""
        print(f"{label:<12} | {pehe:.4f}{marker}")

# ============================================================

if __name__ == "__main__":
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443  (target: < 1.0)

Training: MMD=True 
  [MMD] z0=418, z1=542
Epoch 000 | loss=10.1361 | recon=1.0491 kl=0.0055 mmd=0.000000
  [MMD] z0=439, z1=521
Epoch 030 | loss=9.6986 | recon=0.9412 kl=0.0380 mmd=0.000001
  [MMD] z0=453, z1=507
Epoch 060 | loss=9.8087 | recon=0.9922 kl=0.1193 mmd=0.000001
  [MMD] z0=501, z1=459
Epoch 090 | loss=10.2201 | recon=1.1230 kl=0.1224 mmd=0.000002
  [MMD] z0=482, z1=478
Epoch 120 | loss=9.9264 | recon=1.0252 kl=0.1480 mmd=0.000001
  [MMD] z0=509, z1=451
Epoch 149 | loss=9.8448 | recon=1.0133 kl=0.1438 mmd=0.000001

Evaluating MMD=True ...
  p0 mean=0.0049  p1 mean=0.0025
  Pred ITE mean=-0.0024  True ITE mean=0.6843
PEHE: 3.0795

Training: MMD=False
Epoch 000 | loss=10.0251 | recon=1.0152 kl=0.0054 mmd=0.000000
Epoch 030 | loss=9.9966 | recon=1.0277 kl=0.0146 mmd=0.00