In [1]:
# improved_training_breast
import os
import random
import math
from collections import Counter

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

import torchvision
from torchvision import transforms, models

In [2]:
# Settings / Hyperparams
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

DATASET_DIR = r"Dataset/BreaKHis/breast"
MAGNIFICATION = "100X"
BATCH_SIZE = 32               # bigger batch if GPU mem allows
NUM_WORKERS = 0
NUM_EPOCHS = 150
USE_AMP = True
VALID_SPLIT = 0.15
TEST_SPLIT = 0.15

RESIZE_SIZE = (224, 224)      # use 224 if using pretrained backbones
NUM_CLASSES = 8

USE_PRETRAINED = True         # set False if you want to force your IRRCNN architecture
USE_MIXUP = True              # try MixUp training
MIXUP_ALPHA = 0.4
EARLY_STOPPING_PATIENCE = 20
GRAD_CLIP = 2.0

CHECKPOINT_PATH = "best_model.pt"

In [3]:
# Class mapping
CLASS_ABBR_MAP = {
    "adenosis": "A", "fibroadenoma": "F", "phyllodes_tumor": "PT", "tubular_adenoma": "TA",
    "ductal_carcinoma": "DC", "lobular_carcinoma": "LC", "mucinous_carcinoma": "MC", "papillary_carcinoma": "PC"
}
class_index_map = {name: idx for idx, name in enumerate(CLASS_ABBR_MAP.keys())}

In [4]:
# Dataset
class BreakHisDataset(Dataset):
    def __init__(self, root_dir, magnification, transform=None):
        self.samples = []
        self.transform = transform
        for binary_class in ["benign", "malignant"]:
            sob_path = os.path.join(root_dir, binary_class, "SOB")
            if not os.path.exists(sob_path):
                continue
            for class_folder in os.listdir(sob_path):
                class_name = class_folder.lower()
                if class_name not in class_index_map:
                    # sometimes folder names differ - try to map by key contains
                    matched = None
                    for key in class_index_map:
                        if key in class_name:
                            matched = key
                            break
                    if matched:
                        class_name = matched
                    else:
                        continue
                label = class_index_map[class_name]
                class_path = os.path.join(sob_path, class_folder)
                if not os.path.isdir(class_path):
                    continue
                for patient_folder in os.listdir(class_path):
                    patient_path = os.path.join(class_path, patient_folder)
                    if not os.path.isdir(patient_path):
                        continue
                    mag_path = os.path.join(patient_path, magnification)
                    if os.path.exists(mag_path):
                        images = [f for f in os.listdir(mag_path) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
                        for img_name in images:
                            img_path = os.path.join(mag_path, img_name)
                            self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image = image.resize(RESIZE_SIZE, Image.LANCZOS)
        if self.transform:
            image = self.transform(image)
        return image, label

    def get_labels(self):
        return [label for _, label in self.samples]

In [5]:
# Augmentations / Transforms
# Stronger train augmentation set
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(RESIZE_SIZE[0], scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomRotation(180),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.02),
    # If torchvision supports AutoAugment/RandAugment you can add:
    # transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize(RESIZE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
# MixUp helper
def mixup_data(x, y, alpha=1.0, device='cpu'):
    """Returns mixed inputs, pairs of targets, and lambda"""
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, preds, y_a, y_b, lam):
    return lam * criterion(preds, y_a) + (1 - lam) * criterion(preds, y_b)

In [7]:
# Prepare dataset + stratified split
full_dataset = BreakHisDataset(DATASET_DIR, MAGNIFICATION, transform=None)
all_labels = full_dataset.get_labels()
n_total = len(full_dataset)
print(f"[DATA] total samples: {n_total}")

