In [None]:
# ===== 1) Setup & Config =====
import os, random, json, math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import amp

SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PROJECT_DIR = os.getcwd()
OUT_DIR = os.path.join(PROJECT_DIR, "cifar10_out"); os.makedirs(OUT_DIR, exist_ok=True)

BATCH_SIZE = 128
NUM_WORKERS = 4  # if Windows notebook crashes on test, set to 0 for test loader

# toggles
USE_MIXUP_CUTMIX = True     # generalization trick (optional)
USE_FOCAL_LOSS   = False    # imbalance method 1 (set True to use focal instead of CE)
USE_CLASS_WEIGHTS= True     # imbalance method 2 (class-weighted CE or focal weighting)
USE_WEIGHTED_SAMPLER = True # imbalance method 3 (resampling minority classes)

EPOCHS = 20
LR = 3e-4
WD = 1e-4
 

In [2]:
# ===== 2) Data: CIFAR-10 Download + Augmentations =====
# We resize CIFAR-10 (32x32) to EfficientNet-B3's expected size ~ 300
IMG_SIZE = 96

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914,0.4822,0.4465), std=(0.2023,0.1994,0.2010)),
])

test_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914,0.4822,0.4465), std=(0.2023,0.1994,0.2010)),
])

DATA_DIR = os.path.join(PROJECT_DIR, "datasets_cifar10")
os.makedirs(DATA_DIR, exist_ok=True)

train_set = datasets.CIFAR10(root=DATA_DIR, train=True,  download=True, transform=train_tf)
test_set  = datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=test_tf)

num_classes = 10
class_names = train_set.classes  # ['airplane','automobile',...]
class_names


Files already downloaded and verified
Files already downloaded and verified


['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [3]:
# ===== 3) Imbalance utilities =====
from collections import Counter
import torch.utils.data as tud

# Count training samples per class
targets = np.array(train_set.targets)
counts = Counter(targets.tolist())
counts_list = [counts[i] for i in range(num_classes)]
print("Train counts per class:", dict(counts))

# Class weights = inverse frequency (normalized)
inv = 1.0 / (np.array(counts_list) + 1e-8)
class_weights = inv / inv.sum() * num_classes
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32, device=DEVICE)
print("Class weights:", class_weights)

# WeightedRandomSampler (higher prob for minority)
if USE_WEIGHTED_SAMPLER:
    sample_weights = torch.tensor([1.0 / (counts[t] + 1e-8) for t in targets], dtype=torch.float32)
    sampler = tud.WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, sampler=sampler,
                              num_workers=NUM_WORKERS, pin_memory=True)
else:
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True)

test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)


Train counts per class: {6: 5000, 9: 5000, 4: 5000, 1: 5000, 2: 5000, 7: 5000, 8: 5000, 3: 5000, 5: 5000, 0: 5000}
Class weights: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [4]:
# ===== 4) Model & Loss =====
import timm
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy

model = timm.create_model('efficientnet_b3', pretrained=True, num_classes=num_classes).to(DEVICE)

# Focal loss (simple implementation)
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha  # Tensor of shape [C] or scalar
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        # logits: [B, C], target: [B] (int64)
        ce = nn.functional.cross_entropy(logits, target, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce)  # pt = 1 - ce' (approx via CE)
        loss = ((1 - pt) ** self.gamma) * ce
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# Loss selection (class-weighted variants if enabled)
if USE_FOCAL_LOSS:
    alpha = class_weights_tensor if USE_CLASS_WEIGHTS else None
    criterion = FocalLoss(alpha=alpha, gamma=2.0, reduction='mean')
    print("Using FocalLoss (gamma=2.0)", "with class weights" if USE_CLASS_WEIGHTS else "")
else:
    if USE_CLASS_WEIGHTS:
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
        print("Using class-weighted CrossEntropyLoss")
    else:
        criterion = nn.CrossEntropyLoss()
        print("Using standard CrossEntropyLoss")

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# Optional Mixup/CutMix
mixup_fn = None
if USE_MIXUP_CUTMIX:
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=1.0, switch_prob=0.5, mode='batch', label_smoothing=0.0,
        num_classes=num_classes
    )
    soft_criterion = SoftTargetCrossEntropy()


Using class-weighted CrossEntropyLoss


In [5]:
# ===== 5) Train / Evaluate =====

from tqdm import tqdm
from torch import amp

def evaluate(model, loader):
    model.eval()
    correct, total, loss_sum = 0, 0, 0.0
    with torch.no_grad(), amp.autocast('cuda', enabled=(DEVICE=='cuda')):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)   # use hard-label CE for eval
            loss_sum += loss.item() * x.size(0)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total += x.size(0)
    return loss_sum/total, correct/total

scaler = amp.GradScaler('cuda', enabled=(DEVICE=='cuda'))
best_acc = 0.0
best_path = os.path.join(OUT_DIR, "cifar10_efficientnet_b3_best.pt")

