In [9]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# -----------------------------
# 0. CONFIG & REPRODUCIBILITY
# -----------------------------
REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Optimized Hyperparameters for DCMVAE
dim_x_features = 5  
SEQUENCE_LENGTH = 30
BATCH_SIZE = 64
HIDDEN_DIM = 128
LATENT_DIM = 64  # Increased for better representation
NUM_EPOCHS = 150  # More epochs for DCMVAE
LEARNING_RATE = 5e-3
KL_WEIGHT = 0.0005  # Reduced for stability
MMD_WEIGHT = 100.0  # Increased for better balance
ADJ_SPARSE_LAMBDA = 0.01  # Reduced for less aggressive sparsity
WINDOW_SIZE = 5

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# -----------------------------
# 1. UTILITIES
# -----------------------------
def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    cross = x @ y.t()
    dists = x_norm + y_norm - 2 * cross
    return torch.exp(-dists / (2 * (sigma ** 2) + 1e-12))

def compute_mmd_stable(x, y, sigma=1.0):
    if x is None or y is None: return torch.tensor(0.0, device=DEVICE)
    n, m = x.size(0), y.size(0)
    if n <= 1 or m <= 1: return torch.tensor(0.0, device=x.device)
    K_xx = gaussian_rbf_matrix(x, x, sigma)
    K_yy = gaussian_rbf_matrix(y, y, sigma)
    K_xy = gaussian_rbf_matrix(x, y, sigma)
    sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
    sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
    sum_xy = K_xy.mean()
    return sum_xx + sum_yy - 2.0 * sum_xy

# -----------------------------
# 2. DATA LOADER
# -----------------------------
class IHDP_TimeSeries:
    def __init__(self, csv_path, batch_size, sequence_length, treatment_lag=9):
        self.csv_path = csv_path
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(window=window_size, min_periods=1).mean()

    def _compute_lag(self, T, lag):
        Tlag = np.zeros_like(T)
        for i in range(lag, len(T)):
            Tlag[i, 0] = T[i - lag, 0]
        return Tlag

    def _load_and_preprocess_data(self):
        if not os.path.exists(self.csv_path):
            df = pd.DataFrame(np.random.randn(4000, 5), columns=['uoe', 'von', 'total_vel', 'zos', 'sithick'])
        else:
            df = pd.read_csv(self.csv_path)

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh = df['zos'].values.reshape(-1, 1)
        vel = df['total_vel'].values.reshape(-1, 1)

        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE).values.reshape(-1, 1)
        T2_smooth = self.apply_moving_window(vel, WINDOW_SIZE).values.reshape(-1, 1)

        T0_np = T0_smooth + (6.0 * T2_smooth) + (2.0 * hidden) + np.random.normal(0, 0.1, vel.shape)
        v0 = np.mean(T2_smooth)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (T2_smooth - v0)))
        T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)

        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        T1_lag_np = self._compute_lag(T1_np, self.treatment_lag)
        
        X_FACTUAL = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(X_FACTUAL) // self.sequence_length
        limit = num_seq * self.sequence_length

        self.yall = y_base[:limit].reshape(num_seq, self.sequence_length, 1)
        self.yall += np.random.normal(0, 0.2, self.yall.shape)

        self.t0_raw = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_raw = T1_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t0_lag = T0_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.t1_lag = T1_lag_np[:limit].reshape(num_seq, self.sequence_length, 1)
        
        self.x_scaler = StandardScaler()
        self.xall = self.x_scaler.fit_transform(X_FACTUAL[:limit]).reshape(num_seq, self.sequence_length, dim_x_features)

        self.y0_cf = self.yall
        delta = -6.0 * np.abs(hidden[:limit].reshape(self.yall.shape)) * np.tanh(5.0 * (self.t1_raw - self.t0_raw))
        self.y1_cf = self.yall + delta

    def get_dataloaders(self):
        vars = [self.xall, self.t0_raw, self.t1_raw, self.t0_lag, self.t1_lag, self.yall, self.y0_cf, self.y1_cf]
        tr, te = zip(*[train_test_split(v, test_size=0.2, random_state=42) for v in vars])
        def make_loader(data, shuffle):
            ds = TensorDataset(*[torch.FloatTensor(v) for v in data])
            return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle)
        return make_loader(tr, True), make_loader(te, False)