# stratified split indices
def stratified_split_indices(labels, test_frac=0.15, val_frac=0.15, seed=SEED):
    labels = np.array(labels)
    np.random.seed(seed)
    unique_classes = np.unique(labels)
    train_idx, val_idx, test_idx = [], [], []
    for c in unique_classes:
        idxs = np.where(labels == c)[0]
        np.random.shuffle(idxs)
        n = len(idxs)
        n_test = int(math.ceil(n * test_frac))
        n_val = int(math.ceil(n * val_frac))
        n_train = n - n_test - n_val
        if n_train < 1:
            # ensure at least one train sample
            n_train = max(1, n - n_test - n_val)
        train_idx.extend(idxs[:n_train])
        val_idx.extend(idxs[n_train:n_train + n_val])
        test_idx.extend(idxs[n_train + n_val:])
    # shuffle each split
    np.random.shuffle(train_idx)
    np.random.shuffle(val_idx)
    np.random.shuffle(test_idx)
    return np.array(train_idx), np.array(val_idx), np.array(test_idx)

train_idx, valid_idx, test_idx = stratified_split_indices(all_labels, test_frac=TEST_SPLIT, val_frac=VALID_SPLIT)
print(f"[DATA] train/val/test sizes: {len(train_idx)}, {len(valid_idx)}, {len(test_idx)}")

train_subset = Subset(full_dataset, train_idx)
valid_subset = Subset(full_dataset, valid_idx)
test_subset = Subset(full_dataset, test_idx)

# wrap to attach transforms
class WrappedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        if self.transform:
            x = self.transform(x)
        return x, y

train_dataset = WrappedSubset(train_subset, train_transform)
valid_dataset = WrappedSubset(valid_subset, test_transform)
test_dataset = WrappedSubset(test_subset, test_transform)

# Weighted sampler to balance classes during training
train_labels = [y for _, y in train_subset]
class_counts = Counter(train_labels)
print(f"[DATA] train class counts: {class_counts}")
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[y] for y in train_labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
                          num_workers=NUM_WORKERS, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=True)

[DATA] total samples: 2081
[DATA] train/val/test sizes: 1449, 316, 316
[DATA] train class counts: Counter({4: 631, 1: 182, 6: 154, 5: 118, 3: 104, 7: 98, 2: 83, 0: 79})


In [8]:
# Model (pretrained ResNet18 by default)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_model(num_classes=NUM_CLASSES, pretrained=USE_PRETRAINED):
    if pretrained:
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) if hasattr(models, "ResNet18_Weights") else models.resnet18(pretrained=True)
        # replace final fc
        in_f = model.fc.in_features
        model.fc = nn.Linear(in_f, num_classes)
        return model
    else:
        # fallback: simple custom small CNN (or you can paste your IRRCNN class here)
        class SmallCNN(nn.Module):
            def __init__(self, num_classes):
                super().__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(3, 32, 3, padding=1),
                    nn.BatchNorm2d(32), nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.BatchNorm2d(64), nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(64, 128, 3, padding=1),
                    nn.BatchNorm2d(128), nn.ReLU(),
                    nn.AdaptiveAvgPool2d((1,1))
                )
                self.fc = nn.Linear(128, num_classes)
            def forward(self, x):
                x = self.features(x)
                x = x.view(x.size(0), -1)
                return self.fc(x)
        return SmallCNN(num_classes)

model = get_model(NUM_CLASSES, USE_PRETRAINED).to(device)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
# Loss, optimizer, scheduler
# Use label smoothing if available
try:
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
except TypeError:
    criterion = nn.CrossEntropyLoss()

optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [None]:
# Train / Eval loops
def train_one_epoch(model, loader, optimizer, criterion, scaler, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=f"Epoch {epoch} [train]", leave=False)
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        if USE_MIXUP:
            images, y_a, y_b, lam = mixup_data(images, labels, MIXUP_ALPHA, device=device)
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, y_a, y_b, lam)
        else:
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                outputs = model(images)
                loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        # gradient clipping
        if GRAD_CLIP is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        # For MixUp accuracy counting we use y_a (approx); this is approximate
        if USE_MIXUP:
            # compute hard preds accuracy against y_a (biased but okay)
            correct += (preds == y_a).sum().item()
            total += labels.size(0)
        else:
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        pbar.set_postfix(loss=loss.item(), acc=correct / total)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, device, desc="Eval"):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, desc=desc, leave=False)
    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            with torch.amp.autocast("cuda", enabled=USE_AMP):
                outputs = model(images)
                loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            pbar.set_postfix(acc=correct / total)
    return running_loss / total, correct / total

