In [1]:
from model import Model
import argparse
import json
import torch

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
def load_data(data_dir, batch_size, split):
    """ Method returning a data loader for labeled data """
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5011, 0.4727, 0.4229), (0.2835, 0.2767, 0.2950))  # RGB means, RGB stds
    ])
    data = datasets.ImageFolder(f'{data_dir}/supervised/{split}', transform=transform)
    data_loader = DataLoader(
        data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0
    )
    return data_loader

def evaluate(model, data_loader, device, split, top_k=5):
    """ Method returning accuracy@1 and accuracy@top_k """
    print(f'\nEvaluating {split} set...')
    model.eval()
    n_samples = 0.
    n_correct_top_1 = 0
    n_correct_top_k = 0

    for img, target in data_loader:
        img, target = img.to(device), target.to(device)
        batch_size = img.size(0)
        n_samples += batch_size

        # Forward
        output = model(img)[0]

        # Top 1 accuracy
        pred_top_1 = torch.topk(output, k=1, dim=1)[1]
        n_correct_top_1 += pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
        pred_top_k = torch.topk(output, k=top_k, dim=1)[1]
        target_top_k = target.view(-1, 1).expand(batch_size, top_k)
        n_correct_top_k += pred_top_k.eq(target_top_k).int().sum().item()

    # Accuracy
    top_1_acc = n_correct_top_1/n_samples
    top_k_acc = n_correct_top_k/n_samples

    # Log
    print(f'{split} top 1 accuracy: {top_1_acc:.4f}')
    print(f'{split} top {top_k} accuracy: {top_k_acc:.4f}')

In [3]:
# Set random seed and device
torch.manual_seed(10)
torch.cuda.manual_seed_all(10)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
# Load pre-trained model
model = Model().to(device) 
print('n parameters: %d' % sum([m.numel() for m in model.parameters()]))

Parameters found in pretrained model:
	conv1.weight
	layer1.0.conv_a1.weight
	layer1.0.bn_a1.weight
	layer1.0.bn_a1.bias
	layer1.0.bn_a1.running_mean
	layer1.0.bn_a1.running_var
	layer1.0.bn_a1.num_batches_tracked
	layer1.0.conv_a2.weight
	layer1.0.bn_a2.weight
	layer1.0.bn_a2.bias
	layer1.0.bn_a2.running_mean
	layer1.0.bn_a2.running_var
	layer1.0.bn_a2.num_batches_tracked
	layer1.0.conv_b1.weight
	layer1.0.bn_b1.weight
	layer1.0.bn_b1.bias
	layer1.0.bn_b1.running_mean
	layer1.0.bn_b1.running_var
	layer1.0.bn_b1.num_batches_tracked
	layer1.0.conv_b2.weight
	layer1.0.bn_b2.weight
	layer1.0.bn_b2.bias
	layer1.0.bn_b2.running_mean
	layer1.0.bn_b2.running_var
	layer1.0.bn_b2.num_batches_tracked
	layer1.0.downsample.0.weight
	layer1.0.downsample.1.weight
	layer1.0.downsample.1.bias
	layer1.0.downsample.1.running_mean
	layer1.0.downsample.1.running_var
	layer1.0.downsample.1.num_batches_tracked
	layer1.1.conv_a1.weight
	layer1.1.bn_a1.weight
	layer1.1.bn_a1.bias
	layer1.1.bn_a1.running_mean


conv1.weight have been loaded correctly in current model.
layer1.0.conv_a1.weight have been loaded correctly in current model.
layer1.0.bn_a1.weight have been loaded correctly in current model.
layer1.0.bn_a1.bias have been loaded correctly in current model.
layer1.0.bn_a1.running_mean have been loaded correctly in current model.
layer1.0.bn_a1.running_var have been loaded correctly in current model.
layer1.0.bn_a1.num_batches_tracked have been loaded correctly in current model.
layer1.0.conv_a2.weight have been loaded correctly in current model.
layer1.0.bn_a2.weight have been loaded correctly in current model.
layer1.0.bn_a2.bias have been loaded correctly in current model.
layer1.0.bn_a2.running_mean have been loaded correctly in current model.
layer1.0.bn_a2.running_var have been loaded correctly in current model.
layer1.0.bn_a2.num_batches_tracked have been loaded correctly in current model.
layer1.0.conv_b1.weight have been loaded correctly in current model.
layer1.0.bn_b1.weight

n parameters: 26959616


In [6]:
data_loader_val = load_data('/scratch/ehd255/ssl_data_96', 32, split='val')

In [8]:
evaluate(model, data_loader_val, device, 'Validation')


Evaluating Validation set...


KeyboardInterrupt: 