In [9]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import random
from sklearn.preprocessing import StandardScaler
from lifelines.utils import concordance_index
from TLSTM import TLSTM

from data_loader import get_data_loaders


# DataLoader
train_loader, val_loader = get_data_loaders()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Use the device name:", device)


def cox_ph_loss(risks, durations, events):
    """
    Cox partial likelihood loss for survival analysis.
    
    Args:
        risks: Tensor [B], log hazard ratios (higher = higher risk)
        durations: Tensor [B], survival times
        events: Tensor [B], 1=event occurred, 0=censored
        
    Returns:
        Scalar loss (negative log partial likelihood)
    """
    risks = risks.view(-1)
    durations = durations.view(-1)
    events = events.view(-1)
    
    # Only compute loss for observed events (not censored)
    event_mask = events == 1
    
    if event_mask.sum() == 0:
        # No events observed, return zero loss
        return torch.tensor(0.0, device=risks.device, requires_grad=True)
    
    # Get indices of observed events
    event_indices = torch.where(event_mask)[0]
    
    losses = []
    for i in event_indices:
        # Risk set: all subjects still at risk at time t_i
        # (duration >= duration of current event)
        at_risk_mask = durations >= durations[i]
        
        if at_risk_mask.sum() == 0:
            continue
            
        # Numerator: risk of subject i
        numerator = risks[i]
        
        # Denominator: log-sum-exp of risks for all subjects at risk
        denominator = torch.logsumexp(risks[at_risk_mask], dim=0)
        
        # Partial likelihood for this event
        partial_ll = numerator - denominator
        losses.append(-partial_ll)  # Negative because we want to minimize
    
    if len(losses) == 0:
        return torch.tensor(0.0, device=risks.device, requires_grad=True)
    
    return torch.stack(losses).mean()


# def cox_ph_loss(risks, durations, events):
#  
#     order = torch.argsort(durations, descending=True)
#     risks = risks[order]
#     events = events[order]

#     log_cumsum_exp = torch.logcumsumexp(risks, dim=0)
#     losses = risks - log_cumsum_exp
#     return -torch.mean(losses * events)

    

input_dim = 5 # feature number
hidden_dim = 64
fc_dim = 32
output_dim = 1  

model = TLSTM(input_dim, hidden_dim, output_dim, fc_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


# === 5. train ===
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0

    for t_seq, lengths, x_seq, durations, events, pids in loader:
        x_seq = x_seq.to(device)
        t_seq = t_seq.to(device)
        durations = durations.to(device)
        events = events.to(device)

        optimizer.zero_grad()
        risk_pred = model(x_seq, t_seq).squeeze(1)  
        loss = cox_ph_loss(risk_pred, durations, events)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x_seq.size(0)
    return total_loss / len(loader.dataset)


# === 6. evaluate ===

def evaluate(model, loader):
    model.eval()
    risks_all = []
    durations_all = []
    events_all = []
    with torch.no_grad():
        for t_seq, lengths, x_seq, durations, events, pids in loader:
            x_seq = x_seq.to(device)
            t_seq = t_seq.to(device)
            durations = durations.to(device)
            events = events.to(device)

            risk_pred = model(x_seq, t_seq).squeeze(1)
            risks_all.append(risk_pred.cpu())
            durations_all.append(durations.cpu())
            events_all.append(events.cpu())
    risks_all = torch.cat(risks_all)
    durations_all = torch.cat(durations_all)
    events_all = torch.cat(events_all)
    c_index = concordance_index(durations_all.numpy(), -risks_all.numpy(), events_all.numpy())
    return c_index



# === 7. main ===
num_epochs = 50
best_cindex = 0
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer)
    val_cindex = evaluate(model, val_loader)

    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} - Val C-index: {val_cindex:.4f}")
    
torch.save(model.state_dict(), 'best_tlstm_model.pth')


Use the device name: cpu
Epoch 1/50 - Train Loss: 3.3261 - Val C-index: 0.8620
Epoch 2/50 - Train Loss: 3.2361 - Val C-index: 0.8949
Epoch 3/50 - Train Loss: 3.0735 - Val C-index: 0.8962
Epoch 4/50 - Train Loss: 2.9105 - Val C-index: 0.9011
Epoch 5/50 - Train Loss: 2.7561 - Val C-index: 0.9190
Epoch 6/50 - Train Loss: 2.4297 - Val C-index: 0.9197
Epoch 7/50 - Train Loss: 2.5410 - Val C-index: 0.9159
Epoch 8/50 - Train Loss: 2.7044 - Val C-index: 0.9277
Epoch 9/50 - Train Loss: 2.3949 - Val C-index: 0.9277
Epoch 10/50 - Train Loss: 2.4922 - Val C-index: 0.9199
Epoch 11/50 - Train Loss: 2.2692 - Val C-index: 0.9298
Epoch 12/50 - Train Loss: 2.1442 - Val C-index: 0.9252
Epoch 13/50 - Train Loss: 2.2184 - Val C-index: 0.9341
Epoch 14/50 - Train Loss: 2.1818 - Val C-index: 0.9341
Epoch 15/50 - Train Loss: 2.0520 - Val C-index: 0.9264
Epoch 16/50 - Train Loss: 1.9122 - Val C-index: 0.9321
Epoch 17/50 - Train Loss: 1.9147 - Val C-index: 0.9324
Epoch 18/50 - Train Loss: 1.7615 - Val C-index: 0