In [25]:
# improved_training_breast_patient_level.py
import os
import random
import math
from collections import Counter, defaultdict

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

import torchvision
from torchvision import transforms, models

In [26]:
# Settings / Hyperparams
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

DATASET_DIR = r"Dataset/BreaKHis/breast"
MAGNIFICATION = "400X"
BATCH_SIZE = 32               
NUM_WORKERS = 0
NUM_EPOCHS = 150
USE_AMP = True
VALID_SPLIT = 0.15
TEST_SPLIT = 0.15

RESIZE_SIZE = (224, 224)      
NUM_CLASSES = 8

USE_PRETRAINED = True         
USE_MIXUP = True              
MIXUP_ALPHA = 0.4
EARLY_STOPPING_PATIENCE = 20
GRAD_CLIP = 2.0

CHECKPOINT_PATH = "best_model.pt"

In [27]:
# Class mapping

CLASS_ABBR_MAP = {
    "adenosis": "A", "fibroadenoma": "F", "phyllodes_tumor": "PT", "tubular_adenoma": "TA",
    "ductal_carcinoma": "DC", "lobular_carcinoma": "LC", "mucinous_carcinoma": "MC", "papillary_carcinoma": "PC"
}
class_index_map = {name: idx for idx, name in enumerate(CLASS_ABBR_MAP.keys())}

