In [1]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.nn as nn
import time
import random
import os

In [2]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    print("using cuda")
else:
    DEVICE = torch.device('cpu')
    print("using cpu")

using cuda


In [3]:
def switch_to_device(dataset, device = None):
    tensor_list_X, tensor_list_Y = [], []
    for x, y in dataset:
        tensor_list_X.append(x)
        tensor_list_Y.append(y)
    
    X = torch.stack(tensor_list_X)
    Y = torch.tensor(tensor_list_Y)
    if device is not None:
        X = X.to(device)
        Y = Y.to(device)
    return torch.utils.data.TensorDataset(X, Y)

In [4]:
def get_mnist_dl(batch_size_train = 256, batch_size_valid = 1024, device = None):
    #transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5),(0.5))])
    transform = transforms.Compose([transforms.ToTensor()])
    
    data_train = MNIST('./datasets', train = True, download = True, transform = transform)
    data_train = switch_to_device(data_train, device)
    data_train, data_valid = torch.utils.data.random_split(data_train, [55000, 5000])

    data_test = MNIST('./datasets', train = False, download = True, transform = transform)
    data_test = switch_to_device(data_test, device)

    train_dl = DataLoader(data_train, batch_size = batch_size_train, shuffle = True)
    valid_dl = DataLoader(data_valid, batch_size = batch_size_valid, shuffle = False)
    test_dl = DataLoader(data_test, batch_size = batch_size_valid, shuffle = False)

    return train_dl, valid_dl, test_dl

In [5]:
def print_stats(stats):

  fig, (ax1, ax2) = plt.subplots(1,2,figsize=(7,3), dpi=110)
  ax1.grid()
  ax2.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Valid Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")

  itrs = [x[0] for x in stats['train-loss']]
  loss = [x[1] for x in stats['train-loss']]
  ax1.plot(itrs, loss)

  itrs = [x[0] for x in stats['valid-acc']]
  acc = [x[1] for x in stats['valid-acc']]
  ax2.plot(itrs, acc)

  ax1.set_ylim(0.0, 14.05)
  ax2.set_ylim(0.0, 1.05)

