# Well tempered backpropagation

In [25]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchsummary import summary
from matplotlib import pyplot as plt

import numpy as np
import os
from os.path import join

ROOT='/home/ansuini/repos/WellTemperedSGD/MNIST'
RES=join(ROOT,'results')
datum='data_shuffled'

In [2]:
def init_tensors(model,verbose=False):
    '''
    Init to zero a list of tensors with the same shapes of model.parameters()
    '''
    tensors = [torch.zeros_like(p) for p in model.parameters()]     
    
    if verbose:
        print('Tensors shapes:')
        _ = [print(t.shape) for t in tensors]
    
    return tensors

In [3]:
def init_tensors_one(model,verbose=False):
    '''
    Init to one a list of tensors with the same shapes of model.parameters()
    '''
    tensors = [torch.ones_like(p) for p in model.parameters()]     
    
    if verbose:
        print('Tensors shapes:')
        _ = [print(t.shape) for t in tensors]
    
    return tensors

In [4]:
def acc_grad(grad, model):
    '''
    Accumulate grad in a list of tensors 
    of the same structure of model.parameters() 
    '''
    for g, p in zip(grad, model.parameters()):
        g += p.grad
    return grad
        
def acc_grad2(grad2, model):
    '''
    Accumulate squared grad in a list of tensors 
    of the same structure of model.parameters() 
    '''
    for g, p in zip(grad2, model.parameters()):
        g += torch.mul(p.grad, p.grad)
    return grad2

def clone_tensors(tensors):
    '''
    Clone gradient data to make some tests
    '''
    return [t.grad.clone() for t in tensors]

def compute_snr(grad, grad2, B, device):
    '''
    Compute snr
    '''  
    
    epsilon = 1e-8 #small quantity to be added to err in order to avoid division by zero
    
    snr = [] #list of tensors with the same structure as model.parameters()
    
    for g, g2 in zip(grad, grad2):
        
        # work with clones in order to avoid modifications of the original data in this function
        g_copy  = g.clone()
        g2_copy = g2.clone()
    
        # average over number of batches (B is the same as in the paper)
        g_copy = g_copy/B   
        g2_copy = g2_copy/B  
        
        # compute error    
        assert(torch.sum(g2_copy - g_copy*g_copy >= 0) ) # assert if the variance is non-negative
        
        err = torch.sqrt( ( g2_copy - g_copy*g_copy )/ B ) # the error is the square root of the variance divided by B
        err[err==0] = epsilon # add small positive quantity if err is 0
        
        # compute signal to error ratio
        # snr is the ratio between the abs value of the gradient averaged
        # over B batches and the err
        snr.append(torch.div( torch.abs(g_copy), err ) ) 
            
    return snr

In [5]:
class Net(nn.Module):
    # dropout is 0 by default
    def __init__(self,p1=0.0, p2=0.0):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        self.dropout1 = nn.Dropout(p=p1)
        self.dropout2 = nn.Dropout(p=p2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = self.dropout1(x)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [6]:
def stats(model, loader, device):    
    model.eval()    
    loss = 0
    correct = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(loader):
                
            data, target = data.to(device), target.to(device)
            output = model(data)            
            loss += F.nll_loss(output, target)*data.shape[0]
            pred = output.argmax(dim=1, keepdim=True) 
            correct += pred.eq(target.view_as(pred)).sum().item()
                
    loss /= len(loader.dataset)
    acc = 100. * correct / len(loader.dataset)

    if loader.dataset.train==True:
        datatype='training'
    else:
        datatype='test'
        
    print(datatype + ' set: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
           loss, correct, len(loader.dataset), acc))
    
    return loss.item(),acc

In [7]:
def train(model, train_loader, optimizer, epoch, device):   
    # init tensors to store gradients    
    grad = init_tensors(model)
    grad2 = init_tensors(model)
    
    model.train()
    B = 0 # count mini-batches
    
    # iterate on 1 epoch accumulating grad and grad2
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # control the number of training samples
        if batch_idx*train_loader.batch_size > nsamples_train:
            break
        
        B += 1
        
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()

        grad = acc_grad(grad,model)
        grad2 = acc_grad2(grad2,model)
    
    print('N.of minibatches: {}'.format(B) )
    return grad, grad2, B

