This notebook illustrates how to train Fully Connected models with PEPITA. We train and test the model on CIFAR-10.

#### Import libraries

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.autograd import Variable

import copy

import matplotlib.pyplot as plt
import numpy as np

#### Define Network architecture

In [8]:
# models with Dropout
class NetFC1x1024DOcust(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32*32*3,1024,bias=False)
        self.fc2 = nn.Linear(1024, 10,bias=False)
        
        # initialize the layers using the He uniform initialization scheme
        fc1_nin = 32*32*3 # Note: if dataset is MNIST --> fc1_nin = 28*28*1
        fc1_limit = np.sqrt(6.0 / fc1_nin)
        torch.nn.init.uniform_(self.fc1.weight, a=-fc1_limit, b=fc1_limit)
        fc2_nin = 1024
        fc2_limit = np.sqrt(6.0 / fc2_nin)
        torch.nn.init.uniform_(self.fc2.weight, a=-fc2_limit, b=fc2_limit)
        

    def forward(self, x, do_masks):
        x = F.relu(self.fc1(x))
        # apply dropout --> we use a custom dropout implementation because we need to present the same dropout mask in the two forward passes
        if do_masks is not None:
            x = x * do_masks[0]   
        x = F.softmax(self.fc2(x))
        return x
    



#### Set hyperparameters and train+test the model

In [39]:
# set hyperparameters
## learning rate
eta = 0.01  
## dropout keep rate
keep_rate = 0.9
## loss --> used to monitor performance, but not for parameter updates (PEPITA does not backpropagate the loss)
criterion = nn.CrossEntropyLoss()
## optimizer (choose 'SGD' o 'mom')
optim = 'mom' # --> default in the paper
if optim == 'SGD':
    gamma = 0
elif optim == 'mom':
    gamma = 0.9
## batch size
batch_size = 64 # --> default in the paper

# initialize the network
net = NetFC1x1024DOcust()

# load the dataset
transform = transforms.Compose(
    [transforms.ToTensor()]) # this normalizes to [0,1]
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


# define function to register the activations --> we need this to compare the activations in the two forward passes
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook
for name, layer in net.named_modules():
    layer.register_forward_hook(get_activation(name))


# define B --> this is the F projection matrix in the paper (here named B because F is torch.nn.functional)
nin = 32*32*3
sd = np.sqrt(6/nin)
B = (torch.rand(nin,10)*2*sd-sd)*0.05  # B is initialized with the He uniform initialization (like the forward weights)


# check cosine similarity before training AND matrix norm
angles = []
w_all = []
norm_w0 = []
for l_idx,w in enumerate(net.parameters()):
    with torch.no_grad():
        w_all.append(copy.deepcopy(w))
        if l_idx == 0:
            norm_w0.append(torch.norm(w))
        print('norm of w at layer {} is {}'.format(l_idx,torch.norm(w)))
w_prod = w_all[0].T
for idx in range(1,len(w_all)):
    w_prod = torch.matmul(w_prod,w_all[idx].T)
    print(w_prod.size())

# do one forward pass to get the activation size needed for setting up the dropout masks
dataiter = iter(trainloader)
images, labels = dataiter.next()
images = torch.flatten(images, 1) # flatten all dimensions except batch        
outputs = net(images,do_masks=None)
layers_act = []
for key in activation:
    if 'fc' in key or 'conv' in key:
        layers_act.append(F.relu(activation[key]))
        
# set up for momentum
if optim == 'mom':
    gamma = 0.9
    v_w_all = []
    for l_idx,w in enumerate(net.parameters()):
        if len(w.shape)>1:
            with torch.no_grad():
                v_w_all.append(torch.zeros(w.shape))

