In [1]:
import os, math, random

os.environ["CUDA_VISIBLE_DEVICES"] = "5" 


import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Subset, random_split, ConcatDataset
import torchvision
import torchvision.transforms as T
import torchvision.datasets as datasets

print("python:", os.sys.version.splitlines()[0])
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print(torch.cuda.device_count())

python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124
Device: cuda
1


In [4]:
DATA_ROOT = "./data"
IMG_SIZE = 224
BATCH_PRED = 256
NUM_WORKERS = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "resnext101_32x8d"
BEST_CKPT = os.path.join("m2_retrained_checkpoints/best_m2_retrained.pth")
SAVE_CSV = "predictions_60k_m2.csv"
SAVE_FULL_PROBS = False
SAVE_FULL_PROBS_PATH = "full_probs_60k_m2.npz"

In [5]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)
infer_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# build concat dataset (train then test) - matches the ordering used earlier
train_ds = datasets.CIFAR100(root=DATA_ROOT, train=True, download=False, transform=infer_transform)
test_ds  = datasets.CIFAR100(root=DATA_ROOT, train=False, download=False, transform=infer_transform)
concat_ds = ConcatDataset([train_ds, test_ds])
len_train = len(train_ds)
len_test = len(test_ds)
print("Dataset sizes -> train:", len_train, " test:", len_test, " total:", len(concat_ds))

Dataset sizes -> train: 50000  test: 10000  total: 60000


In [7]:
import timm
pred_loader = DataLoader(concat_ds, batch_size=BATCH_PRED, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

# create model and load checkpoint
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=100)
model = model.to(DEVICE)

if not os.path.exists(BEST_CKPT):
    raise RuntimeError(f"Checkpoint not found at: {BEST_CKPT}")

ck = torch.load(BEST_CKPT, map_location=DEVICE)

if isinstance(ck, dict) and "model_state" in ck:
    state_dict = ck["model_state"]
elif isinstance(ck, dict) and "state_dict" in ck:
    state_dict = ck["state_dict"]
elif isinstance(ck, dict) and any(k.startswith("module.") or k in model.state_dict().keys() for k in ck.keys()):
    state_dict = ck
else:
    state_dict = ck


new_state = {}
for k, v in state_dict.items():
    new_k = k
    if k.startswith("module."):
        new_k = k[len("module."):]
    new_state[new_k] = v
model.load_state_dict(new_state)
model.eval()
print("Loaded checkpoint weights into model. Starting inference...")

  from .autonotebook import tqdm as notebook_tqdm


Loaded checkpoint weights into model. Starting inference...


In [9]:
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import timm
import time
rows = []
probs_accum = [] if SAVE_FULL_PROBS else None

start_idx = 0
total_top1_correct = 0
total_top5_correct = 0
total_samples = 0
t0 = time.time()

with torch.no_grad():
    for imgs, labels in pred_loader:
        bs = imgs.size(0)
        imgs = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)

        logits = model(imgs)
        probs = F.softmax(logits, dim=1)
        top1_probs, top1_preds = probs.max(dim=1)

        _, top5_preds = probs.topk(5, dim=1, largest=True, sorted=True)  # [B,5]

        total_top1_correct += (top1_preds == labels).sum().item()

        correct_top5_batch = top5_preds.eq(labels.view(-1,1).expand_as(top5_preds)).any(dim=1).sum().item()
        total_top5_correct += correct_top5_batch

        top1p_np = top1_probs.cpu().numpy()
        top1preds_np = top1_preds.cpu().numpy()
        labels_np = labels.cpu().numpy()

        for i in range(bs):
            global_idx = start_idx + i
            if global_idx < len_train:
                origin = "train"
                orig_index = global_idx
            else:
                origin = "test"
                orig_index = global_idx - len_train

            rows.append({
                "global_idx": int(global_idx),
                "set": origin,
                "orig_index": int(orig_index),
                "true_label": int(labels_np[i]),
                "pred_label": int(top1preds_np[i]),
                "top1_prob": float(top1p_np[i])
            })

        if SAVE_FULL_PROBS:
            probs_accum.append(probs.cpu().numpy())

        start_idx += bs
        total_samples += bs

elapsed = time.time() - t0
top1_acc = 100.0 * total_top1_correct / total_samples
top5_acc = 100.0 * total_top5_correct / total_samples

print(f"Inference done in {elapsed:.1f}s over {total_samples} samples")
print(f"Final results over full 60k -> Top1: {top1_acc:.3f}  Top5: {top5_acc:.3f}")

# save CSV
df = pd.DataFrame(rows)
df.to_csv(SAVE_CSV, index=False)
print("Saved per-image predictions CSV:", SAVE_CSV)

if SAVE_FULL_PROBS:
    all_probs = np.vstack(probs_accum)
    np.savez_compressed(SAVE_FULL_PROBS_PATH, probs=all_probs)
    print("Saved full probs NPZ:", SAVE_FULL_PROBS_PATH)

Inference done in 337.7s over 60000 samples
Final results over full 60k -> Top1: 97.750  Top5: 99.785
Saved per-image predictions CSV: predictions_60k_m2.csv
