<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Get-Pretrained-Params" data-toc-modified-id="Get-Pretrained-Params-0.1"><span class="toc-item-num">0.1&nbsp;&nbsp;</span>Get Pretrained Params</a></span></li><li><span><a href="#Transfer-Weights" data-toc-modified-id="Transfer-Weights-0.2"><span class="toc-item-num">0.2&nbsp;&nbsp;</span>Transfer Weights</a></span></li><li><span><a href="#Prune-w/-SBP" data-toc-modified-id="Prune-w/-SBP-0.3"><span class="toc-item-num">0.3&nbsp;&nbsp;</span>Prune w/ SBP</a></span></li><li><span><a href="#Resume-Training" data-toc-modified-id="Resume-Training-0.4"><span class="toc-item-num">0.4&nbsp;&nbsp;</span>Resume Training</a></span></li><li><span><a href="#Get-Mask-&amp;-Prune-Network" data-toc-modified-id="Get-Mask-&amp;-Prune-Network-0.5"><span class="toc-item-num">0.5&nbsp;&nbsp;</span>Get Mask &amp; Prune Network</a></span></li><li><span><a href="#PTFLOPS" data-toc-modified-id="PTFLOPS-0.6"><span class="toc-item-num">0.6&nbsp;&nbsp;</span>PTFLOPS</a></span></li><li><span><a href="#Flop-weighted-importance" data-toc-modified-id="Flop-weighted-importance-0.7"><span class="toc-item-num">0.7&nbsp;&nbsp;</span>Flop weighted importance</a></span></li><li><span><a href="#Layer-index" data-toc-modified-id="Layer-index-0.8"><span class="toc-item-num">0.8&nbsp;&nbsp;</span>Layer index</a></span></li><li><span><a href="#Correlated-Net" data-toc-modified-id="Correlated-Net-0.9"><span class="toc-item-num">0.9&nbsp;&nbsp;</span>Correlated Net</a></span></li><li><span><a href="#Decorrelated-net" data-toc-modified-id="Decorrelated-net-0.10"><span class="toc-item-num">0.10&nbsp;&nbsp;</span>Decorrelated net</a></span></li></ul></li><li><span><a href="#Subplots" data-toc-modified-id="Subplots-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Subplots</a></span><ul class="toc-item"><li><span><a href="#Net-Slim-Train" data-toc-modified-id="Net-Slim-Train-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Net-Slim Train</a></span></li><li><span><a href="#L2-based-pruning-Train" data-toc-modified-id="L2-based-pruning-Train-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>L2 based pruning Train</a></span><ul class="toc-item"><li><span><a href="#Importance-plots-Netslim-Train" data-toc-modified-id="Importance-plots-Netslim-Train-1.2.1"><span class="toc-item-num">1.2.1&nbsp;&nbsp;</span>Importance plots Netslim Train</a></span></li><li><span><a href="#Importance-plots-L2-train" data-toc-modified-id="Importance-plots-L2-train-1.2.2"><span class="toc-item-num">1.2.2&nbsp;&nbsp;</span>Importance plots L2 train</a></span></li></ul></li></ul></li></ul></div>

In [1]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import tqdm 

In [2]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import os
from utils import progress_bar

In [3]:
from SBP_alexnet import SBP_AlexNet
import SBP_utils
import torch
import torch.nn as nn

In [4]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [5]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified


In [6]:
class AlexNet(nn.Module):

    def __init__(self, cfg, classes=100):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, cfg[0], kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[0]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[1]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[2]),
            
            nn.Conv2d(cfg[2], cfg[3], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[3]),
            
            nn.Conv2d(cfg[3], cfg[4], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[4]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(cfg[4] * 1 * 1, cfg[5]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(cfg[5], cfg[6]),
            nn.ReLU(inplace=True),
            nn.Linear(cfg[6], classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [7]:
def SBP_w_diag(net):
    ### Conv_ind == 0 ###
    w_mat = net.features[0].conv1.weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.features[0].conv1.bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(torch.t(params), params)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())
    
    for conv_ind in [2, 4, 5, 6]:
        w_mat = net.features[conv_ind].conv1.weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net.features[conv_ind].conv1.bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) 
        L_diag = (angle_mat.diag().norm(1))
        L_angle = (angle_mat.norm(1))
        print(L_diag.cpu()/L_angle.cpu())

        
    '''
    IMPT! Untested with the linear SBP Layers in the classifier.
     
    '''
    
    ### lin_ind = 1 ###        
    w_mat = net.classifier[1].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.classifier[1].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(torch.t(params), params)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

    ### lin_ind = 4 ###        
    w_mat = net.classifier[4].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.classifier[4].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(params, torch.t(params))
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

In [8]:
def net_test(epoch,net):
    global best_acc_net
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc_net:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc_net = acc

In [9]:
# Training
def SBP_net_train(epoch,net):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs,kl = net(inputs)
        loss = criterion(outputs, targets) + kl
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

        
def SBP_net_test(epoch,net):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item() 
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
         
    net.get_sparsity()
    return acc

def SBP_net_train_ortho(epoch,net_ortho):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, kl = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].conv1.weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].conv1.bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [2, 4, 5, 6]:
            w_mat = net_ortho.features[conv_ind].conv1.weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind].conv1.bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc + kl #from the sparsity inducer
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)
    
