In [1]:
import os, time, math
import numpy as np
import pandas as pd
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset, Subset
import torchvision.transforms as T
import torchvision.datasets as datasets
import timm


print("python:", os.sys.version.splitlines()[0])
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)

  from .autonotebook import tqdm as notebook_tqdm


python: 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
torch: 2.6.0+cu124
torchvision: 0.21.0+cu124


In [3]:
DATA_ROOT = "./data"
IMG_SIZE = 224
BATCH_PRED = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "resnext101_32x8d"
PRETRAINED = False
CHECKPOINT_PATH = "./checkpoints_linear/best_checkpoint.pth"

SELECT_PCT = 0.375     # 37.5% of the lowest-confidence images

print("Device:", DEVICE)
print(torch.cuda.device_count())

Device: cuda
8


In [4]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)


infer_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])


train_full = datasets.CIFAR100(root=DATA_ROOT, train=True, download=True, transform=infer_transform)
test_set   = datasets.CIFAR100(root=DATA_ROOT, train=False, download=True, transform=infer_transform)


concat_ds = ConcatDataset([train_full, test_set])
len_train = len(train_full)   
len_test  = len(test_set)     
print("train:", len_train, " test:", len_test, " concat:", len(concat_ds))

train: 50000  test: 10000  concat: 60000


In [6]:
model=timm.create_model(MODEL_NAME,pretrained=False,num_classes=100)
model=model.to(DEVICE)

ckpt=torch.load(CHECKPOINT_PATH,map_location=DEVICE)

if "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
elif "state_dict" in ckpt:
    model.load_state_dict(ckpt["state_dict"])
else:
    try:
        model.load_state_dict(ckpt)
    except Exception as e:
        raise RuntimeError("Couldn't find model weights key in checkpoint. Inspect the ckpt dict keys: " + str(ckpt.keys()))

model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_

In [7]:
pred_loader=DataLoader(concat_ds,batch_size=BATCH_PRED,shuffle=False,num_workers=8,pin_memory=True)

rows=[]

SAVE_FULL_PROBS=True
probs_list=[]

start_idx=0
with torch.no_grad():
    for batch_imgs,batch_labels in pred_loader:
        batch_imgs=batch_imgs.to(DEVICE,non_blocking=True)
        logits=model(batch_imgs)
        probs=F.softmax(logits,dim=1)
        top1_probs,top1_preds=probs.max(dim=1)


        top1p_np=top1_probs.cpu().numpy()
        top1preds_np=top1_preds.cpu().numpy()
        labels_np=batch_labels.numpy()

        batch_size=labels_np.shape[0]
        for i in range(batch_size):
            global_idx=start_idx+i

            if global_idx < len_train:
                origin="train"
                orig_index=global_idx
            else:
                origin="test"
                orig_index=global_idx-len_train
            
            rows.append({
                "global_idx":int(global_idx),
                "set":origin,
                "orgi_index":int(orig_index),
                "true_label":int(labels_np[i]),
                "pred_label":int(top1preds_np[i]),
                "top1_prob":float(top1p_np[i])
            })
        if SAVE_FULL_PROBS:
            probs_list.append(probs.cpu().numpy())

        start_idx+=batch_size
    
df_all=pd.DataFrame(rows)
print("DataFrame shape : ",df_all.shape)

print(f"Top1_prob min : { df_all.top1_prob.min()}")
print(f"max :{df_all.top1_prob.max()} ")
print(f"mean : {df_all.top1_prob.mean()}")

csv_path=os.path.join("predictions_60k.csv")
df_all.to_csv(csv_path,index=False)

if SAVE_FULL_PROBS:
    all_probs = np.vstack(probs_list)
    np.savez_compressed("full_probs_60k.npz", probs=all_probs)
    print("Saved full probs NPZ: full_probs_60k.npz")

DataFrame shape :  (60000, 6)
Top1_prob min : 0.07220954447984695
max :0.9999936819076538 
mean : 0.8777095582171033
Saved full probs NPZ: full_probs_60k.npz


In [8]:
df_all.head(10)

Unnamed: 0,global_idx,set,orgi_index,true_label,pred_label,top1_prob
0,0,train,0,19,19,0.94611
1,1,train,1,29,29,0.802069
2,2,train,2,0,0,0.880476
3,3,train,3,11,11,0.949247
4,4,train,4,1,1,0.910771
5,5,train,5,86,86,0.947762
6,6,train,6,90,90,0.954904
7,7,train,7,28,28,0.926097
8,8,train,8,23,23,0.951782
9,9,train,9,31,31,0.972207


In [9]:
n_total=len(df_all)
n_select=int(math.ceil(n_total*SELECT_PCT))
print(f"Selecting : {n_select} images out of {n_total}")

df_sorted=df_all.sort_values(by="top1_prob",ascending=True).reset_index(drop=True)
selected_df=df_sorted.iloc[:n_select].copy()
selected_csv=os.path.join("selected_low_confidence.csv")
selected_df.to_csv(selected_csv,index=False)
print("Saved selected indices CSV",selected_csv)

selected_global_indices=selected_df["global_idx"].astype(int).tolist()

Selecting : 22500 images out of 60000
Saved selected indices CSV selected_low_confidence.csv


In [13]:
selected_df.head()

Unnamed: 0,global_idx,set,orgi_index,true_label,pred_label,top1_prob
0,54840,test,4840,77,44,0.07221
1,51165,test,1165,44,93,0.086189
2,51603,test,1603,27,96,0.088341
3,50403,test,403,91,32,0.091682
4,10272,train,10272,7,6,0.093572