# Train and test the model
test_accs = []
for epoch in range(100):  # loop over the dataset multiple times
    
    # learning rate decay
    if epoch in [60,90]: 
        eta = eta*0.1
        print('eta decreased to ',eta)
    
    # loop over batches
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, target = data
        inputs = torch.flatten(inputs, 1) # flatten all dimensions except batch
        target_onehot = F.one_hot(target,num_classes=10)
        
        # create dropout mask for the two forward passes --> we need to use the same mask for the two passes
        do_masks = []
        if keep_rate < 1:
            for l in layers_act[:-1]:
                input1 = l
                do_mask = Variable(torch.ones(inputs.shape[0],input1.data.new(input1.data.size()).shape[1]).bernoulli_(keep_rate))/keep_rate
                do_masks.append(do_mask)
            do_masks.append(1) # for the last layer we don't use dropout --> just set a scalar 1 (needed for when we register activation layer)
     
        # forward pass 1 with original input --> keep track of activations
        outputs = net(inputs,do_masks)
        layers_act = []
        cnt_act = 0
        for key in activation:
            if 'fc' in key or 'conv' in key:
                layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask
                cnt_act += 1
                
        # compute the error
        error = outputs - target_onehot  
        
        # modify the input with the error
        error_input = error @ B.T
        mod_inputs = inputs + error_input
        
        # forward pass 2 with modified input --> keep track of modulated activations
        mod_outputs = net(mod_inputs,do_masks)
        mod_layers_act = []
        cnt_act = 0
        for key in activation:
            if 'fc' in key or 'conv' in key:
                mod_layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask
                cnt_act += 1
        mod_error = mod_outputs - target_onehot
        
        # compute the delta_w for the batch
        delta_w_all = []
        v_w = []
        for l_idx,w in enumerate(net.parameters()):
            v_w.append(torch.zeros(w.shape))
            
        for l in range(len(layers_act)):
            
            # update for the last layer
            if l == len(layers_act)-1:
                
                if len(layers_act)>1:
                    delta_w = -mod_error.T @ mod_layers_act[-2]
                else:
                    delta_w = -mod_error.T @ mod_inputs
            
            # update for the first layer
            elif l == 0:
                delta_w = -(layers_act[l] - mod_layers_act[l]).T @ mod_inputs
            
            # update for the hidden layers (not first, not last)
            elif l>0 and l<len(layers_act)-1:
                delta_w = -(layers_act[l] - mod_layers_act[l]).T @ mod_layers_act[l-1]
            
            delta_w_all.append(delta_w)
                
        # apply the weight change
        if optim == 'SGD':
            for l_idx,w in enumerate(net.parameters()):
                with torch.no_grad():
                    w += eta * delta_w_all[l_idx]/batch_size # specify for which layer
                    
        elif optim == 'mom':
            for l_idx,w in enumerate(net.parameters()):
                with torch.no_grad():
                    v_w_all[l_idx] = gamma * v_w_all[l_idx] + eta * delta_w_all[l_idx]/batch_size
                    w += v_w_all[l_idx]
                    
        
        # keep track of the loss
        loss = criterion(outputs, target)
        # print statistics
        running_loss += loss.item()
        if i%500 == 499:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 500))
            running_loss = 0.0
                     
    print('Testing...')
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for test_data in testloader:
            test_images, test_labels = test_data
            test_images = torch.flatten(test_images, 1) # flatten all dimensions except batch
            # calculate outputs by running images through the network
            test_outputs = net(test_images,do_masks=None)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(test_outputs.data, 1)
            total += test_labels.size(0)
            correct += (predicted == test_labels).sum().item()

    print('Test accuracy epoch {}: {} %'.format(epoch, 100 * correct / total))
    test_accs.append(100 * correct / total)

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified
norm of w at layer 0 is 45.241451263427734
norm of w at layer 1 is 4.473217010498047
torch.Size([3072, 10])




[1,   500] loss: 2.210
Testing...
Test accuracy epoch 0: 35.12 %
[2,   500] loss: 2.167
Testing...
Test accuracy epoch 1: 38.07 %
[3,   500] loss: 2.150
Testing...
Test accuracy epoch 2: 40.9 %
[4,   500] loss: 2.137
Testing...
Test accuracy epoch 3: 40.41 %
[5,   500] loss: 2.126
Testing...
Test accuracy epoch 4: 41.4 %
[6,   500] loss: 2.117
Testing...
Test accuracy epoch 5: 39.76 %
[7,   500] loss: 2.109
Testing...
Test accuracy epoch 6: 43.73 %
[8,   500] loss: 2.101
Testing...
Test accuracy epoch 7: 45.08 %
[9,   500] loss: 2.096
Testing...
Test accuracy epoch 8: 45.28 %
[10,   500] loss: 2.090
Testing...
Test accuracy epoch 9: 45.05 %
[11,   500] loss: 2.085
Testing...
Test accuracy epoch 10: 44.66 %
[12,   500] loss: 2.079
Testing...
Test accuracy epoch 11: 45.89 %
[13,   500] loss: 2.074
Testing...
Test accuracy epoch 12: 47.39 %
[14,   500] loss: 2.071
Testing...
Test accuracy epoch 13: 46.95 %
[15,   500] loss: 2.065
Testing...
Test accuracy epoch 14: 47.41 %
[16,   500] loss