for epoch in range(1, EPOCHS+1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}")
    for x, y in pbar:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        with amp.autocast('cuda', enabled=(DEVICE=='cuda')):
            if mixup_fn is not None:
                # ✅ pass raw int labels to mixup; it returns soft labels
                x, y_soft = mixup_fn(x, y)
                logits = model(x)
                loss = soft_criterion(logits, y_soft)
            else:
                logits = model(x)
                loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        pbar.set_postfix(loss=float(loss.item()))
    scheduler.step()

    val_loss, val_acc = evaluate(model, test_loader)
    print(f"[Test] loss={val_loss:.4f} acc={val_acc*100:.2f}%")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), best_path)
        print("==> Saved new best:", best_path)

print("Best Test Acc:", best_acc)


Epoch 1/20: 100%|██████████| 391/391 [01:07<00:00,  5.81it/s, loss=0.74] 


[Test] loss=0.5230 acc=86.41%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 2/20: 100%|██████████| 391/391 [01:05<00:00,  5.96it/s, loss=1.62] 


[Test] loss=0.4753 acc=91.00%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 3/20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s, loss=1.11] 


[Test] loss=0.3373 acc=92.51%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 4/20: 100%|██████████| 391/391 [01:05<00:00,  5.94it/s, loss=1.26] 


[Test] loss=0.3797 acc=93.32%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 5/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.09] 


[Test] loss=0.3380 acc=93.58%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 6/20: 100%|██████████| 391/391 [01:05<00:00,  5.96it/s, loss=1.38] 


[Test] loss=0.3884 acc=94.13%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 7/20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s, loss=1.46] 


[Test] loss=0.3562 acc=94.80%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 8/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.17] 


[Test] loss=0.3413 acc=95.20%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 9/20: 100%|██████████| 391/391 [01:05<00:00,  5.96it/s, loss=1.23] 


[Test] loss=0.3237 acc=95.02%


Epoch 10/20: 100%|██████████| 391/391 [01:05<00:00,  5.94it/s, loss=1.43] 


[Test] loss=0.2944 acc=95.23%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 11/20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s, loss=1.26]  


[Test] loss=0.2913 acc=95.51%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 12/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.39]  


[Test] loss=0.2537 acc=95.77%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 13/20: 100%|██████████| 391/391 [01:05<00:00,  5.94it/s, loss=0.345]


[Test] loss=0.2317 acc=96.08%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 14/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.39]  


[Test] loss=0.2805 acc=96.01%


Epoch 15/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.29]  


[Test] loss=0.3013 acc=96.00%


Epoch 16/20: 100%|██████████| 391/391 [01:06<00:00,  5.92it/s, loss=1.25] 


[Test] loss=0.2634 acc=96.08%


Epoch 17/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.15]  


[Test] loss=0.2552 acc=96.29%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt


Epoch 18/20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s, loss=0.985] 


[Test] loss=0.2568 acc=96.26%


Epoch 19/20: 100%|██████████| 391/391 [01:05<00:00,  5.97it/s, loss=1.2]   


[Test] loss=0.2665 acc=96.28%


Epoch 20/20: 100%|██████████| 391/391 [01:05<00:00,  5.94it/s, loss=1.48] 


[Test] loss=0.2733 acc=96.31%
==> Saved new best: d:\A\IE4483_project\cifar10_out\cifar10_efficientnet_b3_best.pt
Best Test Acc: 0.9631


In [6]:
# ===== 6) Per-class accuracy & confusion matrix =====
import torch

def eval_per_class(model, loader, num_classes, names):
    model.eval()
    correct = torch.zeros(num_classes, dtype=torch.long)
    total   = torch.zeros(num_classes, dtype=torch.long)
    with torch.no_grad(), amp.autocast('cuda', enabled=(DEVICE=='cuda')):
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            pred = logits.argmax(1)
            for c in range(num_classes):
                mask = (y==c)
                total[c]   += mask.sum().item()
                correct[c] += ((pred==c)&mask).sum().item()
    acc_pc = (correct.float() / total.clamp(min=1).float()) * 100.0
    return {names[i]: float(acc_pc[i].item()) for i in range(num_classes)}

best_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=num_classes).to(DEVICE)
best_model.load_state_dict(torch.load(best_path, map_location=DEVICE))
per_class = eval_per_class(best_model, test_loader, num_classes, class_names)
print(json.dumps(per_class, indent=2))


  best_model.load_state_dict(torch.load(best_path, map_location=DEVICE))


{
  "airplane": 98.0999984741211,
  "automobile": 98.0999984741211,
  "bird": 96.69999694824219,
  "cat": 90.30000305175781,
  "deer": 96.20000457763672,
  "dog": 92.9000015258789,
  "frog": 98.19999694824219,
  "horse": 96.9000015258789,
  "ship": 98.29999542236328,
  "truck": 97.39999389648438
}
