# Novel Approach

In [None]:
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import AutoTokenizer, DistilBertModel
from tqdm.notebook import tqdm # Progress bars

import matplotlib.pyplot as plt
import numpy as np

from PIL import Image

import os
import random

from src.data import load_omnimed_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Dataset

### Load the Base Dataset

In [None]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))

# Check for image overlap
print("Overlap train-test:", len(set(train_df['image_path']) & set(test_df['image_path'])))
print("Overlap train-val:", len(set(train_df['image_path']) & set(val_df['image_path'])))

## OmniMed Dataset Setup

### Define Image Transforms

In [None]:
# Training transforms (includes augmentation)
train_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),                 # Resize image
    models.ResNet18_Weights.DEFAULT.transforms()  # Use ResNet18 default transforms
])

# Validation / Test transforms (no augmentation)
val_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    models.ResNet18_Weights.DEFAULT.transforms()
])

### Define Tokenizer Setup

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

### Create Novel Dataset

In [None]:
class OmniMedNovelDataset(Dataset):
    def __init__(self, df, image_transform=None, tokenizer=None, max_length=100):
        self.df = df
        self.image_transform = image_transform
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_map = {"option_A": 0, "option_B": 1, "option_C": 2, "option_D": 3}
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # ---- Image ----
        image = Image.open(row['image_path']).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)

        # ---- Text ----
        # For example: "Question: What abnormality is present? Options: A. X B. Y C. Z D. W"
        option_labels = ["A", "B", "C", "D"]
        options_text = " ".join(
            f"{label}. {row[f'option_{label}']}" 
            for label in option_labels 
            if row[f'option_{label}'] is not None
        )

        text_input = f"Question: {row['question']} Options: {options_text}"

        if self.tokenizer:
            tokens = self.tokenizer(text_input, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
            input_ids = tokens.input_ids.squeeze(0)
            attention_mask = tokens.attention_mask.squeeze(0)
        else:
            input_ids, attention_mask = None, None

        label = self.label_map[row['gt_label']]

        return {
            "image": image,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "label": label
        }

In [None]:
train_dataset = OmniMedNovelDataset(train_df, image_transform=train_image_transform, tokenizer=tokenizer)
val_dataset = OmniMedNovelDataset(val_df, image_transform=val_image_transform, tokenizer=tokenizer)
test_dataset = OmniMedNovelDataset(test_df, image_transform=val_image_transform, tokenizer=tokenizer)

### Create Data Loaders

In [None]:
# Define batch_size
# TODO: Add to config.py as constant with optional override
batch_size = 64
num_workers = 0

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

# Quick check
print("Train batches:", len(train_loader))
print("Validation batches:", len(val_loader))
print("Test batches:", len(test_loader))

In [None]:
# Sanity check
for i, batch in enumerate(train_loader):
    images = batch['image']
    labels = batch['label']
    print(i, images.shape, labels.shape)
    if i == 2:
        break

## Model Setup

### Create Model Class

In [None]:
class AttentionGatedMultimodalClassifier(nn.Module):
    """
    Attention-gated multimodal classifier.
    Emphasizes image features over text features via a learnable gating mechanism.
    """

    def __init__(self, fusion_dim=512, num_classes=4, init_image_bias=0.7):
        super().__init__()
        
        # ----- Vision Encoder -----
        # Pretrained ResNet18
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.vision_encoder = nn.Sequential(*list(resnet.children())[:-1])  # output: (batch, 512, 1, 1)
        self.vision_proj = nn.Linear(512, fusion_dim)
        
        # ----- Text Encoder -----
        # Simple embedding + mean pooling (could swap for a transformer if desired)
        self.text_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, fusion_dim)
        
        # ----- Attention-Gated Fusion -----
        # Gate: learnable parameter between 0-1 for weighting image vs text
        self.gate_param = nn.Parameter(torch.tensor(init_image_bias))
        
        # Final classifier
        self.classifier = nn.Linear(fusion_dim, num_classes)
        self.relu = nn.ReLU()

    def forward(self, images, input_ids=None, attention_mask=None):
        """
        Args:
            images: (batch, 3, H, W)
            input_ids: (batch, seq_len)
        Returns:
            logits: (batch, num_classes)
            gate: image/text gate value
        """
        # --- Image features ---
        x_img = self.vision_encoder(images)          # (batch, 512, 1, 1)
        x_img = x_img.view(x_img.size(0), -1)       # (batch, 512)
        x_img = self.relu(self.vision_proj(x_img))  # (batch, fusion_dim)
        
        # --- Text features ---
        if input_ids is not None:
            text_outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
            x_text = text_outputs.last_hidden_state.mean(dim=1)
            x_text = self.relu(self.text_proj(x_text))
        else:
            x_text = torch.zeros_like(x_img)
        
        # --- Attention-gated fusion ---
        gate = torch.sigmoid(self.gate_param)  # scalar between 0-1
        x = gate * x_img + (1 - gate) * x_text
        
        # --- Classification ---
        logits = self.classifier(x)
        
        return logits, gate

### Define the Model

In [None]:
# Model
model = AttentionGatedMultimodalClassifier(
    fusion_dim=512,
    num_classes=4,
    init_image_bias=0.7
)
model = model.to(device)

# Loss, Optimizer, Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

### Training Loop

In [None]:
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# TODO: Add to config.py as constant with optional override
num_epochs = 1

best_val_acc = 0.0
os.makedirs("models", exist_ok=True)
best_model_path = os.path.join("models", "novel_model.pth")

for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")

    # Training
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    for batch in tqdm(train_loader, desc="Training"):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits, gate = model(images, input_ids, attention_mask=attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, preds = torch.max(logits, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)
    
    train_loss /= len(train_loader)
    train_acc = correct_train / total_train
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    
    # Vaidation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits, _ = model(images, input_ids)
            preds = torch.argmax(logits, dim=1)
            
            val_loss += criterion(logits, labels).item()
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)
    
    val_loss /= len(val_loader)
    val_acc = correct_val / total_val
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved best model with val_acc={best_val_acc:.4f}")

    # Step the scheduler
    scheduler.step()

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)


In [None]:
# Plot Loss
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()
plt.show()

# Plot Accuracy
plt.figure(figsize=(10, 4))
plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Validation Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy Curve")
plt.legend()
plt.show()


### Testing Loop

In [None]:
# Load best model before testing
model.load_state_dict(torch.load(best_model_path))
model.eval()
print(f"Loaded best model from {best_model_path}")

y_true = []
y_pred = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        images = batch['image'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        logits, _ = model(images, input_ids)

        # Predictions
        preds = torch.argmax(logits, dim=1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Convert to numpy arrays
y_true = np.array(y_true)
y_pred = np.array(y_pred)


# Accuracy
accuracy = np.mean(y_true == y_pred)
print(f"Test Accuracy: {accuracy:.4f}")

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score

print(f"Test Accuracy:  {accuracy:.4f}")
print(f"Test Precision (macro): {precision_score(y_true, y_pred, average='macro'):.4f}")
print(f"Test Recall (macro):    {recall_score(y_true, y_pred, average='macro'):.4f}")
print(f"Test F1 (macro):        {f1_score(y_true, y_pred, average='macro'):.4f}")