In [28]:
# Dataset
class BreakHisDataset(Dataset):
    def __init__(self, root_dir, magnification, transform=None):
        self.samples = []
        self.transform = transform
        for binary_class in ["benign", "malignant"]:
            sob_path = os.path.join(root_dir, binary_class, "SOB")
            if not os.path.exists(sob_path):
                continue
            for class_folder in os.listdir(sob_path):
                class_name = class_folder.lower()
                if class_name not in class_index_map:
                    # sometimes folder names differ - try to map by key contains
                    matched = None
                    for key in class_index_map:
                        if key in class_name:
                            matched = key
                            break
                    if matched:
                        class_name = matched
                    else:
                        continue
                label = class_index_map[class_name]
                class_path = os.path.join(sob_path, class_folder)
                if not os.path.isdir(class_path):
                    continue
                for patient_folder in os.listdir(class_path):
                    patient_path = os.path.join(class_path, patient_folder)
                    if not os.path.isdir(patient_path):
                        continue
                    mag_path = os.path.join(patient_path, magnification)
                    if os.path.exists(mag_path):
                        images = [f for f in os.listdir(mag_path) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
                        for img_name in images:
                            img_path = os.path.join(mag_path, img_name)
                            self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image = image.resize(RESIZE_SIZE, Image.LANCZOS)
        if self.transform:
            image = self.transform(image)
        return image, label

    def get_labels(self):
        return [label for _, label in self.samples]

In [29]:
# Augmentations / Transforms
# Stronger train augmentation set
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(RESIZE_SIZE[0], scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(180),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize(RESIZE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [30]:
# MixUp helper
def mixup_data(x, y, alpha=1.0, device='cpu'):
    """Returns mixed inputs, pairs of targets, and lambda"""
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, preds, y_a, y_b, lam):
    return lam * criterion(preds, y_a) + (1 - lam) * criterion(preds, y_b)

In [31]:
# Prepare dataset + stratified split
full_dataset = BreakHisDataset(DATASET_DIR, MAGNIFICATION, transform=None)
all_labels = full_dataset.get_labels()
n_total = len(full_dataset)
print(f"[DATA] total samples: {n_total}")

# stratified split indices
def stratified_split_indices(labels, test_frac=0.15, val_frac=0.15, seed=SEED):
    labels = np.array(labels)
    np.random.seed(seed)
    unique_classes = np.unique(labels)
    train_idx, val_idx, test_idx = [], [], []
    for c in unique_classes:
        idxs = np.where(labels == c)[0]
        np.random.shuffle(idxs)
        n = len(idxs)
        n_test = int(math.ceil(n * test_frac))
        n_val = int(math.ceil(n * val_frac))
        n_train = n - n_test - n_val
        if n_train < 1:
            # ensure at least one train sample
            n_train = max(1, n - n_test - n_val)
        train_idx.extend(idxs[:n_train])
        val_idx.extend(idxs[n_train:n_train + n_val])
        test_idx.extend(idxs[n_train + n_val:])
    # shuffle each split
    np.random.shuffle(train_idx)
    np.random.shuffle(val_idx)
    np.random.shuffle(test_idx)
    return np.array(train_idx), np.array(val_idx), np.array(test_idx)

train_idx, valid_idx, test_idx = stratified_split_indices(all_labels, test_frac=TEST_SPLIT, val_frac=VALID_SPLIT)
print(f"[DATA] train/val/test sizes: {len(train_idx)}, {len(valid_idx)}, {len(test_idx)}")

train_subset = Subset(full_dataset, train_idx)
valid_subset = Subset(full_dataset, valid_idx)
test_subset = Subset(full_dataset, test_idx)

# wrap to attach transforms
class WrappedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

train_dataset = WrappedSubset(train_subset, train_transform)
valid_dataset = WrappedSubset(valid_subset, test_transform)
test_dataset = WrappedSubset(test_subset, test_transform)

# Weighted sampler to balance classes during training
train_labels = [y for _, y in train_subset]
class_counts = Counter(train_labels)
print(f"[DATA] train class counts: {class_counts}")
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[y] for y in train_labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=True)

[DATA] total samples: 1820
[DATA] train/val/test sizes: 1266, 277, 277
[DATA] train class counts: Counter({4: 550, 1: 165, 6: 117, 7: 96, 5: 95, 3: 90, 2: 79, 0: 74})


In [32]:
# Model (pretrained ResNet18 by default)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_model(num_classes=NUM_CLASSES, pretrained=USE_PRETRAINED):
    if pretrained:
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) if hasattr(models, "ResNet18_Weights") else models.resnet18(pretrained=True)
        # replace final fc
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, num_classes)
        return model
    else:
        # fallback: simple custom small CNN (or you can paste your IRRCNN class here)
        class SmallCNN(nn.Module):
            def __init__(self, num_classes):
                super().__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(3, 32, 3, padding=1),
                    nn.BatchNorm2d(32), nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.BatchNorm2d(64), nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(64, 128, 3, padding=1),
                    nn.BatchNorm2d(128), nn.ReLU(),
                    nn.AdaptiveAvgPool2d((1,1))
                )
                self.fc = nn.Linear(128, num_classes)
            def forward(self, x):
                x = self.features(x)
                x = x.view(x.size(0), -1)
                return self.fc(x)
        return SmallCNN(num_classes)

model = get_model(NUM_CLASSES, USE_PRETRAINED).to(device)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [33]:
# Loss, optimizer, scheduler
# Use label smoothing if available
try:
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
except TypeError:
    criterion = nn.CrossEntropyLoss()

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [34]:
# Train / Eval loops
def train_one_epoch(model, loader, optimizer, criterion, scaler, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        if USE_MIXUP:
            images, y_a, y_b, lam = mixup_data(images, labels, MIXUP_ALPHA, device=device)
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
        else:
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                outputs = model(images)
                loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        # gradient clipping
        if GRAD_CLIP is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        # For MixUp accuracy counting we use y_a (approx); this is approximate
        if USE_MIXUP:
            # compute hard preds accuracy against y_a (biased but okay)
            correct += (preds == y_a).sum().item()
            total += labels.size(0)
        else:
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        pbar.set_postfix(loss=loss.item(), acc=correct / total)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device, desc="Eval"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=desc, leave=False)
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            with torch.amp.autocast("cuda", enabled=USE_AMP):

                outputs = model(images)
                loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            pbar.set_postfix(acc=correct / total)
    return running_loss / total, correct / total

