In [None]:
import os
import glob

import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
from torch.cuda.amp import autocast
from tqdm import tqdm

from distillistic import ImageNet_loader, set_seed, resnet18

In [None]:
data_path = "../data/imagenet"
batch_size = 32
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
workers = 12

g = set_seed(42)

# train_loader = ImageNet_loader(data_path, batch_size, device,
#     train=True, generator=g, workers=workers, use_amp=torch.cuda.is_available(), use_ffcv=False)
test_loader = ImageNet_loader(data_path, batch_size, device,
    train=False, generator=g, workers=workers, use_amp=cuda, use_ffcv=False)

In [None]:
algo = "baseline"
load_dir = f"./experiments/imagenet/session2/{algo}000/"

model = resnet18(1000, pretrained=False)

In [None]:
if algo == "dml":
    model_pt = glob.glob(os.path.join(load_dir, "student*.pt"))[0]
    state_dict = torch.load(model_pt, map_location=device)
    model.load_state_dict(state_dict)
else:
    state_dict = torch.load(os.path.join(load_dir, "student.pt"), map_location=device)
    model.load_state_dict(state_dict)

In [None]:
model.eval()
outputs = []
targets = []

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tqdm(test_loader)):
        if batch_idx > 100:
            break
        data = data.to(device)
        target = target.to(device)
        
        with autocast(enabled=cuda):
            output = model(data)
        
        if isinstance(output, tuple):
            output = output[0]
        
        outputs.append(output)
        targets.append(target)

In [None]:
out_tensor = torch.cat(outputs, dim=0)
target_tensor = torch.cat(targets, dim=0)

In [None]:
one_hot = F.one_hot(target_tensor, num_classes=1000)
out_prob = F.softmax(out_tensor, dim=-1)
one_hot.shape, out_prob.shape

In [None]:
auc = metrics.roc_auc_score(one_hot, out_prob, multi_class="ovr")

In [None]:
np.unique(target_tensor)