# MRI vs Breast Histopathology - Part 2: Model & Training (One-Layer CNN)

This notebook implements the model and training pipeline for a binary classifier distinguishing MRI vs BreastHisto images, using a simple one-layer CNN.

**Workflow:**
- Part 1 (04_mri_vs_breasthisto_part1_data_preparation.ipynb): Data organization, grayscale conversion, and train/val/test splitting
- Part 2 (this notebook): Model training with preprocessed grayscale images converted to 3-channel format

**Data Pipeline:**
1. Load preprocessed grayscale PNG images from data/processed
2. Convert grayscale to 3-channel (repeat across RGB channels)
3. Apply ImageNet normalization for model compatibility
4. Train a one-layer CNN with BCE loss for binary classification

In [None]:
# Standard libs
from __future__ import annotations
import os, sys, json, random
from dataclasses import dataclass, asdict
from pathlib import Path
from datetime import datetime
from typing import Dict

import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Paths
PROJECT_ROOT = Path.cwd().parent
DATA_PROCESSED = PROJECT_ROOT / 'data' / 'processed'
MODELS_DIR = PROJECT_ROOT / 'models'
OUTPUTS_DIR = PROJECT_ROOT / 'outputs'
MODELS_DIR.mkdir(parents=True, exist_ok=True)
OUTPUTS_DIR.mkdir(parents=True, exist_ok=True)

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

In [None]:
# Load preparation metadata
meta_path = DATA_PROCESSED / 'preparation_metadata.json'
if meta_path.exists():
    with open(meta_path, 'r') as f:
        PREP_META = json.load(f)
    print('Loaded metadata:', PREP_META.get('processed_data_path'))
else:
    PREP_META = {
        'processed_data_path': str(DATA_PROCESSED),
        'ready_for_training': False
    }
    print('Warning: preparation_metadata.json not found. Using defaults.')

@dataclass
class TrainConfig:
    # Data configuration
    image_size: int = 224
    batch_size: int = 32
    num_workers: int = 0  # Windows-safe
    
    # Training configuration
    epochs: int = 15
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    
    # Scheduler configuration
    scheduler: str = 'cosine'  # 'cosine' | 'plateau'
    
    # Early stopping and model saving
    save_best_metric: str = 'val_f1'  # 'val_accuracy' | 'val_f1' | 'val_auc'
    patience: int = 5

CONFIG = TrainConfig()
CLASS_NAMES = ['MRI', 'BreastHisto']
CLASS_TO_IDX = {c:i for i,c in enumerate(CLASS_NAMES)}
IDX_TO_CLASS = {i:c for c,i in CLASS_TO_IDX.items()}

print('Training Configuration:')
print(f'  Image Size: {CONFIG.image_size}x{CONFIG.image_size}')
print(f'  Batch Size: {CONFIG.batch_size}')
print(f'  Max Epochs: {CONFIG.epochs}')
print(f'  Classes: {CLASS_NAMES} -> {CLASS_TO_IDX}')

In [None]:
# ImageNet normalization values for 3-channel compatibility
RGB_MEAN = [0.485, 0.456, 0.406]
RGB_STD = [0.229, 0.224, 0.225]

def load_image(path: str) -> np.ndarray:
    """Load preprocessed grayscale PNG image"""
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        img = np.array(Image.open(path).convert('L'))
    return img

def to_tensor(img: np.ndarray, size: int) -> torch.Tensor:
    """
    Convert preprocessed grayscale image to 3-channel tensor with proper normalization
    Steps:
    1. Resize to target size (if needed)
    2. Normalize to [0, 1]
    3. Convert grayscale to 3-channel by repeating
    4. Apply ImageNet normalization per channel
    """
    # Resize and normalize to [0, 1]
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_LINEAR).astype(np.float32) / 255.0
    
    # Convert to 3-channel by repeating grayscale across RGB channels
    img_3ch = np.stack([img, img, img], axis=0)  # Shape: (3, H, W)
    
    # Apply per-channel normalization (ImageNet stats)
    for i in range(3):
        img_3ch[i] = (img_3ch[i] - RGB_MEAN[i]) / RGB_STD[i]
    
    return torch.from_numpy(img_3ch)