# -----------------------------
# 3. ENHANCED DCMVAE MODEL
# -----------------------------
class DCMVAE(nn.Module):
    def __init__(self, use_adj):
        super().__init__()
        self.use_adj = use_adj
        
        # Bidirectional GRU for better temporal modeling
        self.encoder_rnn = nn.GRU(dim_x_features, HIDDEN_DIM, batch_first=True, bidirectional=True)
        
        # Inference networks (increased capacity)
        self.fc_mu = nn.Sequential(
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM)
        )
        self.fc_logvar = nn.Sequential(
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM, LATENT_DIM)
        )
        
        # Learnable adjacency matrix
        self.adj_logits = nn.Parameter(0.1 * torch.randn(dim_x_features, dim_x_features))
        
        # Separate outcome heads with deeper networks
        self.y0_head = nn.Sequential(
            nn.Linear(LATENT_DIM, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.y1_head = nn.Sequential(
            nn.Linear(LATENT_DIM, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, X, mode='train'):
        # Apply causal graph structure
        if self.use_adj:
            mask = torch.sigmoid(self.adj_logits) * (1.0 - torch.eye(dim_x_features, device=DEVICE))
            X = torch.matmul(X, mask)
        else:
            mask = torch.eye(dim_x_features, device=DEVICE)
        
        # Encode with bidirectional RNN
        h_all, _ = self.encoder_rnn(X)
        
        # Variational inference
        mu = self.fc_mu(h_all)
        logvar = self.fc_logvar(h_all)
        
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        z = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        
        # Predict outcomes
        y0 = self.y0_head(z)
        y1 = self.y1_head(z)
        
        return y0, y1, mu, logvar, z, mask

    def counterfactual_prediction(self, X0, X1):
        self.eval()
        with torch.no_grad():
            y0, _, _, _, _, _ = self.forward(X0, mode='eval')
            _, y1, _, _, _, _ = self.forward(X1, mode='eval')
        return y0.squeeze(-1), y1.squeeze(-1)

    def factual_prediction(self, X):
        self.eval()
        with torch.no_grad():
            y0, y1, _, _, _, _ = self.forward(X, mode='eval')
            t_val = X[:, :, 3:4]
            y_hat = torch.where(t_val > 0, y1, y0)
        return y_hat.squeeze(-1)

# -----------------------------
# 4. BASELINE MODELS
# -----------------------------
class R_CRN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        self.y0_out = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))
        self.ite_out = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        y0_h = self.y0_out(z_seq)
        ite_h = self.ite_out(z_seq)
        return y0_h + t * ite_h, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z0, _ = self.rnn(torch.cat([x, t0_lag], dim=2))
        y0_h = self.y0_out(z0)
        ite_h = self.ite_out(z0)
        return y0_h.squeeze(-1), (y0_h + ite_h).squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

class CF_RNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        self.y0_head = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))
        self.y1_head = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        y0_h, y1_h = self.y0_head(z_seq), self.y1_head(z_seq)
        return (1 - t) * y0_h + t * y1_h, z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z0, _ = self.rnn(torch.cat([x, t0_lag], dim=2))
        return self.y0_head(z0).squeeze(-1), self.y1_head(z0).squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

class TS_TARNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.GRU(4, HIDDEN_DIM, batch_first=True)
        self.y0_head = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))
        self.y1_head = nn.Sequential(nn.Linear(HIDDEN_DIM, 64), nn.ReLU(), nn.Linear(64, 1))

    def forward(self, x, t, t_lag):
        z_seq, _ = self.rnn(torch.cat([x, t_lag], dim=2))
        return (1 - t) * self.y0_head(z_seq) + t * self.y1_head(z_seq), z_seq

    def counterfactual_prediction(self, x, t0, t1, t0_lag, t1_lag):
        z0, _ = self.rnn(torch.cat([x, t0_lag], dim=2))
        return self.y0_head(z0).squeeze(-1), self.y1_head(z0).squeeze(-1)

    def factual_prediction(self, x, t, t_lag):
        y_h, _ = self.forward(x, t, t_lag)
        return y_h.squeeze(-1)

# -----------------------------
# 5. OPTIMIZED TRAINING
# -----------------------------
def train_model(name, model, train_loader, device):
    # DCMVAE gets special treatment
    if name == 'DCMVAE':
        lr = 3e-4
        num_epochs = NUM_EPOCHS
    else:
        lr = LEARNING_RATE
        num_epochs = 100
    
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=10)
    
    model.train()
    t_median = 0.0
    
    for epoch in tqdm(range(num_epochs), desc=name):
        epoch_loss = 0.0
        
        # Progressive weighting for DCMVAE
        if name == 'DCMVAE':
            # Warm-up phases
            kl_scale = min(1.0, epoch / 30.0)  # KL warm-up
            mmd_scale = min(1.0, epoch / 50.0)  # MMD warm-up
        
        for x_in, t0, t1, t0_lag, t1_lag, y_out, y0_cf, y1_cf in train_loader:
            opt.zero_grad()
            x_in, y_f, t0, t0l = x_in.to(device), y_out.to(device), t0.to(device), t0_lag.to(device)
            
            if name == 'DCMVAE':
                y0_p, y1_p, mu, logvar, z, mask = model(x_in)
                
                # Treatment assignment (use actual treatment from data)
                t_idx = (t0 > t0.median()).float()
                y_pred = (1 - t_idx) * y0_p + t_idx * y1_p
                
                # Multi-objective loss
                loss_recon = F.mse_loss(y_pred, y_f)
                loss_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                
                # Improved representation balancing
                zf = z.reshape(-1, LATENT_DIM)
                tf = t_idx.reshape(-1)
                z0, z1 = zf[tf == 0], zf[tf == 1]
                
                if z0.size(0) > 1 and z1.size(0) > 1:
                    # Multi-scale MMD for robustness
                    loss_mmd = 0.0
                    for sigma in [0.5, 1.0, 2.0]:
                        loss_mmd += compute_mmd_stable(z0, z1, sigma=sigma)
                    loss_mmd = loss_mmd / 3.0
                else:
                    loss_mmd = torch.tensor(0.0, device=device)
                
                # Adjacency sparsity
                loss_adj = torch.norm(torch.sigmoid(model.adj_logits), 1)
                
                # Balanced combination with progressive scaling
                loss = (10.0 * loss_recon + 
                       kl_scale * KL_WEIGHT * loss_kl + 
                       mmd_scale * MMD_WEIGHT * loss_mmd + 
                       ADJ_SPARSE_LAMBDA * loss_adj)
                
            else:
                # Baseline models
                y_p, z_s = model(x_in[:, :, :3], t0, t0l)
                loss_recon = F.mse_loss(y_p, y_f)
                loss_mmd = torch.tensor(0.0, device=device)
                
                if name != 'TS-TARNet':
                    zf = z_s.reshape(-1, HIDDEN_DIM)
                    tf = (t0 > t_median).reshape(-1)
                    z0, z1 = zf[tf == 0], zf[tf == 1]
                    if z0.size(0) > 1 and z1.size(0) > 1:
                        loss_mmd = compute_mmd_stable(z0, z1) * MMD_WEIGHT
                
                loss = loss_recon + loss_mmd
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            epoch_loss += loss.item()
        
        # Learning rate scheduling for DCMVAE
        if name == 'DCMVAE':
            scheduler.step(epoch_loss / len(train_loader))

