This notebook illustrates how to train Fully Connected models with PEPITA. We train and test the model on CIFAR-10.

#### Import libraries

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable

import copy

import matplotlib.pyplot as plt
import numpy as np
import psutil
import time
import multiprocessing as mp

#### Define Network architecture

In [17]:
# models with Dropout
class NetFC784(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28,128,bias=True)
        self.fc2 = nn.Linear(128,10,bias=True)
        
        # initialize the layers using the He uniform initialization scheme
        fc1_nin = 28*28 # Note: if dataset is MNIST --> fc1_nin = 28*28*1
        fc1_limit = np.sqrt(6.0 / fc1_nin)
        torch.nn.init.uniform_(self.fc1.weight, a=-fc1_limit, b=fc1_limit)
        fc2_nin = 128
        fc2_limit = np.sqrt(6.0 / fc2_nin)
        torch.nn.init.uniform_(self.fc2.weight, a=-fc2_limit, b=fc2_limit)
        # fc3_nin = 64
        # fc3_limit = np.sqrt(6.0 / fc3_nin)
        # torch.nn.init.uniform_(self.fc3.weight, a=-fc3_limit, b=fc3_limit)
  
        

    def forward(self, x, do_masks):
        x = F.relu(self.fc1(x))
        # x = F.relu(self.fc2(x))
        # apply dropout --> we use a custom dropout implementation because we need to present the same dropout mask in the two forward passes
        if do_masks is not None:
            x = x * do_masks[0] 
        # x = F.relu(self.fc2(x))
        x = F.softmax(self.fc2(x), dim=1)    #normalizes outputs to probability distribution
        return x
    



#### Set hyperparameters and train+test the model

