# Model A: Master Training Notebook
**Role:** Senior Computer Vision Engineer & MLOps Specialist
**Objective:** Train a Multi-Task Learning (MTL) model for Cancer Cell Feature Detection.

## Tasks (Based on 3-Class YOLO Dataset)
1.  **TVNT:** Abnormality Detection (Binary Classification - Any abnormal cells detected vs Normal)
2.  **Mitotic Figures:** Count of mitotic figure instances (Regression)
3.  **Multiple Nucleol:** Count of multiple nucleol instances (Regression)
4.  **Nuclear Hyperchromatism:** Count of nuclear hyperchromatism instances (Regression)

## Classes
- Class 0: Mitotic Figures
- Class 1: Multiple Nucleol  
- Class 2: Nuclear Hyperchromatism

In [16]:
from roboflow import Roboflow
rf = Roboflow(api_key="Fi3Sh8fR3JkMox96bMBc")
project = rf.workspace("segp-fcn6m").project("oral-cancer-1mnve-n5yij")
version = project.version(2)
dataset = version.download("yolov8")

print(f"\n‚úÖ Dataset downloaded to: {dataset.location}")

loading Roboflow workspace...
loading Roboflow project...


Downloading Dataset Version Zip in Oral-Cancer-2 to yolov8:: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41777/41777 [00:04<00:00, 10072.75it/s]





Extracting Dataset Version Zip to Oral-Cancer-2 in yolov8:: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1100/1100 [00:01<00:00, 1011.96it/s]



‚úÖ Dataset downloaded to: c:\Users\user\Desktop\SEGP\Oral-Health-Computer-Vision-Model\Model A\Oral-Cancer-2


In [18]:
# 0.5 Convert YOLO Dataset to Model A Format
# This parses the YOLO annotations and creates labels.csv with correct counts

import os
import shutil
import yaml
import pandas as pd
from pathlib import Path

# Configuration - Update this if your download folder has a different name
YOLO_DATASET_PATH = "Oral-Cancer-2"  # Roboflow downloads with version number
OUTPUT_DATASET_PATH = "dataset"
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'}

# Class IDs
CLASS_MITOTIC = 0       # Mitotic Figures
CLASS_NUCLEOL = 1       # Multiple Nucleol
CLASS_HYPERCHROM = 2    # Nuclear Hyperchromatism

def parse_yolo_label(label_path):
    """Parse YOLO label file and count each class."""
    result = {'has_objects': False, 'mitotic': 0, 'nucleol': 0, 'hyperchrom': 0}
    
    if not os.path.exists(label_path):
        return result
    
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                result['has_objects'] = True
                if class_id == CLASS_MITOTIC:
                    result['mitotic'] += 1
                elif class_id == CLASS_NUCLEOL:
                    result['nucleol'] += 1
                elif class_id == CLASS_HYPERCHROM:
                    result['hyperchrom'] += 1
    return result

# Create output directories
os.makedirs(os.path.join(OUTPUT_DATASET_PATH, "images"), exist_ok=True)

# Process all splits
records = []
for split in ['train', 'valid', 'test']:
    images_dir = os.path.join(YOLO_DATASET_PATH, split, 'images')
    labels_dir = os.path.join(YOLO_DATASET_PATH, split, 'labels')
    
    if not os.path.exists(images_dir):
        print(f"  Skipping {split} (not found)")
        continue
    
    image_files = [f for f in os.listdir(images_dir) if Path(f).suffix.lower() in IMAGE_EXTENSIONS]
    print(f"  Processing {split}: {len(image_files)} images")
    
    for img_file in image_files:
        # Copy image
        src_path = os.path.join(images_dir, img_file)
        dst_path = os.path.join(OUTPUT_DATASET_PATH, "images", img_file)
        shutil.copy2(src_path, dst_path)
        
        # Parse labels
        label_file = Path(img_file).stem + '.txt'
        label_path = os.path.join(labels_dir, label_file)
        counts = parse_yolo_label(label_path)
        
        # TVNT: 1 if any abnormality detected
        tvnt = 1 if counts['has_objects'] else 0
        
        records.append({
            'filename': img_file,
            'tvnt': tvnt,
            'mitotic': counts['mitotic'],
            'nucleol': counts['nucleol'],
            'hyperchrom': counts['hyperchrom']
        })

