# Checkpointing

Your task is to implement checkpointing for a MLP using NumPy.

You are free to use the implementation of a MLP and the backpropagation algorithm that you have developed during lab sessions.

The key takeaway from this task is that with checkpointing we can trade off the computational resources needed to compute the forward pass of the network for the memory requirement needed to perform a backward pass in the network, which is often a major bottleneck when training large networks. In plain english, we can slightly increase the time required for training our network to save some of our GPU's precious memory.

## What is checkpointing?

The aim of checkpointing is to save every $n$-th layer's (e.g. every 2-nd layer's) forward result (instead of saving every layer's forward result as in plain backpropagation) and use these checkpoints for recomputing the forward pass of the network upon doing a backward pass. Checkpoint layers are kept in memory after the forward pass, while the remaining activations are recomputed at most once. After being recomputed, the non-checkpoint layers are kept in memory until they are no longer required.

# What should be done

1. Take the implementation a MLP trained with backpropagation. Analyze the algorithm with respect to the memory that is used by the algorithm with respect to the number of hidden layers.

2. Implement a class NetworkWithCheckpointing that inherits from the Network class defined during lab sessions by:
    a) implementing a method `forward_between_checkpoints` that will recompute the forward pass of the network using one of the checkpointed layers
    b) override the method `backprop` to use only checkpointed layers and otherwise compute the activations using `forward_between_checkpoints` method and keep it in memory until no longer needed.

3. Train your network with checkpoinintg on MNIST. Compare running times and memory usage with respect to the network without checkpointing.


# Implement Checkpointing for a MLP

In [90]:
import random
import numpy as np
from torchvision import datasets, transforms

In [91]:
!wget -O mnist.npz https://s3.amazonaws.com/img-datasets/mnist.npz

--2022-11-21 02:30:38--  https://s3.amazonaws.com/img-datasets/mnist.npz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.202.240, 54.231.202.8, 52.216.108.181, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.202.240|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11490434 (11M) [application/octet-stream]
Saving to: ‘mnist.npz’


2022-11-21 02:30:38 (83.1 MB/s) - ‘mnist.npz’ saved [11490434/11490434]



In [92]:
# Let's read the mnist dataset

def load_mnist(path='mnist.npz'):
    with np.load(path) as f:
        x_train, _y_train = f['x_train'], f['y_train']
        x_test, _y_test = f['x_test'], f['y_test']

    x_train = x_train.reshape(-1, 28 * 28) / 255.
    x_test = x_test.reshape(-1, 28 * 28) / 255.

    y_train = np.zeros((_y_train.shape[0], 10))
    y_train[np.arange(_y_train.shape[0]), _y_train] = 1

    y_test = np.zeros((_y_test.shape[0], 10))
    y_test[np.arange(_y_test.shape[0]), _y_test] = 1

    return (x_train, y_train), (x_test, y_test)

(x_train, y_train), (x_test, y_test) = load_mnist()

In [93]:
def sigmoid(z):
    return 1.0/(1.0+np.exp(-z))

def sigmoid_prime(z):
    # Derivative of the sigmoid
    return sigmoid(z)*(1-sigmoid(z))

In [94]:
class Network(object):
    def __init__(self, sizes):
        # initialize biases and weights with random normal distr.
        # weights are indexed by target node first
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.matmul(w, a)+b)
        return a
    
    def update_mini_batch(self, mini_batch, eta):
        # Update networks weights and biases by applying a single step
        # of gradient descent using backpropagation to compute the gradient.
        # The gradient is computed for a mini_batch which is as in tensorflow API.
        # eta is the learning rate      
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
            
        self.weights = [w-(eta/len(mini_batch[0]))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch[0]))*nb 
                       for b, nb in zip(self.biases, nabla_b)]
        
    def backprop(self, x, y):
        # For a single input (x,y) return a pair of lists.
        # First contains gradients over biases, second over weights.
        g = x
        gs = [g] # list to store all the gs, layer by layer
        fs = [] # list to store all the fs, layer by layer
        for b, w in zip(self.biases, self.weights):
            f = np.dot(w, g)+b
            fs.append(f)
            g = sigmoid(f)
            gs.append(g)
        # backward pass <- both steps at once
        dLdg = self.cost_derivative(gs[-1], y)
        dLdfs = []
        for w,g in reversed(list(zip(self.weights,gs[1:]))):
            dLdf = np.multiply(dLdg,np.multiply(g,1-g))
            dLdfs.append(dLdf)
            dLdg = np.matmul(w.T, dLdf)
        
        dLdWs = [np.matmul(dLdf,g.T) for dLdf,g in zip(reversed(dLdfs),gs[:-1])] 
        dLdBs = [np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1) for dLdf in reversed(dLdfs)] 
        return (dLdBs,dLdWs)

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        #logloss
        return (1-y)/(1-output_activations)-y/output_activations
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

