# Model A: Master Training Notebook (OSCC Analysis)
**Role:** Senior Computer Vision Engineer & MLOps Specialist
**Objective:** Train a Multi-Task Learning (MTL) model for OSCC WSI Analysis.
**Status:** Pre-data phase (Using Dummy Data for pipeline verification).

## Tasks
1.  **TVNT:** Tumour vs Non-Tumour (Binary Classification)
2.  **DOI:** Depth of Invasion (Segmentation -> Depth Calc)
3.  **POI:** Pattern of Invasion (5-Class Classification)
4.  **TB:** Tumour Budding (Count Regression)
5.  **PNI:** Perineural Invasion (Binary Classification)
6.  **MI:** Mitotic Index (Count Regression)

In [None]:
# 1. Imports & Setup
import os
import time
import random
import numpy as np
import pandas as pd  # Added pandas for CSV handling
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import cv2

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [2]:
# 2. Model Definition (DenseNet169 Backbone)

class UpsampleBlock(nn.Module):
    """Helper block for the Segmentation Decoder."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x):
        return self.upsample(self.conv(x))

class OSCCMultiTaskModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Backbone: DenseNet169
        self.backbone = models.densenet169(pretrained=True)
        num_ftrs = self.backbone.classifier.in_features
        
        # Remove original classifier
        self.backbone.classifier = nn.Identity()
        
        # --- HEADS ---
        
        # 1. TVNT (Binary: Tumour vs Non-Tumour)
        self.head_tvnt = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2) 
        )
        
        # 2. POI (5 Classes: Pattern of Invasion)
        self.head_poi = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 5)
        )
        
        # 3. PNI (Binary: Present vs Absent)
        self.head_pni = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2)
        )
        
        # 4. TB (Regression: Count of Tumour Buds)
        self.head_tb = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Linear(128, 1) # Output raw count
        )
        
        # 5. MI (Regression: Count of Mitotic Figures)
        self.head_mi = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Linear(128, 1) # Output raw count
        )
        
        # 6. DOI (Segmentation: Depth of Invasion Mask)
        # Note: DenseNet169 features shape depends on input. 
        # We will use a simplified decoder for this demo.
        self.decoder = nn.Sequential(
            UpsampleBlock(num_ftrs, 512), # /32 -> /16
            UpsampleBlock(512, 256),      # /16 -> /8
            UpsampleBlock(256, 128),      # /8 -> /4
            UpsampleBlock(128, 64),       # /4 -> /2
            UpsampleBlock(64, 32),        # /2 -> /1
            nn.Conv2d(32, 1, kernel_size=1) # Output: 1 channel mask
        )

    def forward(self, x):
        # Extract features
        # DenseNet features are (B, 1664, H/32, W/32)
        features = self.backbone.features(x)
        
        # Global Average Pooling for Classification/Regression Heads
        pooled = F.relu(features, inplace=True)
        pooled = F.adaptive_avg_pool2d(pooled, (1, 1))
        pooled = torch.flatten(pooled, 1)
        
        # Task Outputs
        out_tvnt = self.head_tvnt(pooled)
        out_poi = self.head_poi(pooled)
        out_pni = self.head_pni(pooled)
        out_tb = self.head_tb(pooled)
        out_mi = self.head_mi(pooled)
        
        # Segmentation Output
        out_doi = self.decoder(features)
        
        return {
            'tvnt': out_tvnt,
            'poi': out_poi,
            'pni': out_pni,
            'tb': out_tb,
            'mi': out_mi,
            'doi': out_doi
        }

print("Model Architecture Defined.")

Model Architecture Defined.


In [3]:
# 3. Logic & Metrics (DOI Calculation)

def calculate_doi_from_mask(mask_tensor, pixel_spacing_mm=0.00025):
    """
    Calculates Depth of Invasion (DOI) from a binary segmentation mask.
    
    Args:
        mask_tensor (torch.Tensor): Shape (1, H, W), values 0 or 1.
        pixel_spacing_mm (float): Physical size of one pixel in mm.
        
    Returns:
        float: Depth in mm.
    """
    mask = mask_tensor.squeeze().cpu().numpy().astype(np.uint8)
    
    if np.sum(mask) == 0:
        return 0.0
        
    # Find all tumour pixels
    y_indices, x_indices = np.where(mask > 0)
    
    if len(y_indices) == 0:
        return 0.0
        
    # Heuristic: DOI is distance from the "top-most" tumour pixel (superficial)
    # to the "bottom-most" tumour pixel (deepest).
    # In a real scenario, we would need a reference 'mucosal line'.
    # Here we assume the image is oriented such that 'up' is superficial.
    
    min_y = np.min(y_indices)
    max_y = np.max(y_indices)
    
    pixel_depth = max_y - min_y
    doi_mm = pixel_depth * pixel_spacing_mm
    
    return doi_mm

print("DOI Calculation Function Defined.")

DOI Calculation Function Defined.


In [None]:
# 4. Real Dataset Loader

class OSCCRealDataset(Dataset):
    def __init__(self, img_dir, mask_dir=None, csv_file=None, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform
        
        # Load Labels
        if csv_file and os.path.exists(csv_file):
            self.df = pd.read_csv(csv_file)
            print(f"Loaded {len(self.df)} samples from {csv_file}")
        else:
            # Fallback: List all images, set labels to default/dummy
            self.image_files = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png'))] if os.path.exists(img_dir) else []
            self.df = pd.DataFrame({'filename': self.image_files})
            # Add default columns if missing
            for col in ['tvnt', 'poi', 'pni', 'tb', 'mi']:
                self.df[col] = 0
            print(f"No CSV found. Found {len(self.df)} images in '{img_dir}'. Using placeholder labels (0).")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['filename']
        img_path = os.path.join(self.img_dir, img_name)
        
        # 1. Load Image
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # Fallback for missing file (return black image)
            image = Image.new('RGB', (224, 224))
            
        # 2. Load Mask (if exists)
        mask = np.zeros((224, 224), dtype=np.float32)
        if self.mask_dir:
            # Assume mask has same name but png
            mask_name = os.path.splitext(img_name)[0] + ".png"
            mask_path = os.path.join(self.mask_dir, mask_name)
            if os.path.exists(mask_path):
                m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                if m is not None:
                    m = cv2.resize(m, (224, 224), interpolation=cv2.INTER_NEAREST)
                    mask = m / 255.0 # Normalize to 0-1
        
        # 3. Load Labels
        label_tvnt = int(row.get('tvnt', 0))
        label_poi = int(row.get('poi', 0))
        label_pni = int(row.get('pni', 0))
        label_tb = float(row.get('tb', 0.0))
        label_mi = float(row.get('mi', 0.0))

        if self.transform:
            image = self.transform(image)
            
        mask = torch.tensor(mask).unsqueeze(0).float()
        
        return {
            'image': image,
            'tvnt': torch.tensor(label_tvnt, dtype=torch.long),
            'poi': torch.tensor(label_poi, dtype=torch.long),
            'pni': torch.tensor(label_pni, dtype=torch.long),
            'tb': torch.tensor(label_tb, dtype=torch.float),
            'mi': torch.tensor(label_mi, dtype=torch.float),
            'doi': mask
        }

# Configuration
DATASET_ROOT = "dataset"
IMG_DIR = os.path.join(DATASET_ROOT, "images")
MASK_DIR = os.path.join(DATASET_ROOT, "masks")
CSV_FILE = os.path.join(DATASET_ROOT, "labels.csv")

# Create directories if they don't exist (Helper for user)
os.makedirs(IMG_DIR, exist_ok=True)
os.makedirs(MASK_DIR, exist_ok=True)

# Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize Dataset
train_dataset = OSCCRealDataset(IMG_DIR, MASK_DIR, CSV_FILE, transform=train_transform)

if len(train_dataset) > 0:
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    print("‚úÖ Real Dataset Loaded Successfully.")
else:
    print("‚ö†Ô∏è Dataset folder is empty. Please add images to 'dataset/images' and a 'labels.csv'.")
    # Create dummy loader to prevent crash if user runs without data
    dummy_ds = OSCCRealDataset(IMG_DIR, csv_file=None, transform=train_transform)
    dummy_ds.df = pd.DataFrame({'filename': ['dummy.jpg'], 'tvnt':[0], 'poi':[0], 'pni':[0], 'tb':[0], 'mi':[0]})
    train_loader = DataLoader(dummy_ds, batch_size=1)

Dummy Dataset & DataLoader Created.


## 4. Dataset Configuration (Real Data)
**Instructions for User:**
To train on real data, organize your files as follows:
1.  **Images:** Put your `.jpg` patches in `dataset/images/`.
2.  **Masks (Optional):** Put segmentation masks in `dataset/masks/` (same filename as image, but `.png`).
3.  **Labels:** Create a `dataset/labels.csv` with the following columns:
    *   `filename`: e.g., "slide1_patch_0_0.jpg"
    *   `tvnt`: 0 (Normal) or 1 (Tumour)
    *   `poi`: 0-4 (Pattern of Invasion type)
    *   `pni`: 0 (No) or 1 (Yes)
    *   `tb`: Integer count of tumour buds
    *   `mi`: Integer count of mitotic figures

*If no data is found, the code will warn you but won't crash.*

In [None]:
# 5. Training Loop

model = OSCCMultiTaskModel().to(DEVICE)

# Check if pretrained model exists and load it (Resume Training)
if os.path.exists("model_a.pth"):
    try:
        model.load_state_dict(torch.load("model_a.pth", map_location=DEVICE))
        print("‚úÖ Loaded existing model weights from model_a.pth. Resuming training...")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not load model weights: {e}. Starting from scratch.")
else:
    print("üÜï No existing model found. Starting training from scratch.")

optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Loss Functions
criterion_cls = nn.CrossEntropyLoss() # For TVNT, POI, PNI
criterion_reg = nn.MSELoss()          # For TB, MI
criterion_seg = nn.BCEWithLogitsLoss() # For DOI Mask

NUM_EPOCHS = 2

print("Starting Training Loop...")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    for batch in train_loader:
        images = batch['image'].to(DEVICE)
        
        # Move targets to device
        target_tvnt = batch['tvnt'].to(DEVICE)
        target_poi = batch['poi'].to(DEVICE)
        target_pni = batch['pni'].to(DEVICE)
        target_tb = batch['tb'].to(DEVICE).unsqueeze(1) # (B, 1)
        target_mi = batch['mi'].to(DEVICE).unsqueeze(1) # (B, 1)
        target_doi = batch['doi'].to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward Pass
        outputs = model(images)
        
        # Calculate Losses
        loss_tvnt = criterion_cls(outputs['tvnt'], target_tvnt)
        loss_poi = criterion_cls(outputs['poi'], target_poi)
        loss_pni = criterion_cls(outputs['pni'], target_pni)
        loss_tb = criterion_reg(outputs['tb'], target_tb)
        loss_mi = criterion_reg(outputs['mi'], target_mi)
        
        # Resize mask output to match target if needed (due to pooling/upsampling)
        # Our decoder outputs 224x224 so it should match
        loss_doi = criterion_seg(outputs['doi'], target_doi)
        
        # Total Loss (Weighted Sum - can tune weights later)
        total_loss = (loss_tvnt + loss_poi + loss_pni + 
                      0.5 * loss_tb + 0.5 * loss_mi + 
                      1.0 * loss_doi)
        
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.item()
        
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {running_loss/len(train_loader):.4f}")

print("Training Complete.")



Starting Training Loop...
Epoch [1/2], Loss: 91.3662
Epoch [1/2], Loss: 91.3662
Epoch [2/2], Loss: 86.5585
Training Complete.
Epoch [2/2], Loss: 86.5585
Training Complete.


In [6]:
# 6. Export Model
save_path = "model_a.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

Model saved to model_a.pth
