In [6]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import ColumnTransformer

# 1. Load Data
DATA_FILE = '../data/processed/OAI_model_ready_data.parquet'
df = pd.read_parquet(DATA_FILE)

# 2. Define Features & Target
categorical_cols = ['KL_Grade', 'Sex']
numerical_cols = ['Age', 'BMI', 'WOMAC_Score']
feature_cols = numerical_cols + categorical_cols

X = df[feature_cols]
y_event = df['event'].values
y_time = df['time_to_event'].values

# 3. Split Data
X_train, X_test, y_event_train, y_event_test, y_time_train, y_time_test = train_test_split(
    X, y_event, y_time, test_size=0.25, random_state=42
)

# 4. Preprocessing (Scaling + OneHot)
preprocessor = ColumnTransformer(
    transformers=[
        ('num', StandardScaler(), numerical_cols),
        ('cat', OneHotEncoder(drop='first', sparse_output=False), categorical_cols)
    ]
)

X_train_processed = preprocessor.fit_transform(X_train)
X_test_processed = preprocessor.transform(X_test)

# 5. Convert to PyTorch Tensors
# Note: PyTorch uses Float32 by default
X_train_tensor = torch.tensor(X_train_processed, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test_processed, dtype=torch.float32)

y_event_train_tensor = torch.tensor(y_event_train, dtype=torch.bool) # Events are boolean
y_time_train_tensor = torch.tensor(y_time_train, dtype=torch.float32)

print(f"Features processed and converted to Tensors.")
print(f"Input Shape: {X_train_tensor.shape}")

Features processed and converted to Tensors.
Input Shape: torch.Size([2644, 8])


In [7]:
class WideSurvivalModel(nn.Module):
    def __init__(self, input_dim):
        super(WideSurvivalModel, self).__init__()
        
        # A simple MLP for tabular data
        self.network = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(32, 16),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            
            # Final layer: Outputs a SINGLE scalar (log-risk)
            # No activation function here (CoxPH expects linear log-risk)
            nn.Linear(16, 1) 
        )

    def forward(self, x):
        return self.network(x)

# Initialize the model
input_dim = X_train_tensor.shape[1]
model = WideSurvivalModel(input_dim)

print("WideSurvivalModel initialized.")
print(model)

WideSurvivalModel initialized.
WideSurvivalModel(
  (network): Sequential(
    (0): Linear(in_features=8, out_features=32, bias=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=32, out_features=16, bias=True)
    (5): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Linear(in_features=16, out_features=1, bias=True)
  )
)


In [8]:
import torchsurv.loss
import torchsurv
from sksurv.metrics import concordance_index_censored

# --- DEBUG STEP: Find the correct function name ---
print("Available attributes in torchsurv.loss:")
print(dir(torchsurv.loss))

# --- DYNAMIC FIX ---
# We will check what exists and assign it to 'cox_loss_func'
if hasattr(torchsurv.loss, 'cox'):
    cox_loss_func = torchsurv.loss.cox
elif hasattr(torchsurv.loss, 'cox_loss'):
    cox_loss_func = torchsurv.loss.cox_loss
elif hasattr(torchsurv.loss, 'CoxPH'): # Some versions use a class
    cox_loss_func = torchsurv.loss.CoxPH()
else:
    # If all else fails, we use a custom concise implementation of Cox Loss
    # This ensures your project proceeds even if the library is acting up.
    print("WARNING: Could not find 'cox' in torchsurv. Using custom implementation.")
    def custom_cox_loss(risk_scores, events, times):
        # Sort by time (descending) for efficient calculation
        # Note: robust implementation requires sorting
        order = torch.argsort(times, descending=True)
        risk_scores = risk_scores[order]
        events = events[order]
        
        # Calculate log-sum-exp (partial likelihood)
        log_cumsum_h = torch.logcumsumexp(risk_scores, dim=0)
        
        # Calculate loss (negative log likelihood)
        # We only sum loss for patients who had the event
        loss = -torch.sum(events * (risk_scores - log_cumsum_h)) / events.sum()
        return loss
    cox_loss_func = custom_cox_loss

print(f"\nUsing Loss Function: {cox_loss_func}")

# --- TRAINING LOOP ---
optimizer = optim.Adam(model.parameters(), lr=0.005)
epochs = 500

print(f"Starting training on {len(X_train_tensor)} samples...")

model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # Forward pass
    risk_scores = model(X_train_tensor).squeeze()
    
    # Calculate Loss using the function we found
    loss = cox_loss_func(
        risk_scores, 
        y_event_train_tensor, 
        y_time_train_tensor
    )
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f}")

print("Training complete.")

Available attributes in torchsurv.loss:
['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__']

Using Loss Function: <function custom_cox_loss at 0x7fa371cda430>
Starting training on 2644 samples...
Epoch 50/500 | Loss: 6.0267
Epoch 100/500 | Loss: 5.9853
Epoch 150/500 | Loss: 5.9441
Epoch 200/500 | Loss: 5.9367
Epoch 250/500 | Loss: 5.9081
Epoch 300/500 | Loss: 5.9124
Epoch 350/500 | Loss: 5.8501
Epoch 400/500 | Loss: 5.8656
Epoch 450/500 | Loss: 5.8475
Epoch 500/500 | Loss: 5.8475
Training complete.


In [9]:
# 1. Evaluate on Test Data
model.eval()
with torch.no_grad():
    # Get risk scores for test set
    test_risk_scores_tensor = model(X_test_tensor)
    test_risk_scores = test_risk_scores_tensor.numpy().flatten()

# 2. Calculate C-index
# We use the same sksurv metric as before for a fair comparison
c_index = concordance_index_censored(
    y_event_test.astype(bool),
    y_time_test,
    test_risk_scores
)

print(f"--- PyTorch Wide-Model Evaluation ---")
print(f"Concordance Index (C-index): {c_index[0]:.4f}")
print(f"Baseline (CoxPH) Target:     0.7468")

if c_index[0] >= 0.74:
    print("SUCCESS: PyTorch implementation matches statistical baseline.")
else:
    print("NOTE: Deep Learning on small tabular data can be unstable. Tune LR/Dropout.")

--- PyTorch Wide-Model Evaluation ---
Concordance Index (C-index): 0.7106
Baseline (CoxPH) Target:     0.7468
NOTE: Deep Learning on small tabular data can be unstable. Tune LR/Dropout.