In [97]:
class NetworkWithCheckpointing(object):

    def __init__(self, sizes, checkpoint_every_nth_layer: int = 1):
        self.num_layers = len(sizes)
        self.sizes = sizes
        self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
        self.weights = [np.random.randn(y, x) 
                        for x, y in zip(sizes[:-1], sizes[1:])]
        self.checkpoint = checkpoint_every_nth_layer

    def feedforward(self, a):
        # Run the network on a batch
        a = a.T
        for b, w in zip(self.biases, self.weights):
            a = sigmoid(np.matmul(w, a)+b)
        return a

    def forward_between_checkpoints(self, a, checkpoint_idx_start, layer_idx_end):
        recomputed = []
        for i in range(checkpoint_idx_start, layer_idx_end):
            a = sigmoid(np.matmul(self.weights[i], a)+self.biases[i])
            recomputed.append(a)
        return recomputed

    def update_mini_batch(self, mini_batch, eta):    
        nabla_b, nabla_w = self.backprop(mini_batch[0].T,mini_batch[1].T)
            
        self.weights = [w-(eta/len(mini_batch[0]))*nw 
                        for w, nw in zip(self.weights, nabla_w)]
        self.biases = [b-(eta/len(mini_batch[0]))*nb 
                       for b, nb in zip(self.biases, nabla_b)]

    def backprop(self, x, y):
        g = x
        gs = [g] # list of some gs
        fs = [] # list of some fs
        dLdWs = []
        dLdBs = []
        count = 0
        for b, w in zip(self.biases, self.weights):
            f = np.dot(w, g)+b
            g = sigmoid(f)
            count += 1
            if count%self.checkpoint == 0:
              fs.append(f)
              gs.append(g)
        dLdg = self.cost_derivative(g, y)
        num_checkpoints = len(gs)
        if self.num_layers%self.checkpoint != 1:
          #this part is needed when the last layer is not saved
          recomputed = self.forward_between_checkpoints(gs[-1],(num_checkpoints-1)*self.checkpoint,self.num_layers-1)
          for i in range(len(recomputed)-1,0,-1):
            dLdf = np.multiply(dLdg,np.multiply(recomputed[i],1-recomputed[i]))
            dLdg = np.matmul(self.weights[(num_checkpoints-1)*self.checkpoint+i].T, dLdf)
            dLdWs.append(np.matmul(dLdf,recomputed[i-1].T))
            dLdBs.append(np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1))
          dLdf = np.multiply(dLdg,np.multiply(recomputed[0],1-recomputed[0]))
          dLdg = np.matmul(self.weights[(num_checkpoints-1)*self.checkpoint].T, dLdf)
          dLdWs.append(np.matmul(dLdf,gs[-1].T))
          dLdBs.append(np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1))
        for i in range(num_checkpoints-2,-1,-1):
          recomputed = self.forward_between_checkpoints(gs[i],i*self.checkpoint,(i+1)*self.checkpoint-1)
          recomputed.append(gs[i+1])
          for j in range(len(recomputed)-1,0,-1):
            dLdf = np.multiply(dLdg,np.multiply(recomputed[j],1-recomputed[j]))
            dLdg = np.matmul(self.weights[i*self.checkpoint+j].T, dLdf)
            dLdWs.append(np.matmul(dLdf,recomputed[j-1].T))
            dLdBs.append(np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1))
          dLdf = np.multiply(dLdg,np.multiply(recomputed[0],1-recomputed[0]))
          dLdg = np.matmul(self.weights[i*self.checkpoint].T, dLdf)
          dLdWs.append(np.matmul(dLdf,gs[i].T))
          dLdBs.append(np.sum(dLdf,axis=1).reshape(dLdf.shape[0],1))
        return (reversed(dLdBs),reversed(dLdWs))

    def evaluate(self, test_data):
        # Count the number of correct answers for test_data
        pred = np.argmax(self.feedforward(test_data[0]),axis=0)
        corr = np.argmax(test_data[1],axis=1).T
        return np.mean(pred==corr)
    
    def cost_derivative(self, output_activations, y):
        return (1-y)/(1-output_activations)-y/output_activations
    
    def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
        x_train, y_train = training_data
        if test_data:
            x_test, y_test = test_data
        for j in range(epochs):
            for i in range(x_train.shape[0] // mini_batch_size):
                x_mini_batch = x_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                y_mini_batch = y_train[(mini_batch_size*i):(mini_batch_size*(i+1))]
                self.update_mini_batch((x_mini_batch, y_mini_batch), eta)
            if test_data:
                print("Epoch: {0}, Accuracy: {1}".format(j, self.evaluate((x_test, y_test))))
            else:
                print("Epoch: {0}".format(j))

In [87]:
network = Network([784,30,20,25,10])
network.SGD((x_train, y_train), epochs=100, mini_batch_size=100, eta=2.5, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8733
Epoch: 1, Accuracy: 0.8954
Epoch: 2, Accuracy: 0.9249
Epoch: 3, Accuracy: 0.9297
Epoch: 4, Accuracy: 0.9345
Epoch: 5, Accuracy: 0.9388
Epoch: 6, Accuracy: 0.9369
Epoch: 7, Accuracy: 0.9399
Epoch: 8, Accuracy: 0.94
Epoch: 9, Accuracy: 0.9386
Epoch: 10, Accuracy: 0.9383
Epoch: 11, Accuracy: 0.941
Epoch: 12, Accuracy: 0.9393
Epoch: 13, Accuracy: 0.9439
Epoch: 14, Accuracy: 0.9426
Epoch: 15, Accuracy: 0.9451
Epoch: 16, Accuracy: 0.9455
Epoch: 17, Accuracy: 0.9446
Epoch: 18, Accuracy: 0.944
Epoch: 19, Accuracy: 0.9467
Epoch: 20, Accuracy: 0.9455
Epoch: 21, Accuracy: 0.9447
Epoch: 22, Accuracy: 0.9467
Epoch: 23, Accuracy: 0.9461
Epoch: 24, Accuracy: 0.946
Epoch: 25, Accuracy: 0.9452
Epoch: 26, Accuracy: 0.9468
Epoch: 27, Accuracy: 0.9462
Epoch: 28, Accuracy: 0.9465
Epoch: 29, Accuracy: 0.9476
Epoch: 30, Accuracy: 0.9472
Epoch: 31, Accuracy: 0.9453
Epoch: 32, Accuracy: 0.9473
Epoch: 33, Accuracy: 0.9473
Epoch: 34, Accuracy: 0.947
Epoch: 35, Accuracy: 0.9477
Epoch: 3

In [98]:
network = NetworkWithCheckpointing([784,30,20,25,10],3)
network.SGD((x_train, y_train), epochs=100, mini_batch_size=100, eta=2.5, test_data=(x_test, y_test))

Epoch: 0, Accuracy: 0.8779
Epoch: 1, Accuracy: 0.9105
Epoch: 2, Accuracy: 0.9227
Epoch: 3, Accuracy: 0.9279
Epoch: 4, Accuracy: 0.9312
Epoch: 5, Accuracy: 0.9341
Epoch: 6, Accuracy: 0.9349
Epoch: 7, Accuracy: 0.9382
Epoch: 8, Accuracy: 0.9293
Epoch: 9, Accuracy: 0.9392
Epoch: 10, Accuracy: 0.9397
Epoch: 11, Accuracy: 0.9394
Epoch: 12, Accuracy: 0.9455
Epoch: 13, Accuracy: 0.9445
Epoch: 14, Accuracy: 0.9393
Epoch: 15, Accuracy: 0.9406
Epoch: 16, Accuracy: 0.9447
Epoch: 17, Accuracy: 0.9398
Epoch: 18, Accuracy: 0.944
Epoch: 19, Accuracy: 0.9426
Epoch: 20, Accuracy: 0.9438
Epoch: 21, Accuracy: 0.9418
Epoch: 22, Accuracy: 0.9418
Epoch: 23, Accuracy: 0.9439
Epoch: 24, Accuracy: 0.9456
Epoch: 25, Accuracy: 0.9464
Epoch: 26, Accuracy: 0.9472
Epoch: 27, Accuracy: 0.9458
Epoch: 28, Accuracy: 0.9455
Epoch: 29, Accuracy: 0.9454
Epoch: 30, Accuracy: 0.9444
Epoch: 31, Accuracy: 0.9445
Epoch: 32, Accuracy: 0.9439
Epoch: 33, Accuracy: 0.9469
Epoch: 34, Accuracy: 0.9416
Epoch: 35, Accuracy: 0.9492
Epo

# Odpowiedzi do zadań opisowych

1. W algorytmie musimy trzymać dane dotyczące wag (dla każdej warstwy jest to macierz rozmiarów sizes[i] x sizes[i-1]) i biasów (dla każdej warstwy jest to wektor), ponieważ są to wartości, które optymalizujemy. Stosując propagację wsteczną wykorzystujemy liczby, które są w neuronach, czyli f i g (gdzie g = sigmoid(f)) oraz pochodne funkcji straty po f. Są one trzymane w trzech listach wektorów. Użycie checkpointingu pozwala na kilkukrotne skrócenie tych list, ponieważ większość z tych liczb będzie kalkulowana na bieżąco.

2. Użycie checkpointingu sprawiło, że uczenie sieci trwa zauważalnie dłużej (około 40% dłużej), ponieważ niektóre wartości muszą być obliczane kilkukrotnie. Pozwoliło to jednak na zmniejszenie zużycia pamięci. Wartości f i g przechowywane były w listach o długości num_layers. Po zastosowaniu checkpointingu długość list wynosi około num_layers/checkpoint_every_nth_layer. Dodatkowo, nie trzeba trzymać wszystkich pochodnych po f, wystarczy tylko ostatnia.

PS Odruchowo napisałem odpowiedzi po polsku. Jeśli będzie potrzeba to mogę przetłumaczyć je na język angielski.