In [1]:
%reload_ext autoreload
%autoreload 2
import torch 
import sys; sys.path.append('../')
from src.data_utils import get_imagenet
from src.model_utils import load_resnet50_from_checkpoint
from tqdm import tqdm
import os
import numpy as np

CHECKPOINT_DIR = 'checkpoints/resnet50'
# Read the file confs/steps.txt into STEPS.
with open('confs/steps.txt') as f:
    STEPS = [int(line.strip()) for line in f.readlines()]

In [2]:
def get_model_at_step(step):
    """
    Get the model at a certain step.
    """
    model_path = os.path.join(CHECKPOINT_DIR, f'imagenet-step={step}.ckpt')
    return load_resnet50_from_checkpoint(model_path)

initial_model = get_model_at_step(STEPS[0])
final_model = get_model_at_step(STEPS[-1])


In [3]:
model = load_resnet50_from_checkpoint('checkpoints/resnet50/imagenet-step=450359.ckpt')
model.cuda()

In [25]:
imagenet_data = get_imagenet(imagenet_path='/data/', train=False, no_transform=False)

In [26]:
loader = torch.utils.data.DataLoader(imagenet_data,
                          batch_size=2048,
                          shuffle=False,
                          num_workers=4,
                          pin_memory=True,
                          prefetch_factor=2)

In [28]:
from tqdm import tqdm
def evaluate_model(model, data_loader, cuda=torch.cuda.is_available()):
    """
    Evaluate a model on a dataset.
    
    Args:
        model: The model to evaluate.
        data_loader: The dataset.
    
    Returns:
        The accuracy of the model.
    """
    if cuda:
        model.cuda()
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in tqdm(data_loader):
            if cuda:
                data, target = data.cuda(), target.cuda()
            pred = model(data).argmax(dim=1)
            correct += pred.eq(target.data.view_as(pred)).sum().item()
            total += len(data)
    return correct / total

evaluate_model(model, loader)

100%|██████████| 25/25 [01:06<00:00,  2.67s/it]


0.75722