# Install Repository

In [None]:
#first time run
!git clone https://github.com/antonioverdi/MLReproChallenge.git
#At the moment all requirements are met by Google Colab already I believe
import os
os.chdir("MLReproChallenge")
!pip install -r requirements.txt

In [None]:
#subsequent runs
!git pull
!pip install -r requirements.txt

# Train Models
### main.py Arguments
<table>
    <tr><th>Argument</th><th>Default</th><th>help</th></tr>
    <tr><td>--arch, - a</td><td>str: 'resnet56'</td><td>model architecture</td></tr>
    <tr><td>-j, --workers</td><td>int: 4</td><td>number of data loading workers</td></tr>
    <tr><td>--epochs</td><td>int: 182</td><td>number of total epochs to run</td></tr>
    <tr><td>--start-epoch</td><td>int: 0</td><td>manual epoch number, for restarts</td></tr>
    <tr><td>-b, --batch-size</td><td>int: 128</td><td>mini-batch size</td></tr>
    <tr><td>--lr, --learning-rate</td><td>float: 0.1</td><td>initial learning rate</td></tr>
    <tr><td>--momentum</td><td>float: 0.9</td><td>momentum</td></tr>
    <tr><td>--weight-decay, --wd</td><td>float: 2e-4</td><td>weight decay</td></tr>
    <tr><td>--print-freq, -p</td><td>int: 50</td><td>print frequency</td></tr>
    <tr><td>--resume</td><td>str: ' '</td><td>path to latest checkpoint</td></tr>
    <tr><td>-e, --evaluate</td><td>bool: False</td><td>evaluate model on validation set</td></tr>
    <tr><td>--pretrained</td><td>bool: False</td><td>use pre-trained model</td></tr>
    <tr><td>--half</td><td>bool: False</td><td>use half-precision float(16-bit)</td></tr>
    <tr><td>--save-dir</td><td>str: 'save_temp'</td><td>Directory used to save trained models</td></tr>
    <tr><td>--save-every</td><td>int: 10</td><td>Save checkpoint at every specified number of epochs</td></tr>
    <tr><td>--colab</td><td>bool: False</td><td>Set this to true when running in Google Colab</td></tr>
    <tr><td>--snip</td><td>bool: False</td><td>Set this to true to run SNIP experiments</td></tr>
    <tr><td>--snip_compression</td><td>float: 0.5</td><td>eg. compression of 0.25 retains 25 percent of weights</td></tr>
   </table>

In [None]:
%%time
!python main.py --colab

# Test Models
### test.py Arguments
(currently no arguments. test.py needs to be made more colab friendly)
<table>
    <tr><th>Argument</th><th>Default</th><th>help</th></tr>
    <tr><td>--arch, - a</td><td>str: 'resnet56'</td><td>model architecture</td></tr>
   </table>

<h4>SNIP experiments</h4>

In [None]:
!python main.py --colab --snip --arch="resnet56_snip" --save-dir="SNIP_checkpoints" --save-every=25

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import types

In [2]:
X = torch.rand((10,5))
y = torch.ones(10, dtype=torch.long)
y[:5] = 0

In [3]:
class TestModel(nn.Module):
    def __init__(self, init_strategy="kaiming"):
        super(TestModel, self).__init__()
        self.layer1 = nn.Linear(5,3)
        self.layer2 = nn.Linear(3,5)
        if init_strategy == "kaiming":
            self.apply(_kaiming_weights_init)
        elif init_strategy == "xavier":
            self.apply(_xavier_weights_init)
            
    def forward(self, x):
        out = F.relu(self.layer1(x))
        out = self.layer2(out)
        return out

In [4]:
def _kaiming_weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)
        
def _xavier_weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight)

def snip_forward_conv2d(self, x):
        return F.conv2d(x, self.weight * self.weight_mask, self.bias,
                        self.stride, self.padding, self.dilation, self.groups)

def snip_forward_linear(self, x):
        return F.linear(x, self.weight * self.weight_mask, self.bias)

In [5]:
def snip_mask(model, batch, labels, compression):
    
    
    for layer in model.modules():
        #create pruning masks manually
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            layer.weight_mask = nn.Parameter(torch.ones_like(layer.weight)) 
#             layer.weight.requires_grad = False #computing gradient of mask not weights

        #monkey-patch forward methods
        if isinstance(layer, nn.Conv2d):
            layer.forward = types.MethodType(snip_forward_conv2d, layer)
            
        if isinstance(layer, nn.Linear):
            layer.forward = types.MethodType(snip_forward_linear, layer)
            
    #compute gradients of weight_mask (connections)
    model.zero_grad()
    out = model.forward(batch)
    loss = F.nll_loss(out, labels)
    loss.backward()
    
    absolute_saliency = []
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
            absolute_saliency.append(torch.abs(layer.weight_mask.grad))
            
    saliency_scores = torch.cat([torch.flatten(x) for x in absolute_saliency])
    denominator = torch.sum(saliency_scores)
    saliency_scores.div_(denominator)
    
    kappa = int(len(saliency_scores) * compression)
    sorted_scores, indices = torch.topk(saliency_scores, kappa, sorted=True)
    threshold = sorted_scores[-1]
    
    connection_masks = []
    for c in absolute_saliency:
        connection_masks.append(((c / denominator) >= threshold).float())
    
    return connection_masks

def apply_snip(model, connection_masks):
    prunable_layers = filter(lambda layer: isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear), model.modules())

    for layer, mask in zip(prunable_layers, connection_masks):
#         assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            """
            The hook function can't be defined directly here because of Python's
            late binding which would result in all hooks getting the very last
            mask! Getting it through another function forces early binding.
            source: https://github.com/mil-ad/snip/blob/master/train.py
            """

            def hook(grads):
                return grads * keep_mask

            return hook

        # Set the masked weights to zero (biases are ignored)
        layer.weight.data[mask == 0.] = 0.
        # Make sure their gradients remain zero. Register_hook gets called whenever a gradient is collected
        layer.weight.register_hook(hook_factory(mask))

In [6]:
model = TestModel(init_strategy="xavier")
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
masks = snip_mask(model, X, y, 0.50)
apply_snip(model, masks)


In [9]:
testlist = [5, 6, 7]
print(testlist.next())

AttributeError: 'list' object has no attribute 'next'