In [35]:
# Main training loop (no early stopping, only save best model)
best_val_acc = 0.0

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch)
    val_loss, val_acc = evaluate(model, valid_loader, criterion, device, desc=f"Epoch {epoch} [val]")
    scheduler.step()

    print(f"Epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_acc': val_acc
        }, CHECKPOINT_PATH)
        print(f"[CHECKPOINT] Saved new best model (val_acc={val_acc:.4f})")

                                                                                      

Epoch 1: train_loss=1.9792 | val_loss=1.9179, val_acc=0.2924
[CHECKPOINT] Saved new best model (val_acc=0.2924)


                                                                                      

Epoch 2: train_loss=1.6775 | val_loss=1.4150, val_acc=0.5596
[CHECKPOINT] Saved new best model (val_acc=0.5596)


                                                                                      

Epoch 3: train_loss=1.5458 | val_loss=1.2927, val_acc=0.6282
[CHECKPOINT] Saved new best model (val_acc=0.6282)


                                                                                       

Epoch 4: train_loss=1.4851 | val_loss=1.2144, val_acc=0.6209


                                                                                       

Epoch 5: train_loss=1.4261 | val_loss=1.1921, val_acc=0.6498
[CHECKPOINT] Saved new best model (val_acc=0.6498)


                                                                                       

Epoch 6: train_loss=1.4063 | val_loss=1.1849, val_acc=0.6390


                                                                                      

Epoch 7: train_loss=1.4590 | val_loss=1.2459, val_acc=0.6390


                                                                                       

Epoch 8: train_loss=1.4272 | val_loss=1.1267, val_acc=0.6895
[CHECKPOINT] Saved new best model (val_acc=0.6895)


                                                                                       

Epoch 9: train_loss=1.3645 | val_loss=1.1720, val_acc=0.6751


                                                                                        

Epoch 10: train_loss=1.2739 | val_loss=1.1973, val_acc=0.6679


                                                                                        

Epoch 11: train_loss=1.4085 | val_loss=1.0944, val_acc=0.6823


                                                                                        

Epoch 12: train_loss=1.2830 | val_loss=1.0255, val_acc=0.7256
[CHECKPOINT] Saved new best model (val_acc=0.7256)


                                                                                        

Epoch 13: train_loss=1.1291 | val_loss=1.1176, val_acc=0.7040


                                                                                        

Epoch 14: train_loss=1.2613 | val_loss=1.1386, val_acc=0.6895


                                                                                        

Epoch 15: train_loss=1.2352 | val_loss=1.0730, val_acc=0.7329
[CHECKPOINT] Saved new best model (val_acc=0.7329)


                                                                                        

Epoch 16: train_loss=1.2795 | val_loss=1.1500, val_acc=0.7112


                                                                                        

Epoch 17: train_loss=1.2582 | val_loss=1.0207, val_acc=0.7726
[CHECKPOINT] Saved new best model (val_acc=0.7726)


                                                                                        

Epoch 18: train_loss=1.2542 | val_loss=1.0851, val_acc=0.7256


                                                                                        

Epoch 19: train_loss=1.1078 | val_loss=0.9993, val_acc=0.7726


                                                                                        

Epoch 20: train_loss=1.1749 | val_loss=1.0390, val_acc=0.7401


                                                                                        

Epoch 21: train_loss=1.2369 | val_loss=1.0472, val_acc=0.7329


                                                                                        

Epoch 22: train_loss=1.1256 | val_loss=1.0262, val_acc=0.7437


                                                                                        

Epoch 23: train_loss=1.1152 | val_loss=1.0039, val_acc=0.7798
[CHECKPOINT] Saved new best model (val_acc=0.7798)


                                                                                        

Epoch 24: train_loss=1.2815 | val_loss=0.9992, val_acc=0.7870
[CHECKPOINT] Saved new best model (val_acc=0.7870)


                                                                                        

Epoch 25: train_loss=1.1267 | val_loss=0.9836, val_acc=0.8014
[CHECKPOINT] Saved new best model (val_acc=0.8014)


                                                                                        

