# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [23]:
import random
import numpy as np
from torchvision import datasets, transforms
import time

In [24]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-16 22:40:23--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.199.112
Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.199.112|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-16 22:40:24 (35.6 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [25]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [26]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

def softmax(x):
  max_x = np.max(x, axis=0)
  exp_x = np.exp(x-max_x)
  return exp_x / exp_x.sum(axis=0)

In [27]:
class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.biases_velocity = [np.zeros_like(b) for b in self.biases]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
        self.weights_velocity = [np.zeros_like(w) for w in self.weights]

    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for b, w in zip(self.biases[:-1], self.weights[:-1]):
            a = sigmoid(np.matmul(w, a)+b)

        a = softmax(np.matmul(self.weights[-1], a) + self.biases[-1])
        return a
    
    def update_mini_batch(self, mini_batch, eta, momentum=0, l2=0):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)

        self.weights_velocity = [momentum*w_v + (eta/len(mini_batch[0]))*nw 
                        for w_v, nw in zip(self.weights_velocity, nabla_w)]
            
        self.weights = [(1-l2)*w - w_v for w, w_v in zip(self.weights, self.weights_velocity)]

        self.biases_velocity = [momentum*b_v + (eta/len(mini_batch[0]))*nb 
                        for b_v, nb in zip(self.biases_velocity, nabla_b)]

        self.biases = [b - b_v for b, b_v in zip(self.biases, self.biases_velocity)]
        
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        dLdWs = [np.zeros_like(p) for p in self.weights]
        dLdBs = [np.zeros_like(p) for p in self.biases]

        g = x
        gs = [g] # list to store all the gs, layer by layer
        for b, w in zip(self.biases, self.weights):
            f = np.matmul(w, g)+b
            g = sigmoid(f)
            gs.append(g)
        # backward pass <- both steps at once
        dLdg = self.cost_derivative(gs[-1], y)
        dLdfs = []

        for index in reversed(range(1, len(gs))):
            dLdf = np.multiply(dLdg,np.multiply(gs[index],1-gs[index]))
            dLdfs.append(dLdf)
            dLdWs[index-1] = np.matmul(dLdf,gs[index-1].T)
            dLdBs[index-1] = np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1)

            dLdg = np.matmul(self.weights[index-1].T, dLdf)
        
        return (dLdBs,dLdWs)

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (output_activations-y) 
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, momentum=0, l2=0, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta, momentum, l2)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

start = time.time()
network = Network([784] + [30]*10 + [10])
network.SGD((x_train, y_train), epochs=100, mini_batch_size=100, eta=3.0, momentum=0, l2=0, test_data=(x_test, y_test))
end = time.time()
print(f"execution time = {end - start}")

Epoch: 0, Accuracy: 0.5021
Epoch: 1, Accuracy: 0.723
Epoch: 2, Accuracy: 0.8031
Epoch: 3, Accuracy: 0.8497
Epoch: 4, Accuracy: 0.8859
Epoch: 5, Accuracy: 0.8906
Epoch: 6, Accuracy: 0.9023
Epoch: 7, Accuracy: 0.9079
Epoch: 8, Accuracy: 0.9128
Epoch: 9, Accuracy: 0.9136
Epoch: 10, Accuracy: 0.9141
Epoch: 11, Accuracy: 0.9111
Epoch: 12, Accuracy: 0.9217
Epoch: 13, Accuracy: 0.9173
Epoch: 14, Accuracy: 0.9237
Epoch: 15, Accuracy: 0.926
Epoch: 16, Accuracy: 0.9259
Epoch: 17, Accuracy: 0.9263
Epoch: 18, Accuracy: 0.928
Epoch: 19, Accuracy: 0.9284
Epoch: 20, Accuracy: 0.9315
Epoch: 21, Accuracy: 0.9315
Epoch: 22, Accuracy: 0.9329
Epoch: 23, Accuracy: 0.9319
Epoch: 24, Accuracy: 0.9324
Epoch: 25, Accuracy: 0.9327
Epoch: 26, Accuracy: 0.9329
Epoch: 27, Accuracy: 0.9344
Epoch: 28, Accuracy: 0.9344
Epoch: 29, Accuracy: 0.9358
Epoch: 30, Accuracy: 0.9357
Epoch: 31, Accuracy: 0.9358
Epoch: 32, Accuracy: 0.9347
Epoch: 33, Accuracy: 0.9379
Epoch: 34, Accuracy: 0.9348
Epoch: 35, Accuracy: 0.9359
Epoch

In [29]:

