In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import models
from tqdm import tqdm




In [3]:
class CSVImageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row["filename"])
        img = Image.open(img_path).convert("RGB")
        label = row.get("class_id", -1)  # test set may not have labels

        if self.transform:
            img = self.transform(img)

        return img, label, row["filename"]


In [4]:
checkpoint_path = "/scratch/sp7007/bigrun_resumed/checkpoint0500.pth"

train_csv = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/train_labels.csv"
val_csv   = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/val_labels.csv"
test_csv  = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/test_labels.csv"

train_dir = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/train"
val_dir   = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/val"
test_dir  = "/scratch/sp7007/imagenet/kaggle_data_miniimagenett/test"

num_labels = 100

epochs = 50
batch_size = 512
extract_batch_size = 512
lr = 0.01
weight_decay = 1e-5
optimizer_name = "adamw"
normalization = "l2"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [7]:
transform = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


In [8]:
model = models.__dict__["vit_small"](patch_size=8, return_all_tokens=False)

state = torch.load(checkpoint_path, map_location="cpu")
if "teacher" in state: state = state["teacher"]
if "student" in state: state = state["student"]

# Remove keys not part of backbone
state = {k.replace("module.", "").replace("backbone.", ""): v 
         for k, v in state.items() if not k.startswith("head")}

msg = model.load_state_dict(state, strict=False)
print(msg)

model = model.to(device).eval()


  state = torch.load(checkpoint_path, map_location="cpu")


<All keys matched successfully>


In [9]:
train_ds = CSVImageDataset(train_csv, train_dir, transform)
val_ds   = CSVImageDataset(val_csv, val_dir, transform)
test_ds  = CSVImageDataset(test_csv, test_dir, transform)

train_loader = DataLoader(train_ds, batch_size=extract_batch_size, shuffle=False, num_workers=8)
val_loader   = DataLoader(val_ds, batch_size=extract_batch_size, shuffle=False, num_workers=8)
test_loader  = DataLoader(test_ds, batch_size=extract_batch_size, shuffle=False, num_workers=8)


FileNotFoundError: [Errno 2] No such file or directory: '/scratch/sp7007/imagenet/kaggle_data_miniimagenett/train_labels.csv'