Epoch 26: train_loss=1.1578 | val_loss=1.0156, val_acc=0.7653


                                                                                        

Epoch 27: train_loss=1.1584 | val_loss=0.9334, val_acc=0.7978


                                                                                        

Epoch 28: train_loss=1.1814 | val_loss=1.0010, val_acc=0.7726


                                                                                        

Epoch 29: train_loss=1.1641 | val_loss=0.9758, val_acc=0.7870


                                                                                        

Epoch 30: train_loss=1.0587 | val_loss=0.9374, val_acc=0.7978


                                                                                        

Epoch 31: train_loss=1.0897 | val_loss=0.8368, val_acc=0.8448
[CHECKPOINT] Saved new best model (val_acc=0.8448)


                                                                                        

Epoch 32: train_loss=1.1648 | val_loss=0.9551, val_acc=0.7726


                                                                                        

Epoch 33: train_loss=1.0276 | val_loss=0.9885, val_acc=0.7653


                                                                                        

Epoch 34: train_loss=1.0217 | val_loss=1.0205, val_acc=0.7690


                                                                                        

Epoch 35: train_loss=1.0556 | val_loss=0.9205, val_acc=0.8087


                                                                                        

Epoch 36: train_loss=0.9829 | val_loss=1.0351, val_acc=0.7509


                                                                                        

Epoch 37: train_loss=1.0324 | val_loss=1.0034, val_acc=0.7762


                                                                                        

Epoch 38: train_loss=1.0157 | val_loss=0.8849, val_acc=0.8267


                                                                                        

Epoch 39: train_loss=1.0736 | val_loss=0.9479, val_acc=0.8014


                                                                                        

Epoch 40: train_loss=1.1130 | val_loss=0.9206, val_acc=0.8303


                                                                                        

Epoch 41: train_loss=1.1548 | val_loss=0.9067, val_acc=0.8195


                                                                                        

Epoch 42: train_loss=1.1018 | val_loss=0.9034, val_acc=0.8267


                                                                                        

Epoch 43: train_loss=1.0817 | val_loss=0.8456, val_acc=0.8339


                                                                                        

Epoch 44: train_loss=1.1515 | val_loss=1.0149, val_acc=0.7690


                                                                                        

Epoch 45: train_loss=1.1021 | val_loss=0.9119, val_acc=0.8267


                                                                                        

Epoch 46: train_loss=1.1095 | val_loss=0.9056, val_acc=0.8195


                                                                                        

Epoch 47: train_loss=1.0381 | val_loss=0.8526, val_acc=0.8375


                                                                                        

Epoch 48: train_loss=1.0879 | val_loss=0.8545, val_acc=0.8412


                                                                                        

Epoch 49: train_loss=1.0266 | val_loss=0.8986, val_acc=0.8195


                                                                                        

Epoch 50: train_loss=1.0040 | val_loss=0.9205, val_acc=0.8014


                                                                                        

Epoch 51: train_loss=1.1497 | val_loss=0.8861, val_acc=0.8267


                                                                                        

Epoch 52: train_loss=1.0166 | val_loss=0.9123, val_acc=0.8195


                                                                                        

Epoch 53: train_loss=1.0151 | val_loss=0.8421, val_acc=0.8556
[CHECKPOINT] Saved new best model (val_acc=0.8556)


                                                                                        

Epoch 54: train_loss=1.0101 | val_loss=0.9154, val_acc=0.8267


                                                                                        

Epoch 55: train_loss=1.1141 | val_loss=0.9067, val_acc=0.8087


                                                                                        

Epoch 56: train_loss=1.0292 | val_loss=0.8866, val_acc=0.8375


                                                                                        

Epoch 57: train_loss=1.1072 | val_loss=0.8546, val_acc=0.8267


                                                                                        

Epoch 58: train_loss=1.0000 | val_loss=0.9326, val_acc=0.8159


                                                                                        

