# Binary Networks using ```PyTorch```

Herein, we will show how to binarize neural networks using PyTorch following:

1. **XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks**, *Mohammad Rastegari, Vicente Ordonez, Joseph Redmon, Ali Farhadi*, ECCV 2016 [arxiv](https://arxiv.org/abs/1603.05279)

2. **Improved training of binary networks for human pose estimation and image recognition**, *Adrian Bulat, Georgios Tzimiropoulos, Jean Kossaifi, Maja Pantic*, arxiv 2019 [arxiv](https://arxiv.org/pdf/1904.05868)

In order to run this code python 3.7 and PyTorch version '1.1.0' or later is recommanded.

As in the previous demos, we demonstrate the method on the MNIST dataset, which consists of images of digits between 0 and 9 (60,000 images for for training and 10,000 for testing). The task is to predict, given an image, which digit it represents. The neural architecture that will use is based on the one used in the previous part of this hands-on demo.

## Binarization function (sign)

The weights and features of a NN are typically binarized using the ```sign(x)``` function. Since the derivative of the sign is an impulse, in practice we approximate its derivative using an STE.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [2]:
class BinaryActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        #TODO: Update 
        return grad_input

## Prepare the data

Define the batch size alongside the target device on which the network is going to be trained and tested. Herein we also instantiate and prepare our data. Notice that we are going to substract the mean and divide by the std all of our input images.

In [3]:
# choose the size of your minibatch
batch_size = 32

device = 'cpu' # to run on GPU use 'cuda'

transformation = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/datasets/', train=True, download=True,
                   transform=transformation), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('~/datasets/', train=False, 
                   transform=transformation), batch_size=batch_size, shuffle=True)

## Create the binary helper

This class will take our of the weights binarization process.

In [4]:
class Binarizer():
    def __init__(self, model, ignore_first_last=True):
        # create a buffer for the weights and a list of the to be binarized modules
        self.saved_params = []
        self.target_modules = []
        for m in model.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                self.saved_params.append(m.weight.data.clone())
                self.target_modules.append(m.weight)

        # Don't binarize the first and the last module
        if ignore_first_last and len(self.target_modules)>2:
            self.saved_params = self.saved_params[1:-1]
            self.target_modules = self.target_modules[1:-1]

        self.num_of_params = len(self.target_modules)
         
    def binarize_weights(self):
        for m in self.target_modules:
            # mean center the real params
            neg_mean = m.data.mean(1, keepdim=True).mul(-1)
            m.data = m.data.add(neg_mean)

            # clamp the real params
            m.data = m.data.clamp(-1.0, 1.0)

        # store the real paramters in a buffer before binarizing them
        self._save_params()

        # binarize the parameters inside the module using the sign function
        self._binarize_weights()

    def _save_params(self):
        for i in range(len(self.saved_params)):
            self.saved_params[i].copy_(self.target_modules[i].data)

    def _binarize_weights(self):
        for m in self.target_modules:
            n = m.data[0].nelement()
            s = m.data.size()
            if len(s) == 4:
                mean = m.data.norm(1, 3, keepdim=True).sum([1,2], keepdim=True).div(n)
            elif len(s) == 2:
                mean = m.data.norm(1, 1, keepdim=True).div(n)
            m.data = m.data.sign().mul(mean)

    def restore_weights(self):
        for index in range(self.num_of_params):
            self.target_modules[index].data.copy_(self.saved_params[index])

    def update_weight_gradients(self):
        for module in self.target_modules:
            weight = module.data
            # Use a simple STE and multiple the grad with a large number

## Define the network

Here we define a simple neural network. Notice that the first and last layer are kept real.
As opposed to the previous variant, herein we add a series of batch norm and quantization layers removing at the same time the dropout.

In [5]:
#TODO: Use the batchnorm and the sign function in the right place (2 batch norm, 2 sign)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(10) # for binarization
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.bn_fc1 = nn.BatchNorm1d(320) # for binarization
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10) # We output log-probabilities for 10 classes

    def forward(self, x):
        # the input is (bs, 1, 28, 28)
        x = self.conv1(x) # Loose 2 pixels on each side
        
        # x is now (bs, 10, 24, 24)
        x = F.max_pool2d(x, 2) # divide resolution  by two
        x = F.relu(x)
        
        x = self.conv2(x)
        # x is (bs, 20, 8, 8)
        
        x = F.max_pool2d(x, 2)
        # x is (bs, 20, 4, 4)
        x = F.relu(x)
        
        x = x.view(-1, 320) 
        # we flattened x (320 = 20*4*4)
        
        x = F.relu(self.fc1(x))
        # x is (bs, 50)
        #x = F.dropout(x, training=self.training, p=0.1)

        x = self.fc2(x)
        # x is (bs, 10)
        return F.log_softmax(x, dim=1)

In [6]:
# instantiate the network
model = Net()
model = model.to(device)

# binarize it
bin = Binarizer(model)

Define the criterion and the optimizer. While we could use SGD, it is generally easier to train binary networks using Adam.

In [7]:
# define the optimizer, choose the best parameters
#optimizer = ??
criterion = nn.CrossEntropyLoss()

Define the training and testing loop. Notice the differences that arrive due to the binarization process.

In [8]:
# Making use of the above defined class modify this code to binarize the weights
n_epoch = 5 # Number of epochs

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # Send the data and label to the correct device
        data, target = data.to(device), target.to(device)
        
        # Important: do not forget to reset the gradients
        optimizer.zero_grad()
        
        # Pass the data through the networks
        output = model(data)
        
        # Compute the loss
        loss = criterion(output, target)
        
        # Backprogragate the gradient
        loss.backward()
        
        # Update the weights
        optimizer.step()
        
        # That's just printing some info...
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss = criterion(output,target)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('mean: {}'.format(test_loss))
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
       100. * correct / len(test_loader.dataset)))

Train the network for ```n_epoch```s.

In [9]:
for epoch in range(n_epoch):
    train(epoch)
    test()

mean: 4.802381954505108e-05

Test set: Average loss: 0.0000, Accuracy: 9088/10000 (90%)

mean: 2.6715517378761433e-05

Test set: Average loss: 0.0000, Accuracy: 9285/10000 (92%)



KeyboardInterrupt: 