def SBP_net_test_ortho(epoch,net_ortho):
    global best_acc_ortho
    net_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_ortho(inputs)
            loss = criterion(outputs, targets) 

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_acc_ortho:
        print('Saving..')
        state = {
            'net': net_ortho.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('ortho_checkpoint'):
            os.mkdir('ortho_checkpoint')
        torch.save(state, './ortho_checkpoint/ckpt.pth')
        best_acc_ortho = acc
    return acc 

In [10]:
from __future__ import print_function, absolute_import

__all__ = ['accuracy']

def kaccuracy(output, target, topk=(5,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def top5cal(net):
    net.eval()
    correct = 0
    total = 0
    top1 = 0
    top5 = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            acc1, acc5 = kaccuracy(outputs, targets, topk=(1, 5))
            top1 += (acc1.item()*inputs.shape[0])
            top5 += (acc5.item()*inputs.shape[0])
    top1 /= 10000
    top5 /= 10000
    
    print("top5", top5)

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


### Get Pretrained Params

In [12]:
cfg = [64, 192, 384, 256, 256, 4096, 4096]
best_acc = 0

In [13]:
best_net = AlexNet(cfg)

In [14]:
net_dict = torch.load('./pretrained_alex.pth', map_location=torch.device('cpu'))
best_net.load_state_dict(net_dict['net'])
best_acc = net_dict['best_acc']

In [15]:
## now we must manually load in the weights from the best layer
best_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [16]:
global best_acc_net
best_acc_net = 0

In [17]:
## check pretrained accuracy 51%
criterion = nn.CrossEntropyLoss()
net_test(1,best_net)

Saving..


In [17]:
print('best_net acc: ',best_acc_net)

best_net acc:  0


### Transfer Weights 

In [18]:
best_acc = 0 #reset best accuracy to save after running SBP

In [19]:
sbp_net = SBP_AlexNet(cfg,conv=True).to(device)

In [20]:
for i in best_net.modules():
    if(isinstance(i,nn.Conv2d)):
        print(i)

Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [21]:
best_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [22]:
sbp_net 

SBP_AlexNet(
  (block1): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block5): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1

In [22]:
sbp_net.block1.conv1.weight = best_net.features[0].weight
sbp_net.block1.conv1.bias = best_net.features[0].bias

sbp_net.block1.bn1.weight = best_net.features[2].weight
sbp_net.block1.bn1.bias = best_net.features[2].bias


sbp_net.block2.conv1.weight = best_net.features[4].weight
sbp_net.block2.conv1.bias = best_net.features[4].bias

sbp_net.block2.bn1.weight = best_net.features[6].weight
sbp_net.block2.bn1.bias = best_net.features[6].bias

sbp_net.block3.conv1.weight = best_net.features[8].weight
sbp_net.block3.conv1.bias = best_net.features[8].bias

sbp_net.block3.bn1.weight = best_net.features[10].weight
sbp_net.block3.bn1.bias = best_net.features[10].bias

sbp_net.block4.conv1.weight = best_net.features[11].weight
sbp_net.block4.conv1.bias = best_net.features[11].bias

sbp_net.block4.bn1.weight = best_net.features[13].weight
sbp_net.block4.bn1.bias = best_net.features[13].bias

sbp_net.block5.conv1.weight = best_net.features[14].weight
sbp_net.block5.conv1.bias = best_net.features[14].bias

sbp_net.block5.bn1.weight = best_net.features[16].weight
sbp_net.block5.bn1.bias = best_net.features[16].bias

sbp_net.classifier = best_net.classifier

In [24]:
#its not that terrifying that the accuracy is trash 
# because the sbp layers need to be trained... 
# hopefully it works lol
criterion = nn.CrossEntropyLoss()

SBP_net_test(1,sbp_net)

Saving..
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(246.9508), tensor(246.9508), tensor(246.9507), tensor(246.9508), tensor(246.9508)]


1.0

### Prune w/ SBP 

In [28]:
sbp_learningrate = 1e-4
finetune_epoch = 300 ## that seems excessive
optimizer = optim.Adam(sbp_net.parameters(), lr=sbp_learningrate, betas=[0.95,0.999])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size= 250,gamma=0.1)

In [29]:
sbp_net

SBP_AlexNet(
  (block1): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block5): SBP_ConvBlock(
    (conv1): Conv2d_SBP()
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1

In [None]:
for epoch in range(20):
    SBP_net_train(epoch,sbp_net)
    SBP_net_test(epoch,sbp_net)


Epoch: 0
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(133.4652), tensor(133.6114), tensor(133.7856), tensor(133.9306), tensor(133.4543)]

Epoch: 1
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(134.8893), tensor(133.6431), tensor(134.6861), tensor(135.4410), tensor(133.3140)]