class NetworkWithCheckpointing(Network):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 0, *args, **kwargs):
        super().__init__(sizes, *args, **kwargs)
        if checkpoint_every_nth_layer <= 0:
          self.checkpoint_every_nth_layer = 1
        else:
          self.checkpoint_every_nth_layer = checkpoint_every_nth_layer

    def forward_between_checkpoints(self, checkpoint_idx_start, layer_idx_end, gs):
        for index in range(checkpoint_idx_start+1, layer_idx_end):
          gs[index] = sigmoid(np.dot(self.weights[index-1], gs[index-1])+self.biases[index-1])
        
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        dLdWs = [np.zeros_like(p) for p in self.weights]
        dLdBs = [np.zeros_like(p) for p in self.biases]

        checkpoint_distance = self.checkpoint_every_nth_layer

        g = x
        gs = {} # dict to store all the gs, layer by layer
        for index in range(self.num_layers-1):
            if index % checkpoint_distance == 0:
                gs[index] = g
            f = np.matmul(self.weights[index], g)+self.biases[index]
            g = sigmoid(f)
        gs[self.num_layers-1] = g

        # backward pass <- both steps at once
        dLdg = self.cost_derivative(gs[self.num_layers-1], y)
        dLdfs = []

        for index in reversed(range(self.num_layers-1)):
            dLdf = np.multiply(dLdg,np.multiply(gs[index+1],1-gs[index+1]))
            del gs[index+1]
            dLdfs.append(dLdf)

            if index not in gs:
                self.forward_between_checkpoints(
                    index-(index%checkpoint_distance),
                    index+1, 
                    gs)
            
            dLdWs[index] = np.matmul(dLdf,gs[index].T)
            dLdBs[index] = np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1)

            dLdg = np.matmul(self.weights[index].T, dLdf)

        return (dLdBs,dLdWs)


start = time.time()
network = NetworkWithCheckpointing([784] + [30]*10 + [10], checkpoint_every_nth_layer=4)
network.SGD((x_train, y_train), epochs=100, mini_batch_size=100, eta=0.01, momentum=0.9, l2=0.00001, test_data=(x_test, y_test))
end = time.time()
print(f"execution time = {end - start}")

Epoch: 0, Accuracy: 0.0939
Epoch: 1, Accuracy: 0.0958
Epoch: 2, Accuracy: 0.0956
Epoch: 3, Accuracy: 0.1337
Epoch: 4, Accuracy: 0.1425
Epoch: 5, Accuracy: 0.1538
Epoch: 6, Accuracy: 0.1673
Epoch: 7, Accuracy: 0.1804
Epoch: 8, Accuracy: 0.2059
Epoch: 9, Accuracy: 0.2231
Epoch: 10, Accuracy: 0.2336
Epoch: 11, Accuracy: 0.2482
Epoch: 12, Accuracy: 0.2715
Epoch: 13, Accuracy: 0.2869
Epoch: 14, Accuracy: 0.3028
Epoch: 15, Accuracy: 0.3235
Epoch: 16, Accuracy: 0.344
Epoch: 17, Accuracy: 0.3624
Epoch: 18, Accuracy: 0.3813
Epoch: 19, Accuracy: 0.3912
Epoch: 20, Accuracy: 0.4043
Epoch: 21, Accuracy: 0.4239
Epoch: 22, Accuracy: 0.443
Epoch: 23, Accuracy: 0.4574
Epoch: 24, Accuracy: 0.4762
Epoch: 25, Accuracy: 0.4908
Epoch: 26, Accuracy: 0.5084
Epoch: 27, Accuracy: 0.5253
Epoch: 28, Accuracy: 0.5405
Epoch: 29, Accuracy: 0.5548
Epoch: 30, Accuracy: 0.5703
Epoch: 31, Accuracy: 0.585
Epoch: 32, Accuracy: 0.5982
Epoch: 33, Accuracy: 0.6131
Epoch: 34, Accuracy: 0.6276
Epoch: 35, Accuracy: 0.6395
Epoch

$h$ - number of hidden layers  
$l_s$ - size of hidden layers  
$b_s$ - batch size  
$c_d$ - checkpoint_distance

Standard network memory usage:

in backprop we need:  
$h*l_s^2$ for dLdWs  
$h*l_s^2$ for dLdBs  
$h*l_s*b_s$ for gs  

Standard network execution time = 294.5 sec  

Checkpointing network memory usage:

in backprop we need:  
$h*l_s^2$ for dLdWs  
$h*l_s^2$ for dLdBs  
$(h*l_s*b_s)/c_d + l_s*b_s*c_d$ for gs  
minimum memory used when $ c_d = \sqrt{h}$

Checkpointing network execution time = 405.7 sec