In [11]:
# Main training loop (no early stopping, only save best model)
best_val_acc = 0.0

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, scaler, device, epoch)
    val_loss, val_acc = evaluate(model, valid_loader, criterion, device, desc=f"Epoch {epoch} [val]")
    scheduler.step()

    print(f"Epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
    #print(f"Epoch {epoch}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'val_acc': val_acc
        }, CHECKPOINT_PATH)
        print(f"[CHECKPOINT] Saved new best model (val_acc={val_acc:.4f})")

  with torch.cuda.amp.autocast(enabled=USE_AMP):
  with torch.cuda.amp.autocast(enabled=USE_AMP):
                                                                         

Epoch 1: train_loss=1.8399 | val_loss=1.8566, val_acc=0.2943
[CHECKPOINT] Saved new best model (val_acc=0.2943)


                                                                                      

Epoch 2: train_loss=1.5641 | val_loss=1.1484, val_acc=0.6867
[CHECKPOINT] Saved new best model (val_acc=0.6867)


                                                                                       

Epoch 3: train_loss=1.5195 | val_loss=1.4081, val_acc=0.5253


                                                                                      

Epoch 4: train_loss=1.5063 | val_loss=1.1495, val_acc=0.6741


                                                                                       

Epoch 5: train_loss=1.3860 | val_loss=1.7441, val_acc=0.4335


                                                                                       

Epoch 6: train_loss=1.3714 | val_loss=1.1564, val_acc=0.6930
[CHECKPOINT] Saved new best model (val_acc=0.6930)


                                                                                       

Epoch 7: train_loss=1.3332 | val_loss=0.9745, val_acc=0.7880
[CHECKPOINT] Saved new best model (val_acc=0.7880)


                                                                                       

Epoch 8: train_loss=1.2141 | val_loss=1.1284, val_acc=0.6930


                                                                                       

Epoch 9: train_loss=1.2979 | val_loss=1.1877, val_acc=0.6930


                                                                                        

Epoch 10: train_loss=1.2849 | val_loss=0.9779, val_acc=0.7816


                                                                                        

Epoch 11: train_loss=1.2454 | val_loss=0.9153, val_acc=0.8196
[CHECKPOINT] Saved new best model (val_acc=0.8196)


                                                                                        

Epoch 12: train_loss=1.2706 | val_loss=0.9498, val_acc=0.7848


                                                                                        

Epoch 13: train_loss=1.2044 | val_loss=1.1675, val_acc=0.7247


                                                                                        

Epoch 14: train_loss=1.2042 | val_loss=1.0828, val_acc=0.7342


                                                                                        

Epoch 15: train_loss=1.2023 | val_loss=0.9022, val_acc=0.8323
[CHECKPOINT] Saved new best model (val_acc=0.8323)


                                                                                        

Epoch 16: train_loss=1.1412 | val_loss=1.0649, val_acc=0.7437


                                                                                        

Epoch 17: train_loss=1.1129 | val_loss=0.8882, val_acc=0.8291


                                                                                        

Epoch 18: train_loss=1.2566 | val_loss=0.8539, val_acc=0.8291


                                                                                        

Epoch 19: train_loss=1.1777 | val_loss=1.0147, val_acc=0.7753


                                                                                        

Epoch 20: train_loss=1.1277 | val_loss=0.9880, val_acc=0.7753


                                                                                        

Epoch 21: train_loss=1.1459 | val_loss=0.8254, val_acc=0.8513
[CHECKPOINT] Saved new best model (val_acc=0.8513)


                                                                                        

