In [82]:
import os

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision.datasets
import torchvision.transforms as transforms
import torchvision

from mean_teacher import datasets, architectures

from IPython.display import clear_output

device = "cuda" if torch.cuda.is_available() else "cpu"

import matplotlib.pyplot as plt

to_image = transforms.ToPILImage()

In [81]:
!squeue -u ijh216

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
           1112955 p40_4,p10    ssm_c   ijh216 PD       0:00      1 (Priority)
           1112956 p40_4,p10     ss_c   ijh216 PD       0:00      1 (Priority)
           1112947    c32_38 jupyterC   ijh216  R      46:14      1 c36-03


In [61]:
!scancel 1079531  

In [83]:
dataset_config = datasets.__dict__['sslMini']()

In [84]:
def load_weights(model_arch, pretrained_model_path, state_dict, cuda=True):
        # Load pretrained model
        pretrained_model = torch.load(f=pretrained_model_path, map_location="cuda" if cuda else "cpu")

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in pretrained_model[state_dict].items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v

        # Load pre-trained weights in current model
        with torch.no_grad():
            model_arch.load_state_dict(new_state_dict, strict=True)

        # Debug loading
        #print('Parameters found in pretrained model:')
        pretrained_layers = new_state_dict.keys()
        #for l in pretrained_layers:
        #    print('\t' + l)
        #print('')

        for name, module in model_arch.state_dict().items():
            if name in pretrained_layers:
                assert torch.equal(new_state_dict[name].cpu(), module.cpu())
                #print('{} have been loaded correctly in current model.'.format(name))
            else:
                raise ValueError("state_dict() keys do not match")
                
        return model_arch


In [85]:
BATCH_SIZE = 64

evaldir = "/scratch/ehd255/ssl_data_96/supervised/val"

eval_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(evaldir, dataset_config['eval_transformation']),
                                              batch_size=BATCH_SIZE,
                                              shuffle=True,
                                              num_workers=2,
                                              #pin_memory=True,
                                              drop_last=False)


In [88]:
pretrained_model_path = "/scratch/ijh216/ssl/ssl_shake_mini_augment/2019-05-06_18-04-18/10/transient/checkpoint.325.ckpt"

In [92]:
aa = torch.load(f=pretrained_model_path, map_location="cuda" if False else "cpu")

In [None]:
class Shake(Function):
    @classmethod
    def forward(cls, ctx, inp1, inp2, training):
        assert inp1.size() == inp2.size()
        gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)]
        gate = inp1.new(*gate_size)
        if training:
            gate.uniform_(0, 1)
        else:
            gate.fill_(0.5)
        return inp1 * gate + inp2 * (1. - gate)

    @classmethod
    def backward(cls, ctx, grad_output):
        grad_inp1 = grad_inp2 = grad_training = None
        gate_size = [grad_output.size()[0], *itertools.repeat(1,
                                                              grad_output.dim() - 1)]
        gate = grad_output.detach().new(*gate_size).uniform_(0, 1)
        if ctx.needs_input_grad[0]:
            grad_inp1 = grad_output * gate
        if ctx.needs_input_grad[1]:
            grad_inp2 = grad_output * (1 - gate)
        assert not ctx.needs_input_grad[2]
        return grad_inp1, grad_inp2, grad_training


In [95]:

class Shake(Function):
    @classmethod
    def forward(cls, ctx, inp1, inp2, training):
        assert inp1.size() == inp2.size()
        gate_size = [inp1.size()[0], *itertools.repeat(1, inp1.dim() - 1)]
        gate = inp1.new(*gate_size)
        if training:
            gate.uniform_(0, 1)
        else:
            gate.fill_(0.5)
        return inp1 * gate + inp2 * (1. - gate)

    @classmethod
    def backward(cls, ctx, grad_output):
        grad_inp1 = grad_inp2 = grad_training = None
        gate_size = [grad_output.size()[0], *itertools.repeat(1,
                                                              grad_output.dim() - 1)]
        gate = Variable(grad_output.data.new(*gate_size).uniform_(0, 1))
        if ctx.needs_input_grad[0]:
            grad_inp1 = grad_output * gate
        if ctx.needs_input_grad[1]:
            grad_inp2 = grad_output * (1 - gate)
        assert not ctx.needs_input_grad[2]
        return grad_inp1, grad_inp2, grad_training


def shake(inp1, inp2, training=False):
    return Shake.apply(inp1, inp2, training)

dict_keys(['epoch', 'global_step', 'arch', 'state_dict', 'ema_state_dict', 'best_prec1', 'optimizer'])

In [65]:
model_dir = "/scratch/ijh216/ssl/ssl_shake_mini_augment/2019-05-06_18-04-18/10/transient/checkpoint.325.ckpt" 
model = architectures.__dict__['cifar_shakeshake26']().to(device)
model = load_weights(model, model_dir, state_dict="ema_state_dict", cuda=True)

n_samples = 0.
n_correct_top_1 = 0
n_correct_top_k = 0

for i, (img, target) in enumerate(eval_loader):
    img, target = img.to(device), target.to(device)
    n_samples += BATCH_SIZE

        # Forward
    output = model(img)[0]

        # Top 1 accuracy
    pred_top_1 = torch.topk(output, k=1, dim=1)[1]
    n_correct_top_1 += pred_top_1.eq(target.view_as(pred_top_1)).int().sum().item()

        # Top k accuracy
    pred_top_k = torch.topk(output, k=5, dim=1)[1]
    target_top_k = target.view(-1, 1).expand(BATCH_SIZE, 5)
    n_correct_top_k += pred_top_k.eq(target_top_k).int().sum().item()
    
    if i % 100 == 0:
        print("******************************")
        print("Acc@1", n_correct_top_1/n_samples)
        print("Acc@5", n_correct_top_k/n_samples)
        print("******************************")
    
    # Accuracy
top_1_acc = n_correct_top_1/n_samples
top_k_acc = n_correct_top_k/n_samples

print("******************************")
print("Acc@1", top_1_acc)
print("Acc@5", top_k_acc)
print("******************************")

******************************
Acc@1 0.40625
Acc@5 0.640625
******************************
******************************
Acc@1 0.29718440594059403
Acc@5 0.5383663366336634
******************************
******************************
Acc@1 0.3023165422885572
Acc@5 0.5359141791044776
******************************
******************************
Acc@1 0.30632267441860467
Acc@5 0.5371677740863787
******************************
******************************
Acc@1 0.303927680798005
Acc@5 0.5353413341645885
******************************
******************************
Acc@1 0.3029565868263473
Acc@5 0.5362400199600799
******************************
******************************
Acc@1 0.30355657237936773
Acc@5 0.5358257071547421
******************************
******************************
Acc@1 0.3045203281027104
Acc@5 0.5385164051355207
******************************
******************************
Acc@1 0.304229088639201
Acc@5 0.5373946629213483
******************************
************