# Save labels.csv
df = pd.DataFrame(records)
csv_path = os.path.join(OUTPUT_DATASET_PATH, "labels.csv")
df.to_csv(csv_path, index=False)

print(f"\n‚úÖ Dataset Conversion Complete!")
print(f"   Total images: {len(df)}")
print(f"   CSV saved to: {csv_path}")

if len(df) > 0:
    print(f"\nüìä Class Distribution:")
    print(f"   TVNT: Normal={sum(df['tvnt']==0)}, Abnormal={sum(df['tvnt']==1)}")
    print(f"   Mitotic Figures:     {sum(df['mitotic'] > 0)} images with annotations, Total count: {df['mitotic'].sum()}")
    print(f"   Multiple Nucleol:    {sum(df['nucleol'] > 0)} images with annotations, Total count: {df['nucleol'].sum()}")
    print(f"   Hyperchromatism:     {sum(df['hyperchrom'] > 0)} images with annotations, Total count: {df['hyperchrom'].sum()}")
else:
    print(f"\n‚ö†Ô∏è WARNING: No images found!")
    print(f"   Please check that '{YOLO_DATASET_PATH}' contains train/valid/test folders with images.")

  Processing train: 474 images
  Processing valid: 44 images
  Processing test: 26 images

‚úÖ Dataset Conversion Complete!
   Total images: 544
   CSV saved to: dataset\labels.csv

üìä Class Distribution:
   TVNT: Normal=18, Abnormal=526
   Mitotic Figures:     37 images with annotations, Total count: 48
   Multiple Nucleol:    369 images with annotations, Total count: 2523
   Hyperchromatism:     428 images with annotations, Total count: 1848


In [19]:
# 1. Imports & Setup
import os
import time
import random
import numpy as np
import pandas as pd  # Added pandas for CSV handling
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import cv2

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


In [20]:
# 2. Model Definition (DenseNet169 Backbone) - Simplified for 3-Class Detection