class BinaryMedicalDataset(Dataset):
    """Dataset for loading preprocessed binary medical images"""
    def __init__(self, root: Path, split: str, size: int = 224):
        self.root = Path(root)
        self.split = split
        self.size = size
        self.samples = []
        
        for cls in CLASS_NAMES:
            cdir = self.root / split / cls
            if not cdir.exists():
                continue
            for ext in ('*.png','*.jpg','*.jpeg','*.bmp','*.tif','*.tiff'):
                for p in cdir.glob(ext):
                    self.samples.append((str(p), CLASS_TO_IDX[cls]))
        
        random.shuffle(self.samples)
        self.counts = {c:0 for c in CLASS_NAMES}
        for _,lbl in self.samples:
            self.counts[IDX_TO_CLASS[lbl]] += 1
        print(f"Loaded {split}: {len(self.samples)} | {self.counts}")

    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        try:
            img = load_image(path)
        except Exception:
            img = np.zeros((CONFIG.image_size, CONFIG.image_size), dtype=np.uint8)
        x = to_tensor(img, CONFIG.image_size)
        y = torch.tensor(float(label), dtype=torch.float32)
        return x, y

def create_loaders(data_root: Path, cfg: TrainConfig):
    ds_tr = BinaryMedicalDataset(data_root, 'train', cfg.image_size)
    ds_va = BinaryMedicalDataset(data_root, 'val', cfg.image_size)
    ds_te = BinaryMedicalDataset(data_root, 'test', cfg.image_size)
    dl_tr = DataLoader(ds_tr, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    dl_te = DataLoader(ds_te, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    return {'train': dl_tr, 'val': dl_va, 'test': dl_te}

class OneLayerCNN(nn.Module):
    """Simple CNN with 1 convolutional layer for binary classification"""
    def __init__(self, in_channels: int = 3, num_classes: int = 1):
        super(OneLayerCNN, self).__init__()
        
        # Feature extraction: 1 convolutional block
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16), # Adding BatchNorm is a good practice
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 224x224 -> 112x112
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5), # Adding Dropout is a good practice
            # The number of input features to the linear layer is 16 * 112 * 112
            nn.Linear(16 * 112 * 112, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        """Forward pass"""
        x = self.features(x)
        x = self.classifier(x)
        return x.squeeze(-1)  # Remove last dimension for binary classification

In [None]:
def compute_metrics(y_true: torch.Tensor, y_prob: torch.Tensor):
    y_pred = (y_prob >= 0.5).float()
    yt = y_true.cpu().numpy(); yp = y_pred.cpu().numpy(); ypb = y_prob.cpu().numpy()
    acc = accuracy_score(yt, yp)
    pr, rc, f1, _ = precision_recall_fscore_support(yt, yp, average='binary', zero_division=0)
    try:
        auc = roc_auc_score(yt, ypb)
    except Exception:
        auc = 0.0
    return {'accuracy': acc, 'precision': pr, 'recall': rc, 'f1': f1, 'auc': auc}

def train_one_epoch(model, loader, criterion, optimizer):
    """Train model for one epoch"""
    model.train()
    total_loss = 0.0
    y_true_list = []
    y_prob_list = []
    
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        y_true_list.append(y.detach().cpu())
        y_prob_list.append(torch.sigmoid(logits).detach().cpu())
    
    # Compute epoch metrics
    y_true_tensor = torch.cat(y_true_list)
    y_prob_tensor = torch.cat(y_prob_list)
    metrics = compute_metrics(y_true_tensor, y_prob_tensor)
    metrics['loss'] = total_loss / len(loader)
    
    return metrics

def evaluate(model, loader, criterion):
    """Evaluate model"""
    model.eval()
    total_loss = 0.0
    y_true_list = []
    y_prob_list = []
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            # Forward pass
            logits = model(x)
            loss = criterion(logits, y)
            
            # Track metrics
            total_loss += loss.item()
            y_true_list.append(y.detach().cpu())
            y_prob_list.append(torch.sigmoid(logits).detach().cpu())
    
    # Compute metrics
    y_true_tensor = torch.cat(y_true_list)
    y_prob_tensor = torch.cat(y_prob_list)
    metrics = compute_metrics(y_true_tensor, y_prob_tensor)
    metrics['loss'] = total_loss / len(loader)
    
    return metrics

def train_pipeline():
    """Main training pipeline"""
    print("Starting Training Pipeline")
    print("=" * 40)
    
    # Load data
    print("Loading datasets...")
    data_root = Path(PREP_META.get('processed_data_path', DATA_PROCESSED))
    loaders = create_loaders(data_root, CONFIG)
    
    # Initialize model
    print("Initializing model...")
    model = OneLayerCNN(in_channels=3, num_classes=1).to(DEVICE)
    
    # Initialize training components
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG.learning_rate, weight_decay=CONFIG.weight_decay)
    
    if CONFIG.scheduler == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=CONFIG.epochs, eta_min=CONFIG.learning_rate*0.01
        )
    else:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3
        )
    
    # Training setup
    history = {
        'train_loss': [], 'train_accuracy': [], 'train_f1': [],
        'val_loss': [], 'val_accuracy': [], 'val_f1': [], 'val_auc': []
    }
    
    best_metric = 0.0
    patience_counter = 0
    best_path = MODELS_DIR / 'best_onelayercnn_binary.pth'
    
    print("Starting training...")
    
    # Training loop
    for epoch in range(CONFIG.epochs):
        print(f"Epoch {epoch+1}/{CONFIG.epochs}")
        
        # Training phase
        train_metrics = train_one_epoch(model, loaders['train'], criterion, optimizer)
        
        # Validation phase
        val_metrics = evaluate(model, loaders['val'], criterion)
        
        # Learning rate update
        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_metrics['loss'])
        else:
            scheduler.step()
        
        # Update history
        history['train_loss'].append(train_metrics['loss'])
        history['train_accuracy'].append(train_metrics['accuracy'])
        history['train_f1'].append(train_metrics['f1'])
        
        history['val_loss'].append(val_metrics['loss'])
        history['val_accuracy'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1'])
        history['val_auc'].append(val_metrics['auc'])
        
        # Select metric for best model saving
        if CONFIG.save_best_metric == 'val_f1':
            current_metric = val_metrics['f1']
        elif CONFIG.save_best_metric == 'val_accuracy':
            current_metric = val_metrics['accuracy']
        else:
            current_metric = val_metrics['auc']
        
        # Best model saving and early stopping
        if current_metric > best_metric + 1e-4:
            best_metric = current_metric
            patience_counter = 0
            
            # Save best model
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'config': asdict(CONFIG),
                'history': history,
                'best_metric': best_metric,
                'best_metric_name': CONFIG.save_best_metric
            }
            torch.save(checkpoint, best_path)
            print(f"New best {CONFIG.save_best_metric}: {current_metric:.4f} - Model saved!")
        else:
            patience_counter += 1
        
        # Epoch summary
        print(f"Train -> Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.3f}, F1: {train_metrics['f1']:.3f}")
        print(f"Val   -> Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.3f}, F1: {val_metrics['f1']:.3f}, AUC: {val_metrics['auc']:.3f}")
        print(f"Best {CONFIG.save_best_metric}: {best_metric:.4f}, Patience: {patience_counter}/{CONFIG.patience}")
        
        # Early stopping check
        if patience_counter >= CONFIG.patience:
            print(f"Early stopping triggered after {patience_counter} epochs without improvement")
            break
    
    # Training completion
    print(f"\nTraining Complete!")
    print(f"Best {CONFIG.save_best_metric}: {best_metric:.4f}")
    print(f"Best model saved to: {best_path}")
    
    # Save training history
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    history_path = OUTPUTS_DIR / f"training_history_onelayer_{timestamp}.json"
    
    # Convert numpy types to Python types for JSON serialization
    history_serializable = {}
    for key, values in history.items():
        history_serializable[key] = [float(v) if hasattr(v, 'item') else v for v in values]
    
    training_summary = {
        'config': asdict(CONFIG),
        'training_history': history_serializable,
        'best_metric': float(best_metric),
        'best_metric_name': CONFIG.save_best_metric,
        'total_epochs': epoch + 1,
        'timestamp': timestamp
    }
    
    with open(history_path, 'w') as f:
        json.dump(training_summary, f, indent=2)
    
    print(f"Training history saved to: {history_path}")
    
    return history

In [None]:
# Run training
if __name__ == "__main__":
    history = train_pipeline()
    print('\nTraining complete. Best model saved in models/.')