In [None]:
# train_fast.py
import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from lifelines.utils import concordance_index

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)

# -----------------------------
# 1) 读取 R 导出的 CSV
# -----------------------------
train = pd.read_csv("data/train.csv")
val   = pd.read_csv("data/val.csv")
test  = pd.read_csv("data/test.csv")

FEATS = ["age", "sbp"]
def to_tensor(df):
  x = torch.tensor(df[FEATS].values, dtype=torch.float32, device=DEVICE)
  a = torch.tensor(df["A"].values.reshape(-1,1), dtype=torch.float32, device=DEVICE)
  t = torch.tensor(df["Y"].values, dtype=torch.float32, device=DEVICE)
  e = torch.tensor(df["Delta"].values, dtype=torch.float32, device=DEVICE)
  return x, a, t, e

x_tr, a_tr, t_tr, e_tr = to_tensor(train)
x_va, a_va, t_va, e_va = to_tensor(val)
x_te, a_te, t_te, e_te = to_tensor(test)

# -----------------------------
# 2) 基线模型（DeepSurv 风格：MLP -> risk score）
# -----------------------------
class DeepSurv(nn.Module):
  def __init__(self, in_dim):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(in_dim, 32), nn.ReLU(),
      nn.Linear(32, 16), nn.ReLU(),
      nn.Linear(16, 1)  # risk score (log hazard ratio)
    )
  def forward(self, x):
    return self.net(x)  # (N,1)

# -----------------------------
# 3) Cox 部分对数似然（全样本版，稳定）
#    公式：sum_{i: e_i=1} (r_i - logsumexp(r_j, j in R_i))
#    用时间降序 + 反向累积 log-sum-exp 近似实现
# -----------------------------
def cox_ph_loss(risk, time, event, eps=1e-8):
  # risk: (N,1)
  risk = risk.squeeze(-1)
  # 排序：time 降序
  order = torch.argsort(time, descending=True)
  t_ord = time[order]; e_ord = event[order]; r_ord = risk[order]
  # 反向累积 log-sum-exp 近似风险集
  r_rev = torch.flip(r_ord, dims=[0])
  cumsum_rev = torch.log(torch.cumsum(torch.exp(r_rev), dim=0) + eps)
  log_riskset = torch.flip(cumsum_rev, dims=[0])
  ll = torch.sum(e_ord * (r_ord - log_riskset))
  return -ll  # minimize

# -----------------------------
# 4) MINE 互信息估计器（DV 下界 + moving average）
#    输入是 (A, Z)，Z 可取 risk 表征（或隐藏层表示）
# -----------------------------
class MINE(nn.Module):
  def __init__(self, in_dim=2, hidden=64):
    super().__init__()
    self.tnet = nn.Sequential(
      nn.Linear(in_dim, hidden), nn.ReLU(),
      nn.Linear(hidden, hidden), nn.ReLU(),
      nn.Linear(hidden, 1)
    )
    self.ma_et = None  # moving average of E[e^T] for stability

  def forward(self, joint, marginal):
    t_joint = self.tnet(joint)          # (N,1)
    t_marg  = self.tnet(marginal)       # (N,1)
    # DV lower bound: E[T] - log E[e^T]
    et = torch.exp(t_marg)
    mean_t = torch.mean(t_joint)
    mean_et = torch.mean(et)

    # moving-average trick to stabilize log-mean-exp term
    ma_rate = 0.01
    if self.ma_et is None:
      self.ma_et = mean_et.detach()
    else:
      self.ma_et = (1 - ma_rate) * self.ma_et + ma_rate * mean_et.detach()

    mi = mean_t - torch.log(self.ma_et + 1e-8)
    return mi, mean_t.item(), mean_et.item()

def sample_joint_and_marginal(a, z):
  # a: (N,1), z: (N,1). joint 是配对 (a_i, z_i);
  # marginal 是打乱 z 得到的 (a_i, z_pi)
  idx = torch.randperm(a.shape[0], device=a.device)
  joint = torch.cat([a, z], dim=1)
  marg  = torch.cat([a, z[idx]], dim=1)
  return joint, marg