class OSCCMultiTaskModel(nn.Module):
    """
    Multi-Task Model for OSCC Cell Feature Detection
    
    Tasks:
    1. TVNT: Binary classification (Abnormality detected vs Normal)
    2. Mitotic Figures Count: Regression
    3. Multiple Nucleol Count: Regression
    4. Nuclear Hyperchromatism Count: Regression
    """
    def __init__(self):
        super().__init__()
        # Backbone: DenseNet169
        self.backbone = models.densenet169(weights=models.DenseNet169_Weights.IMAGENET1K_V1)
        num_ftrs = self.backbone.classifier.in_features
        
        # Remove original classifier
        self.backbone.classifier = nn.Identity()
        
        # --- HEADS ---
        
        # 1. TVNT (Binary: Abnormality Detected vs Normal)
        self.head_tvnt = nn.Sequential(
            nn.Linear(num_ftrs, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2) 
        )
        
        # 2. Mitotic Figures Count (Regression)
        self.head_mitotic = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # 3. Multiple Nucleol Count (Regression)
        self.head_nucleol = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # 4. Nuclear Hyperchromatism Count (Regression)
        self.head_hyperchrom = nn.Sequential(
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        # Extract features
        features = self.backbone.features(x)
        
        # Global Average Pooling for Classification/Regression Heads
        pooled = F.relu(features, inplace=True)
        pooled = F.adaptive_avg_pool2d(pooled, (1, 1))
        pooled = torch.flatten(pooled, 1)
        
        # Task Outputs
        return {
            'tvnt': self.head_tvnt(pooled),
            'mitotic': self.head_mitotic(pooled),
            'nucleol': self.head_nucleol(pooled),
            'hyperchrom': self.head_hyperchrom(pooled)
        }

print("Model Architecture Defined (3-Class Cell Feature Detection).")

Model Architecture Defined (3-Class Cell Feature Detection).


In [11]:
# 3. Class Names Mapping

CLASS_NAMES = {
    0: 'Mitotic Figures',
    1: 'Multiple Nucleol',
    2: 'Nuclear Hyperchromatism'
}

print("Class Names Defined:", CLASS_NAMES)

Class Names Defined: {0: 'Mitotic Figures', 1: 'Multiple Nucleol', 2: 'Nuclear Hyperchromatism'}


In [21]:
# 4. Real Dataset Loader (Updated for 3-Class Detection)

from torch.utils.data import random_split

class OSCCRealDataset(Dataset):
    """
    Dataset loader for OSCC Cell Feature Detection
    
    Expected CSV columns:
    - filename: image filename
    - tvnt: 0 (Normal) or 1 (Abnormality detected)
    - mitotic: Count of mitotic figures
    - nucleol: Count of multiple nucleol
    - hyperchrom: Count of nuclear hyperchromatism
    """
    def __init__(self, img_dir, csv_file=None, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        
        # Load Labels
        if csv_file and os.path.exists(csv_file):
            self.df = pd.read_csv(csv_file)
            print(f"Loaded {len(self.df)} samples from {csv_file}")
            
            # Map old column names to new if needed
            column_mapping = {
                'mi': 'mitotic',  # Old adapter used 'mi' for mitotic figures
            }
            for old_col, new_col in column_mapping.items():
                if old_col in self.df.columns and new_col not in self.df.columns:
                    self.df[new_col] = self.df[old_col]
            
            # Ensure all required columns exist
            for col in ['tvnt', 'mitotic', 'nucleol', 'hyperchrom']:
                if col not in self.df.columns:
                    self.df[col] = 0
                    
        else:
            # Fallback: List all images, set labels to default/dummy
            self.image_files = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.png'))] if os.path.exists(img_dir) else []
            self.df = pd.DataFrame({'filename': self.image_files})
            for col in ['tvnt', 'mitotic', 'nucleol', 'hyperchrom']:
                self.df[col] = 0
            print(f"No CSV found. Found {len(self.df)} images in '{img_dir}'. Using placeholder labels.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['filename']
        img_path = os.path.join(self.img_dir, img_name)
        
        # Load Image
        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new('RGB', (224, 224))
        
        # Load Labels
        label_tvnt = int(row.get('tvnt', 0))
        label_mitotic = float(row.get('mitotic', 0.0))
        label_nucleol = float(row.get('nucleol', 0.0))
        label_hyperchrom = float(row.get('hyperchrom', 0.0))

        if self.transform:
            image = self.transform(image)
        
        return {
            'image': image,
            'tvnt': torch.tensor(label_tvnt, dtype=torch.long),
            'mitotic': torch.tensor(label_mitotic, dtype=torch.float),
            'nucleol': torch.tensor(label_nucleol, dtype=torch.float),
            'hyperchrom': torch.tensor(label_hyperchrom, dtype=torch.float)
        }

# Configuration
DATASET_ROOT = "dataset"
IMG_DIR = os.path.join(DATASET_ROOT, "images")
CSV_FILE = os.path.join(DATASET_ROOT, "labels.csv")

# Create directory if it doesn't exist
os.makedirs(IMG_DIR, exist_ok=True)

# Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Initialize Full Dataset (with train transforms initially for splitting)
full_dataset = OSCCRealDataset(IMG_DIR, CSV_FILE, transform=train_transform)

if len(full_dataset) > 0:
    # Split into train (80%) and validation (20%)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Note: For proper validation, we should use val_transform
    # This requires a wrapper dataset class (simplified here for demonstration)
    
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    print("‚úÖ Dataset Split Successfully.")
    print(f"   Training samples:   {len(train_dataset)}")
    print(f"   Validation samples: {len(val_dataset)}")
    
    # Print class distribution
    print(f"\nüìä Dataset Statistics:")
    df = full_dataset.df
    print(f"   TVNT Distribution: Normal={sum(df['tvnt']==0)}, Abnormal={sum(df['tvnt']==1)}")
    print(f"   Mitotic Figures:   Min={df['mitotic'].min()}, Max={df['mitotic'].max()}, Mean={df['mitotic'].mean():.2f}")
    print(f"   Multiple Nucleol:  Min={df['nucleol'].min()}, Max={df['nucleol'].max()}, Mean={df['nucleol'].mean():.2f}")
    print(f"   Hyperchromatism:   Min={df['hyperchrom'].min()}, Max={df['hyperchrom'].max()}, Mean={df['hyperchrom'].mean():.2f}")
else:
    print("‚ö†Ô∏è Dataset folder is empty. Please run adapter_script.py first.")
    train_loader = None
    val_loader = None

Loaded 544 samples from dataset\labels.csv
‚úÖ Dataset Split Successfully.
   Training samples:   435
   Validation samples: 109

üìä Dataset Statistics:
   TVNT Distribution: Normal=18, Abnormal=526
   Mitotic Figures:   Min=0, Max=4, Mean=0.09
   Multiple Nucleol:  Min=0, Max=36, Mean=4.64
   Hyperchromatism:   Min=0, Max=42, Mean=3.40


## 4. Dataset Configuration (3-Class Cell Feature Detection)
**Instructions for User:**
To train on real data, run `adapter_script.py` first, which will:
1. Copy all images to `dataset/images/`
2. Create `dataset/labels.csv` with the following columns:
    *   `filename`: e.g., "OSCC_400x_1_jpg.rf.xxxxx.jpg"
    *   `tvnt`: 0 (Normal) or 1 (Abnormality detected - any of the 3 classes present)
    *   `mitotic`: Count of Mitotic Figures in the image
    *   `nucleol`: Count of Multiple Nucleol in the image
    *   `hyperchrom`: Count of Nuclear Hyperchromatism in the image

**Classes:**
- **Mitotic Figures**: Cells undergoing mitosis (cell division)
- **Multiple Nucleol**: Cells with multiple nucleoli
- **Nuclear Hyperchromatism**: Cells with abnormally dark nuclei

In [22]:
# 5. Training Loop (Updated for 3-Class Detection with Validation)

model = OSCCMultiTaskModel().to(DEVICE)

# Check if pretrained model exists (for resume training)
# Note: Since architecture changed, we start fresh
print("üÜï Starting training with new 3-class architecture...")

optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)

# Loss Functions
criterion_cls = nn.CrossEntropyLoss()  # For TVNT (binary classification)
criterion_reg = nn.MSELoss()           # For count regression tasks

NUM_EPOCHS = 50
best_val_loss = float('inf')

def calculate_batch_loss(outputs, batch, device):
    """Calculate total loss for a batch."""
    target_tvnt = batch['tvnt'].to(device)
    target_mitotic = batch['mitotic'].to(device).unsqueeze(1)
    target_nucleol = batch['nucleol'].to(device).unsqueeze(1)
    target_hyperchrom = batch['hyperchrom'].to(device).unsqueeze(1)
    
    loss_tvnt = criterion_cls(outputs['tvnt'], target_tvnt)
    loss_mitotic = criterion_reg(outputs['mitotic'], target_mitotic)
    loss_nucleol = criterion_reg(outputs['nucleol'], target_nucleol)
    loss_hyperchrom = criterion_reg(outputs['hyperchrom'], target_hyperchrom)
    
    total_loss = (2.0 * loss_tvnt + 
                  0.5 * loss_mitotic + 
                  0.5 * loss_nucleol + 
                  0.5 * loss_hyperchrom)
    
    return total_loss, loss_tvnt, loss_mitotic + loss_nucleol + loss_hyperchrom

if train_loader is None:
    print("‚ùå Cannot train: No data loaded. Please run adapter_script.py first.")
else:
    print(f"Starting Training Loop for {NUM_EPOCHS} epochs...")
    print(f"Training on {len(train_dataset)} samples, Validating on {len(val_dataset)} samples\n")
    
    history = {'train_loss': [], 'val_loss': [], 'train_tvnt': [], 'val_tvnt': []}
    
    for epoch in range(NUM_EPOCHS):
        # ============ Training Phase ============
        model.train()
        running_loss = 0.0
        running_tvnt_loss = 0.0
        running_count_loss = 0.0
        
        for batch in train_loader:
            images = batch['image'].to(DEVICE)
            optimizer.zero_grad()
            
            outputs = model(images)
            total_loss, loss_tvnt, loss_counts = calculate_batch_loss(outputs, batch, DEVICE)
            
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            running_tvnt_loss += loss_tvnt.item()
            running_count_loss += loss_counts.item()
        
        avg_train_loss = running_loss / len(train_loader)
        avg_train_tvnt = running_tvnt_loss / len(train_loader)
        avg_train_count = running_count_loss / len(train_loader)
        
        # ============ Validation Phase ============
        model.eval()
        val_loss = 0.0
        val_tvnt_loss = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(DEVICE)
                outputs = model(images)
                total_loss, loss_tvnt, _ = calculate_batch_loss(outputs, batch, DEVICE)
                val_loss += total_loss.item()
                val_tvnt_loss += loss_tvnt.item()
        
        avg_val_loss = val_loss / len(val_loader)
        avg_val_tvnt = val_tvnt_loss / len(val_loader)
        
        # Track history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['train_tvnt'].append(avg_train_tvnt)
        history['val_tvnt'].append(avg_val_tvnt)
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "model_a_best.pth")
            best_marker = " ‚≠ê (Best)"
        else:
            best_marker = ""
        
        scheduler.step()
        
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | "
              f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | "
              f"TVNT: {avg_train_tvnt:.4f}/{avg_val_tvnt:.4f}{best_marker}")

    print("\n‚úÖ Training Complete!")
    print(f"   Best Validation Loss: {best_val_loss:.4f}")

