In [1]:
import os
import time
import math
import random
from datetime import datetime


os.environ["CUDA_VISIBLE_DEVICES"] = "3" 

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets
import timm
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda")  

torch.cuda.memory_allocated(device)   # bytes allocated by tensors
torch.cuda.memory_reserved(device)    # bytes reserved by allocator
torch.cuda.empty_cache()              # frees cached memory (non-deterministic improvement)

In [3]:
print("python:", os.sys.version.splitlines()[0])
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)

python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124


In [4]:
def print_gpu_info():
    if torch.cuda.is_available():
        dev = torch.device("cuda")
        n = torch.cuda.device_count()
        print(f"CUDA available. {n} device(s):")
        for i in range(n):
            name = torch.cuda.get_device_name(i)
            cap = torch.cuda.get_device_capability(i)
            total = torch.cuda.get_device_properties(i).total_memory / (1024**3)
            print(f"  [{i}] {name}  (cap={cap}, mem={total:.1f} GB)")
    else:
        print("CUDA not available, using CPU.")
print_gpu_info()

CUDA available. 1 device(s):
  [0] Tesla V100-SXM2-32GB  (cap=(7, 0), mem=31.7 GB)


In [5]:
def set_seed(seed=42, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if deterministic:
        # may slow training, but makes some operations deterministic
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True  # good for speed on fixed-size inputs
set_seed(42, deterministic=False)


In [6]:
DATA_ROOT = "./data"
MODEL_NAME = "resnext101_32x8d"
PRETRAINED = True
NUM_CLASSES = 100
IMG_SIZE = 224
BATCH_SIZE = 64        # reduce if OOM: try 128 or 64
EPOCHS = 120
BASE_LR = 0.1           # intended for BATCH_SIZE=256 baseline
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
WARMUP_EPOCHS = 5
NUM_WORKERS = 8
PIN_MEMORY = True
MIXUP_ALPHA = 0.8       # set 0.0 to disable mixup
LABEL_SMOOTHING = 0.1   # used only when mixup disabled
EMA_DECAY = 0.9999      # <=0 to disable EMA
GRAD_CLIP = None        # set number like 1.0 to enable
SEED = 42
USE_AMP = True          # mixed precision
PRINT_FREQ = 50
SAVE_DIR = "./checkpoints_linear"
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(SAVE_DIR, exist_ok=True)

print("Python:", os.sys.version.splitlines()[0])
print("PyTorch:", torch.__version__)
print("Timm:", timm.__version__)

Python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
PyTorch: 2.6.0+cu124
Timm: 1.0.21


In [7]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


train_transform = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.08, 1.0), ratio=(0.75, 1.3333)),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


train_full = datasets.CIFAR100(root=DATA_ROOT, train=True, download=True, transform=train_transform)
test_set    = datasets.CIFAR100(root=DATA_ROOT, train=False, download=True, transform=test_transform)

val_size=5000
train_size=len(train_full)-val_size
torch.manual_seed(SEED)
train_set,val_set=torch.utils.data.random_split(train_full,[train_size,val_size],generator=torch.Generator().manual_seed(SEED))

val_set.dataset.transform=test_transform

train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True,pin_memory=PIN_MEMORY,num_workers=NUM_WORKERS,drop_last=True)
val_loader=DataLoader(val_set,batch_size=BATCH_SIZE,shuffle=False,pin_memory=PIN_MEMORY,num_workers=NUM_WORKERS)
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY,shuffle=False)

print("Train/Val/Test batches : ", len(train_loader), len(val_loader), len(test_loader))

Train/Val/Test batches :  703 79 157


In [8]:
model=timm.create_model(MODEL_NAME,pretrained=PRETRAINED,num_classes=NUM_CLASSES)
model=model.to(device)

total_params=sum(p.numel() for p in model.parameters())
trainable_parameters=sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters : {total_params}, trainable parameters : {trainable_parameters}")

Total parameters : 86947236, trainable parameters : 86947236


In [9]:
model.eval()
with torch.no_grad():
    dummy = torch.zeros((1, 3, IMG_SIZE, IMG_SIZE), device=device)
    out = model(dummy)
print("Forward OK. Output shape:", tuple(out.shape))

Forward OK. Output shape: (1, 100)


In [10]:
lr=BASE_LR *(BATCH_SIZE/256.0)
optimizer=torch.optim.SGD(model.parameters(),lr=lr,momentum=MOMENTUM,weight_decay=WEIGHT_DECAY)

