In [24]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [25]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 3

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class Baseline1(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder_rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)

        # Decoder
        self.decoder_rnn = nn.GRU(1, HIDDEN_DIM, batch_first=True) # Inputs future treatment

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        # x: (batch, seq, 3), t: (batch, seq, 1), t_lag: (batch, seq, 1)
        _, h_n = self.encoder_rnn(torch.cat([x, t_lag], dim=2))

        # Decoder uses final hidden state (h_n) and future treatments (t)
        # Fix: If h_n is unexpectedly a tuple (e.g., from an LSTM output), take the first element.
        # Although encoder_rnn is a GRU and should return a tensor, the error indicates h_n is a tuple.
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n
        z_seq, _ = self.decoder_rnn(t, h_n_for_decoder)

        t_emb = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once
        _, h_n = self.encoder_rnn(torch.cat([x, t0_lag], dim=2))

        # Apply the same fix as in forward method for h_n
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n

        # Decode for treatment path 0
        z0, _ = self.decoder_rnn(t0, h_n_for_decoder)
        p0 = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Decode for treatment path 1
        z1, _ = self.decoder_rnn(t1, h_n_for_decoder)
        p1 = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

class Baseline2(nn.Module):
    def __init__(self, n_post=24, nb_features=4, output_dim=1, n_hidden=HIDDEN_DIM):
        super().__init__()
        self.n_post = n_post

        # Encoder: LSTM
        self.encoder_rnn = nn.LSTM(nb_features, n_hidden, batch_first=True)

        # Decoder: GRU
        self.decoder_rnn = nn.GRU(n_hidden, n_hidden, batch_first=True)

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(n_hidden + 16, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x, t, t_lag):
        # 1. Encode History (x: 3 dims + t_lag: 1 dim = 4 features)
        history = torch.cat([x, t_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)   # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]                   # (batch, n_hidden)

        # 2. Repeat context vector across sequence length
        batch_size, seq_len, _ = t.size()
        repeat_vector = context_vector.unsqueeze(1).repeat(1, seq_len, 1)

        # 3. Decode
        z_seq, _ = self.decoder_rnn(repeat_vector)

        # 4. Predict
        t_emb  = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))

        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once using factual lag
        history = torch.cat([x, t0_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)    # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]

        # Path 0 (factual)
        rep0   = context_vector.unsqueeze(1).repeat(1, t0.size(1), 1)
        z0, _  = self.decoder_rnn(rep0)
        p0     = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Path 1 (counterfactual)
        rep1   = context_vector.unsqueeze(1).repeat(1, t1.size(1), 1)
        z1, _  = self.decoder_rnn(rep1)
        p1     = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class Baseline3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.LSTM(4, HIDDEN_DIM, batch_first=True)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))  # LSTM: _ = (h_n, c_n)
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _ = self.rnn(torch.cat([x, t0_lag], dim=2))     # LSTM: _ = (h_n, c_n)
        p0   = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1   = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'Baseline1'    : Baseline1().to(DEVICE),
        'Baseline2'   : Baseline2().to(DEVICE),
        'Baseline3': Baseline3().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}       ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()

DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=436, z1=524
  Epoch 000 | loss=9.9177 | recon=0.9821 kl=0.0058 mmd=0.000000
  [MMD] z0=472, z1=488
  Epoch 030 | loss=9.9863 | recon=1.0345 kl=0.0348 mmd=0.000001
  [MMD] z0=447, z1=513
  Epoch 060 | loss=9.5032 | recon=0.8884 kl=0.0772 mmd=0.000001
  [MMD] z0=528, z1=432
  Epoch 090 | loss=9.9871 | recon=1.0476 kl=0.1439 mmd=0.000002
  [MMD] z0=494, z1=466
  Epoch 120 | loss=9.8331 | recon=1.0022 kl=0.1120 mmd=0.000001
  [MMD] z0=498, z1=462
  Epoch 149 | loss=9.8227 | recon=0.9973 kl=0.1202 mmd=0.000002

Training: Baseline1


Baseline1: 100%|██████████| 100/100 [00:11<00:00,  8.88it/s]



Training: Baseline2


Baseline2: 100%|██████████| 100/100 [00:09<00:00, 10.00it/s]



Training: Baseline3


Baseline3: 100%|██████████| 100/100 [00:07<00:00, 13.58it/s]


Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9883       | 3.0804       
Baseline1       | 0.9814       | 3.1353       
Baseline2       | 0.9933       | 3.2119       
Baseline3       | 0.9994       | 3.1660       





