In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 9

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class R_CRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class CF_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class TS_TARNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'R-CRN'    : R_CRN().to(DEVICE),
        'CF-RNN'   : CF_RNN().to(DEVICE),
        'TS-TARNet': TS_TARNet().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}     ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=489, z1=471
  Epoch 000 | loss=9.8574 | recon=0.9682 kl=0.0052 mmd=0.000000
  [MMD] z0=498, z1=462
  Epoch 030 | loss=10.0448 | recon=1.0517 kl=0.0401 mmd=0.000001
  [MMD] z0=551, z1=409
  Epoch 060 | loss=9.9402 | recon=1.0297 kl=0.0988 mmd=0.000001
  [MMD] z0=452, z1=508
  Epoch 090 | loss=9.8511 | recon=0.9968 kl=0.1253 mmd=0.000002
  [MMD] z0=458, z1=502
  Epoch 120 | loss=9.6630 | recon=0.9428 kl=0.1034 mmd=0.000001
  [MMD] z0=430, z1=530
  Epoch 149 | loss=9.8466 | recon=0.9993 kl=0.1126 mmd=0.000001

Training: R-CRN


R-CRN: 100%|██████████| 100/100 [00:19<00:00,  5.25it/s]



Training: CF-RNN


CF-RNN: 100%|██████████| 100/100 [00:12<00:00,  7.85it/s]



Training: TS-TARNet


TS-TARNet: 100%|██████████| 100/100 [00:04<00:00, 22.73it/s]



Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9897       | 3.0804     
R-CRN           | 1.0169       | 3.1424     
CF-RNN          | 1.0101       | 3.1167     
TS-TARNet       | 1.0281       | 3.1721     


In [7]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 6

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class R_CRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class CF_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class TS_TARNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'R-CRN'    : R_CRN().to(DEVICE),
        'CF-RNN'   : CF_RNN().to(DEVICE),
        'TS-TARNet': TS_TARNet().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}      ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=489, z1=471
  Epoch 000 | loss=9.8572 | recon=0.9682 kl=0.0052 mmd=0.000000
  [MMD] z0=498, z1=462
  Epoch 030 | loss=10.0408 | recon=1.0505 kl=0.0416 mmd=0.000001
  [MMD] z0=551, z1=409
  Epoch 060 | loss=9.9497 | recon=1.0302 kl=0.0997 mmd=0.000001
  [MMD] z0=452, z1=508
  Epoch 090 | loss=9.8532 | recon=0.9972 kl=0.1296 mmd=0.000002
  [MMD] z0=458, z1=502
  Epoch 120 | loss=9.6595 | recon=0.9425 kl=0.1096 mmd=0.000001
  [MMD] z0=430, z1=530
  Epoch 149 | loss=9.8475 | recon=0.9988 kl=0.1183 mmd=0.000001

Training: R-CRN


R-CRN: 100%|██████████| 100/100 [00:17<00:00,  5.76it/s]



Training: CF-RNN


CF-RNN: 100%|██████████| 100/100 [00:12<00:00,  7.71it/s]



Training: TS-TARNet


TS-TARNet: 100%|██████████| 100/100 [00:03<00:00, 27.57it/s]


Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9892       | 3.0804      
R-CRN           | 1.0156       | 3.1620      
CF-RNN          | 1.0099       | 3.1133      
TS-TARNet       | 1.0137       | 3.1671      





In [8]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# ============================================================
# 0. CONFIG & REPRODUCIBILITY
# ============================================================

REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
dim_x_features  = 5        # updated dynamically from data
SEQUENCE_LENGTH  = 30
BATCH_SIZE       = 32
HIDDEN_DIM       = 128
LATENT_DIM       = 64
NUM_EPOCHS       = 150
LEARNING_RATE    = 5e-4
KL_WEIGHT        = 0.001
MMD_WEIGHT       = 1.0
WINDOW_SIZE      = 5
TREATMENT_LAG    = 3

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# ============================================================
# 1. MMD UTILITIES
# ============================================================