# -----------------------------
# 6. EVALUATION
# -----------------------------
def calculate_factual_rmse(model, loader, device, model_type):
    model.eval()
    se, count = 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            yf = yf.to(device).squeeze(-1)
            if model_type == 'DCMVAE':
                yh = model.factual_prediction(x.to(device))
            else:
                yh = model.factual_prediction(x[:, :, :3].to(device), t0.to(device), t0l.to(device))
            se += torch.sum((yh - yf) ** 2).item()
            count += yf.numel()
    return math.sqrt(se / count)

def calculate_pehe_ate(model, loader, device, model_type):
    model.eval()
    ite_se, ate_err, count = 0, 0, 0
    with torch.no_grad():
        for x, t0, t1, t0l, t1l, yf, y0c, y1c in loader:
            y0t, y1t = y0c.to(device).squeeze(-1), y1c.to(device).squeeze(-1)
            
            if model_type == 'DCMVAE':
                x0, x1 = x.clone().to(device), x.clone().to(device)
                x0[:, :, 3], x0[:, :, 4] = t0.squeeze(-1), t0l.squeeze(-1)
                x1[:, :, 3], x1[:, :, 4] = t1.squeeze(-1), t1l.squeeze(-1)
                y0h, y1h = model.counterfactual_prediction(x0, x1)
            else:
                y0h, y1h = model.counterfactual_prediction(
                    x[:, :, :3].to(device), t0.to(device), t1.to(device), 
                    t0l.to(device), t1l.to(device)
                )
            
            ite_h, ite_t = (y1h - y0h), (y1t - y0t)
            ite_se += torch.sum((ite_h - ite_t) ** 2).item()
            ate_err += torch.sum(ite_h - ite_t).item()
            count += y0t.numel()
    
    return math.sqrt(ite_se / count), abs(ate_err / count)

# -----------------------------
# 7. BENCHMARK
# -----------------------------
def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    
    model_list = {
        'DCMVAE': DCMVAE(use_adj=True).to(DEVICE),
        'R-CRN': R_CRN().to(DEVICE),
        'CF-RNN': CF_RNN().to(DEVICE),
        'TS-TARNet': TS_TARNet().to(DEVICE)
    }
    
    for name, model in model_list.items():
        train_model(name, model, train_loader, DEVICE)
    
    print("\n" + "=" * 70)
    print(f"{'Model':<15} | {'Test RMSE':<12} | {'Test PEHE':<12} | {'Test ATE':<12}")
    print("-" * 70)
    
    for name, model in model_list.items():
        rmse = calculate_factual_rmse(model, test_loader, DEVICE, name)
        pehe, ate = calculate_pehe_ate(model, test_loader, DEVICE, name)
        print(f"{name:<15} | {rmse:.4f}       | {pehe:.4f}       | {ate:.4f}")

if __name__ == '__main__':
    run_benchmark()

DCMVAE: 100%|██████████| 150/150 [00:01<00:00, 142.12it/s]
R-CRN: 100%|██████████| 100/100 [00:00<00:00, 276.78it/s]
CF-RNN: 100%|██████████| 100/100 [00:00<00:00, 273.32it/s]
TS-TARNet: 100%|██████████| 100/100 [00:00<00:00, 434.26it/s]


Model           | Test RMSE    | Test PEHE    | Test ATE    
----------------------------------------------------------------------
DCMVAE          | 0.3206       | 3.8390       | 0.9700
R-CRN           | 0.2079       | 3.8548       | 0.9160
CF-RNN          | 0.2242       | 3.8663       | 0.9597
TS-TARNet       | 0.2031       | 3.8582       | 0.9279