scaler=torch.cuda.amp.GradScaler(enabled=(USE_AMP and device=="cuda"))

if EMA_DECAY > 0:
    ema_model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
    ema_model.load_state_dict(model.state_dict())   # copy initial weights
    ema_model.to(device)
    for p in ema_model.parameters():
        p.requires_grad_(False)
else:
    ema_model = None


def one_hot_smooth(labels, num_classes, smoothing, device):
    if smoothing > 0:
        off_value = smoothing / (num_classes - 1)
        on_value = 1.0 - smoothing
    else:
        off_value = 0.0
        on_value = 1.0
    y = torch.full((labels.size(0), num_classes), off_value, device=device)
    y.scatter_(1, labels.unsqueeze(1), on_value)
    return y

  scaler=torch.cuda.amp.GradScaler(enabled=(USE_AMP and device=="cuda"))


In [12]:
best_val_top1 = 0.0
min_lr = 1e-6


for epoch in range(EPOCHS):
    # ---- compute LR for this epoch (linear warmup then cosine anneal) ----
    if epoch < WARMUP_EPOCHS:
        cur_lr = lr * float(epoch + 1) / float(max(1, WARMUP_EPOCHS))
    else:
        t = float(epoch - WARMUP_EPOCHS) / float(max(1, EPOCHS - WARMUP_EPOCHS))
        cur_lr = min_lr + 0.5 * (lr - min_lr) * (1.0 + math.cos(math.pi * t))
    for g in optimizer.param_groups:
        g['lr'] = cur_lr

    print(f"\nEpoch {epoch+1}/{EPOCHS} - lr: {cur_lr:.6f} - starting training...")
    model.train()
    running_loss = 0.0
    seen = 0
    t0 = time.time()

    pbar = enumerate(train_loader)
    for i, (imgs, targets) in pbar:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # ----- MIXUP inline (or fall back to label smoothed one-hot) -----
        if MIXUP_ALPHA > 0:
            lam = np.random.beta(MIXUP_ALPHA, MIXUP_ALPHA)
            idx = torch.randperm(imgs.size(0)).to(device)
            imgs_mixed = lam * imgs + (1.0 - lam) * imgs[idx]
            # build soft targets
            y_a = one_hot_smooth(targets, NUM_CLASSES, 0.0, device)      # no smoothing in mixup construction
            y_b = one_hot_smooth(targets[idx], NUM_CLASSES, 0.0, device)
            soft_targets = lam * y_a + (1.0 - lam) * y_b
        else:
            imgs_mixed = imgs
            soft_targets = one_hot_smooth(targets, NUM_CLASSES, LABEL_SMOOTHING, device)

        optimizer.zero_grad()

        # forward + backward with AMP
        with torch.cuda.amp.autocast(enabled=(USE_AMP and device=="cuda")):
            logits = model(imgs_mixed)
            log_probs = F.log_softmax(logits, dim=1)
            loss = -(soft_targets * log_probs).sum(dim=1).mean()

        if USE_AMP and device=="cuda":
            scaler.scale(loss).backward()
            if GRAD_CLIP is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if GRAD_CLIP is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optimizer.step()

        # EMA update inline (if enabled)
        if ema_model is not None:
            with torch.no_grad():
                decay = EMA_DECAY
                msd = model.state_dict()
                esd = ema_model.state_dict()
                for k in esd.keys():
                    target = esd[k]
                    source = msd[k].to(target.device)
                    # Only do EMA math for floating tensors; for integer/bool buffers just copy
                    if target.dtype.is_floating_point:
                        # ensure same dtype then ema update
                        source = source.type_as(target)
                        target.mul_(decay).add_(source, alpha=(1.0 - decay))
                    else:
                        # direct copy for non-float tensors (e.g., num_batches_tracked)
                        target.copy_(source)

        bs = imgs.size(0)
        running_loss += loss.item() * bs
        seen += bs

        if (i + 1) % PRINT_FREQ == 0 or (i+1) == len(train_loader):
            avg_loss = running_loss / seen
            elapsed = time.time() - t0
            print(f"Epoch {epoch+1} Batch {i+1}/{len(train_loader)}  AvgLoss: {avg_loss:.4f}  Time: {elapsed:.1f}s")

    # ---- end of epoch: validation using EMA model if exists, else student model ----
    eval_model = ema_model if (ema_model is not None) else model
    eval_model.eval()
    val_running_loss = 0.0
    total = 0
    top1_count = 0
    top5_count = 0
    with torch.no_grad():
        for imgs, targets in val_loader:
            imgs = imgs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            logits = eval_model(imgs)
            loss = F.cross_entropy(logits, targets)   # hard-label loss for reporting
            bs = imgs.size(0)
            val_running_loss += loss.item() * bs
            total += bs
            # top-k
            _, pred = logits.topk(5, dim=1, largest=True, sorted=True)
            correct = pred.eq(targets.view(-1,1).expand_as(pred))
            top1_count += correct[:, :1].reshape(-1).float().sum().item()
            top5_count += correct[:, :5].reshape(-1).float().sum().item()

    val_avg_loss = val_running_loss / total
    val_top1 = 100.0 * top1_count / total
    val_top5 = 100.0 * top5_count / total
    print(f"Epoch {epoch+1} VALID -> Loss: {val_avg_loss:.4f}  Top1: {val_top1:.3f}  Top5: {val_top5:.3f}")

    # Save checkpoint (student weights; also keep best)
    ckpt = {
        "epoch": epoch+1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "val_top1": val_top1,
        "cfg": {
            "MODEL_NAME": MODEL_NAME, "IMG_SIZE": IMG_SIZE, "BATCH_SIZE": BATCH_SIZE
        }
    }
    if epoch%5==0:
        torch.save(ckpt, os.path.join(SAVE_DIR, f"checkpoint_epoch_{epoch+1}.pth"))
    if val_top1 > best_val_top1:
        best_val_top1 = val_top1
        torch.save(ckpt, os.path.join(SAVE_DIR, "best_checkpoint.pth"))
        print("Saved new best checkpoint.")