Epoch 22: train_loss=1.1459 | val_loss=0.8481, val_acc=0.8481


                                                                                        

Epoch 23: train_loss=1.1003 | val_loss=0.8054, val_acc=0.8671
[CHECKPOINT] Saved new best model (val_acc=0.8671)


                                                                                        

Epoch 24: train_loss=1.1250 | val_loss=1.0538, val_acc=0.7595


                                                                                        

Epoch 25: train_loss=1.0831 | val_loss=1.0312, val_acc=0.7532


                                                                                        

Epoch 26: train_loss=0.9951 | val_loss=0.8215, val_acc=0.8639


                                                                                        

Epoch 27: train_loss=1.0603 | val_loss=0.9636, val_acc=0.7880


                                                                                        

Epoch 28: train_loss=1.0065 | val_loss=1.0140, val_acc=0.7627


                                                                                        

Epoch 29: train_loss=1.0316 | val_loss=0.9836, val_acc=0.8038


                                                                                        

Epoch 30: train_loss=1.0469 | val_loss=1.1352, val_acc=0.7089


                                                                                        

Epoch 31: train_loss=1.0771 | val_loss=0.8799, val_acc=0.8354


                                                                                        

Epoch 32: train_loss=1.1726 | val_loss=0.9429, val_acc=0.8038


                                                                                        

Epoch 33: train_loss=1.1465 | val_loss=0.8645, val_acc=0.8544


                                                                                        

Epoch 34: train_loss=1.0976 | val_loss=0.8570, val_acc=0.8544


                                                                                        

Epoch 35: train_loss=1.1122 | val_loss=0.8534, val_acc=0.8323


                                                                                        

Epoch 36: train_loss=1.1333 | val_loss=0.7871, val_acc=0.8829
[CHECKPOINT] Saved new best model (val_acc=0.8829)


                                                                                        

Epoch 37: train_loss=1.1350 | val_loss=0.9435, val_acc=0.8070


                                                                                        

Epoch 38: train_loss=1.0887 | val_loss=0.9032, val_acc=0.8449


                                                                                        

Epoch 39: train_loss=1.0397 | val_loss=0.8738, val_acc=0.8576


                                                                                        

Epoch 40: train_loss=1.0366 | val_loss=0.9601, val_acc=0.7975


                                                                                        

Epoch 41: train_loss=1.1333 | val_loss=0.8219, val_acc=0.8703


                                                                                        

Epoch 42: train_loss=1.0026 | val_loss=0.8031, val_acc=0.8576


                                                                                        

Epoch 43: train_loss=1.0668 | val_loss=0.9270, val_acc=0.8196


                                                                                        

Epoch 44: train_loss=1.0326 | val_loss=0.7677, val_acc=0.9082
[CHECKPOINT] Saved new best model (val_acc=0.9082)


                                                                                        

Epoch 45: train_loss=1.0483 | val_loss=0.8164, val_acc=0.8703


                                                                                        

Epoch 46: train_loss=1.0782 | val_loss=0.9480, val_acc=0.7975


                                                                                        

Epoch 47: train_loss=1.0747 | val_loss=0.7788, val_acc=0.8703


                                                                                        

Epoch 48: train_loss=1.0063 | val_loss=0.7490, val_acc=0.8892


                                                                                        

Epoch 49: train_loss=1.0174 | val_loss=0.8257, val_acc=0.8576


                                                                                        

Epoch 50: train_loss=1.0278 | val_loss=0.7105, val_acc=0.9114
[CHECKPOINT] Saved new best model (val_acc=0.9114)


                                                                                        

Epoch 51: train_loss=1.0484 | val_loss=0.8565, val_acc=0.8576


                                                                                        

Epoch 52: train_loss=1.0974 | val_loss=0.8537, val_acc=0.8576


                                                                                        

Epoch 53: train_loss=0.9770 | val_loss=0.9978, val_acc=0.7911


                                                                                        

Epoch 54: train_loss=0.9455 | val_loss=0.7967, val_acc=0.8639


                                                                                        