In [6]:
class Conv_2D():
    def __init__(self, input_dim, output_dim, k_size = 3, stride = 1, padding = 1):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.k_size = k_size
        self.stride = stride
        self.padding = padding

        #The dimension of weight is (OIKK)
        self.weight = torch.normal(mean = torch.full((self.output_dim, self.input_dim, self.k_size, self.k_size), 0.), std = torch.full((self.output_dim, self.input_dim, self.k_size, self.k_size), 0.1)).to(DEVICE)
        #The dimension of bias is (O1)
        self.bias = torch.normal(mean = torch.full([self.output_dim], 0.), std = torch.full([self.output_dim], 0.1)).to(DEVICE)
        self.weights_grad = torch.zeros(self.weight.shape).to(DEVICE)
        self.bias_grad = torch.zeros(self.bias.shape).to(DEVICE)

        self.Jacobi = None
        self.input = None
        self.output_h = None
        self.output_w = None

    def forward(self, input):
        '''
        input : (N,I,H,W)
        '''
        m = nn.ZeroPad2d(self.padding)
        input = m(input)
        self.input = input
        self.Jacobi = torch.zeros(input.shape)
        N, C, H, W = input.shape
        self.output_h = int((H - self.k_size) / self.stride + 1)
        self.output_w = int((W - self.k_size) /self.stride + 1)
        
        unfold_input = nn.functional.unfold(input,(self.k_size, self.k_size))  #(N*(I*K*K)*-1)
        output = unfold_input.transpose(1,2).matmul(self.weight.view(self.weight.shape[0], -1).t()).transpose(1,2) #(N*O*-1)
        output = nn.functional.fold(output, (self.output_h, self.output_w), (1,1))
        output = output + self.bias.view(1, -1, 1, 1)
        return output
    
    '''
    def backward(self, Next_Jacobi):
        
        #Next_Jacobi: (N, output_dim, output_h output_w)
        
        def judge_h(x):
            if x % 1 == 0 and x < self.output_h and x > -1:
                return int(x)
            else:
                return -1
        
        def judge_w(x):
            if x % 1 == 0 and x < self.output_w and x > -1:
                return int(x)
            else:
                return -1

        #for the value at one position in input
        for i in range(self.Jacobi.shape[2]):
            for j in range(self.Jacobi.shape[3]):
                mask = torch.zeros((self.input.shape[0], self.output_dim, self.k_size, self.k_size)).to(DEVICE)
                #What positions in the output tensor can this vale effect
                index_h = [(i - k) / self.stride for k in range(self.k_size)]
                index_w = [(j - k) / self.stride for k in range(self.k_size)]
                index_h_ = list(map(judge_h, index_h))
                index_w_ = list(map(judge_w, index_w))

                for m in range(self.k_size):
                    for n in range(self.k_size):
                        if index_h_[m] != -1 and index_w_[n] != -1:
                            mask[:, :, m, n] = Next_Jacobi[:, :, index_h_[m], index_w_[n]] 
                        else:
                            continue
                
                mask = mask.reshape(self.input.shape[0], 1, self.output_dim, self.k_size, self.k_size)
                Jacobi_t = mask * self.weight.permute(1, 0, 2, 3) #(N, 1, O, K, K)*(I, O, K, K) = (N, I, O, K, K)
                Jacobi_s_t = torch.sum(Jacobi_t, dim = (2, 3, 4))
                self.Jacobi[:, :, i, j] = Jacobi_s_t
        
        #Get rid of padding
        self.Jacobi = self.Jacobi[:, :, self.padding:self.input.shape[2]-self.padding, self.padding:self.input.shape[3]-self.padding].to(DEVICE)

        N, C, K, H, W = self.input.shape[0], self.input.shape[1], self.k_size**2, self.output_h, self.output_w
        tmp = torch.zeros((N,C,K,H,W)).to(DEVICE)
        for i in range(self.k_size):
            for j in range(self.k_size):
                tmp[:, :, i*self.k_size + j, :, :] = self.input[:, :, i : self.output_h + i : self.stride, j : self.output_w + j : self.stride]

        tmp_new = torch.sum(Next_Jacobi.reshape(N, self.output_dim, 1, 1, H, W)*tmp.reshape(N,1,C,K,H,W), dim = (4, 5))

        self.weights_grad = torch.sum(tmp_new.reshape(N, self.output_dim, C, self.k_size, self.k_size).permute(1, 2, 0, 3, 4), dim = 2)

        tmp_bias = torch.sum(Next_Jacobi, dim = (2,3))
        self.bias_grad = torch.sum(tmp_bias, dim = 0)

        return self.Jacobi
    '''

    def backward(self, Next_Jacobi):
        #Next_Jacobi: (N, output_dim, output_h output_w)
        #Only consider output_h = output_w
        padding_num = self.stride*self.input.shape[2] + self.k_size - self.stride - Next_Jacobi.shape[2]
        padding_num /= 2
        m = nn.ZeroPad2d(int(padding_num))
        padded_Next_Jacobi = m(Next_Jacobi)

        rotate_weight = self.weight.transpose(0, 1)
        rotate_weight = torch.rot90(rotate_weight, k = 2, dims = [2, 3])
        unfold_next_Jacobi = nn.functional.unfold(padded_Next_Jacobi,(self.k_size, self.k_size))
        self.Jacobi = unfold_next_Jacobi.transpose(1,2).matmul(rotate_weight.reshape(rotate_weight.shape[0], -1).t()).transpose(1,2) #(N*O*-1)
        self.Jacobi = nn.functional.fold(self.Jacobi, (self.input.shape[2], self.input.shape[3]), (1,1))

        tmp_bias = torch.sum(Next_Jacobi, dim = (2,3))
        self.bias_grad = torch.sum(tmp_bias, dim = 0)
        
        self.input = self.input.transpose(0,1)
        Next_Jacobi = Next_Jacobi.transpose(0,1)
        unfold_input = nn.functional.unfold(self.input, (Next_Jacobi.shape[2], Next_Jacobi.shape[3]))
        tmp_weights_grad = unfold_input.transpose(1, 2).matmul(Next_Jacobi.reshape(Next_Jacobi.shape[0], -1).t()).transpose(1,2)
        tmp_weights_grad = nn.functional.fold(tmp_weights_grad, (self.k_size, self.k_size), (1,1))
        self.weights_grad = tmp_weights_grad.transpose(0,1)

        return self.Jacobi

    
    def update(self, lr):
        self.weight -= lr*self.weights_grad
        self.bias -= lr*self.bias_grad

In [7]:

class max_pooling_2D():
    def __init__(self, k_size = 2, stride = 2, padding = 0):
        self.k_size = k_size
        self.stride = stride
    
    def forward(self, input):

        out = input.reshape(input.shape[0], input.shape[1], input.shape[2]//self.k_size, self.k_size, input.shape[3]//self.k_size, self.k_size)
        out = out.max(dim = 3)[0].max(dim = 4)[0]
        self.index = out.repeat_interleave(self.k_size, dim = 2).repeat_interleave(self.k_size, dim = 3) == input
        return out

    def backward(self, Next_Jacobi):      
        return Next_Jacobi.repeat_interleave(self.k_size, dim = 2).repeat_interleave(self.k_size, dim = 3) * self.index


In [8]:
class Relu():
    def __init__(self):
        self.Jacobi = None

    def forward(self, input):
        output = (torch.abs(input) + input)/2.0
        self.Jacobi = output.clone()
        self.Jacobi[self.Jacobi>0] = 1
        return output
    
    def backward(self, Next_Jacobi):
        self.Jacobi = self.Jacobi*Next_Jacobi
        return self.Jacobi

In [9]:
class Linear():
    def __init__ (self, input_num, output_num):
        self.input_num, self.output_num = input_num, output_num
        self.weights = torch.normal(mean = torch.full((self.output_num, self.input_num), 0.), std = torch.full((self.output_num, self.input_num), 0.1)).to(DEVICE)
        self.bias = torch.normal(mean = torch.full((1, self.output_num), 0.), std = torch.full((1, self.output_num), 0.1)).to(DEVICE)
        self.weights_grad = torch.zeros(self.weights.shape).to(DEVICE)
        self.bias_grad = torch.zeros(self.bias.shape).to(DEVICE)

        self.Jacobi = None
        self.input = None
    
    def forward(self, input):
        output = torch.matmul(self.weights, input.transpose(1, 0))
        self.input = input
        return output.T + self.bias

    def backward(self, Next_Jacobi):
        self.Jacobi = torch.matmul(self.weights.T, Next_Jacobi.T).T

        self.bias_grad = torch.sum(Next_Jacobi, dim = 0)
        self.weights_grad = torch.matmul(Next_Jacobi.T, self.input)
        return self.Jacobi
        
    def update(self, lr):
        self.bias -= lr*self.bias_grad
        self.weights -= lr*self.weights_grad

In [10]:
class Softmax_CrossEntropy():
    def __init__(self):
        self.Jacobi = None
        self.loss = None

    def forward(self, input, labels):
        batch_size = input.shape[0]
        self.Jacobi = torch.zeros(input.shape).to(DEVICE)
        x = torch.exp(input)
        y = torch.sum(x, dim = 1).reshape(batch_size, 1)
        softmax_output = x/y
        loss = torch.sum(-(labels*torch.log(softmax_output))) / batch_size
        self.loss = loss
        self.Jacobi = (softmax_output - labels)/batch_size
        return self.loss
    
    def backward(self):
        return self.Jacobi
        

In [11]:
class CNN_Nets():
    def __init__(self):
        self.conv1 = Conv_2D(input_dim=1, output_dim=26, k_size=5, stride=1, padding=0) #(N, 26, 24, 24)
        self.Relu_1 = Relu()
        self.maxpooling_1 = max_pooling_2D(k_size=2, stride=2) #(N, 26, 12, 12)

        self.conv2 = Conv_2D(input_dim=26, output_dim=52, k_size=3, stride=1, padding=0) #(N, 52, 10, 10)
        self.Relu_2 = Relu()

        self.conv3 = Conv_2D(input_dim=52, output_dim=10, k_size=1, stride=1, padding=0) #(N, 10, 10, 10)
        self.Relu_3 = Relu()
        self.maxpooling_3 = max_pooling_2D(k_size=2, stride=2) #(N, 10, 5, 5)

        self.fc_1 = Linear(input_num=5*5*10, output_num=1000)
        self.Relu_4 = Relu()

        self.fc_2 = Linear(input_num=1000, output_num=10)
        self.softmax_CrossEntropy = Softmax_CrossEntropy()

        self.output = None
        self.loss = None

    def forward(self, input, labels):
        N, C, H, W = input.shape

        output = self.conv1.forward(input)
        output = self.Relu_1.forward(output)
        output = self.maxpooling_1.forward(output)

        output = self.conv2.forward(output)
        output = self.Relu_2.forward(output)

        output = self.conv3.forward(output)
        output = self.Relu_3.forward(output)
        output = self.maxpooling_3.forward(output)

        output = torch.reshape(output, (N, -1))
        output = self.fc_1.forward(output)
        output = self.Relu_4.forward(output)

        output = self.fc_2.forward(output)
        self.output = output
        loss = self.softmax_CrossEntropy.forward(output, labels)
        self.loss = loss
    
    def backward(self):
        grad = self.softmax_CrossEntropy.Jacobi
        grad = self.fc_2.backward(grad)
        grad = self.Relu_4.backward(grad)
        grad = self.fc_1.backward(grad)
        grad = grad.reshape(grad.shape[0], 10, 5, 5)

        grad = self.maxpooling_3.backward(grad)
        grad = self.Relu_3.backward(grad)
        grad = self.conv3.backward(grad)
        grad = self.Relu_2.backward(grad)
        grad = self.conv2.backward(grad)
        grad = self.maxpooling_1.backward(grad)
        grad = self.Relu_1.backward(grad)
        grad = self.conv1.backward(grad)

    def update(self, lr):
        self.conv1.update(lr)
        self.conv2.update(lr)
        self.conv3.update(lr)
        self.fc_1.update(lr)
        self.fc_2.update(lr)

In [12]:
class MLP_Nets():
    def __init__(self, device = DEVICE):
        self.device = device
        self.fc_1 = Linear(input_num = 28*28, output_num = 1024)
        self.sigmoid_1 = Relu()
        self.fc_2 = Linear(input_num = 1024, output_num = 1024)
        self.sigmoid_2 = Relu()
        self.fc_3 = Linear(input_num = 1024, output_num = 10)
        self.softmax_CrossEntropy = Softmax_CrossEntropy()
        self.output = None
        self.loss = None

    def forward(self, input, labels):
        input = torch.reshape(input, (input.shape[0], 28*28))
        output = self.fc_1.forward(input)
        output = self.sigmoid_1.forward(output)
        output = self.fc_2.forward(output)
        output = self.sigmoid_2.forward(output)
        output = self.fc_3.forward(output)
        self.output = output
        loss = self.softmax_CrossEntropy.forward(output, labels)
        self.loss = loss

    def backward(self):
        grad = self.softmax_CrossEntropy.Jacobi
        grad = self.fc_3.backward(grad)
        grad = self.sigmoid_2.backward(grad)
        grad = self.fc_2.backward(grad)
        grad = self.sigmoid_1.backward(grad)
        grad = self.fc_1.backward(grad)
    
    def update(self, lr):
        self.fc_1.update(lr)
        self.fc_2.update(lr)
        self.fc_3.update(lr)

In [13]:
@torch.no_grad()
def get_acc(model, dl):
  acc = []
  i= 0
  for X, y in dl:
    one_hot_y = torch.zeros(X.shape[0], 10).to(DEVICE)
    one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
    model.forward(X, one_hot_y)
    acc.append(torch.argmax(model.output, dim=1) == y)
    i+=1
    if i == 3:
      break
  acc = torch.cat(acc)
  acc = torch.sum(acc)/len(acc)
  return acc.item()

In [14]:
@torch.no_grad()
def run_experiment(model, train_dl, valid_dl, test_dl, max_epochs=20, lr = 1e-3):

  itr = -1
  stats = {'train-loss': [], 'valid-acc':[]}
  time_list = []
  memory_list = []
  for epoch in range(max_epochs):
    for X, y in train_dl:
        itr += 1
        one_hot_y = torch.zeros(X.shape[0], 10).to(DEVICE)
        one_hot_y[[i for i in range(X.shape[0])], [k.item() for k in y]] = 1
        start = time.time()
        model.forward(X, one_hot_y)
        memory_list.append(torch.cuda.memory_allocated()/1024/1024)
        model.backward()
        model.update(lr)
        time_list.append(time.time()-start)
        stats['train-loss'].append((itr, model.loss.item()))

        if itr % 20 == 0:

          valid_acc = get_acc(model, valid_dl)
          stats['valid-acc'].append((itr, valid_acc))
          s = f"{epoch}:{itr} [train] loss:{model.loss.item():.3f}, [valid] acc:{valid_acc:.3f}, time: {np.sum(time_list)/20}, memory: {np.sum(memory_list)/len(memory_list)} "
          print(s)
          time_list = []
          memory_list = []

  test_acc = get_acc(model, test_dl)
  print(f"[test] acc:{test_acc:.3f}")
  return stats

In [15]:
max_epochs = 20
train_batch = 1024
valid_batch = 256
lr = 1e-4

In [16]:
train_dl, valid_dl, test_dl = get_mnist_dl(batch_size_train=train_batch, batch_size_valid=valid_batch, device=DEVICE)

model = MLP_Nets()

stats = run_experiment(model, train_dl, valid_dl, test_dl, max_epochs=max_epochs, lr = lr)

print_stats(stats)

0:0 [train] loss:5.612, [valid] acc:0.132, time: 0.03099067211151123, memory: 243.9423828125 
0:20 [train] loss:5.139, [valid] acc:0.133, time: 0.0016548514366149902, memory: 256.0361328125 
0:40 [train] loss:4.799, [valid] acc:0.135, time: 0.0016459584236145019, memory: 255.5205078125 


KeyboardInterrupt: 