In [1]:
import os, math, random

os.environ["CUDA_VISIBLE_DEVICES"] = "5" 


import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset, random_split, ConcatDataset
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets

print("python:", os.sys.version.splitlines()[0])
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print(torch.cuda.device_count())

python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
Device: cuda
1


In [2]:
DATA_ROOT = "./data"
IMG_SIZE = 224
BATCH_M2 = 64
NUM_WORKERS = 8
PIN_MEMORY = True
SELECTED_CSV = "selected_low_confidence.csv"
VAL_RATIO = 0.10
SEED = 42

In [3]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_transform_m2 = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.08, 1.0), ratio=(0.75, 1.3333)),
    T.RandomHorizontalFlip(p=0.5),
    # T.RandAugment(num_ops=2, magnitude=9),  # optional
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


train_ds_for_concat = datasets.CIFAR100(root=DATA_ROOT, train=True, download=False, transform=train_transform_m2)
test_ds_for_concat  = datasets.CIFAR100(root=DATA_ROOT, train=False, download=False, transform=train_transform_m2)
concat_aug = ConcatDataset([train_ds_for_concat, test_ds_for_concat])
print("Concat dataset length:", len(concat_aug))


df_sel=pd.read_csv(SELECTED_CSV)
sel_global_idx=df_sel["global_idx"].astype(int).tolist()
print(f"Loaded {len(sel_global_idx)} selected indices from {SELECTED_CSV}")

selected_subset = Subset(concat_aug, sel_global_idx)
n_total = len(selected_subset)
n_val = max(1, int(round(VAL_RATIO * n_total)))
n_train = n_total - n_val
print("Selected subset size:", n_total, " -> train:", n_train, " val:", n_val)


torch.manual_seed(SEED)
train_subset, val_subset = random_split(selected_subset, [n_train, n_val])

train_loader_m2 = DataLoader(train_subset, batch_size=BATCH_M2, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)
val_loader_m2   = DataLoader(val_subset,   batch_size=BATCH_M2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

print("Train loader batches:", len(train_loader_m2), " Val loader batches:", len(val_loader_m2))

Concat dataset length: 60000
Loaded 22500 selected indices from selected_low_confidence.csv
Selected subset size: 22500  -> train: 20250  val: 2250
Train loader batches: 316  Val loader batches: 36


In [4]:
import os, time, math
import numpy as np
import torch, torch.nn.functional as F
import timm


MODEL_NAME = "resnext101_32x8d"
CHECKPOINT_PATH = "./checkpoints_linear/best_checkpoint.pth"   # path to your original best model checkpoint
M2_SAVE_DIR = "m2_retrained_checkpoints"
os.makedirs(M2_SAVE_DIR, exist_ok=True)

M2_EPOCHS = 40
USE_MIXUP_M2 = True
MIXUP_ALPHA_M2 = 0.8
BASE_LR = 0.01 * (BATCH_M2 / 256.0)
MIN_LR = 1e-6
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading model and checkpoint:", CHECKPOINT_PATH)

model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=100)
model = model.to(DEVICE)

ck = torch.load(CHECKPOINT_PATH, map_location=DEVICE)

if "model_state" in ck:
    state_dict = ck["model_state"]
elif "state_dict" in ck:
    state_dict = ck["state_dict"]
else:
    try:
        model.load_state_dict(ck)
        print("Loaded raw state_dict from checkpoint.")
        state_dict = None
    except Exception as e:
        raise RuntimeError("Could not find model_state or state_dict in checkpoint. Inspect keys: " + str(list(ck.keys())))


if state_dict is not None:
    new_state = {}
    for k, v in state_dict.items():
        new_k = k
        if k.startswith("module."):
            new_k = k[len("module."):]
        new_state[new_k] = v
    model.load_state_dict(new_state)
    print("Loaded model weights from checkpoint keys.")

  from .autonotebook import tqdm as notebook_tqdm


Loading model and checkpoint: ./checkpoints_linear/best_checkpoint.pth
Loaded model weights from checkpoint keys.


In [5]:

optimizer = torch.optim.SGD(model.parameters(), lr=BASE_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))

best_val_top1=0.0
start_epoch=0


