In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

import os
import datetime

In [3]:
class NN_MNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 8, (3,3), padding='same'),     # 28 x 28 x 8
            nn.MaxPool2d(kernel_size=(2,2), stride=2),   # 14 x 14 x 8
            nn.ReLU(),
            nn.Conv2d(8, 32, (3,3), padding='same'),     # 14 x 14 x 32
            nn.MaxPool2d(kernel_size=(2,2), stride=2),   #  7 x  7 x 32
            nn.ReLU(),
            nn.Flatten(),                               # 1568
            nn.Linear(1568, 200),                       # 200
            nn.Linear(200, 10),                         # 10 
            # nn.Softmax(dim=1)                           # 10
        )            

    def forward(self, x):
        return nn.Softmax(dim=1)(self.model(x))

    def zero_grad(self):
        return self.model.zero_grad()
    
    def train(self, dataset, args):
        train_dataset, test_dataset = dataset
        epochs = args['base_epochs']
        lr = args['base_lr']
        batch_size = args['base_batch_size']
        weight_decay = args['base_weight_decay']
        save_freq = args['save_freq']

        save_dir = os.path.join(args['save_dir'], 'MNIST', str(datetime.datetime.now())[:-10])
        os.makedirs(save_dir)

        loss_list = []
        test_loss_list = []
        batch_loss_list = []

        train_dl = DataLoader(train_dataset, batch_size=batch_size)
        test_dl = DataLoader(test_dataset, batch_size=batch_size)
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        for i in range(epochs):
            epoch_loss = 0
            total = 0
            correct = 0
            for idx, data in enumerate(train_dl):
                inputs = data[0]
                targets = data[1]
                outputs = self.model(inputs)
                
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                
                loss = loss_fn(outputs, targets)
                batch_loss_list.append(loss)
                epoch_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            loss_list.append(epoch_loss)

            with torch.no_grad():
                test_total = 0
                test_correct = 0
                test_epoch_loss = 0
                for idx, data in enumerate(test_dl):
                    inputs = data[0]
                    targets = data[1]
                    outputs = self.model(inputs)
                    
                    _, predicted = torch.max(outputs, 1)
                    test_total += targets.size(0)
                    test_correct += (predicted == targets).sum().item()

                    loss = loss_fn(outputs, targets)
                    test_epoch_loss += loss.item()

            test_loss_list.append(test_epoch_loss)
            print(f'Epoch: {i+1}/{epochs}  |  Training Accuracy: {round(correct/total*100,2)},   Validation Accuracy: {round(test_correct/test_total*100,2)}')

            if (i+1)%save_freq == 0:
                model_path = os.path.join(save_dir, f'model_{i+1}')
                torch.save(self.model.state_dict(), model_path)
                print('Saved model at: ', model_path)

        print('Finished Training')
        
        plt.figure()
        plt.plot(batch_loss_list)
        plt.xlabel('Iterations')
        plt.ylabel('Loss')
        plt.savefig(os.path.join(save_dir, 'loss_vs_iters.png'))
        
        plt.figure()
        plt.plot(loss_list, label='train')
        plt.plot(test_loss_list, label='val')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(os.path.join(save_dir, 'loss_vs_epochs.png'))


In [None]:
class NN_MNIST_linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.model1 = nn.Sequential(
            nn.Flatten(),                               # 784
            nn.Linear(784, 200)                         # 200
        )
        self.model2 = nn.Sequential(
            nn.ReLU(),
            nn.Linear(200, 10),                         # 10 
            nn.Softmax()   
        )
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward1(self, inputs):
        return self.model1(inputs)
    
    def forward2(self, inputs):
        return self.model2(inputs)

    def linbp_grad(self, x, y):
        output = self.forward1(x)
        output.backward(retain_graph=True)
        output = self.forward2(output)
        loss = self.loss_fn(output, y)
        loss.backward()


In [None]:
model = NN_MNIST()
model_linear = NN_MNIST_linear()
# x = torch.rand(1,28,28)
x = -1*torch.ones(1, 28, 28)
model_linear(x)

tensor([[0.1064, 0.0928, 0.0928, 0.1185, 0.0928, 0.0970, 0.0928, 0.1015, 0.1009,
         0.1047]], grad_fn=<SoftmaxBackward0>)

In [None]:
# len(list(nn_linear))
len([layer for layer in model_linear.modules() if not isinstance(layer, nn.Sequential) and not isinstance(layer, NN_MNIST_linear)])
# nn_linear.modules()

# for layer in model.modules():
#     if not isinstance(layer, nn.Sequential) and not isinstance(layer, NN_MNIST):
#         print(layer)

5

In [None]:
def linbp_grad(x, y):
    # layer_list = [module for module in model_linear.modules() if not isinstance(module, nn.Sequential) and not isinstance(module, NN_MNIST_linear)]
    
    

In [None]:
# def init_weights(m):
#     if isinstance(m, nn.Linear):
#         torch.nn.init.uniform_(m.weight, a=0, b=1)
#         m.bias.data.fill_(0.01)

# nn_linear.apply(init_weights)
# nn_linear(x)

# Linear Backpropagation on Toy examples

### Method 1
use backward() just before the activation and then explicitly multiply the gradient of loss wrt the logits

In [None]:
x = torch.tensor([-2.0], requires_grad=True)
W = torch.tensor([3.0], requires_grad=True)
output = W*x
# output.backward(retain_graph=True)
output = F.relu(output)
# output.retain_grad()
loss = output*torch.tensor([2.0])  ## say
loss.backward()
# output = linbp_relu(x)
print('output: ', output)
with torch.no_grad():
    print('grad x:', x.grad)
    # print('final grad: ', x.grad*output.grad)

output:  tensor([0.], grad_fn=<ReluBackward0>)
grad x: tensor([0.])


In [None]:
x = torch.tensor([-2.0], requires_grad=True)
W = torch.tensor([3.0], requires_grad=True)
output = W*x
output.backward(retain_graph=True)
output = F.relu(output)
output.retain_grad()
loss = output*torch.tensor([2.0])  ## say
loss.backward()
# output = linbp_relu(x)
print('output: ', output)
with torch.no_grad():
    print('grad x:', x.grad)
    print('final grad: ', x.grad*output.grad)

output:  tensor([0.], grad_fn=<ReluBackward0>)
grad x: tensor([3.])
final grad:  tensor([6.])


In [None]:
def linbp_relu(x):
    x_p = F.relu(-x)
    x = x + x_p.data
    return x

# Trials

In [1]:
import torch
import torch.nn as nn

In [4]:
class Normalize(nn.Module):
    def __init__(self,):
        super(Normalize, self).__init__()
        self.ms = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]
    def forward(self, input):
        x = input.clone()
        for i in range(x.shape[1]):
            x[:,i] = (x[:,i] - self.ms[0][i]) / self.ms[1][i]
        return x

In [5]:
model = NN_MNIST()

model = nn.Sequential(
    Normalize(),
    model
)

In [6]:
model[0]

Normalize()

In [7]:
model[1]

NN_MNIST(
  (model): Sequential(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU()
    (3): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1568, out_features=200, bias=True)
    (8): Linear(in_features=200, out_features=10, bias=True)
  )
)