In [39]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision

import numpy as np
import matplotlib.pyplot as plt

# MNIST data preparation

In [36]:
# train data
train_data = torchvision.datasets.MNIST(root='./mnist', train=True)
train_x = Variable(torch.unsqueeze(train_data.train_data, dim=1), volatile=True).type(torch.FloatTensor)
train_x = train_x[:10000]/255.

train_y = train_data.train_labels[:10000]

# test data
test_data = torchvision.datasets.MNIST(root='./mnist', train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)
test_x = test_x[:2000]/255.

test_y = test_data.test_labels[:2000]

# declare model

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(# input size(1,28,28)
            nn.Conv2d( 
                in_channels = 1, # gray(if RGB, then in_channel=3)
                out_channels = 16, # num_filter
                kernel_size=5, # filter's height and width = 5
                stride=1, # shift 1 distance each time
                padding=2, # if stride=1, padding=(kernel_size-1)/2, filled with zero
            ),# -->(16, 28, 28)
            nn.ReLU(),# -->(16, 28, 28)
            nn.MaxPool2d(kernel_size=2,),# -->(16, 14, 14)
        )
        self.conv2 = nn.Sequential(# input size(16,14,14)
            nn.Conv2d(16, 32, 5, 1, 2), #(in, out, kernel, strid, padding) # -->(32, 14, 14)
            nn.ReLU(),# -->(32, 14, 14)
            nn.MaxPool2d(2), # -->(32, 7, 7)
        )
        self.out = nn.Linear(32*7*7, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x) # (batch, 32, 7, 7)
        x = x.view(x.size(0), -1) #(batch, 32*7*7)
        output = self.out(x)
        return output

In [23]:
# load weight
CNN_b50=CNN()
CNN_b50.load_state_dict(torch.load('./output/batch_50_weight'))

CNN_b1000=CNN()
CNN_b1000.load_state_dict(torch.load('./output/batch_1000_weight'))

In [40]:
# six tensor should be update

# cnn.state_dict()['conv1.0.weight']
# cnn.state_dict()['conv1.0.bias']
# cnn.state_dict()['conv2.0.weight']
# cnn.state_dict()['conv2.0.bias']
# cnn.state_dict()['out.weight']
# cnn.state_dict()['out.bias']

# evaluation

In [44]:
def interpolation(cnn, b_50, b_1000, alpha): # (model, state_dict, state_dict, ratio)
    keys = ['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'out.weight', 'out.bias']
    for key in keys:
        cnn.state_dict()[key] = (1-alpha) * b_1000[key] + alpha * b_1000[key]
    
    return cnn
        
        
def evaluate(cnn):
    loss_func = nn.CrossEntropyLoss()

    ### training
    train_pred = cnn.forward(train_x)
    # train_loss
    train_loss = loss_func(train_pred, Variable(train_y)).data[0]
    # test_acc
    train_pred = torch.max(train_pred, 1)[1].data.squeeze()
    train_acc = sum(train_pred ==train_y)/ train_y.size(0)

    ### testing
    test_pred = cnn.forward(test_x)
    # test_loss
    test_loss = loss_func(test_pred, Variable(test_y)).data[0]
    # test_acc
    test_pred = torch.max(test_pred, 1)[1].data.squeeze()
    test_acc = sum(test_pred ==test_y)/ test_y.size(0)
    
    return train_acc, train_loss, test_acc, test_loss

In [46]:
result = {}

result['alpha'] = []
result['train_acc'] = []
result['test_acc'] = []
result['train_loss'] = []
result['test_loss'] = []

In [45]:
alphaList = np.arange(-1, 2, 0.01)
for alpha in alphaList:
    customCNN = interpolation(CNN(), CNN_b50.state_dict(), CNN_b1000.state_dict(), alpha)
    train_acc, train_loss, test_acc, test_loss = evaluate(customCNN)
    
    print('train_loss %.4f' % train_loss,
        ' | train_acc %.4f' % train_acc, 
        ' |test_loss %.4f' % test_loss,
        ' |test_acc %.4f' % test_acc)
    
    result['alpha'].append(round(alpha, 4))
    result['train_acc'].append(round(train_acc, 4))
    result['test_acc'].append(round(test_acc, 4))
    result['train_loss'].append(round(train_loss, 4))
    result['test_loss'].append(round(test_loss, 4))

train_loss 2.3071
train_acc 0.0789
test_loss 2.3077
test_acc 0.1015
train_loss 2.2965
train_acc 0.1329
test_loss 2.2986
test_acc 0.1155
train_loss 2.2971
train_acc 0.1015
test_loss 2.3006
test_acc 0.0890
train_loss 2.3119
train_acc 0.0750
test_loss 2.3112
test_acc 0.0750
train_loss 2.3053
train_acc 0.1141
test_loss 2.3039
test_acc 0.1200
train_loss 2.3050
train_acc 0.1452
test_loss 2.3085
test_acc 0.1280
train_loss 2.2979
train_acc 0.1128
test_loss 2.2968
test_acc 0.1380
train_loss 2.3084
train_acc 0.0877
test_loss 2.3069
test_acc 0.0890


KeyboardInterrupt: 