In [19]:
def train_model():
    # set hyperparameters
    ## learning rate
    learning_rate = 0.1
    print('Learning rate:',learning_rate)
    ## dropout keep rate
    keep_rate = 0.9
    ## loss --> used to monitor performance, but not for parameter updates (PEPITA does not backpropagate the loss)
    criterion = nn.CrossEntropyLoss()
    ## optimizer (choose 'SGD' o 'mom')
    optim = 'mom' # --> default in the paper
    if optim == 'SGD':
        gamma = 0
    elif optim == 'mom':
        gamma = 0.9
    ## batch size
    batch_size = 64 # --> default in the paper
    print('Batch size: ',batch_size)
    #epochs
    epochs = 100

    # initialize the network
    net = NetFC784()

    # load the dataset
    transform = transforms.Compose(
        [transforms.ToTensor()]) # this normalizes to [0,1]
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)      
    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
 
    # define function to register the activations --> we need this to compare the activations in the two forward passes
    activation = {}
    def get_activation(name):             
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    for name, layer in net.named_modules():
        layer.register_forward_hook(get_activation(name))


    # define B --> this is the F projection matrix in the paper (here named B because F is torch.nn.functional)
    nin = 28*28         #B-> 784*10
    sd = np.sqrt(6/nin)
    B = (torch.rand(nin,10)*2*sd-sd)*0.05  # B is initialized with the He uniform initialization (like the forward weights)
    #B = torch.ones((784,10))
    
    # check cosine similarity before training AND matrix norm
    angles = []
    w_all = []
    norm_w0 = []
    for l_idx, (name,w) in enumerate(net.named_parameters()):
        if 'bias' in name:
            continue
        with torch.no_grad():
            w_all.append(copy.deepcopy(w))
            if l_idx == 0:
                norm_w0.append(torch.norm(w))
            print('norm of w at layer {} is {}'.format(l_idx,torch.norm(w)))
    w_prod = w_all[0].T
    for idx in range(1,len(w_all)):
        w_prod = torch.matmul(w_prod,w_all[idx].T)

    # do one forward pass to get the activation size needed for setting up the dropout masks
    dataiter = iter(trainloader)
    images, labels = next(dataiter)
   
    images = torch.flatten(images, 1) # flatten all dimensions except batch    

    outputs = net(images,do_masks=None)
    layers_act = []
    for key in activation:
        if 'fc' in key or 'conv' in key:
            layers_act.append(F.relu(activation[key]))        
            
    # set up for momentum4
    if optim == 'mom':
        gamma = 0.9
        v_w_all = []
        for l_idx, (name,w) in enumerate(net.named_parameters()):
            if len(w.shape)>1:
                with torch.no_grad():
                    v_w_all.append(torch.zeros(w.shape))

    # # # start = torch.cuda.memory_allocated(device)
    # # # print("Starting at 0 memory usage as baseline.")
    # # # net.to(device)
    # # # after_model =  torch.cuda.memory_allocated(device) - start
    # # # print(f"1: After model to device: {after_model:,}")
    # # # print("")

    # Train and test the model
    test_accs = []
    train_accs = []
    train_losses = []
    for epoch in range(epochs):  # loop over the dataset multiple times

        # learning rate decay
        if epoch in [30,60,90]: 
            learning_rate = learning_rate*0.1
            print('learning_rate decreased to ',learning_rate)
        
        # loop over batches
        running_loss = 0.0
        batch_count = 0
        total_train = 0
        correct_train = 0
        for i, data in enumerate(trainloader, 0):
            inputs, target = data
            
            inputs = torch.flatten(inputs, 1) # flatten all dimensions except batch
           
            target_onehot = F.one_hot(target,num_classes=10)
            
            # create dropout mask for the two forward passes --> we need to use the same mask for the two passes
            do_masks = []
            if keep_rate < 1:
                for l in layers_act[:-1]:
                    input1 = l
                    do_mask = Variable(torch.ones(inputs.shape[0],input1.data.new(input1.data.size()).shape[1]).bernoulli_(keep_rate))/keep_rate
                    do_masks.append(do_mask)        #
                do_masks.append(1) # for the last layer we don't use dropout --> just set a scalar 1 (needed for when we register activation layer)
                            
            # forward pass 1 with original input --> keep track of activations
            outputs = net(inputs,do_masks)
            
            # # # a = torch.cuda.memory_allocated(device)  - start
            # # # outputs = net(inputs.to(device),do_masks=None)
            # # # b = torch.cuda.memory_allocated(device) - start
            # # # print(f"2: Memory consumed after first forward pass (activations stored, depends on batch size): {b:,} change: ", f'{b - a:,}' )  # batch * num layers * hidden_size * 4 bytes per float


              # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)  #returns index of column of max values of every row
            total_train += target.size(0) #size of rows of test_lbl
            correct_train += (predicted == target).sum().item()
              
            # # # x = torch.cuda.memory_allocated(device) - start
            layers_act = []
            cnt_act = 0
            for key in activation:
                if 'fc' in key or 'conv' in key:
                    layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask
                    cnt_act += 1
            # # # y = torch.cuda.memory_allocated(device) - start
            # # # print(f"2: Memory for layers_act: {y:,} change: ", f'{y - x:,}' )  # batch * num layers * hidden_size * 4 bytes per float
           
            # compute the error
            error = outputs - target_onehot
            
            # modify the input with the error
            error_input = error @ B.T
            mod_inputs = inputs + error_input
            
            # forward pass 2 with modified input --> keep track of modulated activations
            mod_outputs = net(mod_inputs,do_masks)
            
            # # # c = torch.cuda.memory_allocated(device) - start         
            # # # # forward pass 2 with modified input --> keep track of modulated activations
            # # # mod_outputs = net(mod_inputs.to(device),do_masks=None)
            # # # d = torch.cuda.memory_allocated(device)  - start
            # # # print(f"4: After second forward pass: {d:,} change: {d-c:,} " )

            # # # j = torch.cuda.memory_allocated(device) - start
            
            mod_layers_act = []
            cnt_act = 0
            for key in activation:
                if 'fc' in key or 'conv' in key:
                    mod_layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask
                    cnt_act += 1
            # # # z = torch.cuda.memory_allocated(device) - start
            # # # print(f"2: Memory for mod_layers_act: {z:,} change: ", f'{z - j:,}' )  # batch * num layers * hidden_size * 4 bytes per float

            mod_error = mod_outputs - target_onehot
            
            # compute the delta_w for the batch
            delta_w_all = []        #weight update
            v_w = []
            for l_idx,w in enumerate(net.parameters()):
                v_w.append(torch.zeros(w.shape))

                
            for l in range(len(layers_act)):    #0,1

                # update for the last layer
                if l == len(layers_act)-1:  #last layer #1
                    
                    if len(layers_act)>1:
                        delta_w = -mod_error.T @ mod_layers_act[-2]  #weight update (layer before)

                    else:
                        delta_w = -mod_error.T @ mod_inputs
                
                # update for the first layer
                elif l == 0:
                    delta_w = -(layers_act[l] - mod_layers_act[l]).T @ mod_inputs #(x+Fe)

                # update for the hidden layers (not first, not last)
                elif l>0 and l<len(layers_act)-1:
                    delta_w = -(layers_act[l] - mod_layers_act[l]).T @ mod_layers_act[l-1]
                
                delta_w_all.append(delta_w)

            # # # y2 = torch.cuda.memory_allocated(device) - start
            # # # print(f"2: Memory for delta_w_all: {y2:,} change: ", f'{y2 - y1:,}' )
                 
            # apply the weight change
            if optim == 'SGD':
                for l_idx,w in enumerate(net.parameters()):
                    with torch.no_grad():
                        w += learning_rate * delta_w_all[l_idx]/batch_size # specify for which layer
                        
            elif optim == 'mom':
                weight_idx = 0
                for l_idx, (name,w) in enumerate(net.named_parameters()):
                    if 'bias' in name:
                        continue
                    with torch.no_grad():
                        v_w_all[weight_idx] = gamma * v_w_all[weight_idx] + learning_rate * delta_w_all[weight_idx]/batch_size
                        w += v_w_all[weight_idx]
                        weight_idx += 1
                        
            
            # keep track of the loss
            loss = criterion(outputs, target)
            # print statistics
            running_loss += loss.item()
            batch_count += 1
            

        curr_loss = running_loss / batch_count
        print('[%d, %5d] loss: %.3f' % (epoch, batch_count, curr_loss))
        train_losses.append(curr_loss)
        print('Train accuracy epoch {}: {} %'.format(epoch, 100 * correct_train / total_train))
        train_accs.append(100 * correct_train / total_train)
                        
        print('Testing...')
        correct_test = 0
        total_test = 0
        # since we're not training, we don't need to calculate the gradients for our outputs
        with torch.no_grad():
            for test_data in testloader:
                test_images, test_labels = test_data
                test_images = torch.flatten(test_images, 1) # flatten all dimensions except batch
                # calculate outputs by running images through the network
                test_outputs = net(test_images,do_masks=None) 
                # the class with the highest energy is what we choose as prediction
                _, predicted = torch.max(test_outputs.data, 1)  #returns index of column of max values of every row
                total_test += test_labels.size(0) #size of rows of test_lbl
                correct_test += (predicted == test_labels).sum().item()

        print('Test accuracy epoch {}: {} %'.format(epoch, 100 * correct_test / total_test))
        test_accs.append(100 * correct_test / total_test)

    print('Finished Training')

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss', color="green")
    plt.xlabel('Epoch')
    plt.ylabel('Loss') 
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(test_accs, label='Test Accuracy', color="green")
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.yticks(np.arange(min(test_accs)-1, max(test_accs)+2, 1))
    plt.legend()
    plt.show()

    plt.figure(figsize=(5, 5))
    plt.plot(test_accs, label='Test Accuracy', color="green")
    plt.plot(train_accs, label='Train Accuracy', color="red")
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    #plt.yticks(np.arange(min(train_accs)-1, max(train_accs)+2, 1))
    plt.legend()
    plt.show()

    plt.figure(figsize=(5, 5))
    plt.plot(train_accs, label='Train Accuracy', color="red")
    plt.xlabel('Epoch')
    plt.ylabel('Train accuracy (%)')
    #plt.yticks(np.arange(min(train_accs)-1, max(train_accs)+2, 1))
    plt.legend()
    plt.show()

In [None]:
if __name__ == '__main__':
    train_model()