<a href="https://colab.research.google.com/github/asia281/dnn2022/blob/main/Asia_of_hw1_checkpoints_student.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [11]:
import random
import numpy as np
from torchvision import datasets, transforms

In [12]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-19 17:08:24--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.204.56, 52.217.201.104, 52.217.198.184, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.204.56|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-19 17:08:27 (5.97 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [13]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [14]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

def softmax(z):
    # Numericaly stable
    exps = np.exp(z - np.max(z))
    sum = np.sum(exps, axis=0)
    res = exps / sum
    return res

In [15]:
class Network(object):
    def __init__(self, sizes, l2=0.0, momentum=0.0):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.l2 = l2
        self.momentum = momentum
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x)
                        for x, y in zip(sizes[:-1], sizes[1:])]
        self.weights_momentum = [np.zeros_like(x) for x in self.weights]
        self.biases_momentum = [np.zeros_like(x) for x in self.biases]
    
    def activation(self, z, l):
        if l == (self.num_layers - 1):
            return softmax(z)

        return sigmoid(z)
    
    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for l, b, w in zip(range(1, self.num_layers), self.biases, self.weights):
          a = self.activation(np.matmul(w, a)+b, l)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
        
        ### Momentum equation for parameter p
        ### p_(t+1) = p_t + m_(t+1)
        ### m_(t+1) = lambda_momentum * m_t - eta * gradient(p_t)
        self.weights_momentum = [(self.momentum*wm)-(eta/len(mini_batch[0]))*nw 
                                 for wm, w, nw in zip(self.weights_momentum, self.weights, nabla_w)]
        self.biases_momentum = [(self.momentum*bm)-(eta/len(mini_batch[0]))*nb 
                                for bm, b, nb in zip(self.biases_momentum, self.biases, nabla_b)]
                            
        self.weights = [w+wm for w, wm in zip(self.weights, self.weights_momentum)]
        self.biases = [b+bm for b, bm in zip(self.biases, self.biases_momentum)]
        
    def backprop(self, x, y):
        g = x
        gs = [g] # list to store all the gs, layer by layer
        fs = [] # list to store all the fs, layer by layer
        for l, b, w in zip(range(1, self.num_layers), self.biases, self.weights):
            f = np.dot(w, g)+b
            fs.append(f)
            g = self.activation(f, l)
            gs.append(g)
            
        dLdf = self.cost_derivative(gs[-1], y)
        dLdfs = []
        for l,w,g in reversed(list(zip(range(1, self.num_layers), self.weights, gs[1:]))):
            if l < (self.num_layers - 1):
                dLdf = np.multiply(dLdg, np.multiply(g,1-g))
            dLdfs.append(dLdf)
            dLdg = np.matmul(w.T, dLdf)
        
        dLdWs = [np.matmul(dLdf,g.T) for dLdf,g in zip(reversed(dLdfs),gs[:-1])] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 

        dLdWs = [dLdW + (self.l2 * w) for dLdW, w in zip(dLdWs, self.weights)]

        return (dLdBs, dLdWs)

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