def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    dists  = x_norm + y_norm - 2 * (x @ y.t())
    return torch.exp(-dists / (2 * sigma**2 + 1e-12))

def compute_mmd_stable(x, y):
    if x is None or y is None or x.size(0) <= 1 or y.size(0) <= 1:
        return torch.tensor(0.0, device=DEVICE)
    mmd = 0.0
    for sigma in [0.5, 1.0, 2.0]:
        K_xx = gaussian_rbf_matrix(x, x, sigma)
        K_yy = gaussian_rbf_matrix(y, y, sigma)
        K_xy = gaussian_rbf_matrix(x, y, sigma)
        n, m = x.size(0), y.size(0)
        sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
        sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
        sum_xy = K_xy.mean()
        mmd += sum_xx + sum_yy - 2.0 * sum_xy
    return mmd / 3.0

# ============================================================
# 2. DATA MODULE
# ============================================================

class IHDP_TimeSeries:

    def __init__(self, csv_path, batch_size, sequence_length,
                 treatment_lag=TREATMENT_LAG):
        self.csv_path        = csv_path
        self.batch_size      = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag   = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(
            window=window_size, min_periods=1
        ).mean().values.reshape(-1, 1)

    def _compute_lag(self, T, lag):
        """Shift treatment by `lag` steps; fill leading entries with zero."""
        Tlag = np.zeros_like(T)
        Tlag[lag:] = T[:-lag]
        return Tlag

    def _load_and_preprocess_data(self):
        if os.path.exists(self.csv_path):
            df = pd.read_csv(self.csv_path)
        else:
            df = pd.DataFrame(
                np.random.randn(1621, 5),
                columns=['uoe', 'von', 'total_vel', 'zos', 'sithick']
            )

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh    = df['zos'].values.reshape(-1, 1)
        vel    = df['total_vel'].values.reshape(-1, 1)
        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)

        # Control treatment T0: SSH-based with seasonal signal
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE)
        T0_np     = T0_smooth + (2.0 * hidden) + np.random.normal(0, 0.1, ssh.shape)

        # Treated treatment T1: regime-dependent scaling using raw velocity
        np.random.seed(REPRODUCIBILITY_SEED)
        v0      = np.mean(vel)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (vel - v0)))
        T1_np   = ((1.0 + 1.5 * sigmoid) * T0_np).reshape(-1, 1)

        # Treatment lags
        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)

        # Covariates: 3 base + T0 + T0_lag = 5 features
        X_RAW = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(df) // self.sequence_length
        limit   = num_seq * self.sequence_length

        # Update dim_x_features dynamically
        global dim_x_features
        dim_x_features = X_RAW.shape[1]

        scaler   = StandardScaler()
        X_scaled = scaler.fit_transform(X_RAW[:limit])

        self.xall      = X_scaled.reshape(num_seq, self.sequence_length, dim_x_features)
        self.t_factual = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t_counter = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag    = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag    = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.y_factual = y_base[:limit].reshape(num_seq, self.sequence_length, 1)

        # Counterfactual outcomes
        hidden_seq = hidden[:limit].reshape(num_seq, self.sequence_length, 1)
        T0_seq     = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        T1_seq     = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        delta      = -6.0 * np.abs(hidden_seq) * np.tanh(2.0 * (T1_seq - T0_seq))
        self.y0_cf = self.y_factual
        self.y1_cf = self.y_factual + delta

        # Diagnostics
        ite    = self.y1_cf - self.y0_cf
        t_corr = np.corrcoef(T0_np.flatten(), T1_np.flatten())[0, 1]
        print("=" * 55)
        print("DATA DIAGNOSTICS")
        print("=" * 55)
        print(f"Num sequences:      {self.xall.shape[0]}")
        print(f"Sequence length:    {self.xall.shape[1]}")
        print(f"Feature dim:        {self.xall.shape[2]}")
        print(f"Mean |ITE|:         {np.abs(ite).mean():.4f}")
        print(f"Std  ITE:           {ite.std():.4f}")
        print(f"% near-zero ITE:    {(np.abs(ite) < 0.01).mean()*100:.1f}%")
        print(f"T0-T1 correlation:  {t_corr:.4f}")
        print("=" * 55)

    def get_dataloaders(self):
        # Split indices ONCE to guarantee alignment across all arrays
        indices        = np.arange(len(self.xall))
        tr_idx, te_idx = train_test_split(
            indices, test_size=0.2, random_state=REPRODUCIBILITY_SEED
        )

        # Train: factual only — no counterfactuals (prevents leakage)
        # Includes t0_lag for baseline models
        train_ds = TensorDataset(
            torch.FloatTensor(self.xall[tr_idx]),
            torch.FloatTensor(self.t_factual[tr_idx]),
            torch.FloatTensor(self.t_counter[tr_idx]),
            torch.FloatTensor(self.t0_lag[tr_idx]),
            torch.FloatTensor(self.t1_lag[tr_idx]),
            torch.FloatTensor(self.y_factual[tr_idx])
        )

        # Test: include y0_cf and y1_cf for PEHE evaluation only
        test_ds = TensorDataset(
            torch.FloatTensor(self.xall[te_idx]),
            torch.FloatTensor(self.t_factual[te_idx]),
            torch.FloatTensor(self.t_counter[te_idx]),
            torch.FloatTensor(self.t0_lag[te_idx]),
            torch.FloatTensor(self.t1_lag[te_idx]),
            torch.FloatTensor(self.y_factual[te_idx]),
            torch.FloatTensor(self.y0_cf[te_idx]),
            torch.FloatTensor(self.y1_cf[te_idx])
        )

        return (
            DataLoader(train_ds, BATCH_SIZE, shuffle=True),
            DataLoader(test_ds,  BATCH_SIZE, shuffle=False)
        )

