# Model Comparison: DyT vs TFT

This notebook compares the predictions of the **Dynamic Transformer (DyT)** against the **Temporal Fusion Transformer (TFT)** baseline.

We will:
1. Load both models.
2. Run inference on the test set.
3. Identify cases where DyT outperforms TFT.
4. Visualize the risk trajectories side-by-side.

In [1]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from tqdm import tqdm

# Add project root
sys.path.append(os.path.abspath('../../'))

from src.models.dyt import DyTTransformer
from src.models.tft import TFTBaseline
from src.data.loader import get_loader

In [2]:
# Config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_path = '../../data/processed_splits/test.parquet'
dyt_ckpt = '../../models/test_run/dyt_best.pth'
tft_ckpt = '../../models/test_run/tft_best.pth'

# Load Data to get dimensions
df = pd.read_parquet(test_path)
feature_cols = [c for c in df.columns if c not in ['PatientID', 'SepsisLabel', 'Unit1', 'Unit2', 'HospAdmTime', 'ICULOS']]
input_dim = len(feature_cols)

print(f"Input Dimension: {input_dim}")

Input Dimension: 20


In [3]:
# Load Models
dyt_model = DyTTransformer(input_dim=input_dim, d_model=64, n_heads=4, num_layers=2).to(device)
dyt_model.load_state_dict(torch.load(dyt_ckpt, map_location=device))
dyt_model.eval()

tft_model = TFTBaseline(input_dim=input_dim, d_model=64, n_heads=4, num_layers=2).to(device)
tft_model.load_state_dict(torch.load(tft_ckpt, map_location=device))
tft_model.eval()

print("Models Loaded.")

Models Loaded.


In [4]:
# Helper to get predictions for a patient
def get_predictions(patient_id, df, model):
    p_data = df[df['PatientID'] == patient_id].copy()
    if 'ICULOS' in p_data.columns:
        p_data = p_data.sort_values('ICULOS')
        
    features = p_data[feature_cols].values
    labels = p_data['SepsisLabel'].values
    
    # Gaps
    if 'ICULOS' in p_data.columns:
        gaps = np.zeros(len(features))
        gaps[1:] = p_data['ICULOS'].values[1:] - p_data['ICULOS'].values[:-1]
    else:
        gaps = np.ones(len(features))
        
    # Truncate to max_len (200)
    max_len = 200
    if len(features) > max_len:
        features = features[-max_len:]
        gaps = gaps[-max_len:]
        
    x = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
    t = torch.tensor(gaps, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
    
    with torch.no_grad():
        logits, _ = model(x, t)
        probs = torch.sigmoid(logits).cpu().numpy().flatten()
        
    # Return corresponding time/labels (truncated if needed)
    if len(labels) > max_len:
        labels = labels[-max_len:]
        time_vals = p_data['ICULOS'].values[-max_len:] if 'ICULOS' in p_data.columns else np.arange(len(features))
    else:
        time_vals = p_data['ICULOS'].values if 'ICULOS' in p_data.columns else np.arange(len(features))
        
    return probs, labels, time_vals

# Find a patient where DyT is confident and TFT is not
# This is a heuristic search
sepsis_patients = df[df['SepsisLabel'] == 1]['PatientID'].unique()
print(f"Searching through {len(sepsis_patients)} sepsis patients for interesting cases...")

interesting_cases = []

for pid in sepsis_patients[:50]: # Check first 50
    dyt_probs, labels, _ = get_predictions(pid, df, dyt_model)
    tft_probs, _, _ = get_predictions(pid, df, tft_model)
    
    # Check max risk near onset
    onset_idx = np.where(labels == 1)[0]
    if len(onset_idx) == 0: continue
    
    # Look at 6 hours before onset
    start_check = max(0, onset_idx[0] - 6)
    end_check = onset_idx[0] + 1
    
    dyt_max = dyt_probs[start_check:end_check].max()
    tft_max = tft_probs[start_check:end_check].max()
    
    if dyt_max > 0.7 and tft_max < 0.4:
        interesting_cases.append(pid)
        
print(f"Found {len(interesting_cases)} interesting cases: {interesting_cases}")

Searching through 423 sepsis patients for interesting cases...


Found 0 interesting cases: []


In [5]:
# Visualize Case
if interesting_cases:
    pid = interesting_cases[0]
    dyt_probs, labels, time = get_predictions(pid, df, dyt_model)
    tft_probs, _, _ = get_predictions(pid, df, tft_model)
    
    plt.figure(figsize=(12, 6))
    plt.plot(time, dyt_probs, label='DyT Risk', color='blue', linewidth=2)
    plt.plot(time, tft_probs, label='TFT Risk', color='gray', linestyle='--')
    
    # Onset
    onset_idx = np.where(labels == 1)[0]
    if len(onset_idx) > 0:
        plt.axvline(x=time[onset_idx[0]], color='red', label='Sepsis Onset')
        
    plt.title(f"Patient {pid}: DyT vs TFT")
    plt.ylabel("Sepsis Probability")
    plt.xlabel("Time (Hours)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("No strong divergence found in the first 50 patients. Try checking more or different criteria.")

No strong divergence found in the first 50 patients. Try checking more or different criteria.