Epoch: 2
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(127.7240), tensor(127.5423), tensor(127.6600), tensor(127.5935), tensor(127.6219)]

Epoch: 3
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(119.3935), tensor(119.1167), tensor(119.5392), tensor(119.6653), tensor(118.9387)]

Epoch: 4
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(118.8847), tensor(118.5076), tensor(118.9281), tensor(118.8893), tensor(118.3586)]

Epoch: 5
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(113.3901), tensor(113.6175), tensor(113.2370), tensor(113.0592), tensor(113.9350)]

Epoch: 6
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(108.7222), tensor(109.0070), tensor(108.3961), tensor(108.4231), tensor(109.5331)]

Epoch: 7
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(102.3439), tensor(102.1550), tensor(102.3085), tensor(102.4023), te

### Resume Training

In [23]:
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
sbp_net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['best_acc']
start_epoch = 20 # set manually from previous training. 

==> Resuming from checkpoint..


In [29]:
x = torch.randn(2,3,32,32)
sbp_net.eval()
sbp_net(x)
print(sbp_net.features[0].sbp.mask)
print(sbp_net.features[0].sbp.multiplicator)

SNR:  tensor(7289.9297)
MASK:  tensor(64.)
SNR:  tensor(21809.4375)
MASK:  tensor(192.)
SNR:  tensor(43668.1445)
MASK:  tensor(384.)
SNR:  tensor(29090.8965)
MASK:  tensor(256.)
SNR:  tensor(29085.9785)
MASK:  tensor(256.)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([[[0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         ...,
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588]],

        [[0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588,

In [45]:
print(best_acc)

50.39


In [None]:
for epoch in range(20,45):
    SBP_net_train(epoch,sbp_net)
    SBP_net_test(epoch,sbp_net)


Epoch: 20

Epoch: 21

Epoch: 22

Epoch: 23

Epoch: 24
Saving..

Epoch: 25

Epoch: 26

Epoch: 27

Epoch: 28

Epoch: 29

Epoch: 30

Epoch: 31

Epoch: 32

Epoch: 33
Saving..

Epoch: 34

Epoch: 35

Epoch: 36

Epoch: 37

Epoch: 38
 [=>.......................]  Step: 261ms | Tot: 5s400ms | Loss: 7.590 | Acc: 63.594% (1628/256 20/391 

### Get Mask & Prune Network

[SBP_Block(
   (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 SBP_Block(
   (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 SBP_Block(
   (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 SBP_Block(
   (conv1): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(2

In [70]:
def cal_importance_0(net, l_id, num_stop=100):
    layer_weights = net.features[0].sbp.weight.data
    imp_corr_bn = layer_weights.abs().sum(dim=(1,2,3))        
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    return neuron_order

In [1]:
def snr_truncated_log_normal(mu, sigma, a, b):
    alpha = (a - mu)/sigma
    beta = (b - mu)/sigma
    z = phi(beta) - phi(alpha)
    ratio = erfcx((sigma-beta)/(2 ** 0.5))*torch.exp((b-mu)-beta**2/2.0)
    ratio = ratio - erfcx((sigma-alpha)/2 ** 0.5)*torch.exp((a-mu)-alpha**2/2.0)
    denominator = 2*z*erfcx((2.0*sigma-beta)/2 ** 0.5)*torch.exp(2.0*(b-mu)-beta**2/2.0)
    denominator = denominator - 2*z*erfcx((2.0*sigma-alpha)/(2 ** 0.5))*torch.exp(2.0*(a-mu)-alpha**2/2.0)
    denominator = denominator - ratio**2
    ratio = ratio/torch.sqrt(1e-8 + denominator)
    print(denominator)
    return ratio

In [2]:
## get the sbp layer mask
def get_mask(mu,sigma,min_log=-20, max_log=0):
    multiplicator = SBP_utils.mean_truncated_log_normal_reduced(mu, sigma, min_log, max_log)
    print("multiplicator: ",multiplicator)
    snr = snr_truncated_log_normal(mu, sigma, min_log, max_log)
    print("snr: ", snr)
    mask = snr
    
    mask[snr <= 1.0] = 0.0
    mask[snr > 1.0] = 1.
    
    print("mask: ",mask)
    multiplicator = multiplicator * mask
    
    return mask, multiplicator

In [107]:
sbp_net.features[0].sbp.log_sigma.detach()

tensor([-4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441])

In [3]:
mask_list = []
multiplicator_list = []
for i in sbp_net.features:
    if not (isinstance(i,nn.MaxPool2d)):
        #print(i)
        mu = i.sbp.mu.detach()
        sigma = i.sbp.log_sigma.detach()
        mask,multiplicator = get_mask(mu,sigma)
        mask_list.append(mask)
        multiplicator_list.append(multiplicator)

NameError: name 'sbp_net' is not defined

In [118]:
x = torch.randn(2,3,32,32)
sbp_net.eval()
y = sbp_net.forward(x)
print(y,y.size())


tensor([[ -6.5035,  -8.3538,  -6.6634, -10.2934, -17.9784,  -5.0302, -11.5280,
         -10.6385,  -1.8071,  -4.3939,   4.3780, -12.2435, -13.1021, -10.5908,
          -7.5355, -12.7850,   2.3540, -12.5417,  -9.5258, -16.9257,  -8.0328,
         -16.6011,   4.9747,  -3.3861, -12.9907,  -4.4500,  -5.9589, -14.0933,
          -4.2610, -14.5732, -16.0110, -13.5321,  -3.9859,  -2.0124, -13.0010,
          -9.6966,  -9.1011,  -3.9229, -16.7927,   8.4844,  -8.6888, -13.6158,
          -8.8288,  -9.8294, -15.8931,  -3.8235, -12.4232,  -6.1532, -10.6135,
         -11.1450,  -8.6385,  -7.6694, -15.9593,  -3.3235,  -4.4470, -18.4937,
         -20.6754,  -0.6448, -18.1565,  -9.3899, -13.1982,   7.3392,  -3.7231,
          -9.8469, -13.8983,  -7.2742, -11.2102, -11.6250, -19.0181, -17.3108,
           3.8214, -17.8009, -10.2750, -17.4043, -15.8597, -20.3146, -11.4058,
         -14.2033,   1.2470,  -9.8833, -10.9950,  -7.2665,  -4.5578,  -1.9062,
          -7.2311, -17.2258,  -2.7540,  -8.5210,  -3

In [137]:
sbp_net.eval()
out_1 = sbp_net.features[0].conv1(x)
print(out.shape)
out_2 = sbp_net.features[0].sbp(out_1)
mu = sbp_net.features[0].sbp.mu.detach()
sigma = sbp_net.features[0].sbp.log_sigma.detach()
print(sbp_net.features[0].sbp.mu)
print(sbp_net.features[0].sbp.log_sigma)
print(get_mask(mu,sigma))
#print(out_1)

torch.Size([2, 64, 16, 16])
Parameter containing:
tensor([-0.0501, -0.0499, -0.0500, -0.0501, -0.0501, -0.0500, -0.0497, -0.0501,
        -0.0500, -0.0500, -0.0502, -0.0500, -0.0501, -0.0500, -0.0498, -0.0500,
        -0.0501, -0.0509, -0.0499, -0.0500, -0.0500, -0.0500, -0.0501, -0.0500,
        -0.0501, -0.0500, -0.0498, -0.0500, -0.0500, -0.0500, -0.0500, -0.0500,
        -0.0500, -0.0499, -0.0501, -0.0500, -0.0501, -0.0500, -0.0501, -0.0499,
        -0.0500, -0.0500, -0.0497, -0.0501, -0.0500, -0.0501, -0.0495, -0.0500,
        -0.0501, -0.0499, -0.0500, -0.0500, -0.0500, -0.0501, -0.0500, -0.0505,
        -0.0500, -0.0500, -0.0500, -0.0498, -0.0502, -0.0500, -0.0499, -0.0513],
       requires_grad=True)
Parameter containing:
tensor([-4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.

tensor([[[[-4.3465e-01, -4.3465e-01, -1.7600e-01,  ..., -4.3465e-01,
            1.7735e+00, -4.3465e-01],
          [ 1.3551e+00,  5.0437e-01, -4.3465e-01,  ..., -4.3465e-01,
           -3.1233e-01,  1.1398e+00],
          [-4.3465e-01, -4.3465e-01, -4.3465e-01,  ...,  5.5210e-01,
           -4.3465e-01, -4.3465e-01],
          ...,
          [ 2.9487e+00, -4.3465e-01, -4.3465e-01,  ..., -4.3465e-01,
            5.1174e-01, -4.3465e-01],
          [-4.3465e-01, -3.3155e-01, -4.3465e-01,  ..., -4.3465e-01,
           -4.3465e-01, -4.3465e-01],
          [-4.3465e-01,  1.4996e+00,  3.1659e+00,  ..., -4.3465e-01,
            9.4603e-01,  3.1426e+00]],

         [[-5.0862e-01,  1.2093e+00, -5.0862e-01,  ...,  9.4067e-02,
           -5.0862e-01, -5.0862e-01],
          [-5.0862e-01,  3.7695e-01,  4.3916e-01,  ..., -5.0862e-01,
            3.3940e+00,  4.3031e-01],
          [ 2.3318e+00,  1.4081e+00,  4.7294e+00,  ..., -5.0862e-01,
           -5.0862e-01, -5.0862e-01],
          ...,
     

In [71]:
cal_importance_0(sbp_net,l_id=1,num_stop=100)

AttributeError: 'SBP_layer' object has no attribute 'weight'

### PTFLOPS

In [53]:
from ptflops import get_model_complexity_info

In [57]:
with torch.cuda.device(-1):
    flops, params = get_model_complexity_info(sbp_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

SBP_AlexNet(
  20.5 M, 100.000% Params, 0.044 GMac, 100.000% MACs, 
  (block1): SBP_Block(
    0.002 M, 0.010% Params, 0.001 GMac, 1.167% MACs, 
    (conv1): Conv2d(0.002 M, 0.009% Params, 0.0 GMac, 1.054% MACs, 3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (sbp): SBP_layer(0.0 M, 0.001% Params, 0.0 GMac, 0.000% MACs, )
    (bn1): BatchNorm2d(0.0 M, 0.001% Params, 0.0 GMac, 0.075% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, inplace=True)
  )
  (mp1): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_Block(
    0.112 M, 0.544% Params, 0.005 GMac, 12.541% MACs, 
    (conv1): Conv2d(0.111 M, 0.540% Params, 0.005 GMac, 12.476% MACs, 64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer(0.0 M, 0.002% Params, 0.0 GMac, 0.000% MACs, )
    (bn1): BatchNorm2d(0.0 M, 0.002% Param

In [62]:
with torch.cuda.device(-1):
    #flops, params = get_model_complexity_info(sbp_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    flops, params = get_model_complexity_info(best_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)

AlexNet(
  20.498 M, 100.000% Params, 0.044 GMac, 100.000% MACs, 
  (features): Sequential(
    2.254 M, 10.996% Params, 0.025 GMac, 58.072% MACs, 
    (0): Conv2d(0.002 M, 0.009% Params, 0.0 GMac, 1.054% MACs, 3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, inplace=True)
    (2): BatchNorm2d(0.0 M, 0.001% Params, 0.0 GMac, 0.075% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(0.111 M, 0.540% Params, 0.005 GMac, 12.476% MACs, 64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.022% MACs, inplace=True)
    (6): BatchNorm2d(0.0 M, 0.002% Params, 0.0 GMac, 0.043% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.022% MAC

In [None]:
sbp_net.

### Flop weighted importance

In [28]:
### importance for SBP Net

l_imp = {}

for conv_ind in [0, 2, 4, 5, 6]:
    l_imp.update({conv_ind: net.features[conv_ind].bn1.bias.shape[0]})
## do not need to update the classifer indices b/c no blocks 
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [29]:
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [30]:
net = SBP_AlexNet(cfg)
net_ortho = SBP_AlexNet(cfg)

In [31]:
for epoch in range(30):
    SBP_net_train(epoch,net)
    SBP_net_test(epoch,net)
    


Epoch: 0
Saving..

Epoch: 1

Epoch: 2
Saving..

Epoch: 3

Epoch: 4
 [=>.......................]  Step: 279ms | Tot: 4s887ms | Loss: 11.956 | Acc: 0.873% (19/217 17/391 

KeyboardInterrupt: 

In [24]:
for epoch in range(30):
    SBP_net_train_ortho(epoch,net_ortho)
    w_diag(net_ortho)
    SBP_net_test_ortho(epoch,net_ortho)


Epoch: 0

Epoch: 0
angle_cost:  283.252374375


TypeError: net_test() takes 1 positional argument but 2 were given

# Inner product training

In [16]:
alex_net = AlexNet(cfg)
print(alex_net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [18]:
alex_net.features[2].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)

In [None]:
net_ortho = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_ortho.load_state_dict(net_dict['net'])
best_acc_ortho = net_dict['best_acc']

In [25]:
### importance for SBP Net

l_imp = {}

for conv_ind in [0, 2, 4, 5, 6]:
    l_imp.update({conv_ind: net.features[conv_ind].bn1.bias.shape[0]})
## do not need to update the classifer indices b/c no blocks 
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
### normal AlexNet
l_imp = {}

for conv_ind in [2, 6, 10, 13, 16]:
    l_imp.update({conv_ind: net.features[conv_ind].bias.shape[0]})
    
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [18]:
def net_test_ortho(epoch):
    global best_acc_ortho
    net_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_acc_ortho:
        print('Saving..')
        state = {
            'net': net_ortho.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('ortho_checkpoint'):
            os.mkdir('ortho_checkpoint')
        torch.save(state, './ortho_checkpoint/ckpt.pth')
        best_acc_ortho = acc

In [20]:

for conv_ind in [6, 10, 13, 16]:
    print(alex_net.features[conv_ind-2])

Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [19]:
def SBP_net_train_ortho(epoch,net_ortho):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, kl = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].conv1.weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].conv1.bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [2, 4, 5, 6]:
            w_mat = net_ortho.features[conv_ind].conv1.weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind].conv1.bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc + kl #from the sparsity inducer
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [None]:
def net_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [6, 10, 13, 16]:
            w_mat = net_ortho.features[conv_ind-2].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind-2].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [22]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()

In [15]:
SBP_net_ortho = SBP_AlexNet(cfg).to(device)

In [23]:
optimizer = optim.Adam(SBP_net_ortho.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [17]:
SBP_net_ortho

SBP_AlexNet(
  (block1): SBP_Block(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_Block(
    (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_Block(
    (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_Block(
    (conv1): Conv2d(384, 256, kernel_size=

In [25]:
for epoch in range(2):
    #SBP_net_train(epoch,SBP_net_ortho)
    SBP_net_train_ortho(epoch,SBP_net_ortho)
    #net_test_ortho(epoch)
    #w_diag(net_ortho)

NameError: name 'SBP_net_ortho' is not defined

In [None]:
# PATH = './w_decorr/base_params/wnet_base_2.pth'
# torch.save(net.state_dict(), PATH)

# Importance analysis

### Layer index

In [35]:
alex_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [37]:
SBP_net_ortho

SBP_AlexNet(
  (block1): SBP_Block(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_Block(
    (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_Block(
    (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_Block(
    (conv1): Conv2d(384, 256, kernel_size=

In [39]:
l_index = 4
layer_id = 'bn'

### Correlated Net

In [None]:
net = AlexNet(cfg).to(device)
net_dict = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(net_dict['net'])
net = net.eval()

In [40]:
#changed for SBP_Aexnet
weight_base = net.features[l_index].bn1.weight.data.clone().detach()
bias_base = net.features[l_index].bn1.bias.data.clone().detach()

In [41]:
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs, kl = net(inputs)
        loss = criterion(outputs, labels) + kl
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [46]:
## changed for SBP_Alexnet

loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in tqdm(range(weight_base.shape[0])): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].bn1.weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bn1.bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs, kl = net(inputs)

        loss = (criterion(outputs, labels)) + kl

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].bn1.weight.data = weight_base.clone().detach()
    net.features[l_index].bn1.bias.data = bias_base.clone().detach()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [44]:
#normal alex net w/ no KL loss
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [None]:
### SBP_Alexnet
### Change bn1 to access other layers in the block.
loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in tqdm(range(weight_base.shape[0])): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].bn1.weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bn1.bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs, kl = net(inputs)

        loss = (criterion(outputs, labels)) + kl

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].bn1.weight.data = weight_base.clone().detach()
    net.features[l_index].bn1.bias.data = bias_base.clone().detach()

In [None]:
### normal AlexNet

loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)

        loss = (criterion(outputs, labels))

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].weight.data = weight_base.clone().detach()
    net.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_corr, './w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')
loss_mat_corr = torch.load('./w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:

### SBP AlexNet
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in tqdm(range(n_epochs)):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs, kl = net(inputs) + kl
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].bn1.weight.grad)*(net.features[l_index].bn1.weight.data)) + ((net.features[l_index].bn1.bias.grad)*(net.features[l_index].bn1.bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

In [None]:

### Normal AlexNet
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

### Decorrelated net

In [None]:
net_decorr = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_decorr.load_state_dict(net_dict['net'])
net_decorr = net_decorr.eval() 

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
num_stop = 0
loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):        
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_decorr = loss_base_decorr**2

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_mat_decorr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    num_stop = 0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        net_decorr.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        net_decorr.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        outputs = net_decorr(inputs)
        
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_decorr[n_index] = running_loss**2
    
    net_decorr.features[l_index].weight.data = weight_base.clone().detach()
    net_decorr.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_decorr, './w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')
loss_mat_decorr = torch.load('./w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).abs().pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

# Graphs

In [None]:
figure(figsize=(20,5))
s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Correlated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/corr/graphs/"+str(l_index)+".png")

In [None]:
figure(figsize=(20,5))
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Decorrelated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/decorr/graphs/"+str(l_index)+".png")

In [None]:
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_decorr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())
ortho_rms = ((loss_diff - s)**2).sum()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_corr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_corr - loss_base_corr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())

base_rms = ((loss_diff - s)**2).sum()

In [None]:
(ortho_rms, base_rms)

In [None]:
# rms_ortho = np.sqrt(np.array(rms_ortho) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))
# rms_base = np.sqrt(np.array(rms_base) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))

In [None]:
plt.figure(figsize=(10,5))
plt.bar(np.linspace(0,30,10)-0.5, rms_ortho, label="Decorrelated network")
plt.bar(np.linspace(0,30,10)+0.5, rms_base, label="Correlated network")
plt.xlabel("Layer ID")
plt.ylabel("RMS")
plt.legend()
plt.savefig("./w_decorr/loss_mats/rms.png")

## Subplots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20,5))

s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
axes[0].plot(s/s.max(), label="Estimated importance")
axes[0].set_title("Decorrelated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
axes[0].set_xlabel("Neuron index")
axes[0].set_ylabel("Normalized importance")
axes[0].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[0].legend()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
axes[1].plot(s/s.max(), label="Estimated importance")
axes[1].set_title("Correlated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
axes[1].set_xlabel("Neuron index")
axes[1].set_ylabel("Normalized importance")
axes[1].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[1].legend()

plt.savefig("./w_decorr/loss_mats/subplots/"+str(l_index)+".png")

# Other metrics

### Net-Slim Train

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### L2 based pruning Train

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_imp_corr.numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
w_imp_decorr = (w_imp_decorr - w_imp_decorr.min())
w_imp_decorr = w_imp_decorr/w_imp_decorr.max()
np.corrcoef(w_imp_decorr.numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

#### Importance plots Netslim Train

In [None]:
figure(figsize=(20,5))

s = scale_corr.cpu().sort()[0].cpu().numpy()
order = scale_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Net-Slim)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = scale_decorr.cpu().sort()[0].cpu().numpy()
order = scale_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Net-Slim)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

#### Importance plots L2 train

In [None]:
figure(figsize=(20,5))
s = w_imp_corr.sort()[0].cpu().numpy()
order = w_imp_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (L2)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = w_imp_decorr.sort()[0].cpu().numpy()
order = w_imp_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (L2)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())