Epoch 55: train_loss=1.0085 | val_loss=0.9330, val_acc=0.8259


                                                                                        

Epoch 56: train_loss=1.0298 | val_loss=0.7458, val_acc=0.9051


                                                                                        

Epoch 57: train_loss=1.1390 | val_loss=0.7627, val_acc=0.8956


                                                                                        

Epoch 58: train_loss=1.1111 | val_loss=0.7440, val_acc=0.9146
[CHECKPOINT] Saved new best model (val_acc=0.9146)


                                                                                        

Epoch 59: train_loss=0.9029 | val_loss=0.7195, val_acc=0.9019


                                                                                        

Epoch 60: train_loss=0.9672 | val_loss=0.7048, val_acc=0.9272
[CHECKPOINT] Saved new best model (val_acc=0.9272)


                                                                                        

Epoch 61: train_loss=1.0449 | val_loss=0.7955, val_acc=0.8734


                                                                                        

Epoch 62: train_loss=0.9658 | val_loss=0.7701, val_acc=0.8861


                                                                                        

Epoch 63: train_loss=0.9327 | val_loss=0.7556, val_acc=0.8861


                                                                                        

Epoch 64: train_loss=1.0081 | val_loss=0.7022, val_acc=0.9241


                                                                                        

Epoch 65: train_loss=0.9521 | val_loss=0.7283, val_acc=0.9146


                                                                                        

Epoch 66: train_loss=1.0743 | val_loss=0.7122, val_acc=0.9209


                                                                                        

Epoch 67: train_loss=0.9964 | val_loss=0.7101, val_acc=0.9114


                                                                                        

Epoch 68: train_loss=0.9985 | val_loss=0.7447, val_acc=0.9241


                                                                                        

Epoch 69: train_loss=1.0669 | val_loss=0.8109, val_acc=0.8924


                                                                                        

Epoch 70: train_loss=0.9942 | val_loss=0.7159, val_acc=0.9209


                                                                                        

Epoch 71: train_loss=0.9540 | val_loss=0.7397, val_acc=0.9019


                                                                                        

Epoch 72: train_loss=1.0662 | val_loss=0.8268, val_acc=0.8829


                                                                                        

Epoch 73: train_loss=1.0404 | val_loss=0.7883, val_acc=0.8924


                                                                                        

Epoch 74: train_loss=0.8825 | val_loss=0.7325, val_acc=0.9177


                                                                                        

Epoch 75: train_loss=0.9608 | val_loss=0.7864, val_acc=0.9019


                                                                                        

Epoch 76: train_loss=0.9544 | val_loss=0.8090, val_acc=0.8861


                                                                                        

Epoch 77: train_loss=0.9667 | val_loss=0.7247, val_acc=0.9241


                                                                                        

Epoch 78: train_loss=1.0195 | val_loss=0.7762, val_acc=0.8924


                                                                                        

Epoch 79: train_loss=0.9694 | val_loss=0.7324, val_acc=0.9051


                                                                                        

Epoch 80: train_loss=0.9594 | val_loss=0.7277, val_acc=0.9241


                                                                                        

Epoch 81: train_loss=1.0090 | val_loss=0.7623, val_acc=0.9019


                                                                                        

Epoch 82: train_loss=1.0532 | val_loss=0.6999, val_acc=0.9241


                                                                                        

Epoch 83: train_loss=1.0590 | val_loss=0.8105, val_acc=0.8892


                                                                                        

Epoch 84: train_loss=0.9635 | val_loss=0.7200, val_acc=0.9114


                                                                                        

Epoch 85: train_loss=0.9168 | val_loss=0.6953, val_acc=0.9241


                                                                                        

Epoch 86: train_loss=0.9464 | val_loss=0.7402, val_acc=0.9051


                                                                                        

Epoch 87: train_loss=1.1144 | val_loss=0.7382, val_acc=0.9146


                                                                                        

