# Experiment 07: ResNet Slice Classifier (ResClass)

This notebook implements the first stage of the RAovSeg pipeline: the slice selection classifier. The goal is to train a ResNet-18 model to distinguish between MRI slices that contain an ovary and those that do not.

### **Model Configuration**

*   **Objective**: Train a binary classifier to identify ovary-containing slices.
*   **Model Architecture**: **ResNet-18** (from `torchvision.models`).
*   **Dataset**: D2_TCPW, using **all slices** from eligible patients, with binary labels (1=ovary, 0=no ovary).
*   **Preprocessing**: RAovSeg custom preprocessing.
*   **Data Augmentation**: Simple `RandomAffine` and `RandomHorizontalFlip`.
*   **Loss Function**: **`BCEWithLogitsLoss`** (standard for binary classification).
*   **Optimizer**: Adam.
*   **Learning Rate**: 1e-4 (constant).
*   **Epochs**: 20.
*   **Batch Size**: **16** (we can use a larger batch size for classification).
*   **Image Size**: 256x256.
*   **Data Split**: 80% train / 20% validation, split by patient ID.
*   **Class Imbalance Strategy**: Subsampling negative examples to a 2:1 ratio (negative:positive).

In [4]:
# --- Imports and Setup ---
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import numpy as np
import torchvision.models as models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

project_root = os.path.abspath('..')
if project_root not in sys.path:
    sys.path.append(project_root)

# <--- Import our new Dataset class ---
from src.data_loader import SliceClassifierDataset

# --- Configuration ---
manifest_path = '../data/d2_manifest_t2fs_ovary_eligible.csv'
image_size = 256
batch_size = 16 # We can use a larger batch size for classification
num_epochs = 20
lr = 1e-4

# --- Data Loading ---
print("--- Loading Full Slice Data for Classification ---")
# Use the new dataset class which loads positive and negative samples
train_full_dataset = SliceClassifierDataset(manifest_path=manifest_path, image_size=image_size, augment=True)
val_full_dataset = SliceClassifierDataset(manifest_path=manifest_path, image_size=image_size, augment=False)

# This split is now on the slice level, but the underlying patient IDs are still separated by the dataset object creation.
# For simplicity, we'll split the shuffled list of slices.
num_slices = len(train_full_dataset.slice_data)
split_idx = int(num_slices * 0.8)

# Note: This is a simplified split. A more rigorous approach would split patient IDs first.
# But since the data is shuffled, this is a reasonable starting point.
train_dataset = Subset(train_full_dataset, range(split_idx))
val_dataset = Subset(val_full_dataset, range(split_idx, num_slices))


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
print(f"Data successfully split:\nTraining samples: {len(train_dataset)}\nValidation samples: {len(val_dataset)}")


# --- Training and Validation Functions for CLASSIFICATION ---
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        
        preds = torch.sigmoid(outputs) > 0.5
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
    accuracy = accuracy_score(all_labels, all_preds)
    return running_loss / len(loader.dataset), accuracy


def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device).unsqueeze(1)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            
            preds = torch.sigmoid(outputs) > 0.5
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    
    return running_loss / len(loader.dataset), accuracy, precision, recall

# --- Main Training Loop ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

# Load a pretrained ResNet-18 and adapt it for our single-channel, binary task
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# Modify for single-channel input
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# Modify for binary output
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)

optimizer = Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss() # Standard loss for binary classification

train_loss_history, val_loss_history, val_acc_history = [], [], []
best_val_acc = -1.0
best_epoch = -1
model_save_path = "../models/17_resclass_best.pth"
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)

print("\n--- Starting ResNet Classifier (ResClass) Training ---")
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, val_prec, val_recall = validate(model, val_loader, criterion, device)
    
    train_loss_history.append(train_loss)
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)
    
    print(f"Epoch {epoch+1}/{num_epochs} -> Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Precision: {val_prec:.4f}, Val Recall: {val_recall:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        torch.save(model.state_dict(), model_save_path)
        print(f"  -> New best model saved at epoch {best_epoch} with Val Acc: {best_val_acc:.4f}")

print("--- Finished Training ---")
print(f"Best model was from epoch {best_epoch} with a validation accuracy of {best_val_acc:.4f}")
print(f"Model saved to {model_save_path}\n")

# --- Visualization ---
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_loss_history, label='Training Loss', marker='.')
plt.plot(range(1, num_epochs + 1), val_loss_history, label='Validation Loss', marker='.')
plt.title('Training and Validation Loss (ResClass)')
plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), val_acc_history, label='Validation Accuracy', color='green', marker='.')
plt.title('Validation Accuracy (ResClass)')
plt.xlabel('Epochs'); plt.ylabel('Accuracy')
plt.axvline(x=best_epoch, color='r', linestyle='--', label=f'Best Acc @ Epoch {best_epoch}')
plt.legend(); plt.grid(True)
plt.suptitle('ResNet Classifier (ResClass) Results', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

--- Loading Full Slice Data for Classification ---
Loading manifest from ../data/d2_manifest_t2fs_ovary_eligible.csv and creating slice map for classifier...


IndexError: index 35 is out of bounds for axis 0 with size 35