In [1]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
import sys
import os

# Add 'src' to path so we can import our modules
sys.path.append(os.path.abspath('../src'))

from model import WideAndDeepSurvivalModel
from dataset import TriModalDataset

# 1. Setup Transforms
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 2. Load Data (Processed Parquet)
df = pd.read_parquet('../data/processed/OAI_model_ready_data.parquet')

# 3. Quick Preprocessing (Re-doing OneHot just for this test)
# In the real pipeline, we'd save the preprocessor. 
# Here we just manually create the columns to match our Dataset expectations.
df = pd.get_dummies(df, columns=['KL_Grade', 'Sex'], drop_first=True)

# Ensure all expected columns exist (fill 0 if missing)
expected_cols = ['KL_Grade_1.0', 'KL_Grade_2.0', 'KL_Grade_3.0', 'KL_Grade_4.0', 'Sex_2']
for col in expected_cols:
    if col not in df.columns:
        df[col] = 0

# 4. Initialize Dataset (Sandbox Mode)
dataset = TriModalDataset(
    dataframe=df.head(10), # Just use 10 rows for testing
    image_dir='../data/sandbox',
    transform=val_transform,
    mode='sandbox' # <--- Important!
)

loader = DataLoader(dataset, batch_size=2, shuffle=True)

# 5. Initialize Model
model = WideAndDeepSurvivalModel(wide_input_dim=8)

# 6. Run Integration Test
print("Running Integration Test (Forward Pass)...")
for images, clinical, events, times in loader:
    print(f"Image Batch Shape: {images.shape}")       # Expect [2, 3, 224, 224]
    print(f"Clinical Batch Shape: {clinical.shape}") # Expect [2, 8]
    
    # Forward Pass
    risk_scores = model(images, clinical)
    
    print(f"Risk Score Output: {risk_scores.shape}") # Expect [2, 1]
    print("Success! Pipeline is connected.")
    break

Dataset initialized in SANDBOX mode. Found 9786 images to sample from.
Running Integration Test (Forward Pass)...
Image Batch Shape: torch.Size([2, 3, 224, 224])
Clinical Batch Shape: torch.Size([2, 8])
Risk Score Output: torch.Size([2, 1])
Success! Pipeline is connected.