Epoch 59: train_loss=0.9988 | val_loss=0.9392, val_acc=0.8159


                                                                                        

Epoch 60: train_loss=0.9509 | val_loss=0.9095, val_acc=0.8339


                                                                                        

Epoch 61: train_loss=1.0917 | val_loss=0.9420, val_acc=0.8159


                                                                                        

Epoch 62: train_loss=0.9301 | val_loss=0.8741, val_acc=0.8448


                                                                                        

Epoch 63: train_loss=1.0917 | val_loss=0.9332, val_acc=0.8159


                                                                                        

Epoch 64: train_loss=1.0437 | val_loss=0.8692, val_acc=0.8556


                                                                                        

Epoch 65: train_loss=1.0039 | val_loss=0.8853, val_acc=0.8520


                                                                                        

Epoch 66: train_loss=0.8878 | val_loss=0.7985, val_acc=0.8556


                                                                                        

Epoch 67: train_loss=0.9941 | val_loss=0.9193, val_acc=0.8375


                                                                                        

Epoch 68: train_loss=1.0614 | val_loss=0.8344, val_acc=0.8484


                                                                                        

Epoch 69: train_loss=1.1419 | val_loss=0.8758, val_acc=0.8448


                                                                                        

Epoch 70: train_loss=1.0773 | val_loss=0.8656, val_acc=0.8520


                                                                                        

Epoch 71: train_loss=1.0273 | val_loss=0.8698, val_acc=0.8412


                                                                                        

Epoch 72: train_loss=0.8473 | val_loss=0.8237, val_acc=0.8520


                                                                                        

Epoch 73: train_loss=1.0171 | val_loss=0.8425, val_acc=0.8700
[CHECKPOINT] Saved new best model (val_acc=0.8700)


                                                                                        

Epoch 74: train_loss=1.0249 | val_loss=0.8661, val_acc=0.8484


                                                                                        

Epoch 75: train_loss=0.9351 | val_loss=0.8360, val_acc=0.8592


                                                                                        

Epoch 76: train_loss=0.9994 | val_loss=0.7905, val_acc=0.8773
[CHECKPOINT] Saved new best model (val_acc=0.8773)


                                                                                        

Epoch 77: train_loss=1.0152 | val_loss=0.8444, val_acc=0.8736


                                                                                        

Epoch 78: train_loss=0.9880 | val_loss=0.8326, val_acc=0.8556


                                                                                        

Epoch 79: train_loss=0.9702 | val_loss=0.8118, val_acc=0.8520


                                                                                        

Epoch 80: train_loss=1.0968 | val_loss=0.8292, val_acc=0.8773


                                                                                        

Epoch 81: train_loss=0.9033 | val_loss=0.8028, val_acc=0.8773


                                                                                        

Epoch 82: train_loss=1.0371 | val_loss=0.8227, val_acc=0.8700


                                                                                        

Epoch 83: train_loss=1.0288 | val_loss=0.8496, val_acc=0.8700


                                                                                        

Epoch 84: train_loss=1.0296 | val_loss=0.8235, val_acc=0.8773


                                                                                        

Epoch 85: train_loss=0.9716 | val_loss=0.8445, val_acc=0.8556


                                                                                        

Epoch 86: train_loss=0.9328 | val_loss=0.8913, val_acc=0.8628


                                                                                        

Epoch 87: train_loss=1.0535 | val_loss=0.8470, val_acc=0.8664


                                                                                        

Epoch 88: train_loss=1.0554 | val_loss=0.8614, val_acc=0.8592


                                                                                        

Epoch 89: train_loss=0.8705 | val_loss=0.8234, val_acc=0.8700


                                                                                        

Epoch 90: train_loss=0.9480 | val_loss=0.8281, val_acc=0.8664


                                                                                        

Epoch 91: train_loss=1.0171 | val_loss=0.8382, val_acc=0.8520


                                                                                        

Epoch 92: train_loss=0.9332 | val_loss=0.8521, val_acc=0.8556


                                                                                        