In [26]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 6

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class Baseline1(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder_rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)

        # Decoder
        self.decoder_rnn = nn.GRU(1, HIDDEN_DIM, batch_first=True) # Inputs future treatment

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        # x: (batch, seq, 3), t: (batch, seq, 1), t_lag: (batch, seq, 1)
        _, h_n = self.encoder_rnn(torch.cat([x, t_lag], dim=2))

        # Decoder uses final hidden state (h_n) and future treatments (t)
        # Fix: If h_n is unexpectedly a tuple (e.g., from an LSTM output), take the first element.
        # Although encoder_rnn is a GRU and should return a tensor, the error indicates h_n is a tuple.
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n
        z_seq, _ = self.decoder_rnn(t, h_n_for_decoder)

        t_emb = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once
        _, h_n = self.encoder_rnn(torch.cat([x, t0_lag], dim=2))

        # Apply the same fix as in forward method for h_n
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n

        # Decode for treatment path 0
        z0, _ = self.decoder_rnn(t0, h_n_for_decoder)
        p0 = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Decode for treatment path 1
        z1, _ = self.decoder_rnn(t1, h_n_for_decoder)
        p1 = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

class Baseline2(nn.Module):
    def __init__(self, n_post=24, nb_features=4, output_dim=1, n_hidden=HIDDEN_DIM):
        super().__init__()
        self.n_post = n_post

        # Encoder: LSTM
        self.encoder_rnn = nn.LSTM(nb_features, n_hidden, batch_first=True)

        # Decoder: GRU
        self.decoder_rnn = nn.GRU(n_hidden, n_hidden, batch_first=True)

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(n_hidden + 16, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x, t, t_lag):
        # 1. Encode History (x: 3 dims + t_lag: 1 dim = 4 features)
        history = torch.cat([x, t_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)   # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]                   # (batch, n_hidden)

        # 2. Repeat context vector across sequence length
        batch_size, seq_len, _ = t.size()
        repeat_vector = context_vector.unsqueeze(1).repeat(1, seq_len, 1)

        # 3. Decode
        z_seq, _ = self.decoder_rnn(repeat_vector)

        # 4. Predict
        t_emb  = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))

        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once using factual lag
        history = torch.cat([x, t0_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)    # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]

        # Path 0 (factual)
        rep0   = context_vector.unsqueeze(1).repeat(1, t0.size(1), 1)
        z0, _  = self.decoder_rnn(rep0)
        p0     = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Path 1 (counterfactual)
        rep1   = context_vector.unsqueeze(1).repeat(1, t1.size(1), 1)
        z1, _  = self.decoder_rnn(rep1)
        p1     = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class Baseline3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.LSTM(4, HIDDEN_DIM, batch_first=True)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))  # LSTM: _ = (h_n, c_n)
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _ = self.rnn(torch.cat([x, t0_lag], dim=2))     # LSTM: _ = (h_n, c_n)
        p0   = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1   = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'Baseline1'    : Baseline1().to(DEVICE),
        'Baseline2'   : Baseline2().to(DEVICE),
        'Baseline3': Baseline3().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}       ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()

DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=436, z1=524
  Epoch 000 | loss=9.9184 | recon=0.9822 kl=0.0058 mmd=0.000000
  [MMD] z0=472, z1=488
  Epoch 030 | loss=9.9840 | recon=1.0343 kl=0.0357 mmd=0.000001
  [MMD] z0=447, z1=513
  Epoch 060 | loss=9.4980 | recon=0.8875 kl=0.0850 mmd=0.000001
  [MMD] z0=528, z1=432
  Epoch 090 | loss=9.9922 | recon=1.0483 kl=0.1582 mmd=0.000002
  [MMD] z0=494, z1=466
  Epoch 120 | loss=9.8228 | recon=1.0004 kl=0.1181 mmd=0.000001
  [MMD] z0=498, z1=462
  Epoch 149 | loss=9.8191 | recon=0.9969 kl=0.1279 mmd=0.000002

Training: Baseline1


Baseline1: 100%|██████████| 100/100 [00:11<00:00,  8.73it/s]



Training: Baseline2


Baseline2: 100%|██████████| 100/100 [00:11<00:00,  8.65it/s]



Training: Baseline3


Baseline3: 100%|██████████| 100/100 [00:07<00:00, 13.57it/s]


Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9884       | 3.0816       
Baseline1       | 0.9792       | 3.1583       
Baseline2       | 0.9928       | 3.2136       
Baseline3       | 1.0004       | 3.1683       





