In [None]:
from model import Model
import argparse
import json
import torch
import time

from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [None]:
def load_data(data_dir, batch_size, split):
    """ Method returning a data loader for labeled data """
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5011, 0.4727, 0.4229), (0.2835, 0.2767, 0.2950))  # RGB means, RGB stds
    ])
    data = datasets.ImageFolder(f'{data_dir}/supervised/{split}', transform=transform)
    data_loader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8
    )
    return data_loader

def evaluate(model, data_loader, device, split, top_k=5):
    """ Method returning accuracy@1 and accuracy@top_k """
    print(f'\nEvaluating {split} set...')
    model.eval()
    i = 0
    n_samples = 0.
    n_correct_top_1 = 0
    n_correct_top_k = 0

    for i, (img, target) in enumerate(data_loader):
        img, target = img.to(device), target.to(device)
        batch_size = img.size(0)
        n_samples += batch_size

        # Forward
        output = model(img)[0]

        # Top 1 accuracy
        pred_top_1 = torch.topk(output, k=1, dim=1)[1]
        n_correct_top_1 += pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
        pred_top_k = torch.topk(output, k=top_k, dim=1)[1]
        target_top_k = target.view(-1, 1).expand(batch_size, top_k)
        n_correct_top_k += pred_top_k.eq(target_top_k).int().sum().item()
        
        # Accuracy
        top_1_acc = n_correct_top_1/n_samples
        top_k_acc = n_correct_top_k/n_samples
        
        if i % 100 == 0:
            print(f"Iteration {i}: {n_samples}")
            print(f"Top 1 {top_1_acc}")
            print(f"Top k {top_k_acc}")
            print("*****************************************")

    # Accuracy
    top_1_acc = n_correct_top_1/n_samples
    top_k_acc = n_correct_top_k/n_samples

    # Log
    print(f'{split} top 1 accuracy: {top_1_acc:.4f}')
    print(f'{split} top {top_k} accuracy: {top_k_acc:.4f}')

In [None]:
# Set random seed and device
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Load pre-trained model
import importlib
import model
importlib.reload(model)
from model import Model
model = Model().to(device)
print('n parameters: %d' % sum([m.numel() for m in model.parameters()]))

In [None]:
data_loader_val = load_data('/scratch/ehd255/ssl_data_96', 32, split='val')

In [None]:
with torch.no_grad():
    t0 = time.time()
    print('Start time: {}'.format(time.asctime(time.localtime(t0))))
    evaluate(model, data_loader_val, device, 'Validation')
    t1 = time.time()
    print('Validation time: {:.3f} s finished {}'.format(t1 - t0, time.asctime(time.localtime(t1))))

In [None]:
import matplotlib.pyplot as plt

top_1 = [0.09, 19.7, 27.9, 31.2, 32.7, 33.6, 33.7, 33.6, 33.9]
top_5 = [0.5, 40.9, 51.7, 55.8, 57.3, 58.1, 58.1, 57.7, 57.9]
ckpts = [0, 15, 30, 45, 60, 75, 90, 105, 120]

plt.plot(ckpts, top_1, label="top-1", color="green")
plt.plot(ckpts, top_5, label="top-5", color="blue")
plt.xlabel('epochs')
plt.ylabel('accuracy (%)')
plt.legend()
plt.show()