Epoch 93: train_loss=1.0124 | val_loss=0.8163, val_acc=0.8700


                                                                                        

Epoch 94: train_loss=1.0047 | val_loss=0.8340, val_acc=0.8700


                                                                                        

Epoch 95: train_loss=0.9323 | val_loss=0.7975, val_acc=0.8845
[CHECKPOINT] Saved new best model (val_acc=0.8845)


                                                                                        

Epoch 96: train_loss=0.9465 | val_loss=0.8003, val_acc=0.8773


                                                                                        

Epoch 97: train_loss=0.9921 | val_loss=0.8307, val_acc=0.8809


                                                                                        

Epoch 98: train_loss=1.0984 | val_loss=0.8101, val_acc=0.8953
[CHECKPOINT] Saved new best model (val_acc=0.8953)


                                                                                        

Epoch 99: train_loss=0.9980 | val_loss=0.8154, val_acc=0.8881


                                                                                         

Epoch 100: train_loss=0.9599 | val_loss=0.7958, val_acc=0.8881


                                                                                         

Epoch 101: train_loss=1.0333 | val_loss=0.8412, val_acc=0.8736


                                                                                         

Epoch 102: train_loss=0.8876 | val_loss=0.8273, val_acc=0.8773


                                                                                         

Epoch 103: train_loss=0.9926 | val_loss=0.8212, val_acc=0.8773


                                                                                         

Epoch 104: train_loss=1.0769 | val_loss=0.8171, val_acc=0.8736


                                                                                         

Epoch 105: train_loss=1.0831 | val_loss=0.8416, val_acc=0.8881


                                                                                         

Epoch 106: train_loss=0.9532 | val_loss=0.8260, val_acc=0.8773


                                                                                         

Epoch 107: train_loss=1.0725 | val_loss=0.8675, val_acc=0.8809


                                                                                         

Epoch 108: train_loss=0.9951 | val_loss=0.8670, val_acc=0.8592


                                                                                         

Epoch 109: train_loss=0.9254 | val_loss=0.8255, val_acc=0.8773


                                                                                         

Epoch 110: train_loss=0.9310 | val_loss=0.8386, val_acc=0.8736


                                                                                         

Epoch 111: train_loss=0.9027 | val_loss=0.8560, val_acc=0.8700


                                                                                         

Epoch 112: train_loss=0.9261 | val_loss=0.8368, val_acc=0.8628


                                                                                         

Epoch 113: train_loss=0.9263 | val_loss=0.8537, val_acc=0.8700


                                                                                         

Epoch 114: train_loss=1.0383 | val_loss=0.8735, val_acc=0.8628


                                                                                         

Epoch 115: train_loss=0.9177 | val_loss=0.8218, val_acc=0.8700


                                                                                         

Epoch 116: train_loss=0.9506 | val_loss=0.8483, val_acc=0.8592


                                                                                         

Epoch 117: train_loss=0.9971 | val_loss=0.8168, val_acc=0.8700


                                                                                         

Epoch 118: train_loss=0.9410 | val_loss=0.8295, val_acc=0.8773


                                                                                         

Epoch 119: train_loss=0.9886 | val_loss=0.7954, val_acc=0.8736


                                                                                         

Epoch 120: train_loss=0.8636 | val_loss=0.7906, val_acc=0.8736


                                                                                         

Epoch 121: train_loss=0.9683 | val_loss=0.8133, val_acc=0.8700


                                                                                         

Epoch 122: train_loss=0.9345 | val_loss=0.7886, val_acc=0.8881


                                                                                         

Epoch 123: train_loss=0.9995 | val_loss=0.8237, val_acc=0.8736


                                                                                         

Epoch 124: train_loss=0.9998 | val_loss=0.8357, val_acc=0.8773


                                                                                         

Epoch 125: train_loss=0.9853 | val_loss=0.8219, val_acc=0.8736


                                                                                         