üÜï Starting training with new 3-class architecture...
Starting Training Loop for 50 epochs...
Training on 435 samples, Validating on 109 samples

Epoch [1/50] | Train Loss: 33.5083 | Val Loss: 28.9263 | TVNT: 0.1675/0.1584 ‚≠ê (Best)
Epoch [2/50] | Train Loss: 24.2585 | Val Loss: 25.3663 | TVNT: 0.1353/0.1613 ‚≠ê (Best)
Epoch [3/50] | Train Loss: 21.7368 | Val Loss: 24.9232 | TVNT: 0.1418/0.1842 ‚≠ê (Best)
Epoch [4/50] | Train Loss: 23.5644 | Val Loss: 22.2236 | TVNT: 0.1406/0.1491 ‚≠ê (Best)
Epoch [5/50] | Train Loss: 20.7442 | Val Loss: 22.3582 | TVNT: 0.1320/0.1774
Epoch [6/50] | Train Loss: 16.9058 | Val Loss: 22.6454 | TVNT: 0.1164/0.1378
Epoch [7/50] | Train Loss: 15.4374 | Val Loss: 19.8306 | TVNT: 0.1076/0.1655 ‚≠ê (Best)
Epoch [8/50] | Train Loss: 12.2294 | Val Loss: 18.1739 | TVNT: 0.1105/0.1396 ‚≠ê (Best)
Epoch [9/50] | Train Loss: 13.0885 | Val Loss: 15.8438 | TVNT: 0.1061/0.1394 ‚≠ê (Best)
Epoch [10/50] | Train Loss: 11.9894 | Val Loss: 17.3265 | TVNT: 0.1069/0.1412
Epoc

