In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch 
import torchvision
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms   

# CIFAR-10 normalization for training and test
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

# Load datasets
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Split train into train/val
train_size = int(0.9 * len(trainset))
val_size = len(trainset) - train_size
train_subset, val_subset = random_split(trainset, [train_size, val_size], generator=torch.Generator().manual_seed(56))

# DataLoaders
trainloader = DataLoader(train_subset, batch_size=48, shuffle=True, num_workers=4)
valloader = DataLoader(val_subset, batch_size=48, shuffle=False, num_workers=4)
testloader = DataLoader(testset, batch_size=48, shuffle=False, num_workers=4)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(testset)}")

Train: 45000, Val: 5000, Test: 10000


### Load Models 

In [3]:
import model_helper as helper

device = 'cuda' if torch.cuda.is_available() else 'cpu'

target_models_args = [
    "gcvit_tiny", 
    # "gcvit_tiny", 
    # "efficientvit_b0"
    ]

target_models = [] 

for i in target_models_args: 
    model = helper.load_model_hub(i)
    model = model.to(device)
    target_models.append(model.eval())

In [4]:
import torch
import torch.nn.functional as F
from tqdm import tqdm 

accuracies = {}

with torch.no_grad():
    for model_name, model in zip(target_models_args, target_models):
        model.to(device)
        model.eval()
        correct = 0
        total = 0

        for images, labels in tqdm(testloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        acc = correct / total * 100
        accuracies[model_name] = acc
        print(f"{model_name}: {acc:.2f}%")

100%|███████████████████████████████████████████████████████| 209/209 [00:21<00:00,  9.51it/s]

gcvit_tiny: 93.05%





### Select 1000 images 

In [11]:
import torch
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Store indices of samples correctly predicted by all models
all_correct_indices = []

with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(tqdm(testloader)):
        images, labels = images.to(device), labels.to(device)
        batch_size = labels.size(0)

        all_correct = torch.ones(batch_size, dtype=torch.bool, device=device)

        for model in target_models:
            model.to(device)
            model.eval()
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            all_correct &= (preds == labels)  # keep only those correct for all models

        # Store original dataset indices
        start_idx = batch_idx * testloader.batch_size
        batch_indices = torch.arange(start_idx, start_idx + batch_size, device='cpu')
        correct_batch_indices = batch_indices[all_correct.cpu()]
        all_correct_indices.extend(correct_batch_indices.tolist())

# Limit to 1000 images
selected_indices = all_correct_indices[:1000]
remaining_indices = [i for i in range(len(testset)) if i not in selected_indices]

print(f"Selected 1000 images correctly classified by all models")
print(f"Remaining images: {len(remaining_indices)}")


100%|███████████████████████████████████████████████████████| 209/209 [03:57<00:00,  1.14s/it]

Selected 1000 images correctly classified by all models
Remaining images: 9000





In [12]:
from torch.utils.data import Subset
import torch

# `selected_indices` = indices of 1000 correctly classified images
# `remaining_indices` = remaining 9000 indices from test set

# Create new test and validation subsets
new_testset = Subset(testset, selected_indices)
extended_val_indices = list(val_subset.indices) + remaining_indices  # add remaining test images to original val
new_valset = Subset(trainset, extended_val_indices)

# Optional: create DataLoaders
batch_size = 48
new_testloader = DataLoader(new_testset, batch_size=batch_size, shuffle=False, num_workers=4)
new_valloader = DataLoader(new_valset, batch_size=batch_size, shuffle=False, num_workers=4)

# Save subsets
torch.save(new_testset, "data/cifar10_selected_test.pt")
torch.save(new_valset, "data/cifar10_extended_val.pt")

print(f"✅ Train: {len(train_subset)}, New Val: {len(new_valset)}, New Test: {len(new_testset)}")


✅ Train: 45000, New Val: 14000, New Test: 1000