In [16]:
network = Network([784,30,10], l2=0.001, momentum=0.7)
%time network.SGD((x_train, y_train), epochs=20, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.9063
Epoch: 1, Accuracy: 0.9186
Epoch: 2, Accuracy: 0.9283
Epoch: 3, Accuracy: 0.9321
Epoch: 4, Accuracy: 0.9351
Epoch: 5, Accuracy: 0.9447
Epoch: 6, Accuracy: 0.9473
Epoch: 7, Accuracy: 0.9397
Epoch: 8, Accuracy: 0.9421
Epoch: 9, Accuracy: 0.9441
Epoch: 10, Accuracy: 0.9393
Epoch: 11, Accuracy: 0.9452
Epoch: 12, Accuracy: 0.9396
Epoch: 13, Accuracy: 0.9431
Epoch: 14, Accuracy: 0.9388
Epoch: 15, Accuracy: 0.9499
Epoch: 16, Accuracy: 0.9451
Epoch: 17, Accuracy: 0.9506
Epoch: 18, Accuracy: 0.9492
Epoch: 19, Accuracy: 0.946
CPU times: user 33.3 s, sys: 14.2 s, total: 47.5 s
Wall time: 27.2 s


In [22]:
class NetworkWithCheckpointing(Network):
    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        self.checkpoint_freq = checkpoint_every_nth_layer
        self.layers_nr = len(sizes)

    def get(self, d: dict, idx: int):
      if idx in d:
        return d[idx]
      d[idx] = sigmoid(np.matmul(self.weights[idx - 1], self.get(d, idx - 1)) + self.biases[idx - 1])
      return d[idx]

    def delete(self, d: dict, idx: int):
      if idx % self.checkpoint_freq != 0 and idx in d:
        del d[idx]

    def forward_between_checkpoints(self, a, start, end):
      is_last = (end == self.layers_nr - 1)
      if is_last:
        end -= 1
      for b, w in zip(self.biases[start:end], self.weights[start:end]):
        a = sigmoid(np.dot(w, a) + b)   
      if is_last:
        return softmax(np.matmul(self.weights[-1], a) + self.biases[-1])

      return a

    def backprop(self, x, y):
      gs = {}
      g = x
      gs[0] = g

      for i in range(0, self.layers_nr - 1, self.checkpoint_freq):
        end = min(self.layers_nr - 1, i + self.checkpoint_freq)
        g = self.forward_between_checkpoints(g, i, end)
        gs[end] = g
      # backward pass <- both steps at once
      idx = self.layers_nr - 1
      dLdf = self.cost_derivative(gs[idx], y)
      dLdfs = []
      dLdWs = []
      for idx, w in reversed(list(zip(range(1, self.num_layers), self.weights))):
        if idx < (self.num_layers - 1):
          g = self.get(gs, idx)
          self.delete(gs, idx)
          dLdf = np.multiply(dLdg,np.multiply(g,1-g))

        dLdfs.append(dLdf)
        dLdg = np.matmul(w.T, dLdf)

      for idx, w in enumerate(reversed(dLdfs)):
        dLdWs.append(np.matmul(w, self.get(gs, idx).T))
        if idx > 0:
          self.delete(gs, idx-1)
      
      dLdWs = [dLdW + (self.l2 * w) for dLdW, w in zip(dLdWs, self.weights)]
      dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 
      return (dLdBs, dLdWs)

        


In [23]:
net_check = NetworkWithCheckpointing([784,30,10], 2)
%time net_check.SGD((x_train, y_train), epochs=20, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8923
Epoch: 1, Accuracy: 0.9132
Epoch: 2, Accuracy: 0.9227
Epoch: 3, Accuracy: 0.9272
Epoch: 4, Accuracy: 0.9341
Epoch: 5, Accuracy: 0.9347
Epoch: 6, Accuracy: 0.9362
Epoch: 7, Accuracy: 0.9398
Epoch: 8, Accuracy: 0.9418
Epoch: 9, Accuracy: 0.9419
Epoch: 10, Accuracy: 0.9429
Epoch: 11, Accuracy: 0.9425
Epoch: 12, Accuracy: 0.944
Epoch: 13, Accuracy: 0.9449
Epoch: 14, Accuracy: 0.9452


KeyboardInterrupt: ignored

In [21]:
network = Network([784,30,30,10], l2=0.001, momentum=0.6)
%time network.SGD((x_train, y_train), epochs=20, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

network_check = NetworkWithCheckpointing([784,30,30,10], 2, l2=0.001, momentum=0.6)
%time network_check.SGD((x_train, y_train), epochs=20, mini_batch_size=100, eta=3., test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8771
Epoch: 1, Accuracy: 0.9167
Epoch: 2, Accuracy: 0.9125
Epoch: 3, Accuracy: 0.928
Epoch: 4, Accuracy: 0.9285
Epoch: 5, Accuracy: 0.9352
Epoch: 6, Accuracy: 0.9345
Epoch: 7, Accuracy: 0.9361
Epoch: 8, Accuracy: 0.9377
Epoch: 9, Accuracy: 0.9408
Epoch: 10, Accuracy: 0.9339
Epoch: 11, Accuracy: 0.9371
Epoch: 12, Accuracy: 0.9423
Epoch: 13, Accuracy: 0.9421
Epoch: 14, Accuracy: 0.942
Epoch: 15, Accuracy: 0.939
Epoch: 16, Accuracy: 0.9396
Epoch: 17, Accuracy: 0.9423
Epoch: 18, Accuracy: 0.9423
Epoch: 19, Accuracy: 0.9456
CPU times: user 37.6 s, sys: 14.7 s, total: 52.3 s
Wall time: 27.7 s


ValueError: ignored