# -----------------------------
# 5) 训练：Baseline & FAST
# -----------------------------
def evaluate_cindex(model, x, t, e):
  model.eval()
  with torch.no_grad():
    r = model(x).squeeze(-1).detach().cpu().numpy()
  return concordance_index(t.cpu().numpy(), -r, e.cpu().numpy())  # -r: higher risk -> shorter time

def group_cindex(model, x, t, e, a):
  c0 = evaluate_cindex(model, x[a.squeeze()==0], t[a.squeeze()==0], e[a.squeeze()==0])
  c1 = evaluate_cindex(model, x[a.squeeze()==1], t[a.squeeze()==1], e[a.squeeze()==1])
  return c0, c1, abs(c0 - c1)

# ---- Baseline ----
baseline = DeepSurv(in_dim=len(FEATS)).to(DEVICE)
opt_b = torch.optim.Adam(baseline.parameters(), lr=1e-3)
EPOCHS = 200
for ep in range(EPOCHS):
  baseline.train()
  opt_b.zero_grad()
  risk = baseline(x_tr)
  loss = cox_ph_loss(risk, t_tr, e_tr)
  loss.backward(); opt_b.step()
  if (ep+1) % 50 == 0:
    c = evaluate_cindex(baseline, x_te, t_te, e_te)
    print(f"[Baseline] epoch {ep+1:03d} loss={loss.item():.3f} c-index={c:.3f}")

c_overall_base = evaluate_cindex(baseline, x_te, t_te, e_te)
c0, c1, gap_base = group_cindex(baseline, x_te, t_te, e_te, a_te)
print(f"[Baseline] test c-index={c_overall_base:.3f} | group c0={c0:.3f} c1={c1:.3f} | ΔC={gap_base:.3f}")

# ---- FAST (with MINE) ----
gamma = 0.5               # 公平强度（可尝试 0.1, 0.5, 1.0）
mine_steps = 5            # 每轮先更新 MINE 的步数
model = DeepSurv(in_dim=len(FEATS)).to(DEVICE)
mine  = MINE(in_dim=2, hidden=64).to(DEVICE)
opt_m = torch.optim.Adam(mine.parameters(),   lr=1e-3)
opt_f = torch.optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 200
for ep in range(EPOCHS):
  # (a) 先更新 MINE：最大化 MI（因此对 MINE 做 gradient ascent -> 等价最小化(-MI)）
  model.eval()
  with torch.no_grad():
    z = model(x_tr).detach()
  for _ in range(mine_steps):
    opt_m.zero_grad()
    joint, marg = sample_joint_and_marginal(a_tr, z)
    mi, _, _ = mine(joint, marg)
    loss_m = -mi
    loss_m.backward(); opt_m.step()

  # (b) 再更新主模型：最小化 Cox + gamma * MI
  model.train()
  opt_f.zero_grad()
  z = model(x_tr)
  joint, marg = sample_joint_and_marginal(a_tr, z)
  mi, _, _ = mine(joint.detach(), marg.detach())  # 固定 MINE，估计 MI；不回传到 MINE
  loss_cox = cox_ph_loss(z, t_tr, e_tr)
  loss_total = loss_cox + gamma * mi
  loss_total.backward(); opt_f.step()

  if (ep+1) % 50 == 0:
    c = evaluate_cindex(model, x_te, t_te, e_te)
    print(f"[FAST] epoch {ep+1:03d} total={loss_total.item():.3f} c-index={c:.3f} mi={mi.item():.3f}")

c_overall_fast = evaluate_cindex(model, x_te, t_te, e_te)
c0f, c1f, gap_fast = group_cindex(model, x_te, t_te, e_te, a_te)
print(f"[FAST γ={gamma}] test c-index={c_overall_fast:.3f} | group c0={c0f:.3f} c1={c1f:.3f} | ΔC={gap_fast:.3f}")

# 结果小结
print("\n=== Summary ===")
print(f"Baseline: c-index={c_overall_base:.3f}, ΔC={gap_base:.3f}")
print(f"FAST(γ={gamma}): c-index={c_overall_fast:.3f}, ΔC={gap_fast:.3f}")