Epoch 88: train_loss=1.0682 | val_loss=0.7567, val_acc=0.9146


                                                                                        

Epoch 89: train_loss=0.9424 | val_loss=0.7072, val_acc=0.9209


                                                                                        

Epoch 90: train_loss=1.0316 | val_loss=0.7034, val_acc=0.9177


                                                                                        

Epoch 91: train_loss=1.0445 | val_loss=0.7193, val_acc=0.9177


                                                                                        

Epoch 92: train_loss=0.9131 | val_loss=0.7273, val_acc=0.8987


                                                                                        

Epoch 93: train_loss=0.9070 | val_loss=0.7039, val_acc=0.9209


                                                                                        

Epoch 94: train_loss=0.9762 | val_loss=0.7058, val_acc=0.9209


                                                                                        

Epoch 95: train_loss=0.9213 | val_loss=0.6976, val_acc=0.9209


                                                                                        

Epoch 96: train_loss=1.0591 | val_loss=0.7213, val_acc=0.9241


                                                                                        

Epoch 97: train_loss=0.9788 | val_loss=0.7130, val_acc=0.9177


                                                                                        

Epoch 98: train_loss=0.9608 | val_loss=0.6992, val_acc=0.9272


                                                                                        

Epoch 99: train_loss=0.9191 | val_loss=0.6954, val_acc=0.9272


                                                                                         

Epoch 100: train_loss=0.9858 | val_loss=0.7053, val_acc=0.9177


                                                                                         

Epoch 101: train_loss=0.8847 | val_loss=0.6745, val_acc=0.9209


                                                                                         

Epoch 102: train_loss=0.9634 | val_loss=0.7125, val_acc=0.9177


                                                                                         

Epoch 103: train_loss=0.9165 | val_loss=0.7043, val_acc=0.9114


                                                                                         

Epoch 104: train_loss=0.9970 | val_loss=0.6924, val_acc=0.9209


                                                                                         

Epoch 105: train_loss=1.0214 | val_loss=0.7206, val_acc=0.9241


                                                                                         

Epoch 106: train_loss=0.9898 | val_loss=0.7359, val_acc=0.9177


                                                                                         

Epoch 107: train_loss=0.9166 | val_loss=0.7026, val_acc=0.9146


                                                                                         

Epoch 108: train_loss=0.9499 | val_loss=0.7688, val_acc=0.9051


                                                                                         

Epoch 109: train_loss=0.9773 | val_loss=0.6838, val_acc=0.9241


                                                                                         

Epoch 110: train_loss=0.9368 | val_loss=0.7229, val_acc=0.9209


                                                                                         

Epoch 111: train_loss=0.9530 | val_loss=0.7131, val_acc=0.9114


                                                                                         

Epoch 112: train_loss=0.9764 | val_loss=0.7100, val_acc=0.9114


                                                                                         

Epoch 113: train_loss=0.9204 | val_loss=0.7148, val_acc=0.9209


                                                                                         

Epoch 114: train_loss=0.9416 | val_loss=0.7333, val_acc=0.9146


                                                                                         

Epoch 115: train_loss=0.9135 | val_loss=0.7115, val_acc=0.9177


                                                                                         

Epoch 116: train_loss=0.9290 | val_loss=0.6851, val_acc=0.9177


                                                                                         

Epoch 117: train_loss=0.9660 | val_loss=0.6910, val_acc=0.9114


                                                                                         

Epoch 118: train_loss=0.9591 | val_loss=0.7322, val_acc=0.9209


                                                                                         

Epoch 119: train_loss=0.9887 | val_loss=0.6906, val_acc=0.9272


                                                                                         

Epoch 120: train_loss=0.9384 | val_loss=0.6690, val_acc=0.9272


                                                                                         

Epoch 121: train_loss=1.0060 | val_loss=0.7120, val_acc=0.9335
[CHECKPOINT] Saved new best model (val_acc=0.9335)


                                                                                         

Epoch 122: train_loss=0.9754 | val_loss=0.7161, val_acc=0.9304


                                                                                         

