In [7]:
import os
import math
import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# -----------------------------
# 0. CONFIG & REPRODUCIBILITY
# -----------------------------
REPRODUCIBILITY_SEED = 42
random.seed(REPRODUCIBILITY_SEED)
np.random.seed(REPRODUCIBILITY_SEED)
torch.manual_seed(REPRODUCIBILITY_SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(REPRODUCIBILITY_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Optimized Hyperparameters
dim_x_features = 5  
SEQUENCE_LENGTH = 30
BATCH_SIZE = 64
HIDDEN_DIM = 128
LATENT_DIM = 64  # Increased for better representation
NUM_EPOCHS = 150  # More training for complex model
LEARNING_RATE = 5e-4  # Lower for stability
KL_WEIGHT = 0.0005  # Reduced for stability
MMD_WEIGHT = 100.0  # Increased for better balance
ADJ_SPARSE_LAMBDA = 0.01  # Reduced for flexibility
WINDOW_SIZE = 5

CSV_FILE_PATH = "arctic_s2s_multivar_2020_2024.csv"

# -----------------------------
# 1. UTILITIES
# -----------------------------
def gaussian_rbf_matrix(x, y, sigma=1.0):
    x_norm = (x ** 2).sum(1).unsqueeze(1)
    y_norm = (y ** 2).sum(1).unsqueeze(0)
    cross = x @ y.t()
    dists = x_norm + y_norm - 2 * cross
    return torch.exp(-dists / (2 * (sigma ** 2) + 1e-12))

def compute_mmd_stable(x, y, sigma=1.0):
    if x is None or y is None: return torch.tensor(0.0, device=DEVICE)
    n, m = x.size(0), y.size(0)
    if n <= 1 or m <= 1: return torch.tensor(0.0, device=x.device)
    K_xx = gaussian_rbf_matrix(x, x, sigma)
    K_yy = gaussian_rbf_matrix(y, y, sigma)
    K_xy = gaussian_rbf_matrix(x, y, sigma)
    sum_xx = (K_xx.sum() - torch.diag(K_xx).sum()) / (n * (n - 1))
    sum_yy = (K_yy.sum() - torch.diag(K_yy).sum()) / (m * (m - 1))
    sum_xy = K_xy.mean()
    return sum_xx + sum_yy - 2.0 * sum_xy

# -----------------------------
# 2. DATA LOADER
# -----------------------------
class IHDP_TimeSeries:
    def __init__(self, csv_path, batch_size, sequence_length, treatment_lag=1):
        self.csv_path = csv_path
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.treatment_lag = treatment_lag
        self._load_and_preprocess_data()

    @staticmethod
    def apply_moving_window(series, window_size):
        return pd.Series(series.flatten()).rolling(window=window_size, min_periods=1).mean()

    def _compute_lag(self, T, lag):
        Tlag = np.zeros_like(T)
        for i in range(lag, len(T)):
            Tlag[i, 0] = T[i - lag, 0]
        return Tlag

    def _load_and_preprocess_data(self):
        if not os.path.exists(self.csv_path):
            df = pd.DataFrame(np.random.randn(4000, 5), columns=['uoe', 'von', 'total_vel', 'zos', 'sithick'])
        else:
            df = pd.read_csv(self.csv_path)

        x_base = df[['uoe', 'von', 'total_vel']].values
        y_base = df[['sithick']].values
        ssh = df['zos'].values.reshape(-1, 1)
        vel = df['total_vel'].values.reshape(-1, 1)

        hidden = np.sin(np.linspace(0, 30 * np.pi, len(df))).reshape(-1, 1)
        T0_smooth = self.apply_moving_window(ssh, WINDOW_SIZE).values.reshape(-1, 1)
        T2_smooth = self.apply_moving_window(vel, WINDOW_SIZE).values.reshape(-1, 1)

        T0_np = T0_smooth + (6.0 * T2_smooth) + (2.0 * hidden) + np.random.normal(0, 0.1, vel.shape)
        v0 = np.mean(T2_smooth)
        sigmoid = 1 / (1 + np.exp(-(-5.0) * (T2_smooth - v0)))
        T1_np = ((1.5 + 0.5 * sigmoid) * T0_np).reshape(-1, 1)

        T0_lag_np = self._compute_lag(T0_np, self.treatment_lag)
        X_FACTUAL = np.concatenate([x_base, T0_np, T0_lag_np], axis=1)

        num_seq = len(X_FACTUAL) // self.sequence_length
        limit = num_seq * self.sequence_length

        self.yall = y_base[:limit].reshape(num_seq, self.sequence_length, 1)
        self.yall += np.random.normal(0, 0.2, self.yall.shape)

        self.t_raw = T0_np[:limit].reshape(num_seq, self.sequence_length, 1)
        self.x_scaler = StandardScaler()
        self.xall = self.x_scaler.fit_transform(X_FACTUAL[:limit]).reshape(num_seq, self.sequence_length, dim_x_features)

        self.y0_cf = self.yall
        delta = -6.0 * np.abs(hidden[:limit].reshape(self.yall.shape)) * np.tanh(5.0 * (T1_np[:limit] - T0_np[:limit]).reshape(self.yall.shape))
        self.y1_cf = self.yall + delta

    def get_dataloaders(self):
        vars = [self.xall, self.t_raw, self.yall, self.y0_cf, self.y1_cf]
        tr, te = zip(*[train_test_split(v, test_size=0.2, random_state=42) for v in vars])
        def make_loader(data, shuffle):
            ds = TensorDataset(*[torch.FloatTensor(v) for v in data])
            return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle)
        return make_loader(tr, True), make_loader(te, False)

# -----------------------------
# 3. ENHANCED DCMVAE MODEL
# -----------------------------
class DCMVAE(nn.Module):
    def __init__(self, use_adj):
        super().__init__()
        self.use_adj = use_adj
        
        # Bidirectional GRU for better temporal modeling
        self.encoder_rnn = nn.GRU(dim_x_features, HIDDEN_DIM, batch_first=True, bidirectional=True)
        
        # Enhanced inference networks
        self.fc_mu = nn.Sequential(
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(HIDDEN_DIM, LATENT_DIM)
        )
        self.fc_logvar = nn.Sequential(
            nn.Linear(HIDDEN_DIM * 2, HIDDEN_DIM),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(HIDDEN_DIM, LATENT_DIM)
        )
        
        # Learnable causal adjacency matrix
        self.adj_logits = nn.Parameter(0.1 * torch.randn(dim_x_features, dim_x_features))
        
        # Deeper outcome heads with residual-like connections
        self.y0_head = nn.Sequential(
            nn.Linear(LATENT_DIM, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )
        self.y1_head = nn.Sequential(
            nn.Linear(LATENT_DIM, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )

    def forward(self, X, mode='train'):
        # Apply causal graph structure
        if self.use_adj:
            mask = torch.sigmoid(self.adj_logits) * (1.0 - torch.eye(dim_x_features, device=DEVICE))
            X = torch.matmul(X, mask)
        else:
            mask = torch.eye(dim_x_features, device=DEVICE)

        # Encode with bidirectional RNN
        h_all, _ = self.encoder_rnn(X)
        
        # Variational inference
        mu = self.fc_mu(h_all)
        logvar = self.fc_logvar(h_all)
        
        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        z = mu + torch.randn_like(mu) * std if mode == 'train' else mu
        
        # Predict potential outcomes
        y0 = self.y0_head(z)
        y1 = self.y1_head(z)
        
        return y0, y1, mu, logvar, z, mask

    def counterfactual_prediction(self, X):
        self.eval()
        with torch.no_grad():
            y0, y1, _, _, _, _ = self.forward(X, mode='eval')
        return y0, y1

# -----------------------------
# 4. OPTIMIZED TRAINING
# -----------------------------
def train_model(model, train_loader, use_mmd, use_adj, t_median):
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15)
    
    model.train()
    
    for epoch in range(NUM_EPOCHS):
        epoch_loss = 0.0
        
        # Progressive warm-up for regularization terms
        kl_scale = min(1.0, epoch / 30.0)  # KL warm-up over 30 epochs
        mmd_scale = min(1.0, epoch / 50.0)  # MMD warm-up over 50 epochs
        
        for x_in, t_raw, y_f, y0_cf, y1_cf in train_loader:
            x_in, y_f, t_raw = x_in.to(DEVICE), y_f.to(DEVICE), t_raw.to(DEVICE)
            optimizer.zero_grad()

            y0, y1, mu, logvar, z, mask = model(x_in, mode='train')
            t_idx = (t_raw > t_median).float()

            # Factual prediction with treatment assignment
            y_pred = (1 - t_idx) * y0 + t_idx * y1
            
            # Reconstruction loss (prioritized)
            loss_recon = F.mse_loss(y_pred, y_f)
            
            # KL divergence with warm-up
            loss_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            
            # MMD loss for representation balancing
            loss_mmd = torch.tensor(0.0, device=DEVICE)
            if use_mmd:
                z_flat = z.view(-1, LATENT_DIM)
                t_flat = t_idx.view(-1)
                z0, z1 = z_flat[t_flat == 0], z_flat[t_flat == 1]
                
                if z0.size(0) > 1 and z1.size(0) > 1:
                    # Multi-scale MMD for robustness
                    mmd_sum = 0.0
                    for sigma in [0.5, 1.0, 2.0]:
                        mmd_sum += compute_mmd_stable(z0, z1, sigma=sigma)
                    loss_mmd = (mmd_sum / 3.0) * mmd_scale

            # Adjacency sparsity regularization
            loss_sparse = torch.tensor(0.0, device=DEVICE)
            if use_adj:
                loss_sparse = torch.norm(mask, 1)

            # Combined loss with careful weighting
            loss = (10.0 * loss_recon + 
                   kl_scale * KL_WEIGHT * loss_kl + 
                   MMD_WEIGHT * loss_mmd + 
                   ADJ_SPARSE_LAMBDA * loss_sparse)

            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            epoch_loss += loss.item()
        
        # Learning rate scheduling
        scheduler.step(epoch_loss / len(train_loader))

# -----------------------------
# 5. ABLATION BENCHMARK
# -----------------------------
def run_benchmark():
    dm = IHDP_TimeSeries(CSV_FILE_PATH, BATCH_SIZE, SEQUENCE_LENGTH)
    train_loader, test_loader = dm.get_dataloaders()
    t_median = np.median(dm.t_raw)

    # Ablation configurations
    configs = [
        (True, True),   # Full model: MMD + ADJ
        (False, True),  # No MMD, only ADJ
        (True, False),  # No ADJ, only MMD
        (False, False)  # Baseline: no MMD, no ADJ
    ]
    
    results = []

    for use_mmd, use_adj in configs:
        config_name = f"MMD={use_mmd}, ADJ={use_adj}"
        print(f"\n{'='*60}")
        print(f"Training: {config_name}")
        print(f"{'='*60}")
        
        model = DCMVAE(use_adj=use_adj).to(DEVICE)
        train_model(model, train_loader, use_mmd, use_adj, t_median)

        # Evaluation: Sequence-wide PEHE
        model.eval()
        pehe_sum, count = 0, 0
        
        with torch.no_grad():
            for x_in, t_raw, y_f, y0_cf, y1_cf in test_loader:
                p0, p1 = model.counterfactual_prediction(x_in.to(DEVICE))
                
                # Individual Treatment Effect
                ite_pred = p1 - p0
                ite_true = (y1_cf - y0_cf).to(DEVICE)
                
                # PEHE: sqrt(mean((ITE_pred - ITE_true)^2))
                pehe_sum += torch.sum((ite_pred - ite_true) ** 2).item()
                count += x_in.size(0) * SEQUENCE_LENGTH

        pehe = math.sqrt(pehe_sum / count)
        results.append({"MMD": use_mmd, "ADJ": use_adj, "PEHE": pehe})
        print(f"✓ Test PEHE: {pehe:.4f}")

    # Display results table
    print("\n" + "=" * 60)
    print("FINAL ABLATION RESULTS (Sequence-wide PEHE)")
    print("=" * 60)
    print(f"{'MMD':<8} | {'ADJ':<8} | {'Test PEHE':<12}")
    print("-" * 60)
    
    for r in sorted(results, key=lambda x: x['PEHE']):
        mmd_str = "✓" if r['MMD'] else "✗"
        adj_str = "✓" if r['ADJ'] else "✗"
        print(f"{mmd_str:<8} | {adj_str:<8} | {r['PEHE']:.4f}")
    
    # Highlight best configuration
    best = min(results, key=lambda x: x['PEHE'])
    print("\n" + "=" * 60)


if __name__ == "__main__":
    run_benchmark()


Training: MMD=True, ADJ=True
✓ Test PEHE: 3.7939

Training: MMD=False, ADJ=True
✓ Test PEHE: 3.8581

Training: MMD=True, ADJ=False
✓ Test PEHE: 3.8760

Training: MMD=False, ADJ=False
✓ Test PEHE: 3.8667

FINAL ABLATION RESULTS (Sequence-wide PEHE)
MMD      | ADJ      | Test PEHE   
------------------------------------------------------------
✓        | ✓        | 3.7939
✗        | ✓        | 3.8581
✗        | ✗        | 3.8667
✓        | ✗        | 3.8760

