In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision

import numpy as np
import matplotlib.pyplot as plt

# MNIST data preparation

In [2]:
# train data
train_data = torchvision.datasets.MNIST(root='./mnist', train=True)
train_x = Variable(torch.unsqueeze(train_data.train_data, dim=1), volatile=True).type(torch.FloatTensor)
train_x = train_x[:10000]/255.

train_y = train_data.train_labels[:10000]

# test data
test_data = torchvision.datasets.MNIST(root='./mnist', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)
test_x = test_x[:2000]/255.

test_y = test_data.test_labels[:2000]

# declare model

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(# input size(1,28,28)
            nn.Conv2d( 
                in_channels = 1, # gray(if RGB, then in_channel=3)
                out_channels = 16, # num_filter
                kernel_size=5, # filter's height and width = 5
                stride=1, # shift 1 distance each time
                padding=2, # if stride=1, padding=(kernel_size-1)/2, filled with zero
            ),# -->(16, 28, 28)
            nn.ReLU(),# -->(16, 28, 28)
            nn.MaxPool2d(kernel_size=2,),# -->(16, 14, 14)
        )
        self.conv2 = nn.Sequential(# input size(16,14,14)
            nn.Conv2d(16, 32, 5, 1, 2), #(in, out, kernel, strid, padding) # -->(32, 14, 14)
            nn.ReLU(),# -->(32, 14, 14)
            nn.MaxPool2d(2), # -->(32, 7, 7)
        )
        self.out = nn.Linear(32*7*7, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x) # (batch, 32, 7, 7)
        x = x.view(x.size(0), -1) #(batch, 32*7*7)
        output = self.out(x)
        return output

In [15]:
# load weight
CNN_b50=CNN()
CNN_b50.load_state_dict(torch.load('./output/batch_50_weight'))

CNN_b1000=CNN()
CNN_b1000.load_state_dict(torch.load('./output/batch_1000_weight'))

# evaluation

In [140]:
def interpolation(cnn, b_50, b_1000, alpha): # (model, state_dict, state_dict, ratio)
    for param, param50, param1000 in zip(cnn.parameters(), CNN_b50.parameters(), CNN_b1000.parameters()):
        param.data = (1-alpha) * param50.data + alpha * param1000.data
    
    return cnn
        
        
def evaluate(cnn):
    loss_func = nn.CrossEntropyLoss()

    ### training
    train_pred = cnn.forward(train_x)
    # train_loss
    train_loss = loss_func(train_pred, Variable(train_y)).data[0]
    # test_acc
    train_pred = torch.max(train_pred, 1)[1].data.squeeze()
    train_acc = sum(train_pred ==train_y)/ train_y.size(0)

    ### testing
    test_pred = cnn.forward(test_x)
    # test_loss
    test_loss = loss_func(test_pred, Variable(test_y)).data[0]
    # test_acc
    test_pred = torch.max(test_pred, 1)[1].data.squeeze()
    test_acc = sum(test_pred ==test_y)/ test_y.size(0)
    
    return train_acc, train_loss, test_acc, test_loss

In [141]:
# STORAGE FOR RENDER STAGE

result = {}

result['alpha'] = []
result['train_acc'] = []
result['test_acc'] = []
result['train_loss'] = []
result['test_loss'] = []

In [143]:
result

{'alpha': [-1.0,
  -0.98999999999999999,
  -0.97999999999999998,
  -0.96999999999999997,
  -0.95999999999999996,
  -0.94999999999999996,
  -0.93999999999999995,
  -0.93000000000000005,
  -0.92000000000000004,
  -0.91000000000000003,
  -0.90000000000000002,
  -0.89000000000000001,
  -0.88,
  -0.87,
  -0.85999999999999999,
  -0.84999999999999998,
  -0.83999999999999997,
  -0.82999999999999996,
  -0.81999999999999995,
  -0.81000000000000005,
  -0.80000000000000004,
  -0.79000000000000004,
  -0.78000000000000003,
  -0.77000000000000002,
  -0.76000000000000001,
  -0.75,
  -0.73999999999999999,
  -0.72999999999999998,
  -0.71999999999999997,
  -0.70999999999999996,
  -0.69999999999999996,
  -0.68999999999999995,
  -0.68000000000000005,
  -0.67000000000000004,
  -0.66000000000000003,
  -0.65000000000000002,
  -0.64000000000000001,
  -0.63,
  -0.62,
  -0.60999999999999999,
  -0.59999999999999998,
  -0.58999999999999997,
  -0.57999999999999996,
  -0.56999999999999995,
  -0.56000000000000005,
  

In [145]:
import pickle

with open("./output/result.pkl", "wb") as file:
    pickle.dump(result, file)