# ============================================================
# 3. DCMVAE MODEL
# ============================================================

class DCMVAE(nn.Module):

    def __init__(self, use_mmd=True):
        super().__init__()
        self.use_mmd = use_mmd

        # Encoder receives covariates + treatment so latent space
        # learns treatment-dependent representations for MMD balancing
        self.encoder_rnn = nn.GRU(
            dim_x_features + 1, HIDDEN_DIM,
            batch_first=True, bidirectional=True
        )
        self.fc_mu     = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)
        self.fc_logvar = nn.Linear(HIDDEN_DIM * 2, LATENT_DIM)

        # Treatment projection: amplifies treatment signal (1 → 16 dims)
        self.t_proj = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )

        # Outcome head conditioned on latent z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(LATENT_DIM + 16, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, t, mode='train'):
        x_and_t = torch.cat([X, t], dim=-1)
        h, _    = self.encoder_rnn(x_and_t)
        mu      = self.fc_mu(h)
        logvar  = self.fc_logvar(h)
        std     = torch.exp(0.5 * logvar)
        z       = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        t_emb   = self.t_proj(t)
        z_and_t = torch.cat([z, t_emb], dim=-1)
        y_pred  = self.outcome_head(z_and_t)
        return y_pred, mu, logvar, z

    def counterfactual_prediction(self, X, t0, t1):
        self.eval()
        with torch.no_grad():
            p0, _, _, _ = self.forward(X, t0, mode='eval')
            p1, _, _, _ = self.forward(X, t1, mode='eval')
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, X, t):
        self.eval()
        with torch.no_grad():
            y_hat, _, _, _ = self.forward(X, t, mode='eval')
        return y_hat.squeeze(-1)

# ============================================================
# 4. BASELINE MODELS
# ============================================================

class R_CRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class CF_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)


