# Learning Digital Circuits: A Journey Through Weight Invariant Self-Pruning Neural Networks

This notebook contains source code for paper ["Learning Digital Circuits: A Journey Through Weight Invariant Self-Pruning Neural Networks".](https://arxiv.org/pdf/1909.00052.pdf)

To cite:

```
@ARTICLE{2019arXiv190900052A,
       author = {{Agrawal}, Amey and {Karlupia}, Rohit},
        title = "{Learning Digital Circuits: A Journey Through Weight Invariant Self-Pruning Neural Networks}",
      journal = {arXiv e-prints},
         year = "2019",
        month = "Aug",
          eid = {arXiv:1909.00052},
archivePrefix = {arXiv},
       eprint = {1909.00052}
}

```

### Install Dependecies

We need TensorBoardX for PyTorch TensorBoard support and TensorFlow 2 for TensorBoard notebook plugin. 

In [0]:
!pip install tensorboardx
!pip install -q tensorflow==2.0.0-rc0
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [0]:
# Clear old tensorboard logs
!rm -rf runs

In [0]:
import argparse
import copy
import math
import random

import numpy as np

import torch
from torch.autograd import Function, Variable
from torch.distributions.bernoulli import Bernoulli
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

### Define some useful utilitfunctions

In [0]:
# A handy utility which does `x1 if cond else x2` for tensors
def where(cond, x1, x2):
    return cond.float() * x1 + (1 - cond.float()) * x2

In [0]:
# Add any layers that need to be binarized here
def should_binarize(x):
    return isinstance(x, nn.Conv2d) or isinstance(x, nn.Linear)

In [0]:
# A singleton which allows loading writer from anywhere
class TensorBoard:
    _writer = None

    @classmethod
    def get_writer(cls):
        if cls._writer:
            return cls._writer
        cls._writer = SummaryWriter()
        
        return cls._writer

    @classmethod
    def reset_writer(cls):
        cls._writer = None

In [0]:
# Prints some info about network layers
def print_net_stats():
    print("--------------------- Network Statistics --------------------------")
    for name, param in net.named_parameters():
        print(name)
        param = param.abs()
        n_elems = param.nelement()
        non_zero = torch.sum(where(param >= 0.5, 1, 0)).item()
        print(f"""\
            Number of units: {n_elems},
            Mean: {param.mean()},
            Std: {param.std()},
            Units > 0.5: {non_zero} ({non_zero * 100 / n_elems:.2f}%)""")
        if name.endswith("invert"):
            print(name, param.item())

### Define Normalization layers

In [0]:
class HardInvertLayer(nn.Module):   
    def forward(self, x):
        return 1 - x

In [0]:
class SoftInvertLayer(nn.Module):
    def __init__(self):
        super(SoftInvertLayer, self).__init__()
        self.invert = nn.Parameter(torch.rand(1))
    
    def forward(self, x):
        self.invert.data.clamp_(0, 1)
        return x * (1 - self.invert) + (1 - x) * self.invert

### Define helper module to convert regular network to binarized network

In [0]:
class BinaryConnect():
    def __init__(self, model, binarize_bias):
        self.model = model
        self.binarize_bias = binarize_bias

    # Add any layers with learnable params in the if
    def _get_modules(self):
        return [x for x in self.model.modules() \
                if should_binarize(x)]

    # For weight cliping
    def clip(self, low=0, high=1.0):
        [x.weight.data.clamp_(low, high) for x in self._get_modules()]
    
        if self.binarize_bias:
            [x.bias.data.clamp_(low, high) for x in self._get_modules()]
    
    # Binarizes weights based on > 0.5
    def _get_binary(self, x, on, off):
        return where(torch.abs(x) >= 0.5, on, off)
    
    # Call during training loop to convert all weights to binary values 
    def binarize(self, on, off, random_mag, epsilon):
        self._save_params()
        
        on = on + random.uniform(0, random_mag) if random.random() < epsilon else on

        [x.weight.data.copy_(self._get_binary(x.weight.data, on, off)) \
         for x in self._get_modules()]

        if self.binarize_bias:
            [x.bias.data.copy_(self._get_binary(x.bias.data, on, off)) \
             for x in self._get_modules()]

    # Stores weights during binarization so that they can be restored later
    def _save_params(self):
        self.saved_weights = [x.weight.data.clone() for x in self._get_modules()]

        if self.binarize_bias:
            self.saved_biases = [x.bias.data.clone() for x in self._get_modules()]

    # Restores original weights back
    def restore(self):
        [x.weight.data.copy_(y) for x, y in zip(self._get_modules(), self.saved_weights)]

        if self.binarize_bias:
            [x.bias.data.copy_(y) for x, y in zip(self._get_modules(), self.saved_biases)]

### Define a simple MLP model

In [0]:
class SimpleModel(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 hidden_units,
                 bias,
                 binary,
                 norm_type,
                 hist_freq):

        super(SimpleModel, self).__init__()

        # Config Options
        self._bias = bias
        self._binary = binary
        self._norm_type = norm_type
        # Tensorboard related variables         
        self._num_batches = 0
        self._hist_freq = hist_freq
        self._writer = TensorBoard.get_writer()

        # Define layers
        self.fc1 = nn.Linear(in_features, hidden_units, bias=bias)
        self.norm1 = self._get_norm(hidden_units)
        
        self.fc2 = nn.Linear(hidden_units, hidden_units, bias=bias)
        self.norm2 = self._get_norm(hidden_units)
        
        self.fc3 = nn.Linear(hidden_units, hidden_units, bias=bias)
        self.norm3 = self._get_norm(hidden_units)
        
        self.fc4 = nn.Linear(hidden_units, hidden_units, bias=bias)
        self.norm4 = self._get_norm(hidden_units)

        self.fcx = nn.Linear(hidden_units, out_features, bias=bias)
        self.normx = self._get_norm(out_features)
        
        self._reset_parameters()

    # Returns normalization fuction   
    def _get_norm(self, in_features):
        if self._norm_type == "soft_invert":
            return SoftInvertLayer()
        elif self._norm_type == "hard_invert":
            return HardInvertLayer()
        else:
            return nn.BatchNorm1d(in_features, eps=1e-4, momentum=0.15)
        
    def _reset_parameters(self):
        if self._binary:
            [x.weight.data.bernoulli_(0.001) \
             for x in self._modules.values() if should_binarize(x)]

        if self._bias:
            [x.bias.data.zero_() \
             for x in self._modules.values() if should_binarize(x)]

    def _forward_block(self, x, i):
        a = torch.tanh(self._modules[f"fc{i}"](x))
        
        x = self._modules[f"norm{i}"](a)
        
        if self._num_batches % self._hist_freq == 0:
            self._writer.add_histogram(f"Activation {i}",
                                      a.data.cpu().numpy(),
                                      self._num_batches)
            self._writer.add_histogram(f"Normalized Activation {i}",
                                      x.data.cpu().numpy(),
                                      self._num_batches)

        return x
        
    def forward(self, x):
        self._num_batches += 1
        
        for i in range(1, 4):
            x = self._forward_block(x, i)

        x = self._forward_block(x, 'x')
        x = F.log_softmax(x, dim=1)

        return x

In [0]:
def train(args): 
    print("--------------------- Starting Training ---------------------------")
    
    # Set seeds     
    torch.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    # Init model     
    net = SimpleModel(args.in_features,
                      args.out_features,
                      args.hidden_units,
                      args.bias,
                      args.binary,
                      args.norm,
                      args.hist_freq)


    # Init binarization helper
    bc = BinaryConnect(net, args.bias) if args.binary else None
    
    if args.cuda:
        net.cuda()

    print("--------------------- Training Parameters --------------------------")
    print(args)
    print(net)
    print("--------------------------------------------------------------------")
    

    # Init dataset     
    dataset_provider = datasets.FashionMNIST if args.dataset == "fashion" else datasets.MNIST
    
    kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
    
    train_loader = data.DataLoader(
        dataset_provider('./data', train=True, download=True,
                         transform=transforms.ToTensor()),
                         batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = data.DataLoader(
        dataset_provider('./data', train=False,
                         transform=transforms.ToTensor()),
                         batch_size=args.test_batch_size, shuffle=True, **kwargs)

    # Define optimizer and loss function    
    optimizer = optim.Adam(net.parameters(), lr=args.lr)
    creterion = nn.NLLLoss()
    
    for epoch in range(1, args.epochs + 1):
        if epoch % args.print_freq == 0:
            print(f"Epoch {epoch}:")

        train_epoch(epoch,
                    net,
                    bc,
                    creterion,
                    optimizer,
                    train_loader,
                    args,
                    on=1,
                    off=0,
                    random_mag=0,
                    epsilon=1)


        if epoch % args.full_eval_freq == 0 and args.binary:
            for on in np.arange(0, 4.1, 0.1):
                val_epoch(epoch,
                          net,
                          bc,
                          creterion,
                          test_loader,
                          args,
                          on=on,
                          off=0,
                          random_mag=0,
                          epsilon=1)
        else:
            val_epoch(epoch,
                      net,
                      bc,
                      creterion,
                      test_loader,
                      args,
                      on=1,
                      off=0,
                      random_mag=0,
                      epsilon=1)
    
    print("--------------------- Training Completed --------------------------")

    return net

In [0]:
def train_epoch(epoch,
                net,
                bc,
                creterion,
                optimizer,
                train_loader,
                args, 
                on,
                off,
                random_mag,
                epsilon):
    losses = 0
    accs = 0

    net.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()

        data, target = Variable(data.view(data.shape[0], -1)), Variable(target)

        if args.binary:
            # Binarize input
            data = where(torch.abs(data) >= 0.5, 1, 0)
        
        optimizer.zero_grad()
        
        if args.binary:
            bc.binarize(on, off, random_mag, epsilon)
        
        output = net(data)
        
        loss = creterion(output, target)
        loss.backward()
        
        if args.binary:
            bc.restore()
        
        optimizer.step()

        if args.binary:
            bc.clip()
        
        y_pred = torch.max(output, 1)[1]
        accs += (torch.mean((y_pred == target).float())).item()

        losses += loss.item()
        
    writer = TensorBoard.get_writer()
    
    writer.add_scalar("Train Accuracy", accs / batch_idx, epoch)
    writer.add_scalar("Train Loss", losses / batch_idx, epoch)
    
    [writer.add_histogram(x[0], x[1].data.cpu().numpy(), epoch) for x in net.named_parameters()]
    
    if epoch % args.print_freq == 0:
        print("Train Loss={0:.3f}, Train Accuracy={1:.3f}".format(losses / batch_idx, accs / batch_idx))

In [0]:
def val_epoch(epoch,
              net,
              bc,
              creterion,
              test_loader,
              args, 
              on,
              off,
              random_mag,
              epsilon):

    if not args.binary:
        net.eval()

    if args.binary:
        bc.binarize(on, off, random_mag, epsilon)
    
    losses = 0
    accs = 0
    
    for batch_idx, (data, target) in enumerate(test_loader, 1):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data.view(data.shape[0], -1)), Variable(target)

        if args.binary:
            data = where(torch.abs(data) >= 0.5, 1, 0)
        output = net(data)
        loss = creterion(output, target)
        
        y_pred = torch.max(output, 1)[1]
        accs += (torch.mean((y_pred == target).float())).item()
        
        losses += loss.item()

    if args.binary:
        bc.restore()
    
    writer = TensorBoard.get_writer()
    
    writer.add_scalar(f"Validation Accuracy ({on:.1f})", accs / batch_idx, epoch)
    writer.add_scalar(f"Validation Loss ({on:.1f})", losses / batch_idx, epoch)
    
    if epoch % args.print_freq == 0:
        print("    Weight={0:.1f}, Validation Loss={1:.3f}, Validation Accuracy={2:.3f}"\
              .format(on, losses / batch_idx, accs / batch_idx))

In [0]:
def parse_args(argv):
    parser = argparse.ArgumentParser(description='Binary Neural Networks')

    parser.add_argument('--binary',
                        default=False,
                        action="store_true",
                        help='If to use bianry')
    parser.add_argument('--cuda',
                        default=False,
                        action="store_true",
                        help='Use cuda or not')
    parser.add_argument('--bias',
                        default=False,
                        action="store_true",
                        help='Use bias')
    parser.add_argument('--in_features',
                        type=int,
                        default=784,
                        help='input features dim')
    parser.add_argument('--out_features',
                        type=int,
                        default=10,
                        help='Output features dim')
    parser.add_argument('--hidden_units',
                        type=int,
                        default=4000,
                        help='Network Hidden Units')
    parser.add_argument('--batch_size',
                        type=int,
                        default=200,
                        help='Batch size')
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=1000,
                        help='Batch size')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='Learning rate')
    parser.add_argument('--epochs',
                        type=int,
                        default=60,
                        help='Epochs')
    parser.add_argument('--print_freq',
                        type=int,
                        default=1,
                        help='Print frequency')
    parser.add_argument('--hist_freq',
                        type=int,
                        default=100,
                        help='Number of batches between activation histogram')
    parser.add_argument('--full_eval_freq',
                        type=int,
                        default=10,
                        help='Number of epochs between full evaluation')
    parser.add_argument('--norm',
                        default='batch_norm',
                        choices=['soft_invert', 'hard_invert', 'batch_norm'],
                        help='Normalization function')
    parser.add_argument('--dataset',
                        default='mnist',
                        choices=['fashion', 'mnist'],
                        help='Dataset')

    args = parser.parse_args(argv)
    
    return args

### Load up TensorBoard

In [0]:
# The plugin is kind of buggy, if you want better experiance use ngrok
%tensorboard --logdir runs

### Start Training!

In [0]:
def main():
    TensorBoard.reset_writer()
    args = parse_args("--cuda --binary".split(" "))
    net = train(args)
    print_net_stats(net)

main()