Epoch 126: train_loss=1.0186 | val_loss=0.8489, val_acc=0.8736


                                                                                         

Epoch 127: train_loss=0.8246 | val_loss=0.8043, val_acc=0.8809


                                                                                         

Epoch 128: train_loss=0.9629 | val_loss=0.8381, val_acc=0.8628


                                                                                         

Epoch 129: train_loss=1.0177 | val_loss=0.8281, val_acc=0.8809


                                                                                         

Epoch 130: train_loss=0.9134 | val_loss=0.8213, val_acc=0.8736


                                                                                         

Epoch 131: train_loss=0.9598 | val_loss=0.8084, val_acc=0.8881


                                                                                         

Epoch 132: train_loss=0.9239 | val_loss=0.8337, val_acc=0.8773


                                                                                         

Epoch 133: train_loss=0.9386 | val_loss=0.8090, val_acc=0.8809


                                                                                         

Epoch 134: train_loss=0.9526 | val_loss=0.8123, val_acc=0.8700


                                                                                         

Epoch 135: train_loss=0.9205 | val_loss=0.7894, val_acc=0.8773


                                                                                         

Epoch 136: train_loss=0.9059 | val_loss=0.8055, val_acc=0.8736


                                                                                         

Epoch 137: train_loss=0.9807 | val_loss=0.8377, val_acc=0.8664


                                                                                         

Epoch 138: train_loss=0.9434 | val_loss=0.7922, val_acc=0.8773


                                                                                         

Epoch 139: train_loss=0.9283 | val_loss=0.7982, val_acc=0.8736


                                                                                         

Epoch 140: train_loss=1.0319 | val_loss=0.8241, val_acc=0.8736


                                                                                         

Epoch 141: train_loss=0.9614 | val_loss=0.8157, val_acc=0.8809


                                                                                         

Epoch 142: train_loss=0.9018 | val_loss=0.8136, val_acc=0.8628


                                                                                         

Epoch 143: train_loss=1.0308 | val_loss=0.8395, val_acc=0.8700


                                                                                         

Epoch 144: train_loss=0.9692 | val_loss=0.8456, val_acc=0.8773


                                                                                         

Epoch 145: train_loss=0.8454 | val_loss=0.7942, val_acc=0.8773


                                                                                         

Epoch 146: train_loss=0.9620 | val_loss=0.8128, val_acc=0.8809


                                                                                         

Epoch 147: train_loss=0.9253 | val_loss=0.7928, val_acc=0.8809


                                                                                         

Epoch 148: train_loss=0.9737 | val_loss=0.8189, val_acc=0.8809


                                                                                         

Epoch 149: train_loss=1.0760 | val_loss=0.8354, val_acc=0.8773


                                                                                         

Epoch 150: train_loss=0.9209 | val_loss=0.7939, val_acc=0.8773