In [23]:
# 6. Evaluation Metrics (Comprehensive Metrics for Each Task)

from sklearn.metrics import (
    roc_auc_score, f1_score, accuracy_score, precision_score, recall_score,
    confusion_matrix, classification_report,
    mean_absolute_error, mean_squared_error, r2_score
)
import warnings
warnings.filterwarnings('ignore')

def evaluate_model(model, dataloader, device, dataset_name="Validation"):
    """
    Comprehensive evaluation of the multi-task model.
    
    Metrics by Task:
    ================
    TVNT (Binary Classification):
        - AUC-ROC Score: Area under ROC curve (tile-level)
        - F1 Score: Harmonic mean of precision and recall
        - Accuracy, Precision, Recall
        - Confusion Matrix
    
    Count Regression (Mitotic, Nucleol, Hyperchromatism):
        - MAE: Mean Absolute Error
        - RMSE: Root Mean Squared Error
        - R¬≤ Score: Coefficient of determination
        - Count Accuracy: Exact match, within ¬±1, within ¬±2
        - Mean Count Error: Average signed error (bias indicator)
    """
    model.eval()
    
    # Storage for predictions and targets
    all_tvnt_probs = []
    all_tvnt_preds = []
    all_tvnt_targets = []
    
    all_mitotic_preds = []
    all_mitotic_targets = []
    
    all_nucleol_preds = []
    all_nucleol_targets = []
    
    all_hyperchrom_preds = []
    all_hyperchrom_targets = []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            
            # Get model outputs
            outputs = model(images)
            
            # TVNT predictions
            tvnt_probs = F.softmax(outputs['tvnt'], dim=1)[:, 1].cpu().numpy()
            tvnt_preds = torch.argmax(outputs['tvnt'], dim=1).cpu().numpy()
            tvnt_targets = batch['tvnt'].numpy()
            
            all_tvnt_probs.extend(tvnt_probs)
            all_tvnt_preds.extend(tvnt_preds)
            all_tvnt_targets.extend(tvnt_targets)
            
            # Count predictions (regression)
            mitotic_preds = outputs['mitotic'].cpu().numpy().flatten()
            mitotic_targets = batch['mitotic'].numpy()
            all_mitotic_preds.extend(mitotic_preds)
            all_mitotic_targets.extend(mitotic_targets)
            
            nucleol_preds = outputs['nucleol'].cpu().numpy().flatten()
            nucleol_targets = batch['nucleol'].numpy()
            all_nucleol_preds.extend(nucleol_preds)
            all_nucleol_targets.extend(nucleol_targets)
            
            hyperchrom_preds = outputs['hyperchrom'].cpu().numpy().flatten()
            hyperchrom_targets = batch['hyperchrom'].numpy()
            all_hyperchrom_preds.extend(hyperchrom_preds)
            all_hyperchrom_targets.extend(hyperchrom_targets)
    
    # Convert to numpy arrays
    all_tvnt_probs = np.array(all_tvnt_probs)
    all_tvnt_preds = np.array(all_tvnt_preds)
    all_tvnt_targets = np.array(all_tvnt_targets)
    
    all_mitotic_preds = np.array(all_mitotic_preds)
    all_mitotic_targets = np.array(all_mitotic_targets)
    
    all_nucleol_preds = np.array(all_nucleol_preds)
    all_nucleol_targets = np.array(all_nucleol_targets)
    
    all_hyperchrom_preds = np.array(all_hyperchrom_preds)
    all_hyperchrom_targets = np.array(all_hyperchrom_targets)
    
    # ============ TVNT Metrics (Classification) ============
    print("=" * 65)
    print(f"üìä TVNT (Abnormality Detection) - {dataset_name} Set")
    print("=" * 65)
    
    # AUC-ROC (requires both classes present)
    try:
        if len(np.unique(all_tvnt_targets)) > 1:
            tvnt_auc = roc_auc_score(all_tvnt_targets, all_tvnt_probs)
            print(f"   üéØ Tile-level AUC-ROC:   {tvnt_auc:.4f}")
        else:
            print(f"   üéØ Tile-level AUC-ROC:   N/A (only one class in data)")
            tvnt_auc = None
    except:
        tvnt_auc = None
        print(f"   üéØ Tile-level AUC-ROC:   N/A")
    
    tvnt_f1 = f1_score(all_tvnt_targets, all_tvnt_preds, average='binary', zero_division=0)
    tvnt_acc = accuracy_score(all_tvnt_targets, all_tvnt_preds)
    tvnt_precision = precision_score(all_tvnt_targets, all_tvnt_preds, zero_division=0)
    tvnt_recall = recall_score(all_tvnt_targets, all_tvnt_preds, zero_division=0)
    
    print(f"   üéØ F1 Score:             {tvnt_f1:.4f}")
    print(f"   üìà Accuracy:             {tvnt_acc:.4f} ({int(tvnt_acc*len(all_tvnt_targets))}/{len(all_tvnt_targets)})")
    print(f"   üìà Precision:            {tvnt_precision:.4f}")
    print(f"   üìà Recall/Sensitivity:   {tvnt_recall:.4f}")
    
    # Confusion Matrix
    cm = confusion_matrix(all_tvnt_targets, all_tvnt_preds)
    print(f"\n   Confusion Matrix:")
    print(f"                      Predicted")
    print(f"                    Normal  Abnormal")
    if len(cm) > 1:
        print(f"   Actual Normal     {cm[0][0]:5d}    {cm[0][1]:5d}")
        print(f"   Actual Abnormal   {cm[1][0]:5d}    {cm[1][1]:5d}")
    else:
        print(f"   (Only one class present in targets)")
    
    # ============ Regression Metrics Function ============
    def calc_regression_metrics(preds, targets, task_name, class_label):
        """Calculate regression metrics for count prediction tasks."""
        print(f"\n{'=' * 65}")
        print(f"üìä {task_name} ({class_label}) - Count Metrics")
        print("=" * 65)
        
        # Clip predictions to non-negative (counts can't be negative)
        preds_clipped = np.maximum(preds, 0)
        preds_rounded = np.round(preds_clipped)
        
        # Mean Absolute Error (like mm for DOI)
        mae = mean_absolute_error(targets, preds_clipped)
        print(f"   üéØ Mean Absolute Error (MAE):     {mae:.4f} cells")
        
        # Root Mean Squared Error
        rmse = np.sqrt(mean_squared_error(targets, preds_clipped))
        print(f"   üìà Root Mean Squared Error:       {rmse:.4f}")
        
        # R¬≤ Score (like ICC for agreement)
        try:
            if np.var(targets) > 0:
                r2 = r2_score(targets, preds_clipped)
                print(f"   üìà R¬≤ Score (Agreement):          {r2:.4f}")
            else:
                r2 = None
                print(f"   üìà R¬≤ Score:                      N/A (no variance)")
        except:
            r2 = None
            print(f"   üìà R¬≤ Score:                      N/A")
        
        # Count-specific metrics (like Average Precision for MI)
        exact_matches = np.sum(preds_rounded == targets)
        within_1 = np.sum(np.abs(preds_rounded - targets) <= 1)
        within_2 = np.sum(np.abs(preds_rounded - targets) <= 2)
        total = len(targets)
        
        print(f"\n   Count Precision Metrics:")
        print(f"   üéØ Exact Match Rate:              {exact_matches/total*100:.2f}% ({exact_matches}/{total})")
        print(f"   üìà Within ¬±1 Count:               {within_1/total*100:.2f}% ({within_1}/{total})")
        print(f"   üìà Within ¬±2 Count:               {within_2/total*100:.2f}% ({within_2}/{total})")
        
        # Mean Count Error (bias indicator)
        mean_error = np.mean(preds_rounded - targets)
        print(f"   üìä Mean Count Error (Bias):       {mean_error:+.4f}")
        
        # Distribution summary
        print(f"\n   Distribution Summary:")
        print(f"   Ground Truth: [{targets.min():.0f} - {targets.max():.0f}], Mean: {targets.mean():.2f}, Std: {targets.std():.2f}")
        print(f"   Predictions:  [{preds_clipped.min():.2f} - {preds_clipped.max():.2f}], Mean: {preds_clipped.mean():.2f}")
        
        return {
            'mae': mae,
            'rmse': rmse,
            'r2': r2,
            'exact_match_rate': exact_matches/total,
            'within_1_rate': within_1/total,
            'within_2_rate': within_2/total,
            'mean_error': mean_error
        }
    
    # Calculate metrics for each count task
    mitotic_metrics = calc_regression_metrics(
        all_mitotic_preds, all_mitotic_targets, 
        "Mitotic Figures Count", "Class 0"
    )
    nucleol_metrics = calc_regression_metrics(
        all_nucleol_preds, all_nucleol_targets, 
        "Multiple Nucleol Count", "Class 1"
    )
    hyperchrom_metrics = calc_regression_metrics(
        all_hyperchrom_preds, all_hyperchrom_targets, 
        "Nuclear Hyperchromatism Count", "Class 2"
    )
    
    # ============ Summary Table ============
    print(f"\n{'=' * 65}")
    print(f"üìã EVALUATION SUMMARY ({dataset_name} Set)")
    print("=" * 65)
    print(f"\n   {'Task':<35} | {'Metric':<15} | {'Score':>10}")
    print(f"   {'-'*35}-+-{'-'*15}-+-{'-'*10}")
    print(f"   {'TVNT (Classification)':<35} | {'F1 Score':<15} | {tvnt_f1:>10.4f}")
    if tvnt_auc:
        print(f"   {'TVNT (Classification)':<35} | {'AUC-ROC':<15} | {tvnt_auc:>10.4f}")
    print(f"   {'TVNT (Classification)':<35} | {'Accuracy':<15} | {tvnt_acc:>10.4f}")
    print(f"   {'-'*35}-+-{'-'*15}-+-{'-'*10}")
    print(f"   {'Mitotic Figures (Regression)':<35} | {'MAE':<15} | {mitotic_metrics['mae']:>10.4f}")
    print(f"   {'Mitotic Figures (Regression)':<35} | {'Exact Match':<15} | {mitotic_metrics['exact_match_rate']*100:>9.2f}%")
    print(f"   {'-'*35}-+-{'-'*15}-+-{'-'*10}")
    print(f"   {'Multiple Nucleol (Regression)':<35} | {'MAE':<15} | {nucleol_metrics['mae']:>10.4f}")
    print(f"   {'Multiple Nucleol (Regression)':<35} | {'Exact Match':<15} | {nucleol_metrics['exact_match_rate']*100:>9.2f}%")
    print(f"   {'-'*35}-+-{'-'*15}-+-{'-'*10}")
    print(f"   {'Nuclear Hyperchrom (Regression)':<35} | {'MAE':<15} | {hyperchrom_metrics['mae']:>10.4f}")
    print(f"   {'Nuclear Hyperchrom (Regression)':<35} | {'Exact Match':<15} | {hyperchrom_metrics['exact_match_rate']*100:>9.2f}%")
    
    return {
        'tvnt': {
            'auc': tvnt_auc,
            'f1': tvnt_f1,
            'accuracy': tvnt_acc,
            'precision': tvnt_precision,
            'recall': tvnt_recall
        },
        'mitotic': mitotic_metrics,
        'nucleol': nucleol_metrics,
        'hyperchrom': hyperchrom_metrics
    }