In [27]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 9

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class Baseline1(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder_rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)

        # Decoder
        self.decoder_rnn = nn.GRU(1, HIDDEN_DIM, batch_first=True) # Inputs future treatment

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        # x: (batch, seq, 3), t: (batch, seq, 1), t_lag: (batch, seq, 1)
        _, h_n = self.encoder_rnn(torch.cat([x, t_lag], dim=2))

        # Decoder uses final hidden state (h_n) and future treatments (t)
        # Fix: If h_n is unexpectedly a tuple (e.g., from an LSTM output), take the first element.
        # Although encoder_rnn is a GRU and should return a tensor, the error indicates h_n is a tuple.
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n
        z_seq, _ = self.decoder_rnn(t, h_n_for_decoder)

        t_emb = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once
        _, h_n = self.encoder_rnn(torch.cat([x, t0_lag], dim=2))

        # Apply the same fix as in forward method for h_n
        if isinstance(h_n, tuple):
            h_n_for_decoder = h_n[0]
        else:
            h_n_for_decoder = h_n

        # Decode for treatment path 0
        z0, _ = self.decoder_rnn(t0, h_n_for_decoder)
        p0 = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Decode for treatment path 1
        z1, _ = self.decoder_rnn(t1, h_n_for_decoder)
        p1 = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

class Baseline2(nn.Module):
    def __init__(self, n_post=24, nb_features=4, output_dim=1, n_hidden=HIDDEN_DIM):
        super().__init__()
        self.n_post = n_post

        # Encoder: LSTM
        self.encoder_rnn = nn.LSTM(nb_features, n_hidden, batch_first=True)

        # Decoder: GRU
        self.decoder_rnn = nn.GRU(n_hidden, n_hidden, batch_first=True)

        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())

        self.outcome_head = nn.Sequential(
            nn.Linear(n_hidden + 16, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x, t, t_lag):
        # 1. Encode History (x: 3 dims + t_lag: 1 dim = 4 features)
        history = torch.cat([x, t_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)   # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]                   # (batch, n_hidden)

        # 2. Repeat context vector across sequence length
        batch_size, seq_len, _ = t.size()
        repeat_vector = context_vector.unsqueeze(1).repeat(1, seq_len, 1)

        # 3. Decode
        z_seq, _ = self.decoder_rnn(repeat_vector)

        # 4. Predict
        t_emb  = self.t_proj(t)
        y_pred = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))

        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        # Encode history once using factual lag
        history = torch.cat([x, t0_lag], dim=2)
        _, (h_n, _) = self.encoder_rnn(history)    # LSTM returns (output, (h_n, c_n))
        context_vector = h_n[-1]

        # Path 0 (factual)
        rep0   = context_vector.unsqueeze(1).repeat(1, t0.size(1), 1)
        z0, _  = self.decoder_rnn(rep0)
        p0     = self.outcome_head(torch.cat([z0, self.t_proj(t0)], dim=-1))

        # Path 1 (counterfactual)
        rep1   = context_vector.unsqueeze(1).repeat(1, t1.size(1), 1)
        z1, _  = self.decoder_rnn(rep1)
        p1     = self.outcome_head(torch.cat([z1, self.t_proj(t1)], dim=-1))

        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class Baseline3(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.LSTM(4, HIDDEN_DIM, batch_first=True)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))  # LSTM: _ = (h_n, c_n)
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _ = self.rnn(torch.cat([x, t0_lag], dim=2))     # LSTM: _ = (h_n, c_n)
        p0   = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1   = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'Baseline1'    : Baseline1().to(DEVICE),
        'Baseline2'   : Baseline2().to(DEVICE),
        'Baseline3': Baseline3().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}       ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()

DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=436, z1=524
  Epoch 000 | loss=9.9192 | recon=0.9823 kl=0.0058 mmd=0.000000
  [MMD] z0=472, z1=488
  Epoch 030 | loss=9.9809 | recon=1.0341 kl=0.0360 mmd=0.000001
  [MMD] z0=447, z1=513
  Epoch 060 | loss=9.4974 | recon=0.8875 kl=0.0813 mmd=0.000001
  [MMD] z0=528, z1=432
  Epoch 090 | loss=9.9822 | recon=1.0465 kl=0.1524 mmd=0.000002
  [MMD] z0=494, z1=466
  Epoch 120 | loss=9.8183 | recon=0.9997 kl=0.1146 mmd=0.000001
  [MMD] z0=498, z1=462
  Epoch 149 | loss=9.8136 | recon=0.9955 kl=0.1261 mmd=0.000002

Training: Baseline1


Baseline1: 100%|██████████| 100/100 [00:11<00:00,  8.80it/s]



Training: Baseline2


Baseline2: 100%|██████████| 100/100 [00:10<00:00,  9.52it/s]



Training: Baseline3


Baseline3: 100%|██████████| 100/100 [00:07<00:00, 14.12it/s]


Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9880       | 3.0840       
Baseline1       | 0.9790       | 3.1522       
Baseline2       | 0.9931       | 3.2155       
Baseline3       | 1.0017       | 3.1679       



