In [13]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.datasets
import torchvision

from mean_teacher import datasets, architectures

In [6]:
dataset_config = datasets.__dict__['sslMini']()

In [10]:
evaldir = "/scratch/ijh216/ssl_mini/supervised/val"

eval_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(evaldir, dataset_config['eval_transformation']),
                                              batch_size=256,
                                              shuffle=False,
                                              num_workers=2 * 2,  # Needs images twice as fast
                                              pin_memory=True,
                                              drop_last=False)

In [11]:
import os

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [26]:
model_dir = "/scratch/ijh216/ssl/ssl_shake_mini_continue/2019-05-02_22-56-48/10/transient/checkpoint.75.ckpt" 
model = architectures.__dict__['cifar_shakeshake26']().to(device)
model

ResNet32x32(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): ShakeShakeBlock(
      (conv_a1): Conv2d(16, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_a2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_a2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b1): Conv2d(16, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_b2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn_b2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1

In [28]:
def load_weights(model_arch, pretrained_model_path, cuda=True):
        # Load pretrained model
        pretrained_model = torch.load(f=pretrained_model_path, map_location="cuda" if cuda else "cpu")

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in pretrained_model['state_dict'].items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        # Load pre-trained weights in current model
        with torch.no_grad():
            model_arch.load_state_dict(new_state_dict, strict=True)

        # Debug loading
        print('Parameters found in pretrained model:')
        pretrained_layers = new_state_dict.keys()
        for l in pretrained_layers:
            print('\t' + l)
        print('')

        for name, module in model_arch.state_dict().items():
            if name in pretrained_layers:
                assert torch.equal(new_state_dict[name].cpu(), module.cpu())
                print('{} have been loaded correctly in current model.'.format(name))
            else:
                raise ValueError("state_dict() keys do not match")


In [29]:
load_weights(model, model_dir)

Parameters found in pretrained model:
	conv1.weight
	layer1.0.conv_a1.weight
	layer1.0.bn_a1.weight
	layer1.0.bn_a1.bias
	layer1.0.bn_a1.running_mean
	layer1.0.bn_a1.running_var
	layer1.0.bn_a1.num_batches_tracked
	layer1.0.conv_a2.weight
	layer1.0.bn_a2.weight
	layer1.0.bn_a2.bias
	layer1.0.bn_a2.running_mean
	layer1.0.bn_a2.running_var
	layer1.0.bn_a2.num_batches_tracked
	layer1.0.conv_b1.weight
	layer1.0.bn_b1.weight
	layer1.0.bn_b1.bias
	layer1.0.bn_b1.running_mean
	layer1.0.bn_b1.running_var
	layer1.0.bn_b1.num_batches_tracked
	layer1.0.conv_b2.weight
	layer1.0.bn_b2.weight
	layer1.0.bn_b2.bias
	layer1.0.bn_b2.running_mean
	layer1.0.bn_b2.running_var
	layer1.0.bn_b2.num_batches_tracked
	layer1.0.downsample.0.weight
	layer1.0.downsample.1.weight
	layer1.0.downsample.1.bias
	layer1.0.downsample.1.running_mean
	layer1.0.downsample.1.running_var
	layer1.0.downsample.1.num_batches_tracked
	layer1.1.conv_a1.weight
	layer1.1.bn_a1.weight
	layer1.1.bn_a1.bias
	layer1.1.bn_a1.running_mean


In [30]:
for inp, target in eval_loader:
    inp, target = inp.to(device), target.to(device)
    break

In [34]:
with torch.no_grad():
    out = model(inp)

In [51]:
probs = out[0]
probs_sftmx = F.softmax(probs, dim=1)

In [41]:
_, pred = probs.topk(5, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

In [49]:
pred

tensor([[687, 395, 114,  ..., 425, 261, 798],
        [168, 347, 204,  ..., 349, 781, 136],
        [ 63, 219, 203,  ..., 617, 338,   8],
        [175, 379, 986,  ..., 645, 391,  21],
        [ 66, 515, 959,  ..., 273, 780,   9]], device='cuda:0')