# Run evaluation on validation set
if val_loader is not None:
    print("\nüîç Running Model Evaluation on Validation Set...\n")
    
    # Load best model for evaluation
    if os.path.exists("model_a_best.pth"):
        model.load_state_dict(torch.load("model_a_best.pth", map_location=DEVICE))
        print("‚úÖ Loaded best model (model_a_best.pth) for evaluation\n")
    
    val_metrics = evaluate_model(model, val_loader, DEVICE, "Validation")
    
    # Optionally evaluate on training set too
    print("\n" + "="*65)
    print("üìä Training Set Evaluation (for comparison)")
    print("="*65 + "\n")
    train_metrics = evaluate_model(model, train_loader, DEVICE, "Training")
else:
    print("‚ùå Cannot evaluate: No data loaded.")


üîç Running Model Evaluation on Validation Set...

‚úÖ Loaded best model (model_a_best.pth) for evaluation

üìä TVNT (Abnormality Detection) - Validation Set
   üéØ Tile-level AUC-ROC:   0.6476
   üéØ F1 Score:             0.9813
   üìà Accuracy:             0.9633 (105/109)
   üìà Precision:            0.9633
   üìà Recall/Sensitivity:   1.0000

   Confusion Matrix:
                      Predicted
                    Normal  Abnormal
   Actual Normal         0        4
   Actual Abnormal       0      105

üìä Mitotic Figures Count (Class 0) - Count Metrics
   üéØ Mean Absolute Error (MAE):     0.1436 cells
   üìà Root Mean Squared Error:       0.2961
   üìà R¬≤ Score (Agreement):          -0.0154

   Count Precision Metrics:
   üéØ Exact Match Rate:              93.58% (102/109)
   üìà Within ¬±1 Count:               99.08% (108/109)
   üìà Within ¬±2 Count:               100.00% (109/109)
   üìä Mean Count Error (Bias):       -0.0734

   Distribution Summary:
   Groun

In [24]:
# 7. Export Model
save_path = "model_a.pth"
torch.save(model.state_dict(), save_path)
print(f"‚úÖ Final model saved to {save_path}")

# Also confirm best model exists
if os.path.exists("model_a_best.pth"):
    print(f"‚úÖ Best validation model saved to model_a_best.pth")
    print(f"\nüìù Note: Use 'model_a_best.pth' for inference (best validation performance)")
else:
    print(f"üìù Note: No separate best model saved (use model_a.pth)")

‚úÖ Final model saved to model_a.pth
‚úÖ Best validation model saved to model_a_best.pth

üìù Note: Use 'model_a_best.pth' for inference (best validation performance)