for epoch in range(start_epoch, M2_EPOCHS):
    t = epoch / float(max(1, M2_EPOCHS - 1))
    cur_lr = MIN_LR + 0.5 * (BASE_LR - MIN_LR) * (1.0 + math.cos(math.pi * t))
    for g in optimizer.param_groups:
        g['lr'] = cur_lr

    model.train()
    running_loss = 0.0
    seen = 0
    t0 = time.time()
    for i, (imgs, targets) in enumerate(train_loader_m2):
        imgs = imgs.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        if USE_MIXUP_M2:
            lam = np.random.beta(MIXUP_ALPHA_M2, MIXUP_ALPHA_M2)
            idx = torch.randperm(imgs.size(0)).to(DEVICE)
            imgs_m = lam * imgs + (1.0 - lam) * imgs[idx]
            y_a = torch.zeros((imgs.size(0), 100), device=DEVICE).scatter_(1, targets.unsqueeze(1), 1.0)
            y_b = torch.zeros_like(y_a).scatter_(1, targets[idx].unsqueeze(1), 1.0)
            soft_targets = lam * y_a + (1.0 - lam) * y_b
            loss_fn = lambda logits, soft: -(F.log_softmax(logits, dim=1) * soft).sum(dim=1).mean()
            inputs = imgs_m
        else:
            inputs = imgs
            loss_fn = lambda logits, targs: F.cross_entropy(logits, targs)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):
            logits = model(inputs)
            if USE_MIXUP_M2:
                loss = loss_fn(logits, soft_targets)
            else:
                loss = loss_fn(logits, targets)

        if DEVICE == "cuda":
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        running_loss += float(loss.item()) * imgs.size(0)
        seen += imgs.size(0)
        if (i + 1) % 20 == 0 or (i+1) == len(train_loader_m2):
            print(f"Epoch {epoch+1}/{M2_EPOCHS} Batch {i+1}/{len(train_loader_m2)} AvgLoss:{running_loss/seen:.4f} Time:{time.time()-t0:.1f}s LR:{cur_lr:.6f}")

    model.eval()
    val_running_loss = 0.0
    val_total = 0
    val_top1_count = 0
    val_top5_count = 0
    with torch.no_grad():
        for imgs, targets in val_loader_m2:
            imgs = imgs.to(DEVICE, non_blocking=True)
            targets = targets.to(DEVICE, non_blocking=True)
            logits = model(imgs)
            loss_v = F.cross_entropy(logits, targets)
            bs = imgs.size(0)
            val_running_loss += float(loss_v.item()) * bs
            val_total += bs
            _, pred = logits.topk(5, dim=1, largest=True, sorted=True)
            correct = pred.eq(targets.view(-1,1).expand_as(pred))
            val_top1_count += correct[:, :1].reshape(-1).float().sum().item()
            val_top5_count += correct[:, :5].reshape(-1).float().sum().item()

    val_loss = val_running_loss / val_total
    val_top1 = 100.0 * val_top1_count / val_total
    val_top5 = 100.0 * val_top5_count / val_total
    print(f"Epoch {epoch+1} VALID -> Loss: {val_loss:.4f} Top1: {val_top1:.3f} Top5: {val_top5:.3f}")

    ckpt = {
        "epoch": epoch+1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "val_top1": val_top1,
        "cfg": {"BASE_LR": BASE_LR, "M2_EPOCHS": M2_EPOCHS}
    }
    if epoch%5==0:
        torch.save(ckpt, os.path.join(M2_SAVE_DIR, f"retrain_epoch_{epoch+1}.pth"))
    if val_top1 > best_val_top1:
        best_val_top1 = val_top1
        torch.save(ckpt, os.path.join(M2_SAVE_DIR, "best_m2_retrained.pth"))
        print("Saved new best_m2_retrained.pth (val_top1 improved to {:.3f})".format(val_top1))

print("Retraining finished. Best val top1:", best_val_top1)

  scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=="cuda"))
  with torch.cuda.amp.autocast(enabled=(DEVICE=="cuda")):


Epoch 1/40 Batch 20/316 AvgLoss:3.0787 Time:11.4s LR:0.002500
Epoch 1/40 Batch 40/316 AvgLoss:2.9974 Time:20.1s LR:0.002500
Epoch 1/40 Batch 60/316 AvgLoss:2.9132 Time:28.7s LR:0.002500
Epoch 1/40 Batch 80/316 AvgLoss:2.8637 Time:37.7s LR:0.002500
Epoch 1/40 Batch 100/316 AvgLoss:2.7912 Time:46.5s LR:0.002500
Epoch 1/40 Batch 120/316 AvgLoss:2.7663 Time:55.6s LR:0.002500
Epoch 1/40 Batch 140/316 AvgLoss:2.7342 Time:64.2s LR:0.002500
Epoch 1/40 Batch 160/316 AvgLoss:2.6946 Time:73.1s LR:0.002500
Epoch 1/40 Batch 180/316 AvgLoss:2.7127 Time:82.1s LR:0.002500
Epoch 1/40 Batch 200/316 AvgLoss:2.7076 Time:91.2s LR:0.002500
Epoch 1/40 Batch 220/316 AvgLoss:2.6900 Time:100.0s LR:0.002500
Epoch 1/40 Batch 240/316 AvgLoss:2.6599 Time:108.5s LR:0.002500
Epoch 1/40 Batch 260/316 AvgLoss:2.6407 Time:117.1s LR:0.002500
Epoch 1/40 Batch 280/316 AvgLoss:2.6165 Time:125.7s LR:0.002500
Epoch 1/40 Batch 300/316 AvgLoss:2.6057 Time:134.4s LR:0.002500
Epoch 1/40 Batch 316/316 AvgLoss:2.6091 Time:141.3s LR