# End of epoch loop
print("Training finished. Best val top1:", best_val_top1)


Epoch 1/120 - lr: 0.005000 - starting training...


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device=="cuda")):


Epoch 1 Batch 50/703  AvgLoss: 4.2151  Time: 25.1s
Epoch 1 Batch 100/703  AvgLoss: 3.5584  Time: 49.1s
Epoch 1 Batch 150/703  AvgLoss: 3.1548  Time: 72.9s
Epoch 1 Batch 200/703  AvgLoss: 3.0058  Time: 96.6s
Epoch 1 Batch 250/703  AvgLoss: 2.8722  Time: 120.4s
Epoch 1 Batch 300/703  AvgLoss: 2.7712  Time: 144.0s
Epoch 1 Batch 350/703  AvgLoss: 2.6912  Time: 166.0s
Epoch 1 Batch 400/703  AvgLoss: 2.6200  Time: 190.1s
Epoch 1 Batch 450/703  AvgLoss: 2.5913  Time: 214.2s
Epoch 1 Batch 500/703  AvgLoss: 2.5573  Time: 238.0s
Epoch 1 Batch 550/703  AvgLoss: 2.5109  Time: 262.0s
Epoch 1 Batch 600/703  AvgLoss: 2.4774  Time: 285.9s
Epoch 1 Batch 650/703  AvgLoss: 2.4508  Time: 310.2s
Epoch 1 Batch 700/703  AvgLoss: 2.4189  Time: 334.2s
Epoch 1 Batch 703/703  AvgLoss: 2.4154  Time: 335.1s
Epoch 1 VALID -> Loss: 4.4649  Top1: 3.540  Top5: 13.260
Saved new best checkpoint.

Epoch 2/120 - lr: 0.010000 - starting training...
Epoch 2 Batch 50/703  AvgLoss: 2.1309  Time: 25.2s
Epoch 2 Batch 100/703  A

In [13]:
final_model = ema_model if (ema_model is not None) else model
final_model.eval()
test_running_loss = 0.0
test_total = 0
test_top1 = 0
test_top5 = 0
with torch.no_grad():
    for imgs, targets in test_loader:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        logits = final_model(imgs)
        loss = F.cross_entropy(logits, targets)
        bs = imgs.size(0)
        test_running_loss += loss.item() * bs
        test_total += bs
        _, pred = logits.topk(5, dim=1, largest=True, sorted=True)
        correct = pred.eq(targets.view(-1,1).expand_as(pred))
        test_top1 += correct[:, :1].reshape(-1).float().sum().item()
        test_top5 += correct[:, :5].reshape(-1).float().sum().item()

test_loss = test_running_loss / test_total
test_top1 = 100.0 * test_top1 / test_total
test_top5 = 100.0 * test_top5 / test_total
print(f"Final Test -> Loss: {test_loss:.4f}  Top1: {test_top1:.3f}  Top5: {test_top5:.3f}")

Final Test -> Loss: 0.7919  Top1: 82.200  Top5: 95.200