Epoch 123: train_loss=0.8391 | val_loss=0.6969, val_acc=0.9335


                                                                                         

Epoch 124: train_loss=0.9602 | val_loss=0.7209, val_acc=0.9177


                                                                                         

Epoch 125: train_loss=0.9294 | val_loss=0.6953, val_acc=0.9367
[CHECKPOINT] Saved new best model (val_acc=0.9367)


                                                                                         

Epoch 126: train_loss=1.0327 | val_loss=0.7410, val_acc=0.9209


                                                                                         

Epoch 127: train_loss=0.9920 | val_loss=0.6918, val_acc=0.9304


                                                                                         

Epoch 128: train_loss=0.9261 | val_loss=0.6945, val_acc=0.9367


                                                                                         

Epoch 129: train_loss=0.9554 | val_loss=0.6776, val_acc=0.9209


                                                                                         

Epoch 130: train_loss=0.9248 | val_loss=0.6674, val_acc=0.9367


                                                                                         

Epoch 131: train_loss=0.9950 | val_loss=0.6966, val_acc=0.9304


                                                                                         

Epoch 132: train_loss=1.0117 | val_loss=0.7107, val_acc=0.9272


                                                                                         

Epoch 133: train_loss=0.9092 | val_loss=0.7034, val_acc=0.9272


                                                                                         

Epoch 134: train_loss=1.0452 | val_loss=0.7173, val_acc=0.9209


                                                                                         

Epoch 135: train_loss=0.8449 | val_loss=0.6931, val_acc=0.9272


                                                                                         

Epoch 136: train_loss=1.0056 | val_loss=0.6894, val_acc=0.9304


                                                                                         

Epoch 137: train_loss=0.9706 | val_loss=0.6935, val_acc=0.9335


                                                                                         

Epoch 138: train_loss=0.8853 | val_loss=0.6954, val_acc=0.9430
[CHECKPOINT] Saved new best model (val_acc=0.9430)


                                                                                         

Epoch 139: train_loss=0.9168 | val_loss=0.7009, val_acc=0.9367


                                                                                         

Epoch 140: train_loss=0.8573 | val_loss=0.6817, val_acc=0.9335


                                                                                         

Epoch 141: train_loss=0.8917 | val_loss=0.7011, val_acc=0.9272


                                                                                         

Epoch 142: train_loss=0.8846 | val_loss=0.6873, val_acc=0.9241


                                                                                         

Epoch 143: train_loss=0.8854 | val_loss=0.6723, val_acc=0.9272


                                                                                         

Epoch 144: train_loss=1.0090 | val_loss=0.6804, val_acc=0.9335


                                                                                         

Epoch 145: train_loss=0.8218 | val_loss=0.6769, val_acc=0.9304


                                                                                         

Epoch 146: train_loss=0.8410 | val_loss=0.6721, val_acc=0.9335


                                                                                         

Epoch 147: train_loss=0.9406 | val_loss=0.6947, val_acc=0.9304


                                                                                         

Epoch 148: train_loss=0.9326 | val_loss=0.7150, val_acc=0.9335


                                                                                         

Epoch 149: train_loss=1.0228 | val_loss=0.7098, val_acc=0.9399


                                                                                         

Epoch 150: train_loss=0.8985 | val_loss=0.6691, val_acc=0.9367




In [13]:
# Load best and test
if os.path.exists(CHECKPOINT_PATH):
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(ckpt['model_state'])
    print(f"[LOAD] loaded checkpoint from epoch {ckpt.get('epoch', '?')} with val_acc={ckpt.get('val_acc', 0):.4f}")

test_loss, test_acc = evaluate(model, test_loader, criterion, device, desc="Test")
print(f"Final Test: acc={test_acc:.4f}")


[LOAD] loaded checkpoint from epoch 138 with val_acc=0.9430


  with torch.cuda.amp.autocast(enabled=USE_AMP):
                                                                

Final Test: acc=0.9272


