In [1]:
import torch.nn as nn
import torch.nn.init as init
import torch
from math import sqrt
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
import gzip
import cPickle
from sklearn.decomposition import PCA
import random
from sklearn import preprocessing
from torchvision import datasets, transforms
import torch.optim as optim

In [2]:
class SaakNet(nn.Module):
    def __init__(self):
        super(SaakNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.kernel_size = 2
        
        # conv1
        curL_in = 1; curL_out = 3
        self.conv1 = nn.Conv2d(curL_in, curL_out, self.kernel_size, stride=self.kernel_size, bias = False, padding = 2) 
        init.xavier_uniform(self.conv1.weight, init.calculate_gain('relu'))
        self.conv1_b = nn.Conv2d(curL_in, 1, self.kernel_size, stride=self.kernel_size, bias = False, padding = 2)
        f1 = np.ones([1,curL_in,self.kernel_size,self.kernel_size])
        self.f1 = torch.from_numpy(f1/np.linalg.norm(f1))
        self.conv1_b.weight = torch.nn.Parameter(self.f1.float())
        
        # conv2
        curL_in = curL_out * 2 + 1; curL_out = 27
        self.conv2 = nn.Conv2d(curL_in, curL_out, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0) 
        init.xavier_uniform(self.conv2.weight, init.calculate_gain('relu'))
        self.conv2_b = nn.Conv2d(curL_in, 1, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        f2 = np.ones([1,curL_in,self.kernel_size,self.kernel_size])
        self.f2 = torch.from_numpy(f2/np.linalg.norm(f2))
        self.conv2_b.weight = torch.nn.Parameter(self.f2.float())
        
        # conv3
        curL_in = curL_out * 2 + 1; curL_out = 220
        self.conv3 = nn.Conv2d(curL_in, curL_out, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        init.xavier_uniform(self.conv3.weight, init.calculate_gain('relu'))
        self.conv3_b = nn.Conv2d(curL_in, 1, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        f3 = np.ones([1,curL_in,self.kernel_size,self.kernel_size])
        self.f3 = torch.from_numpy(f3/np.linalg.norm(f3))
        self.conv3_b.weight = torch.nn.Parameter(self.f3.float())
        
        # conv4
        curL_in = curL_out * 2 + 1; curL_out = 1800
        self.conv4 = nn.Conv2d(curL_in, curL_out, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        init.xavier_uniform(self.conv4.weight, init.calculate_gain('relu'))
        self.conv4_b = nn.Conv2d(curL_in, 1, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        f4 = np.ones([1,curL_in,self.kernel_size,self.kernel_size])
        self.f4 = torch.from_numpy(f4/np.linalg.norm(f4))
        self.conv4_b.weight = torch.nn.Parameter(self.f4.float())
        
        # conv5
        curL_in = curL_out * 2 + 1; curL_out = 2000
        self.conv5 = nn.Conv2d(curL_in, curL_out, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        init.xavier_uniform(self.conv5.weight, init.calculate_gain('relu'))
        self.conv5_b = nn.Conv2d(curL_in, 1, self.kernel_size, stride=self.kernel_size, bias = False, padding = 0)
        f5 = np.ones([1,curL_in,self.kernel_size,self.kernel_size])
        self.f5 = torch.from_numpy(f5/np.linalg.norm(f5))
        self.conv5_b.weight = torch.nn.Parameter(self.f5.float())
        
        # fc6
        self.fc6 = nn.Linear(4001 * 1 * 1, 333)
        init.xavier_uniform(self.fc6.weight)
        
        # fc7
        self.fc7 = nn.Linear(667, 55)
        init.xavier_uniform(self.fc7.weight)
        
        # fc8
        self.fc8 = nn.Linear(111, 10)
        init.xavier_uniform(self.fc8.weight)
        
    def In_DC(self, ac):
        temp = ac.view(ac.size(0), ac.size(1), -1)
        temp = temp.view(temp.size(0), temp.size(1) * 4, -1)
        temp = temp.view(ac.size(0), temp.size(1), ac.size(2) / 2, ac.size(3) / 2)
        temp = torch.sum(temp, 1)
        return temp
        
    def Augment(self, ac, dc):
        return F.relu(torch.cat((ac,-ac, dc),1))
    
    def Augment_fc(self, ac):
        dc = torch.sum(ac, 1,  keepdim=True)
        return F.relu(torch.cat((ac,-ac, dc),1))
    
    def forward(self, input):
        f2 = self.Augment(self.conv1(input),self.conv1_b(input))
        f3 = self.Augment(self.conv2(f2),self.conv2_b(f2))
        f4 = self.Augment(self.conv3(f3),self.conv3_b(f3))
        f5 = self.Augment(self.conv4(f4),self.conv4_b(f4))
        output = self.Augment(self.conv5(f5),self.conv5_b(f5))
#         print output.size()
        output = output.view(-1, 4001)
#         print output.size()
        fc6 = self.Augment_fc(self.fc6(output))
#         print fc6.size()
        fc7 = self.Augment_fc(self.fc7(fc6))
#         print fc7.size()
        fc8 = self.fc8(fc7)
#         print fc8.size()
        return fc8
    
    

In [3]:
f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
f.close()
train_data = train_set[0].reshape(-1,1,28,28)
valid_data = valid_set[0].reshape(-1,1,28,28)
test_data = test_set[0].reshape(-1,1,28,28)
train_label = train_set[1]
valid_label = valid_set[1]
test_label = test_set[1]

In [4]:
def accuracy(output, label):
    max_lab = torch.max(output, 1)[1]
    acc = (max_lab == label).float().sum() / float(label.size(0))
    return acc.data[0]

In [5]:
def acc_batch(net, data, label, batch_size):
    avg = 0
    for j in range(data.shape[0] // batch_size):
        X = data[j * batch_size: (j + 1) * batch_size,:]
        Y = label[j * batch_size: (j + 1) * batch_size]
        ind = np.random.permutation(X.shape[0])
        X = X[ind,:]
        Y = Y[ind]
        X_ = Variable(torch.from_numpy(X)).cuda()
        Y_ = Variable(torch.from_numpy(Y.astype(int))).cuda()
        Y_O = net(X_)
        acc = accuracy(Y_O, Y_)
        avg = avg + acc / (num_testing // batch_size)
    return avg

In [7]:
net = SaakNet()
net = net.cuda()
optimizer = optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
criterion = nn.CrossEntropyLoss()

batch_size = 100
num_epoch = 100
num_training = train_data.shape[0]
num_testing = test_data.shape[0]
log_step = 100;
step = 0

# val_data = Variable(torch.DoubleTensor(valid_data)).cuda()
# val_label = Variable(torch.LongTensor(valid_label.astype(int))).cuda()

for i in range(num_epoch):
    print("Train for epoch %d" % i)
    inda = np.random.permutation(num_training // batch_size)
    for j in range(num_training // batch_size):
        X = train_data[inda[j] * batch_size: (inda[j] + 1) * batch_size,:]
        Y = train_label[inda[j] * batch_size: (inda[j] + 1) * batch_size]
        ind = np.random.permutation(X.shape[0])
        X = X[ind,:]
        Y = Y[ind]
        X_ = Variable(torch.from_numpy(X)).cuda()
        Y_ = Variable(torch.from_numpy(Y.astype(int))).cuda()
        Y_O = net(X_)
        loss = criterion(Y_O, Y_)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = accuracy(Y_O, Y_)
        if(step % log_step == 0):
            print('iteration (%d): loss = %.3f, accuracy = %.3f' % (step, loss.data[0], acc))
        step += 1
        
    print("validation for epoch %d" % i)
    val_accuracy = acc_batch(net, valid_data, valid_label, batch_size)
#     val_lab = net(val_data)
#     val_accuracy = accuracy(val_lab, val_label)
    print("validation result is %.5f" % val_accuracy)
    test_accuracy = acc_batch(net, test_data, test_label, batch_size)
    print("test result is %.5f" % test_accuracy)

# avg = 0
# for j in range(num_testing // batch_size):
#     X = test_data[j * batch_size: (j + 1) * batch_size,:]
#     Y = test_label[j * batch_size: (j + 1) * batch_size]
#     ind = np.random.permutation(X.shape[0])
#     X = X[ind,:]
#     Y = Y[ind]
#     X_ = Variable(torch.DoubleTensor(X)).cuda()
#     Y_ = Variable(torch.LongTensor(Y.astype(int))).cuda()
#     Y_O = net(X_)
#     acc = accuracy(Y_O, Y_)
#     avg = avg + acc / (num_testing // batch_size)
    
# print("test result is %.5f" % avg)

Train for epoch 0
iteration (0): loss = 2.357, accuracy = 0.100
iteration (100): loss = 0.152, accuracy = 0.950
iteration (200): loss = 0.483, accuracy = 0.850
iteration (300): loss = 0.121, accuracy = 0.960
iteration (400): loss = 0.181, accuracy = 0.940
validation for epoch 0
validation result is 0.96930
test result is 0.96900
Train for epoch 1
iteration (500): loss = 0.082, accuracy = 0.960
iteration (600): loss = 0.187, accuracy = 0.950
iteration (700): loss = 0.028, accuracy = 1.000
iteration (800): loss = 0.104, accuracy = 0.970
iteration (900): loss = 0.049, accuracy = 0.980
validation for epoch 1
validation result is 0.97800
test result is 0.97470
Train for epoch 2
iteration (1000): loss = 0.129, accuracy = 0.980
iteration (1100): loss = 0.023, accuracy = 1.000
iteration (1200): loss = 0.062, accuracy = 0.980
iteration (1300): loss = 0.003, accuracy = 1.000
iteration (1400): loss = 0.016, accuracy = 0.990
validation for epoch 2
validation result is 0.98050
test result is 0.9797

iteration (12100): loss = 0.022, accuracy = 0.980
iteration (12200): loss = 0.000, accuracy = 1.000
iteration (12300): loss = 0.000, accuracy = 1.000
iteration (12400): loss = 0.000, accuracy = 1.000
validation for epoch 24
validation result is 0.98730
test result is 0.98710
Train for epoch 25
iteration (12500): loss = 0.000, accuracy = 1.000
iteration (12600): loss = 0.002, accuracy = 1.000
iteration (12700): loss = 0.002, accuracy = 1.000
iteration (12800): loss = 0.000, accuracy = 1.000
iteration (12900): loss = 0.000, accuracy = 1.000
validation for epoch 25
validation result is 0.98360
test result is 0.98400
Train for epoch 26
iteration (13000): loss = 0.001, accuracy = 1.000
iteration (13100): loss = 0.016, accuracy = 0.990
iteration (13200): loss = 0.001, accuracy = 1.000
iteration (13300): loss = 0.000, accuracy = 1.000
iteration (13400): loss = 0.003, accuracy = 1.000
validation for epoch 26
validation result is 0.98710
test result is 0.98790
Train for epoch 27
iteration (1350

iteration (24100): loss = 0.000, accuracy = 1.000
iteration (24200): loss = 0.006, accuracy = 1.000
iteration (24300): loss = 0.007, accuracy = 1.000
iteration (24400): loss = 0.024, accuracy = 0.990
validation for epoch 48
validation result is 0.98580
test result is 0.98760
Train for epoch 49
iteration (24500): loss = 0.000, accuracy = 1.000
iteration (24600): loss = 0.001, accuracy = 1.000
iteration (24700): loss = 0.000, accuracy = 1.000
iteration (24800): loss = 0.000, accuracy = 1.000
iteration (24900): loss = 0.010, accuracy = 0.990
validation for epoch 49
validation result is 0.98460
test result is 0.98520
Train for epoch 50
iteration (25000): loss = 0.000, accuracy = 1.000
iteration (25100): loss = 0.000, accuracy = 1.000
iteration (25200): loss = 0.001, accuracy = 1.000
iteration (25300): loss = 0.000, accuracy = 1.000
iteration (25400): loss = 0.000, accuracy = 1.000
validation for epoch 50
validation result is 0.98810
test result is 0.98770
Train for epoch 51
iteration (2550

iteration (36100): loss = 0.000, accuracy = 1.000
iteration (36200): loss = 0.000, accuracy = 1.000
iteration (36300): loss = 0.000, accuracy = 1.000
iteration (36400): loss = 0.000, accuracy = 1.000
validation for epoch 72
validation result is 0.99060
test result is 0.99000
Train for epoch 73
iteration (36500): loss = 0.000, accuracy = 1.000
iteration (36600): loss = 0.000, accuracy = 1.000
iteration (36700): loss = 0.000, accuracy = 1.000
iteration (36800): loss = 0.000, accuracy = 1.000
iteration (36900): loss = 0.000, accuracy = 1.000
validation for epoch 73
validation result is 0.99080
test result is 0.98990
Train for epoch 74
iteration (37000): loss = 0.000, accuracy = 1.000
iteration (37100): loss = 0.000, accuracy = 1.000
iteration (37200): loss = 0.000, accuracy = 1.000
iteration (37300): loss = 0.000, accuracy = 1.000
iteration (37400): loss = 0.000, accuracy = 1.000
validation for epoch 74
validation result is 0.99080
test result is 0.98990
Train for epoch 75
iteration (3750

iteration (48100): loss = 0.000, accuracy = 1.000
iteration (48200): loss = 0.000, accuracy = 1.000
iteration (48300): loss = 0.000, accuracy = 1.000
iteration (48400): loss = 0.000, accuracy = 1.000
validation for epoch 96
validation result is 0.99060
test result is 0.99000
Train for epoch 97
iteration (48500): loss = 0.000, accuracy = 1.000
iteration (48600): loss = 0.000, accuracy = 1.000
iteration (48700): loss = 0.000, accuracy = 1.000
iteration (48800): loss = 0.000, accuracy = 1.000
iteration (48900): loss = 0.000, accuracy = 1.000
validation for epoch 97
validation result is 0.99080
test result is 0.99000
Train for epoch 98
iteration (49000): loss = 0.000, accuracy = 1.000
iteration (49100): loss = 0.000, accuracy = 1.000
iteration (49200): loss = 0.000, accuracy = 1.000
iteration (49300): loss = 0.000, accuracy = 1.000
iteration (49400): loss = 0.000, accuracy = 1.000
validation for epoch 98
validation result is 0.99090
test result is 0.98990
Train for epoch 99
iteration (4950

In [50]:
model = SaakNet()
model = model.cuda()
valid_image = valid_image[:100,:]
valid_ima = Variable(torch.from_numpy(valid_image))
valid_ima = valid_ima.cuda()
output = model(valid_ima)

torch.Size([100, 4001, 1, 1])
torch.Size([100, 4001])
torch.Size([100, 667])
torch.Size([100, 111])
torch.Size([100, 10])