class TS_TARNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn    = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        # Treatment projection: same pattern as DCMVAE (1 → 16 dims)
        self.t_proj = nn.Sequential(nn.Linear(1, 16), nn.ReLU())
        # Single outcome head conditioned on z + projected treatment
        self.outcome_head = nn.Sequential(
            nn.Linear(HIDDEN_DIM + 16, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        t_emb    = self.t_proj(t)
        y_pred   = self.outcome_head(torch.cat([z_seq, t_emb], dim=-1))
        return y_pred, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z, _  = self.rnn(torch.cat([x, t0_lag], dim=2))  # single z
        p0    = self.outcome_head(torch.cat([z, self.t_proj(t0)], dim=-1))
        p1    = self.outcome_head(torch.cat([z, self.t_proj(t1)], dim=-1))
        return p0.squeeze(-1), p1.squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# ============================================================
# 5. TRAINING
# ============================================================

def train_dcmvae(model, train_loader, t_median):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    model.train()

    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        kl_scale   = min(1.0, epoch / 30.0)
        mmd_scale  = min(1.0, epoch / 50.0)

        for batch_idx, (x_in, t0, t1, t0_lag, t1_lag, y_fact) in enumerate(train_loader):
            x_in, t0, y_fact = x_in.to(DEVICE), t0.to(DEVICE), y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_pred, mu, logvar, z = model(x_in, t0, mode='train')

            loss_recon = F.mse_loss(y_pred, y_fact)
            loss_kl    = torch.clamp(
                -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()),
                max=1.0
            )

            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if model.use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = (t0.view(-1) > t_median).float()
                z0_mmd = z_flat[t_flat == 0]
                z1_mmd = z_flat[t_flat == 1]
                if batch_idx == 0 and (epoch % 30 == 0 or epoch == NUM_EPOCHS - 1):
                    print(f"  [MMD] z0={z0_mmd.size(0)}, z1={z1_mmd.size(0)}")
                loss_mmd = compute_mmd_stable(z0_mmd, z1_mmd) * mmd_scale

            loss = (10.0 * loss_recon
                    + KL_WEIGHT * kl_scale * loss_kl
                    + MMD_WEIGHT * loss_mmd)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            epoch_loss += loss.item()

        avg = epoch_loss / len(train_loader)
        scheduler.step(avg)
        if epoch % 30 == 0 or epoch == NUM_EPOCHS - 1:
            print(f"  Epoch {epoch:03d} | loss={avg:.4f} | "
                  f"recon={loss_recon.item():.4f} "
                  f"kl={loss_kl.item():.4f} "
                  f"mmd={loss_mmd.item():.6f}")


def train_baseline(name, model, train_loader):
    optimizer = torch.optim.Adam(
        model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5
    )
    model.train()
    t_median = 0.0

    for epoch in tqdm(range(100), desc=name):
        for x_in, t0, t1, t0_lag, t1_lag, y_fact in train_loader:
            x_in   = x_in.to(DEVICE)
            t0     = t0.to(DEVICE)
            t0_lag = t0_lag.to(DEVICE)
            y_fact = y_fact.to(DEVICE)
            optimizer.zero_grad()

            y_p, z_s = model(x_in[:, :, :3], t0, t0_lag)
            loss_recon = F.mse_loss(y_p, y_fact)
            loss_mmd   = torch.tensor(0.0, device=DEVICE)

            if name != 'TS-TARNet':
                zf = z_s.reshape(-1, HIDDEN_DIM)
                tf = (t0.view(-1) > t_median).float()
                z0, z1 = zf[tf == 0], zf[tf == 1]
                if z0.size(0) > 1 and z1.size(0) > 1:
                    loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT

            loss = loss_recon + loss_mmd
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

# ============================================================
# 6. EVALUATION
# ============================================================

def calculate_factual_rmse(name, model, loader):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(DEVICE).squeeze(-1)
            if name == 'DCMVAE':
                yh = model.factual_prediction(x.to(DEVICE), t0.to(DEVICE))
            else:
                yh = model.factual_prediction(
                    x[:, :, :3].to(DEVICE), t0.to(DEVICE), t0l.to(DEVICE)
                )
            se    += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)


