In [1]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.datasets
import torchvision

from mean_teacher import datasets, architectures

In [2]:
dataset_config = datasets.__dict__['sslMini']()

In [37]:
def load_weights(model_arch, pretrained_model_path, cuda=True):
        # Load pretrained model
        pretrained_model = torch.load(f=pretrained_model_path, map_location="cuda" if cuda else "cpu")

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in pretrained_model['state_dict'].items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        # Load pre-trained weights in current model
        with torch.no_grad():
            model_arch.load_state_dict(new_state_dict, strict=True)

        # Debug loading
        #print('Parameters found in pretrained model:')
        pretrained_layers = new_state_dict.keys()
        #for l in pretrained_layers:
        #    print('\t' + l)
        #print('')

        for name, module in model_arch.state_dict().items():
            if name in pretrained_layers:
                assert torch.equal(new_state_dict[name].cpu(), module.cpu())
                #print('{} have been loaded correctly in current model.'.format(name))
            else:
                raise ValueError("state_dict() keys do not match")
                
        return model_arch


In [10]:
evaldir = "/scratch/ijh216/ssl_mini/supervised/val"

eval_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(evaldir, dataset_config['eval_transformation']),
                                              batch_size=32,
                                              shuffle=False,
                                              num_workers=2,  # Needs images twice as fast
                                              #pin_memory=True,
                                              drop_last=False)

In [4]:
import os

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'

In [39]:
model_dir = "/scratch/ijh216/ssl/ssl_shake_mini/2019-05-01_19-04-25/10/transient/checkpoint.230.ckpt" 
model = architectures.__dict__['cifar_shakeshake26']().to(device)
model = load_weights(model, model_dir, cuda=False)

In [None]:
n_samples = 0.
n_correct_top_1 = 0
n_correct_top_k = 0

for img, target in eval_loader:
    img, target = img.to(device), target.to(device)
    batch_size = img.size(0)
    n_samples += batch_size

        # Forward
    output = model(img)[0]

        # Top 1 accuracy
    pred_top_1 = torch.topk(output, k=1, dim=1)[1]
    n_correct_top_1 += pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
    pred_top_k = torch.topk(output, k=5, dim=1)[1]
    target_top_k = target.view(-1, 1).expand(32, 5)
    n_correct_top_k += pred_top_k.eq(target_top_k).int().sum().item()

    # Accuracy
top_1_acc = n_correct_top_1/n_samples
top_k_acc = n_correct_top_k/n_samples

In [41]:
with torch.no_grad():
    out = model(inp)[0]

In [50]:
pred_top_1 = torch.topk(out, k=1, dim=1)[1]
n_correct_top_1 = pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
pred_top_k = torch.topk(out, k=5, dim=1)[1]
target_top_k = target.view(-1, 1).expand(32, 5)
n_correct_top_k = pred_top_k.eq(target_top_k).int().sum().item()

In [53]:
torch.topk(out, k=5, dim=1)[1]

tensor([[173, 175, 989, 166, 995],
        [299, 362, 472, 685, 293],
        [573, 606, 721, 835, 145],
        [842, 257, 262, 403,  67],
        [325, 745,  54,  67, 266],
        [278, 371, 993, 460, 431],
        [117, 993, 730, 703, 561],
        [108, 344, 322,  44, 201],
        [144, 216, 823, 541, 725],
        [114, 241, 873, 498, 178],
        [320, 542, 510, 630, 545],
        [505, 366, 489, 698, 339],
        [528, 297, 728, 413, 540],
        [588, 806, 801, 585, 500],
        [772, 656, 594, 785, 214],
        [186, 297, 410, 846, 694],
        [335, 564, 363, 286, 419],
        [568, 342, 471, 238,  82],
        [ 84, 386, 336, 846, 541],
        [525, 432, 641, 565, 233],
        [242, 188, 193, 389, 113],
        [114, 312, 730,  89,  77],
        [693, 568, 792, 160, 754],
        [138, 289, 970, 594, 842],
        [145, 729, 512, 792, 715],
        [ 84,   2, 328, 281, 832],
        [459, 946, 521,  29, 819],
        [423, 496, 387, 683, 570],
        [150,  47,  

In [42]:
accuracy(out, target, topk=(1,5))

[tensor([0.]), tensor([0.])]

In [45]:
out.topk(5, 1, True, True)[1]

tensor([[173, 175, 989, 166, 995],
        [299, 362, 472, 685, 293],
        [573, 606, 721, 835, 145],
        [842, 257, 262, 403,  67],
        [325, 745,  54,  67, 266],
        [278, 371, 993, 460, 431],
        [117, 993, 730, 703, 561],
        [108, 344, 322,  44, 201],
        [144, 216, 823, 541, 725],
        [114, 241, 873, 498, 178],
        [320, 542, 510, 630, 545],
        [505, 366, 489, 698, 339],
        [528, 297, 728, 413, 540],
        [588, 806, 801, 585, 500],
        [772, 656, 594, 785, 214],
        [186, 297, 410, 846, 694],
        [335, 564, 363, 286, 419],
        [568, 342, 471, 238,  82],
        [ 84, 386, 336, 846, 541],
        [525, 432, 641, 565, 233],
        [242, 188, 193, 389, 113],
        [114, 312, 730,  89,  77],
        [693, 568, 792, 160, 754],
        [138, 289, 970, 594, 842],
        [145, 729, 512, 792, 715],
        [ 84,   2, 328, 281, 832],
        [459, 946, 521,  29, 819],
        [423, 496, 387, 683, 570],
        [150,  47,  

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])

In [31]:
?out.topk

tensor(32)