In [15]:
batch_size=6000
nsamples_train=60000
momentum=0.0
epochs=3
lr=0.01
seed=1101
WTB = True # if False use normal backpropagation
SAVE = True

In [16]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# Please notice that shuffle is False here in the training_loader. 
# This is essential if we want to restrict the training dataset
# to nsamples_training < 60000. It is not essential to set it to 
# False in the test_loader

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(join(ROOT,datum), train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=False, **kwargs)


test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(join(ROOT,datum), train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=False, **kwargs)

In [17]:
model = Net().to(device)
print(summary(model,(1,28,28)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 20, 24, 24]             520
            Conv2d-2             [-1, 50, 8, 8]          25,050
           Dropout-3                  [-1, 800]               0
            Linear-4                  [-1, 500]         400,500
           Dropout-5                  [-1, 500]               0
            Linear-6                   [-1, 10]           5,010
Total params: 431,080
Trainable params: 431,080
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.13
Params size (MB): 1.64
Estimated Total Size (MB): 1.77
----------------------------------------------------------------
None


In [18]:
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

In [19]:
if WTB:
    print('Well tempered backprop!')
else:
    print('Normal backprop!')

# training
train_stats = []
test_stats = []
    
# init snr to 1 the first time
snr = init_tensors_one(model)
    
# iterate over epochs
for epoch in range(1, epochs + 1):
    print('\nEpoch: {}'.format(epoch))      
    # ----------------------------- train    
    # init tensors to store gradients    
    grad = init_tensors(model)
    grad2 = init_tensors(model)
    
    model.train()
    B = 0 # count mini-batches  
    # iterate on 1 epoch accumulating grad and grad2
    for batch_idx, (data, target) in enumerate(train_loader):
        
        # control the number of training samples
        if batch_idx*train_loader.batch_size > nsamples_train:
            break        
        B += 1        
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()

        grad = acc_grad(grad,model)
        grad2 = acc_grad2(grad2,model)
        
        with torch.no_grad():
            # gradient modification and update
            for p,s in zip(model.parameters(),snr): 
                
                # modify grad with snr computed on the previous epoch
                p.grad = torch.where( s < 1, s*p.grad, p.grad) 
                
                # update parameters with the new gradient
                p.data -= lr*p.grad.data
    
    # if WTB compute snr, otherwise it will remain 1 and will not affect backprop
    if WTB:
        # update snr at the end of the epoch
        with torch.no_grad():
            snr = compute_snr(grad, grad2, B, device)
        
                                          
    train_stats.append(stats(model, train_loader, device))
    test_stats.append(stats(model, test_loader, device))

    
if SAVE:
    if WTB:
        np.save(join(RES, 'train_stats_wtb'), train_stats)
        np.save(join(RES, 'test_stats_wtb'), test_stats)
    else:
        np.save(join(RES, 'train_stats_norm'), train_stats)
        np.save(join(RES, 'test_stats_norm'), test_stats)

Well tempered backprop!

Epoch: 1
training set: average loss: 2.2810, accuracy: 10020/60000 (17%)
test set: average loss: 2.2791, accuracy: 1717/10000 (17%)

Epoch: 2
training set: average loss: 2.2514, accuracy: 16647/60000 (28%)
test set: average loss: 2.2490, accuracy: 2920/10000 (29%)

Epoch: 3
training set: average loss: 2.2190, accuracy: 24978/60000 (42%)
test set: average loss: 2.2161, accuracy: 4297/10000 (43%)


In [23]:
test_stats_norm = np.load(join(RES,'test_stats_norm.npy'))
train_stats_norm = np.load(join(RES,'train_stats_norm.npy'))
test_stats_wtb = np.load(join(RES,'test_stats_wtb.npy'))
train_stats_wtb = np.load(join(RES,'train_stats_wtb.npy'))