In [36]:
# Patient-level evaluation (new)
def compute_patient_level_prt(model, full_dataset, indices, transform, device, batch_size=32):
    """
    Compute patient-level recognition rate (Prt) on the samples indexed by `indices`.
    - full_dataset: BreakHisDataset instance (has samples = list of (img_path, label))
    - indices: array/list of indices into full_dataset.samples (these are test_idx)
    - transform: test_transform (to be applied before model)
    """
    model.eval()
    # Build list of (img_path, label, original_index)
    items = [(full_dataset.samples[int(i)][0], int(full_dataset.samples[int(i)][1]), int(i)) for i in indices]
    # Group by patient (we use the parent folder of magnification as patient identifier)
    # Given path .../<patient_folder>/<MAGNIFICATION>/<image>
    patient_groups = defaultdict(list)  # patient_key -> list of (img_path, label)
    for img_path, label, orig_idx in items:
        # patient folder path is two levels up from image: dirname(dirname(img_path))
        patient_folder_path = os.path.normpath(os.path.dirname(os.path.dirname(img_path)))
        patient_groups[patient_folder_path].append((img_path, label))

    # Prepare structures to count correct and total per patient
    patient_correct = {}
    patient_total = {}
    # Flatten all image paths for batched inference, but keep mapping to patient key
    all_entries = []  # list of tuples (img_path, label, patient_key)
    for patient_key, imgs in patient_groups.items():
        patient_total[patient_key] = len(imgs)
        patient_correct[patient_key] = 0
        for img_path, label in imgs:
            all_entries.append((img_path, label, patient_key))

    # Batched inference
    with torch.no_grad():
        for i in range(0, len(all_entries), batch_size):
            batch = all_entries[i:i+batch_size]
            tensors = []
            labels = []
            patient_keys = []
            for img_path, label, patient_key in batch:
                img = Image.open(img_path).convert("RGB")
                # resize done by transform; but ensure it receives PIL image
                x = transform(img)
                tensors.append(x)
                labels.append(label)
                patient_keys.append(patient_key)
            batch_tensor = torch.stack(tensors, dim=0).to(device)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                outputs = model(batch_tensor)
            preds = outputs.argmax(dim=1).detach().cpu().numpy()
            labels_np = np.array(labels)
            for p, gt, pk in zip(preds, labels_np, patient_keys):
                if int(p) == int(gt):
                    patient_correct[pk] += 1

    # compute Ps per patient and Prt
    patient_scores = {}
    for pk in patient_groups.keys():
        n_tp = patient_correct[pk]
        n_cp = patient_total[pk]
        if n_cp > 0:
            ps = float(n_tp) / float(n_cp)
        else:
            ps = 0.0
        patient_scores[pk] = ps

    # Prt: average of patient Ps
    N_np = len(patient_scores)
    if N_np > 0:
        Prt = sum(patient_scores.values()) / float(N_np)
    else:
        Prt = 0.0

    return Prt, patient_scores, patient_total, patient_correct

In [37]:
# Load best checkpoint if exists
if os.path.exists(CHECKPOINT_PATH):
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    # print(f"[LOAD] loaded checkpoint from epoch {ckpt.get('epoch', '?')} with val_acc={ckpt.get('val_acc', 0):.4f}")

# Run patient-level evaluation on test_idx
Prt, patient_scores, patient_total, patient_correct = compute_patient_level_prt(
    model, full_dataset, test_idx, test_transform, device, batch_size=BATCH_SIZE
)

print(f"[PATIENT-LEVEL] Number of patients in test set: {len(patient_scores)}")
print(f"[PATIENT-LEVEL] Global patient recognition rate (Prt): {Prt:.4f}")

# Optionally: print summary statistics (median, mean, etc.)
try:
    import statistics
    scores_list = list(patient_scores.values())
    if len(scores_list) > 0:
        print(f"[PATIENT-LEVEL] mean Ps={statistics.mean(scores_list):.4f}, median Ps={statistics.median(scores_list):.4f}, min Ps={min(scores_list):.4f}, max Ps={max(scores_list):.4f}")
except Exception:
    pass

# If you want to save per-patient breakdown to a CSV:
try:
    import csv
    out_csv = "patient_level_results.csv"
    with open(out_csv, "w", newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["patient_key", "num_images", "num_correct", "Ps"])
        for pk in patient_scores:
            writer.writerow([pk, patient_total[pk], patient_correct[pk], f"{patient_scores[pk]:.4f}"])
    print(f"[PATIENT-LEVEL] saved per-patient breakdown to {out_csv}")
except Exception as e:
    print(f"[PATIENT-LEVEL] couldn't save CSV: {e}")


  with torch.cuda.amp.autocast(enabled=USE_AMP):


[PATIENT-LEVEL] Number of patients in test set: 77
[PATIENT-LEVEL] Global patient recognition rate (Prt): 0.9142
[PATIENT-LEVEL] mean Ps=0.9142, median Ps=1.0000, min Ps=0.0000, max Ps=1.0000
[PATIENT-LEVEL] saved per-patient breakdown to patient_level_results.csv