def calculate_pehe_ate(name, model, loader):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t = y0c.to(DEVICE).squeeze(-1)
            y1t = y1c.to(DEVICE).squeeze(-1)

            if name == 'DCMVAE':
                y0h, y1h = model.counterfactual_prediction(
                    x.to(DEVICE), t0.to(DEVICE), t1.to(DEVICE)
                )
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(DEVICE),
                    t0.to(DEVICE), t1.to(DEVICE),
                    t0l.to(DEVICE), t1l.to(DEVICE)
                )

            ite_h    = y1h - y0h
            ite_t    = y1t - y0t
            ite_se  += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count   += y0t.numel()

    return math.sqrt(ite_se / count), abs(ate_err / count)

# ============================================================
# 7. BENCHMARK
# ============================================================

def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = float(np.median(dm.t_factual))

    model_list = {
        'DCMVAE'   : DCMVAE(use_mmd=True).to(DEVICE),
        'R-CRN'    : R_CRN().to(DEVICE),
        'CF-RNN'   : CF_RNN().to(DEVICE),
        'TS-TARNet': TS_TARNet().to(DEVICE),
    }

    # Train all models
    for name, model in model_list.items():
        print(f"\n{'='*55}")
        print(f"Training: {name}")
        print(f"{'='*55}")
        if name == 'DCMVAE':
            train_dcmvae(model, train_loader, t_median)
        else:
            train_baseline(name, model, train_loader)

    # Evaluate all models
    print(f"\n{'='*70}")
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} ")
    print("-" * 70)

    for name, model in model_list.items():
        rmse        = calculate_factual_rmse(name, model, test_loader)
        pehe, ate   = calculate_pehe_ate(name, model, test_loader)
        marker      = " ←" if name == 'DCMVAE' else ""
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}       ")

# ============================================================

if __name__ == '__main__':
    run_benchmark()


DATA DIAGNOSTICS
Num sequences:      54
Sequence length:    30
Feature dim:        5
Mean |ITE|:         2.4961
Std  ITE:           3.3482
% near-zero ITE:    12.3%
T0-T1 correlation:  0.9443

Training: DCMVAE
  [MMD] z0=489, z1=471
  Epoch 000 | loss=9.8575 | recon=0.9682 kl=0.0052 mmd=0.000000
  [MMD] z0=498, z1=462
  Epoch 030 | loss=10.0400 | recon=1.0503 kl=0.0404 mmd=0.000001
  [MMD] z0=551, z1=409
  Epoch 060 | loss=9.9559 | recon=1.0332 kl=0.0826 mmd=0.000001
  [MMD] z0=452, z1=508
  Epoch 090 | loss=9.8509 | recon=0.9970 kl=0.1271 mmd=0.000002
  [MMD] z0=458, z1=502
  Epoch 120 | loss=9.6686 | recon=0.9443 kl=0.1125 mmd=0.000001
  [MMD] z0=430, z1=530
  Epoch 149 | loss=9.8513 | recon=0.9987 kl=0.1230 mmd=0.000001

Training: R-CRN


R-CRN: 100%|██████████| 100/100 [00:17<00:00,  5.64it/s]



Training: CF-RNN


CF-RNN: 100%|██████████| 100/100 [00:12<00:00,  8.24it/s]



Training: TS-TARNet


TS-TARNet: 100%|██████████| 100/100 [00:04<00:00, 21.34it/s]


Model           | Test RMSE    | Test PEHE    
----------------------------------------------------------------------
DCMVAE          | 0.9892       | 3.0782       
R-CRN           | 1.0126       | 3.1418       
CF-RNN          | 1.0026       | 3.1274       
TS-TARNet       | 1.0120       | 3.1292       



