<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Get-Pretrained-Params" data-toc-modified-id="Get-Pretrained-Params-0.1"><span class="toc-item-num">0.1&nbsp;&nbsp;</span>Get Pretrained Params</a></span></li><li><span><a href="#Transfer-Weights" data-toc-modified-id="Transfer-Weights-0.2"><span class="toc-item-num">0.2&nbsp;&nbsp;</span>Transfer Weights</a></span></li><li><span><a href="#Prune-w/-SBP" data-toc-modified-id="Prune-w/-SBP-0.3"><span class="toc-item-num">0.3&nbsp;&nbsp;</span>Prune w/ SBP</a></span></li><li><span><a href="#Resume-Training" data-toc-modified-id="Resume-Training-0.4"><span class="toc-item-num">0.4&nbsp;&nbsp;</span>Resume Training</a></span></li><li><span><a href="#Get-Mask-&amp;-Prune-Network" data-toc-modified-id="Get-Mask-&amp;-Prune-Network-0.5"><span class="toc-item-num">0.5&nbsp;&nbsp;</span>Get Mask &amp; Prune Network</a></span></li><li><span><a href="#PTFLOPS" data-toc-modified-id="PTFLOPS-0.6"><span class="toc-item-num">0.6&nbsp;&nbsp;</span>PTFLOPS</a></span></li><li><span><a href="#Flop-weighted-importance" data-toc-modified-id="Flop-weighted-importance-0.7"><span class="toc-item-num">0.7&nbsp;&nbsp;</span>Flop weighted importance</a></span></li><li><span><a href="#Layer-index" data-toc-modified-id="Layer-index-0.8"><span class="toc-item-num">0.8&nbsp;&nbsp;</span>Layer index</a></span></li><li><span><a href="#Correlated-Net" data-toc-modified-id="Correlated-Net-0.9"><span class="toc-item-num">0.9&nbsp;&nbsp;</span>Correlated Net</a></span></li><li><span><a href="#Decorrelated-net" data-toc-modified-id="Decorrelated-net-0.10"><span class="toc-item-num">0.10&nbsp;&nbsp;</span>Decorrelated net</a></span></li></ul></li><li><span><a href="#Subplots" data-toc-modified-id="Subplots-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Subplots</a></span><ul class="toc-item"><li><span><a href="#Net-Slim-Train" data-toc-modified-id="Net-Slim-Train-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Net-Slim Train</a></span></li><li><span><a href="#L2-based-pruning-Train" data-toc-modified-id="L2-based-pruning-Train-1.2"><span class="toc-item-num">1.2&nbsp;&nbsp;</span>L2 based pruning Train</a></span><ul class="toc-item"><li><span><a href="#Importance-plots-Netslim-Train" data-toc-modified-id="Importance-plots-Netslim-Train-1.2.1"><span class="toc-item-num">1.2.1&nbsp;&nbsp;</span>Importance plots Netslim Train</a></span></li><li><span><a href="#Importance-plots-L2-train" data-toc-modified-id="Importance-plots-L2-train-1.2.2"><span class="toc-item-num">1.2.2&nbsp;&nbsp;</span>Importance plots L2 train</a></span></li></ul></li></ul></li></ul></div>

In [1]:
# -*- coding: utf-8 -*-/
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm.autonotebook import tqdm

In [2]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import os
from utils import progress_bar

In [3]:
from SBP_alexnet import SBPConv_AlexNet
import SBP_utils_gpu as SBP_utils
import torch
import torch.nn as nn

In [4]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [5]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class AlexNet(nn.Module):

    def __init__(self, cfg, classes=100):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, cfg[0], kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[0]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(cfg[0], cfg[1], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[1]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[2]),
            
            nn.Conv2d(cfg[2], cfg[3], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[3]),
            
            nn.Conv2d(cfg[3], cfg[4], kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(cfg[4]),
            
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(cfg[4] * 1 * 1, cfg[5]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(cfg[5], cfg[6]),
            nn.ReLU(inplace=True),
            nn.Linear(cfg[6], classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [7]:
def net_test(epoch,net, criterion = nn.CrossEntropyLoss()):
    global best_base_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_base_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/alex_ckpt.pth')
        best_base_acc = acc

In [8]:
# Training
def SBP_net_train(epoch,net,optimizer,criterion = nn.CrossEntropyLoss(),lr_adjust=None, scheduler=None):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs,kl = net(inputs)
        
        #have to add the KL divergence while training for the SBP layers
        loss = criterion(outputs, targets) + kl
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    if lr_adjust:
        lr_adjust(optimizer,epoch)
    if scheduler:
        scheduler.step()
    return 
def SBP_net_test(epoch,net,criterion= nn.CrossEntropyLoss()):
    global best_sbp_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item() 
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_sbp_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/sbp_ckpt.pth')
        best_spb_acc = acc
         
    net.get_sparsity()
    return acc



In [9]:
from __future__ import print_function, absolute_import

__all__ = ['accuracy']

def kaccuracy(output, target, topk=(5,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def top5cal(net):
    net.eval()
    correct = 0
    total = 0
    top1 = 0
    top5 = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            acc1, acc5 = kaccuracy(outputs, targets, topk=(1, 5))
            top1 += (acc1.item()*inputs.shape[0])
            top5 += (acc5.item()*inputs.shape[0])
    top1 /= 10000
    top5 /= 10000
    
    print("top5", top5)

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### Get Pretrained Params

In [11]:
cfg = [64, 192, 384, 256, 256, 4096, 4096]
best_sbp_acc = best_base_acc = 0

In [12]:
best_net = AlexNet(cfg).to(device)

In [13]:
net_dict = torch.load('./pretrained_alex.pth')
best_net.load_state_dict(net_dict['net'])
best_acc = net_dict['best_acc']

In [14]:
best_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [15]:
## check pretrained accuracy 51%
criterion = nn.CrossEntropyLoss()

net_test(1,best_net)

Saving..


In [16]:
print('best_base_acc: ',best_base_acc)

best_base_acc:  50.96


### Transfer Weights 

In [17]:
best_sbp_acc = best_base_acc = 0 #reset best accuracy to save after running SBP

In [18]:
kl_weights = cfg
all_ones = [1] * 10
sbp_net = SBPConv_AlexNet(cfg,kl_weights=all_ones).to(device)

In [19]:
best_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [27]:
#block 1

def transfer_weights(sbp_net):
    sbp_net.block1.conv1.weight = best_net.features[0].weight
    sbp_net.block1.conv1.bias = best_net.features[0].bias
    sbp_net.block1.bn1.weight = best_net.features[2].weight
    sbp_net.block1.bn1.bias = best_net.features[2].bias

    #block 2
    sbp_net.block2.conv1.weight = best_net.features[4].weight
    sbp_net.block2.conv1.bias = best_net.features[4].bias
    sbp_net.block2.bn1.weight = best_net.features[6].weight
    sbp_net.block2.bn1.bias = best_net.features[6].bias

    #block 3
    sbp_net.block3.conv1.weight = best_net.features[8].weight
    sbp_net.block3.conv1.bias = best_net.features[8].bias
    sbp_net.block3.bn1.weight = best_net.features[10].weight
    sbp_net.block3.bn1.bias = best_net.features[10].bias

    #block 4
    sbp_net.block4.conv1.weight = best_net.features[11].weight
    sbp_net.block4.conv1.bias = best_net.features[11].bias
    sbp_net.block4.bn1.weight = best_net.features[13].weight
    sbp_net.block4.bn1.bias = best_net.features[13].bias


    #block 5
    sbp_net.block5.conv1.weight = best_net.features[14].weight
    sbp_net.block5.conv1.bias = best_net.features[14].bias
    sbp_net.block5.bn1.weight = best_net.features[16].weight
    sbp_net.block5.bn1.bias = best_net.features[16].bias

    #note! Im not using SBP layers in the classifier right now. 
    #Otherwise, would need to transfer code here as well!
    sbp_net.lsbp1.weight = best_net.classifier[1].weight
    sbp_net.lsbp1.bias = best_net.classifier[1].bias

    sbp_net.lsbp2.weight = best_net.classifier[4].weight
    sbp_net.lsbp2.bias = best_net.classifier[4].bias
    
    return sbp_net

### Test Weights After Transferring

In [22]:
#its not that terrifying that the accuracy is trash 
criterion = nn.CrossEntropyLoss()
SBP_net_test(1,sbp_net,criterion=criterion)

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0'), tensor(246.9508, device='cuda:0')]


1.0

### Prune w/ SBP 

In [24]:
#First attempt w/ lr. This is no longer being used. Here for completeness.

sbp_learningrate = 1e-5
finetune_epoch = 300 ## that seems excessive

kl_weights = cfg
all_ones = [1] * 10
sbp_net = SBPConv_AlexNet(cfg,kl_weights=cfg).to(device)

optimizer = optim.Adam(sbp_net.parameters(), lr=sbp_learningrate, betas=[0.95,0.999])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size= 250,gamma=0.1)

In [25]:
for epoch in range(0,300):
    SBP_net_train(epoch,sbp_net,optimizer=optimizer,criterion=nn.CrossEntropyLoss())
    SBP_net_test(epoch,sbp_net)


Epoch: 0
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(207.3631, device='cuda:0'), tensor(207.3631, device='cuda:0'), tensor(207.3631, device='cuda:0'), tensor(207.3631, device='cuda:0'), tensor(207.3631, device='cuda:0'), tensor(207.9739, device='cuda:0'), tensor(207.9592, device='cuda:0')]

Epoch: 1
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(181.1531, device='cuda:0'), tensor(181.1544, device='cuda:0'), tensor(181.1531, device='cuda:0'), tensor(181.1531, device='cuda:0'), tensor(181.1531, device='cuda:0'), tensor(180.7624, device='cuda:0'), tensor(180.7803, device='cuda:0')]

Epoch: 2
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(164.1757, device='cuda:0'), tensor(164.1472, device='cuda:0'), tensor(164.1658, device='cuda:0'), tensor(164.1537, device='cuda:0'), tensor(164.1472, device='cuda:0'), tensor(165.5558, device='cuda:0'), tensor(165.5764, device='cuda:0')]

Epoch: 3
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(15

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(149.0000, device='cuda:0'), tensor(149.0020, device='cuda:0'), tensor(148.9927, device='cuda:0'), tensor(149.0003, device='cuda:0'), tensor(149.0062, device='cuda:0'), tensor(148.6904, device='cuda:0'), tensor(148.6919, device='cuda:0')]

Epoch: 6
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(146.6345, device='cuda:0'), tensor(146.6403, device='cuda:0'), tensor(146.6255, device='cuda:0'), tensor(146.6537, device='cuda:0'), tensor(146.6520, device='cuda:0'), tensor(146.6573, device='cuda:0'), tensor(146.6468, device='cuda:0')]

Epoch: 7
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(145.2052, device='cuda:0'), tensor(145.2205, device='cuda:0'), tensor(145.2202, device='cuda:0'), tensor(145.2542, device='cuda:0'), tensor(145.2223, device='cuda:0'), tensor(145.3037, device='cuda:0'), tensor(145.3162, device='cuda:0')]

Epoch: 8
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(145.6074, de

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(135.2156, device='cuda:0'), tensor(135.2143, device='cuda:0'), tensor(135.0398, device='cuda:0'), tensor(135.0377, device='cuda:0'), tensor(135.1791, device='cuda:0'), tensor(134.5666, device='cuda:0'), tensor(134.3785, device='cuda:0')]

Epoch: 22
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(132.5202, device='cuda:0'), tensor(131.8779, device='cuda:0'), tensor(132.6906, device='cuda:0'), tensor(133.2493, device='cuda:0'), tensor(131.7125, device='cuda:0'), tensor(133.3901, device='cuda:0'), tensor(133.2096, device='cuda:0')]

Epoch: 23
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(134.3783, device='cuda:0'), tensor(133.8230, device='cuda:0'), tensor(133.9697, device='cuda:0'), tensor(134.0416, device='cuda:0'), tensor(134.0424, device='cuda:0'), tensor(133.8369, device='cuda:0'), tensor(133.5646, device='cuda:0')]

Epoch: 24
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(133.6478,

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(132.5656, device='cuda:0'), tensor(132.4066, device='cuda:0'), tensor(132.1649, device='cuda:0'), tensor(131.9333, device='cuda:0'), tensor(132.6419, device='cuda:0'), tensor(130.8342, device='cuda:0'), tensor(130.6447, device='cuda:0')]

Epoch: 31
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(127.8364, device='cuda:0'), tensor(127.6127, device='cuda:0'), tensor(127.9818, device='cuda:0'), tensor(128.0538, device='cuda:0'), tensor(127.3882, device='cuda:0'), tensor(128.0643, device='cuda:0'), tensor(128.0743, device='cuda:0')]

Epoch: 32
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(128.3497, device='cuda:0'), tensor(128.5489, device='cuda:0'), tensor(128.4160, device='cuda:0'), tensor(127.9779, device='cuda:0'), tensor(128.7033, device='cuda:0'), tensor(127.6083, device='cuda:0'), tensor(127.4878, device='cuda:0')]

Epoch: 33
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(128.1081,

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(126.9304, device='cuda:0'), tensor(126.5320, device='cuda:0'), tensor(126.7699, device='cuda:0'), tensor(127.0817, device='cuda:0'), tensor(126.3421, device='cuda:0'), tensor(127.3886, device='cuda:0'), tensor(127.3848, device='cuda:0')]

Epoch: 38
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(126.4336, device='cuda:0'), tensor(126.6670, device='cuda:0'), tensor(126.6535, device='cuda:0'), tensor(126.6290, device='cuda:0'), tensor(126.7180, device='cuda:0'), tensor(126.8485, device='cuda:0'), tensor(126.8391, device='cuda:0')]

Epoch: 39


Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(126.3112, device='cuda:0'), tensor(126.0520, device='cuda:0'), tensor(126.2500, device='cuda:0'), tensor(126.3692, device='cuda:0'), tensor(126.2307, device='cuda:0'), tensor(126.3785, device='cuda:0'), tensor(126.3878, device='cuda:0')]

Epoch: 40
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(125.7079, device='cuda:0'), tensor(125.5470, device='cuda:0'), tensor(125.7993, device='cuda:0'), tensor(125.8800, device='cuda:0'), tensor(125.8184, device='cuda:0'), tensor(125.9055, device='cuda:0'), tensor(125.8959, device='cuda:0')]

Epoch: 41
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(125.3650, device='cuda:0'), tensor(125.3574, device='cuda:0'), tensor(125.2772, device='cuda:0'), tensor(125.2268, device='cuda:0'), tensor(125.4461, device='cuda:0'), tensor(125.3610, device='cuda:0'), tensor(125.3766, device='cuda:0')]

Epoch: 42
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(126.9880,

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(121.0106, device='cuda:0'), tensor(120.8518, device='cuda:0'), tensor(120.8799, device='cuda:0'), tensor(120.8073, device='cuda:0'), tensor(120.8225, device='cuda:0'), tensor(120.8744, device='cuda:0'), tensor(120.8779, device='cuda:0')]

Epoch: 51
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(120.4518, device='cuda:0'), tensor(120.2635, device='cuda:0'), tensor(120.4918, device='cuda:0'), tensor(120.5464, device='cuda:0'), tensor(120.6669, device='cuda:0'), tensor(120.5145, device='cuda:0'), tensor(120.4946, device='cuda:0')]

Epoch: 52
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(120.1386, device='cuda:0'), tensor(119.7568, device='cuda:0'), tensor(119.8390, device='cuda:0'), tensor(120.0591, device='cuda:0'), tensor(119.8318, device='cuda:0'), tensor(119.9685, device='cuda:0'), tensor(119.9897, device='cuda:0')]

Epoch: 53
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(121.4385,

KeyboardInterrupt: 

In [42]:
equal_optimizer.param_groups[0]['lr']

1e-05

## Train Equal KL Weights

In [22]:
#This is similar to the original paper's hyper-parameters for VGG. Also fails. 


best_sbp_acc = best_base_acc = 0 #reset best accuracy to save after running SBP

equal_weights = [1,1,1,1,1,1,1]
lr_x = [] #learning rate decay


def learning_rate_calc(optimizer, epoch):
    if epoch < 250:
        return 1e-5
    else: 
        return 1e-5 * (300-epoch)/(300-250)

sbp_learningrate = 1e-5
finetune_epoch = 300 ## that seems excessive
equal_optimizer = optim.Adam(sbp_equal_net.parameters(),lr=sbp_learningrate, betas=[0.95,0.999])
#equal_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_x)

for epoch in range(0,300):
    SBP_net_train(epoch,sbp_equal_net,optimizer=equal_optimizer,criterion=nn.CrossEntropyLoss(),lr_adjust=learning_rate_calc)
    SBP_net_test(epoch,sbp_equal_net)


Epoch: 0
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(207.3795, device='cuda:0'), tensor(207.5979, device='cuda:0'), tensor(207.8234, device='cuda:0'), tensor(207.5694, device='cuda:0'), tensor(207.5939, device='cuda:0'), tensor(208.0278, device='cuda:0'), tensor(208.0318, device='cuda:0')]

Epoch: 1
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(181.1606, device='cuda:0'), tensor(181.1507, device='cuda:0'), tensor(180.9419, device='cuda:0'), tensor(181.1448, device='cuda:0'), tensor(181.1411, device='cuda:0'), tensor(181.5152, device='cuda:0'), tensor(181.5367, device='cuda:0')]

Epoch: 2
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(164.5153, device='cuda:0'), tensor(165.5872, device='cuda:0'), tensor(165.3192, device='cuda:0'), tensor(165.6361, device='cuda:0'), tensor(165.7697, device='cuda:0'), tensor(165.6310, device='cuda:0'), tensor(165.6275, device='cuda:0')]

Epoch: 3
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(15

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(144.2218, device='cuda:0'), tensor(144.1597, device='cuda:0'), tensor(144.3626, device='cuda:0'), tensor(144.7958, device='cuda:0'), tensor(144.9916, device='cuda:0'), tensor(144.3100, device='cuda:0'), tensor(144.3037, device='cuda:0')]

Epoch: 9
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(143.8009, device='cuda:0'), tensor(145.1761, device='cuda:0'), tensor(145.5326, device='cuda:0'), tensor(143.3979, device='cuda:0'), tensor(143.4100, device='cuda:0'), tensor(143.3314, device='cuda:0'), tensor(143.3410, device='cuda:0')]

Epoch: 10
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(143.8509, device='cuda:0'), tensor(142.7272, device='cuda:0'), tensor(143.4786, device='cuda:0'), tensor(143.9166, device='cuda:0'), tensor(143.8954, device='cuda:0'), tensor(142.4787, device='cuda:0'), tensor(142.5112, device='cuda:0')]

Epoch: 11
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(143.3564, 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(138.2236, device='cuda:0'), tensor(138.5338, device='cuda:0'), tensor(138.0969, device='cuda:0'), tensor(138.4467, device='cuda:0'), tensor(138.4867, device='cuda:0'), tensor(137.4068, device='cuda:0'), tensor(137.3909, device='cuda:0')]

Epoch: 19
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(136.0600, device='cuda:0'), tensor(136.8817, device='cuda:0'), tensor(136.7387, device='cuda:0'), tensor(136.2519, device='cuda:0'), tensor(136.1008, device='cuda:0'), tensor(136.8454, device='cuda:0'), tensor(136.8059, device='cuda:0')]

Epoch: 20
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(135.4910, device='cuda:0'), tensor(136.0259, device='cuda:0'), tensor(135.3560, device='cuda:0'), tensor(135.4102, device='cuda:0'), tensor(135.3918, device='cuda:0'), tensor(136.2519, device='cuda:0'), tensor(136.2741, device='cuda:0')]

Epoch: 21
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(134.5009,

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(128.5370, device='cuda:0'), tensor(128.3341, device='cuda:0'), tensor(133.3353, device='cuda:0'), tensor(128.3745, device='cuda:0'), tensor(128.3614, device='cuda:0'), tensor(130.0380, device='cuda:0'), tensor(130.0564, device='cuda:0')]

Epoch: 33
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(127.9086, device='cuda:0'), tensor(127.6834, device='cuda:0'), tensor(127.6335, device='cuda:0'), tensor(127.6951, device='cuda:0'), tensor(127.7316, device='cuda:0'), tensor(129.5060, device='cuda:0'), tensor(129.4686, device='cuda:0')]

Epoch: 34
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(130.1107, device='cuda:0'), tensor(129.8537, device='cuda:0'), tensor(129.8993, device='cuda:0'), tensor(129.7938, device='cuda:0'), tensor(129.8949, device='cuda:0'), tensor(128.9569, device='cuda:0'), tensor(128.9605, device='cuda:0')]

Epoch: 35
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(125.6803,

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(114.8551, device='cuda:0'), tensor(114.9987, device='cuda:0'), tensor(114.8402, device='cuda:0'), tensor(114.9087, device='cuda:0'), tensor(114.8295, device='cuda:0'), tensor(115.1983, device='cuda:0'), tensor(115.2155, device='cuda:0')]

Epoch: 64
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(112.5431, device='cuda:0'), tensor(112.6619, device='cuda:0'), tensor(112.3979, device='cuda:0'), tensor(112.4643, device='cuda:0'), tensor(112.4480, device='cuda:0'), tensor(114.7320, device='cuda:0'), tensor(114.7754, device='cuda:0')]

Epoch: 65
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(114.1087, device='cuda:0'), tensor(114.1377, device='cuda:0'), tensor(113.9020, device='cuda:0'), tensor(114.2302, device='cuda:0'), tensor(114.1401, device='cuda:0'), tensor(114.3199, device='cuda:0'), tensor(114.2841, device='cuda:0')]

Epoch: 66
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(114.0480,

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(102.6249, device='cuda:0'), tensor(102.8007, device='cuda:0'), tensor(102.6809, device='cuda:0'), tensor(102.4243, device='cuda:0'), tensor(102.5638, device='cuda:0'), tensor(101.2654, device='cuda:0'), tensor(101.2947, device='cuda:0')]

Epoch: 97
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(99.1454, device='cuda:0'), tensor(99.2163, device='cuda:0'), tensor(99.1006, device='cuda:0'), tensor(99.2581, device='cuda:0'), tensor(99.1312, device='cuda:0'), tensor(100.8847, device='cuda:0'), tensor(100.8867, device='cuda:0')]

Epoch: 98
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(100.3374, device='cuda:0'), tensor(100.1768, device='cuda:0'), tensor(100.2058, device='cuda:0'), tensor(100.5047, device='cuda:0'), tensor(100.3146, device='cuda:0'), tensor(100.4885, device='cuda:0'), tensor(100.4759, device='cuda:0')]

Epoch: 99
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(99.9724, devic

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(88.3110, device='cuda:0'), tensor(88.5687, device='cuda:0'), tensor(88.4902, device='cuda:0'), tensor(88.5415, device='cuda:0'), tensor(88.4946, device='cuda:0'), tensor(88.6914, device='cuda:0'), tensor(88.6607, device='cuda:0')]

Epoch: 131
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(86.8320, device='cuda:0'), tensor(87.1393, device='cuda:0'), tensor(89.1920, device='cuda:0'), tensor(87.2185, device='cuda:0'), tensor(87.1493, device='cuda:0'), tensor(88.3345, device='cuda:0'), tensor(88.3533, device='cuda:0')]

Epoch: 132
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(86.4913, device='cuda:0'), tensor(86.8046, device='cuda:0'), tensor(86.7944, device='cuda:0'), tensor(86.8901, device='cuda:0'), tensor(86.7512, device='cuda:0'), tensor(88.0021, device='cuda:0'), tensor(88.0078, device='cuda:0')]

Epoch: 133
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(87.4993, device='cuda:0'), 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(78.0188, device='cuda:0'), tensor(78.1589, device='cuda:0'), tensor(78.1759, device='cuda:0'), tensor(78.0672, device='cuda:0'), tensor(78.1101, device='cuda:0'), tensor(78.2816, device='cuda:0'), tensor(78.2777, device='cuda:0')]

Epoch: 163
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(77.8934, device='cuda:0'), tensor(77.7592, device='cuda:0'), tensor(77.7658, device='cuda:0'), tensor(77.8319, device='cuda:0'), tensor(77.8144, device='cuda:0'), tensor(77.9586, device='cuda:0'), tensor(77.9692, device='cuda:0')]

Epoch: 164
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(76.7634, device='cuda:0'), tensor(76.9354, device='cuda:0'), tensor(76.8912, device='cuda:0'), tensor(76.7770, device='cuda:0'), tensor(76.9482, device='cuda:0'), tensor(77.6684, device='cuda:0'), tensor(77.6717, device='cuda:0')]

Epoch: 165
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(77.9696, device='cuda:0'), 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(68.9357, device='cuda:0'), tensor(68.8724, device='cuda:0'), tensor(68.9611, device='cuda:0'), tensor(69.0181, device='cuda:0'), tensor(68.9685, device='cuda:0'), tensor(69.0786, device='cuda:0'), tensor(69.0873, device='cuda:0')]

Epoch: 195
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(68.9715, device='cuda:0'), tensor(68.8581, device='cuda:0'), tensor(68.6562, device='cuda:0'), tensor(68.9426, device='cuda:0'), tensor(68.9136, device='cuda:0'), tensor(68.8136, device='cuda:0'), tensor(68.8181, device='cuda:0')]

Epoch: 196
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(68.5109, device='cuda:0'), tensor(68.3475, device='cuda:0'), tensor(68.4378, device='cuda:0'), tensor(68.4446, device='cuda:0'), tensor(68.4695, device='cuda:0'), tensor(68.5489, device='cuda:0'), tensor(68.5504, device='cuda:0')]

Epoch: 197
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(68.0825, device='cuda:0'), 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(60.3478, device='cuda:0'), tensor(60.2791, device='cuda:0'), tensor(60.3083, device='cuda:0'), tensor(60.3074, device='cuda:0'), tensor(60.3065, device='cuda:0'), tensor(60.7296, device='cuda:0'), tensor(60.7454, device='cuda:0')]

Epoch: 228
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(60.2721, device='cuda:0'), tensor(60.4248, device='cuda:0'), tensor(60.3602, device='cuda:0'), tensor(60.4114, device='cuda:0'), tensor(60.3937, device='cuda:0'), tensor(60.4917, device='cuda:0'), tensor(60.5079, device='cuda:0')]

Epoch: 229
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(60.1535, device='cuda:0'), tensor(60.0711, device='cuda:0'), tensor(60.1814, device='cuda:0'), tensor(60.1726, device='cuda:0'), tensor(60.1480, device='cuda:0'), tensor(60.2573, device='cuda:0'), tensor(60.2708, device='cuda:0')]

Epoch: 230
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(59.6186, device='cuda:0'), 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(58.7593, device='cuda:0'), tensor(58.7679, device='cuda:0'), tensor(58.7824, device='cuda:0'), tensor(58.7446, device='cuda:0'), tensor(58.7435, device='cuda:0'), tensor(58.8639, device='cuda:0'), tensor(58.8778, device='cuda:0')]

Epoch: 236
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(58.2165, device='cuda:0'), tensor(58.2444, device='cuda:0'), tensor(58.2463, device='cuda:0'), tensor(58.2545, device='cuda:0'), tensor(58.2711, device='cuda:0'), tensor(58.6380, device='cuda:0'), tensor(58.6512, device='cuda:0')]

Epoch: 237
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(58.3199, device='cuda:0'), tensor(58.2864, device='cuda:0'), tensor(58.3008, device='cuda:0'), tensor(58.3079, device='cuda:0'), tensor(58.2898, device='cuda:0'), tensor(58.4041, device='cuda:0'), tensor(58.4193, device='cuda:0')]

Epoch: 238
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(58.1475, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(56.9628, device='cuda:0'), tensor(56.9101, device='cuda:0'), tensor(56.9836, device='cuda:0'), tensor(56.9525, device='cuda:0'), tensor(56.9361, device='cuda:0'), tensor(57.0540, device='cuda:0'), tensor(57.0679, device='cuda:0')]

Epoch: 244
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(57.0320, device='cuda:0'), tensor(57.0057, device='cuda:0'), tensor(57.0084, device='cuda:0'), tensor(56.9910, device='cuda:0'), tensor(56.9835, device='cuda:0'), tensor(56.8317, device='cuda:0'), tensor(56.8482, device='cuda:0')]

Epoch: 245
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(56.5807, device='cuda:0'), tensor(56.4806, device='cuda:0'), tensor(56.4789, device='cuda:0'), tensor(56.5544, device='cuda:0'), tensor(56.4983, device='cuda:0'), tensor(56.6105, device='cuda:0'), tensor(56.6265, device='cuda:0')]

Epoch: 246
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(56.0377, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(55.8840, device='cuda:0'), tensor(55.9200, device='cuda:0'), tensor(55.7622, device='cuda:0'), tensor(55.9270, device='cuda:0'), tensor(55.8737, device='cuda:0'), tensor(55.7348, device='cuda:0'), tensor(55.7537, device='cuda:0')]

Epoch: 250
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(55.3955, device='cuda:0'), tensor(55.4078, device='cuda:0'), tensor(55.4104, device='cuda:0'), tensor(55.4239, device='cuda:0'), tensor(55.4226, device='cuda:0'), tensor(55.5161, device='cuda:0'), tensor(55.5371, device='cuda:0')]

Epoch: 251
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(55.4119, device='cuda:0'), tensor(55.4572, device='cuda:0'), tensor(55.4720, device='cuda:0'), tensor(55.4558, device='cuda:0'), tensor(55.4296, device='cuda:0'), tensor(55.3023, device='cuda:0'), tensor(55.3192, device='cuda:0')]

Epoch: 252
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(54.9985, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(54.1130, device='cuda:0'), tensor(54.1477, device='cuda:0'), tensor(54.1323, device='cuda:0'), tensor(54.1786, device='cuda:0'), tensor(54.1372, device='cuda:0'), tensor(54.0222, device='cuda:0'), tensor(54.0428, device='cuda:0')]

Epoch: 258
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(53.7054, device='cuda:0'), tensor(53.7020, device='cuda:0'), tensor(53.6925, device='cuda:0'), tensor(53.7303, device='cuda:0'), tensor(53.7162, device='cuda:0'), tensor(53.8143, device='cuda:0'), tensor(53.8361, device='cuda:0')]

Epoch: 259
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(53.4791, device='cuda:0'), tensor(53.5026, device='cuda:0'), tensor(53.4710, device='cuda:0'), tensor(53.5162, device='cuda:0'), tensor(53.5219, device='cuda:0'), tensor(53.6062, device='cuda:0'), tensor(53.6238, device='cuda:0')]

Epoch: 260
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(53.5427, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(51.6611, device='cuda:0'), tensor(51.6401, device='cuda:0'), tensor(51.6273, device='cuda:0'), tensor(51.6669, device='cuda:0'), tensor(51.6416, device='cuda:0'), tensor(51.5581, device='cuda:0'), tensor(51.5816, device='cuda:0')]

Epoch: 270
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(51.2039, device='cuda:0'), tensor(51.2268, device='cuda:0'), tensor(51.2316, device='cuda:0'), tensor(51.2349, device='cuda:0'), tensor(51.2502, device='cuda:0'), tensor(51.3606, device='cuda:0'), tensor(51.3814, device='cuda:0')]

Epoch: 271
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(50.8044, device='cuda:0'), tensor(50.8483, device='cuda:0'), tensor(50.8493, device='cuda:0'), tensor(50.8447, device='cuda:0'), tensor(50.8647, device='cuda:0'), tensor(51.1607, device='cuda:0'), tensor(51.1847, device='cuda:0')]

Epoch: 272
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(50.6501, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(50.6206, device='cuda:0'), tensor(50.6058, device='cuda:0'), tensor(50.6519, device='cuda:0'), tensor(50.6182, device='cuda:0'), tensor(50.6309, device='cuda:0'), tensor(50.5682, device='cuda:0'), tensor(50.5898, device='cuda:0')]

Epoch: 275
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(50.4513, device='cuda:0'), tensor(50.4386, device='cuda:0'), tensor(50.4510, device='cuda:0'), tensor(50.4458, device='cuda:0'), tensor(50.4301, device='cuda:0'), tensor(50.3692, device='cuda:0'), tensor(50.3956, device='cuda:0')]

Epoch: 276
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(50.0895, device='cuda:0'), tensor(50.0720, device='cuda:0'), tensor(50.0524, device='cuda:0'), tensor(50.0756, device='cuda:0'), tensor(50.0504, device='cuda:0'), tensor(50.1740, device='cuda:0'), tensor(50.1983, device='cuda:0')]

Epoch: 277
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(49.8974, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(48.6835, device='cuda:0'), tensor(48.7361, device='cuda:0'), tensor(48.7108, device='cuda:0'), tensor(48.6899, device='cuda:0'), tensor(48.7019, device='cuda:0'), tensor(48.8275, device='cuda:0'), tensor(48.8536, device='cuda:0')]

Epoch: 284
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(48.3403, device='cuda:0'), tensor(48.3299, device='cuda:0'), tensor(48.3710, device='cuda:0'), tensor(48.3545, device='cuda:0'), tensor(48.3694, device='cuda:0'), tensor(48.6380, device='cuda:0'), tensor(48.6651, device='cuda:0')]

Epoch: 285
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(48.3264, device='cuda:0'), tensor(48.3246, device='cuda:0'), tensor(48.3241, device='cuda:0'), tensor(48.3335, device='cuda:0'), tensor(48.3169, device='cuda:0'), tensor(48.4498, device='cuda:0'), tensor(48.4769, device='cuda:0')]

Epoch: 286
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(48.1606, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.9443, device='cuda:0'), tensor(47.9431, device='cuda:0'), tensor(47.9569, device='cuda:0'), tensor(47.9590, device='cuda:0'), tensor(47.9479, device='cuda:0'), tensor(48.0738, device='cuda:0'), tensor(48.1033, device='cuda:0')]

Epoch: 288
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.9216, device='cuda:0'), tensor(47.9293, device='cuda:0'), tensor(47.9132, device='cuda:0'), tensor(47.9111, device='cuda:0'), tensor(47.9164, device='cuda:0'), tensor(47.8864, device='cuda:0'), tensor(47.9166, device='cuda:0')]

Epoch: 289
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.7347, device='cuda:0'), tensor(47.7183, device='cuda:0'), tensor(47.7421, device='cuda:0'), tensor(47.7269, device='cuda:0'), tensor(47.7293, device='cuda:0'), tensor(47.7001, device='cuda:0'), tensor(47.7308, device='cuda:0')]

Epoch: 290
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.4014, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.0567, device='cuda:0'), tensor(47.0600, device='cuda:0'), tensor(47.0497, device='cuda:0'), tensor(47.0481, device='cuda:0'), tensor(47.0734, device='cuda:0'), tensor(47.3320, device='cuda:0'), tensor(47.3623, device='cuda:0')]

Epoch: 292
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(47.0205, device='cuda:0'), tensor(47.0345, device='cuda:0'), tensor(46.8689, device='cuda:0'), tensor(47.0183, device='cuda:0'), tensor(46.9989, device='cuda:0'), tensor(47.1494, device='cuda:0'), tensor(47.1784, device='cuda:0')]

Epoch: 293
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(46.7047, device='cuda:0'), tensor(46.6762, device='cuda:0'), tensor(46.6926, device='cuda:0'), tensor(46.6629, device='cuda:0'), tensor(46.6997, device='cuda:0'), tensor(46.9661, device='cuda:0'), tensor(46.9940, device='cuda:0')]

Epoch: 294
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(46.6786, device='cuda:0'), 

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(45.9756, device='cuda:0'), tensor(45.9276, device='cuda:0'), tensor(45.9425, device='cuda:0'), tensor(45.9191, device='cuda:0'), tensor(45.9265, device='cuda:0'), tensor(46.0629, device='cuda:0'), tensor(46.0943, device='cuda:0')]

Epoch: 299
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(45.7534, device='cuda:0'), tensor(45.7535, device='cuda:0'), tensor(45.7590, device='cuda:0'), tensor(45.7318, device='cuda:0'), tensor(45.7591, device='cuda:0'), tensor(45.8827, device='cuda:0'), tensor(45.9142, device='cuda:0')]


## Try New Initialize Scheme

In [23]:
best_sbp_acc = best_base_acc = 0 #reset best accuracy to save after running SBP


In [None]:
#This is the approach I've had the most success with in lowering the sparsity gradually from 0 to 1. 
#But it monotonically increases sparsity while destryoing the accuracy. 

sbp_learningrate = 2e-5
sbp_parameters = [
    {'params': sbp_net.block1.conv1.weight},
    {'params': sbp_net.block2.conv1.weight},
    {'params': sbp_net.block3.conv1.weight},
    {'params': sbp_net.block4.conv1.weight},
    {'params': sbp_net.block5.conv1.weight},
    
    {'params': sbp_net.block1.bn1.weight},
    {'params': sbp_net.block2.bn1.weight},
    {'params': sbp_net.block3.bn1.weight},
    {'params': sbp_net.block4.bn1.weight},
    {'params': sbp_net.block5.bn1.weight},
    
    {'params': sbp_net.block1.conv1.log_sigma, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block2.conv1.log_sigma, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block3.conv1.log_sigma, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block4.conv1.log_sigma, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block5.conv1.log_sigma, 'lr': 10*sbp_learningrate},

    {'params': sbp_net.block1.conv1.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block2.conv1.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block3.conv1.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block4.conv1.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.block5.conv1.mu, 'lr': 10*sbp_learningrate},

    {'params': sbp_net.lsbp1.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.lsbp1.log_sigma, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.lsbp2.mu, 'lr': 10*sbp_learningrate},
    {'params': sbp_net.lsbp2.log_sigma, 'lr': 10*sbp_learningrate},
    
    {'params': sbp_net.last.weight},

    
 ]




finetune_epoch = 300 ## that seems excessive

sbp_net = SBPConv_AlexNet(cfg,kl_weights=all_ones).to(device)
sbp_net = transfer_weights(sbp_net)

sbp_optimizer = optim.Adam(sbp_parameters,lr=sbp_learningrate, betas=[0.95,0.999])
sbp_scheduler = optim.lr_scheduler.StepLR(sbp_optimizer,step_size=250, gamma=0.1)

#epoch 45 is close to the last epoch of death. 
for epoch in range(0,45):
    SBP_net_train(epoch,sbp_net,optimizer=sbp_optimizer,criterion=nn.CrossEntropyLoss(),scheduler=sbp_scheduler)
    SBP_net_test(epoch,sbp_net)


Epoch: 0
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(136.4969, device='cuda:0'), tensor(136.6378, device='cuda:0'), tensor(135.9762, device='cuda:0'), tensor(135.9727, device='cuda:0'), tensor(135.8758, device='cuda:0'), tensor(136.9909, device='cuda:0'), tensor(136.4814, device='cuda:0')]

Epoch: 1
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(126.5271, device='cuda:0'), tensor(126.4764, device='cuda:0'), tensor(126.6217, device='cuda:0'), tensor(126.4917, device='cuda:0'), tensor(126.4697, device='cuda:0'), tensor(126.0857, device='cuda:0'), tensor(125.8198, device='cuda:0')]

Epoch: 2
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(116.2558, device='cuda:0'), tensor(116.3695, device='cuda:0'), tensor(116.0863, device='cuda:0'), tensor(116.0046, device='cuda:0'), tensor(116.0401, device='cuda:0'), tensor(116.5868, device='cuda:0'), tensor(117.3422, device='cuda:0')]

Epoch: 3
Saving..
Sparsity:  [0.0, 0.0, 0.

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(78.2654, device='cuda:0'), tensor(78.2457, device='cuda:0'), tensor(78.2459, device='cuda:0'), tensor(78.2703, device='cuda:0'), tensor(78.2636, device='cuda:0'), tensor(78.8594, device='cuda:0'), tensor(78.9793, device='cuda:0')]

Epoch: 8
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(72.4186, device='cuda:0'), tensor(72.5711, device='cuda:0'), tensor(72.5119, device='cuda:0'), tensor(72.5217, device='cuda:0'), tensor(72.5427, device='cuda:0'), tensor(72.9359, device='cuda:0'), tensor(73.0867, device='cuda:0')]

Epoch: 9
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(66.8745, device='cuda:0'), tensor(66.8835, device='cuda:0'), tensor(66.9159, device='cuda:0'), tensor(66.9118, device='cuda:0'), tensor(66.9136, device='cuda:0'), tensor(67.4447, device='cuda:0'), tensor(67.6218, device='cuda:0')]

Epoch: 10
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(41.8070, device='cuda:0'), tensor(41.8368, device='cuda:0'), tensor(41.8425, device='cuda:0'), tensor(41.8552, device='cuda:0'), tensor(41.8408, device='cuda:0'), tensor(42.2194, device='cuda:0'), tensor(42.3814, device='cuda:0')]

Epoch: 16
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(38.7008, device='cuda:0'), tensor(38.6870, device='cuda:0'), tensor(38.6970, device='cuda:0'), tensor(38.6814, device='cuda:0'), tensor(38.7104, device='cuda:0'), tensor(39.0567, device='cuda:0'), tensor(39.1804, device='cuda:0')]

Epoch: 17
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(35.7580, device='cuda:0'), tensor(35.7920, device='cuda:0'), tensor(35.7816, device='cuda:0'), tensor(35.7862, device='cuda:0'), tensor(35.7791, device='cuda:0'), tensor(36.1370, device='cuda:0'), tensor(36.2707, device='cuda:0')]

Epoch: 18
Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(22.3765, device='cuda:0'), tensor(22.3736, device='cuda:0'), tensor(22.3725, device='cuda:0'), tensor(22.3743, device='cuda:0'), tensor(22.3742, device='cuda:0'), tensor(22.7271, device='cuda:0'), tensor(22.8541, device='cuda:0')]

Epoch: 24

In [26]:
SBP_net_test(44,sbp_net)

Saving..
Sparsity:  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
SNRS:  [tensor(4.2766, device='cuda:0'), tensor(4.2796, device='cuda:0'), tensor(4.2981, device='cuda:0'), tensor(4.2887, device='cuda:0'), tensor(4.2776, device='cuda:0'), tensor(4.8969, device='cuda:0'), tensor(4.8912, device='cuda:0')]


49.54

In [29]:
with torch.cuda.device(-1):
    flops, params = get_model_complexity_info(best_net, (3, 32, 32), as_strings=True, print_per_layer_stat=False)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

Computational complexity:       0.04 GMac


In [28]:
from ptflops import get_model_complexity_info

#have to remove the sbp layers before calling this! B/c ptflops will not support the custom layers. 
with torch.cuda.device(-1):
    flops, params = get_model_complexity_info(sbp_net, (3, 32, 32), as_strings=True, print_per_layer_stat=False)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

Computational complexity:       0.0 GMac


## Try Warm-up learning Rate 

In [32]:
sbp_net.kl_weights = [1] * 10

In [37]:
best_sbp_acc = best_base_acc = 0 #reset best accuracy to save after running SBP

equal_weights = [1,1,1,1,1,1,1]
lr_x = [] #learning rate decay

def learning_rate_calc(optimizer, epoch):
    if epoch < 350:
        return 1e-5
    else: 
        return 1e-5 * (400-epoch)/(400-350)


sbp_learningrate = 1e-5
finetune_epoch = 300 ## that seems excessive
warmup_optimizer = optim.Adam(sbp_net.parameters(),lr=sbp_learningrate, betas=[0.95,0.999])
#equal_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_x)
warmup_scheduler = optim.lr_scheduler.MultiStepLR(warmup_optimizer,milestones=[10,20,50],gamma=0.1)

for epoch in range(0,100):
    SBP_net_train(epoch,sbp_net,optimizer=warmup_optimizer,criterion=nn.CrossEntropyLoss(),lr_adjust=None,scheduler=warmup_scheduler)
    SBP_net_test(epoch,sbp_net)


Epoch: 0
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8206, device='cuda:0'), tensor(0.7660, device='cuda:0'), tensor(0.8003, device='cuda:0'), tensor(0.8122, device='cuda:0'), tensor(0.7609, device='cuda:0'), tensor(1.6786, device='cuda:0'), tensor(1.3710, device='cuda:0')]

Epoch: 1
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8193, device='cuda:0'), tensor(0.7646, device='cuda:0'), tensor(0.7988, device='cuda:0'), tensor(0.8108, device='cuda:0'), tensor(0.7595, device='cuda:0'), tensor(1.6753, device='cuda:0'), tensor(1.3681, device='cuda:0')]

Epoch: 2
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8180, device='cuda:0'), tensor(0.7632, device='cuda:0'), tensor(0.7974, device='cuda:0'), tensor(0.8093, device='cuda:0'), tensor(0.7580, device='cuda:0'), tensor(1.6720, device='cuda:0'), tensor(1.3653, device='cuda:0')]

Epoch: 3
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8168, device='cuda:0'), tensor(0.7619, dev

Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8119, device='cuda:0'), tensor(0.7566, device='cuda:0'), tensor(0.7905, device='cuda:0'), tensor(0.8021, device='cuda:0'), tensor(0.7513, device='cuda:0'), tensor(1.6557, device='cuda:0'), tensor(1.3512, device='cuda:0')]

Epoch: 8
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8108, device='cuda:0'), tensor(0.7553, device='cuda:0'), tensor(0.7892, device='cuda:0'), tensor(0.8008, device='cuda:0'), tensor(0.7499, device='cuda:0'), tensor(1.6525, device='cuda:0'), tensor(1.3484, device='cuda:0')]

Epoch: 9
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8096, device='cuda:0'), tensor(0.7541, device='cuda:0'), tensor(0.7879, device='cuda:0'), tensor(0.7994, device='cuda:0'), tensor(0.7486, device='cuda:0'), tensor(1.6492, device='cuda:0'), tensor(1.3456, device='cuda:0')]

Epoch: 10
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8095, device='cuda:0'), tensor(0.7539, device='cuda

Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8089, device='cuda:0'), tensor(0.7533, device='cuda:0'), tensor(0.7871, device='cuda:0'), tensor(0.7986, device='cuda:0'), tensor(0.7479, device='cuda:0'), tensor(1.6473, device='cuda:0'), tensor(1.3439, device='cuda:0')]

Epoch: 16
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8088, device='cuda:0'), tensor(0.7532, device='cuda:0'), tensor(0.7870, device='cuda:0'), tensor(0.7984, device='cuda:0'), tensor(0.7477, device='cuda:0'), tensor(1.6469, device='cuda:0'), tensor(1.3436, device='cuda:0')]

Epoch: 17
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8087, device='cuda:0'), tensor(0.7531, device='cuda:0'), tensor(0.7868, device='cuda:0'), tensor(0.7983, device='cuda:0'), tensor(0.7476, device='cuda:0'), tensor(1.6466, device='cuda:0'), tensor(1.3433, device='cuda:0')]

Epoch: 18
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8086, device='cuda:0'), tensor(0.7529, device='cu

Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8084, device='cuda:0'), tensor(0.7528, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7980, device='cuda:0'), tensor(0.7473, device='cuda:0'), tensor(1.6459, device='cuda:0'), tensor(1.3427, device='cuda:0')]

Epoch: 24
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8084, device='cuda:0'), tensor(0.7528, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7980, device='cuda:0'), tensor(0.7473, device='cuda:0'), tensor(1.6458, device='cuda:0'), tensor(1.3427, device='cuda:0')]

Epoch: 25
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8084, device='cuda:0'), tensor(0.7528, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7980, device='cuda:0'), tensor(0.7473, device='cuda:0'), tensor(1.6458, device='cuda:0'), tensor(1.3426, device='cuda:0')]

Epoch: 26
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8084, device='cuda:0'), tensor(0.7528, device='cu

Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8084, device='cuda:0'), tensor(0.7527, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7979, device='cuda:0'), tensor(0.7473, device='cuda:0'), tensor(1.6456, device='cuda:0'), tensor(1.3425, device='cuda:0')]

Epoch: 31
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7527, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7979, device='cuda:0'), tensor(0.7473, device='cuda:0'), tensor(1.6456, device='cuda:0'), tensor(1.3425, device='cuda:0')]

Epoch: 32
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7527, device='cuda:0'), tensor(0.7865, device='cuda:0'), tensor(0.7979, device='cuda:0'), tensor(0.7472, device='cuda:0'), tensor(1.6456, device='cuda:0'), tensor(1.3425, device='cuda:0')]

Epoch: 33
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7527, device='cu

Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7526, device='cuda:0'), tensor(0.7864, device='cuda:0'), tensor(0.7978, device='cuda:0'), tensor(0.7472, device='cuda:0'), tensor(1.6454, device='cuda:0'), tensor(1.3423, device='cuda:0')]

Epoch: 39
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7526, device='cuda:0'), tensor(0.7864, device='cuda:0'), tensor(0.7978, device='cuda:0'), tensor(0.7472, device='cuda:0'), tensor(1.6454, device='cuda:0'), tensor(1.3423, device='cuda:0')]

Epoch: 40
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7526, device='cuda:0'), tensor(0.7864, device='cuda:0'), tensor(0.7978, device='cuda:0'), tensor(0.7472, device='cuda:0'), tensor(1.6453, device='cuda:0'), tensor(1.3423, device='cuda:0')]

Epoch: 41
Saving..
Sparsity:  [1.0, 1.0, 1.0, 1.0, 1.0]
SNRS:  [tensor(0.8083, device='cuda:0'), tensor(0.7526, device='cu



KeyboardInterrupt: 

In [None]:

for epoch in range(300,400):
    SBP_net_train(epoch,sbp_equal_net,optimizer=equal_optimizer,criterion=nn.CrossEntropyLoss(),lr_adjust=learning_rate_calc)
    SBP_net_test(epoch,sbp_equal_net)

## Restore and Train

In [18]:
best_sbp_acc = best_base_acc = 0 #reset best accuracy to save after running SBP
equal_weights = [1,1,1,1,1,1,1]
sbp_equal_net = SBPConv_AlexNet(cfg,kl_weights=equal_weights).to(device)

In [19]:
net_dict = torch.load('./checkpoint/sbp_ckpt.pth')
sbp_equal_net.load_state_dict(net_dict['net'])
best_acc = net_dict['best_acc']

RuntimeError: Error(s) in loading state_dict for SBPConv_AlexNet:
	Missing key(s) in state_dict: "block1.conv1.log_sigma", "block1.conv1.mu", "block1.conv1.weight", "block1.bn1.weight", "block1.bn1.bias", "block1.bn1.running_mean", "block1.bn1.running_var", "block2.conv1.log_sigma", "block2.conv1.mu", "block2.conv1.weight", "block2.bn1.weight", "block2.bn1.bias", "block2.bn1.running_mean", "block2.bn1.running_var", "block3.conv1.log_sigma", "block3.conv1.mu", "block3.conv1.weight", "block3.bn1.weight", "block3.bn1.bias", "block3.bn1.running_mean", "block3.bn1.running_var", "block4.conv1.log_sigma", "block4.conv1.mu", "block4.conv1.weight", "block4.bn1.weight", "block4.bn1.bias", "block4.bn1.running_mean", "block4.bn1.running_var", "block5.conv1.log_sigma", "block5.conv1.mu", "block5.conv1.weight", "block5.bn1.weight", "block5.bn1.bias", "block5.bn1.running_mean", "block5.bn1.running_var", "lsbp1.weight", "lsbp1.bias", "lsbp1.log_sigma", "lsbp1.mu", "lsbp2.weight", "lsbp2.bias", "lsbp2.log_sigma", "lsbp2.mu", "last.weight", "last.bias". 
	Unexpected key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.2.running_mean", "features.2.running_var", "features.2.num_batches_tracked", "features.4.weight", "features.4.bias", "features.6.weight", "features.6.bias", "features.6.running_mean", "features.6.running_var", "features.6.num_batches_tracked", "features.8.weight", "features.8.bias", "features.10.weight", "features.10.bias", "features.10.running_mean", "features.10.running_var", "features.10.num_batches_tracked", "features.11.weight", "features.11.bias", "features.13.weight", "features.13.bias", "features.13.running_mean", "features.13.running_var", "features.13.num_batches_tracked", "features.14.weight", "features.14.bias", "features.16.weight", "features.16.bias", "features.16.running_mean", "features.16.running_var", "features.16.num_batches_tracked", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias", "classifier.6.weight", "classifier.6.bias". 

In [31]:
net_dict['net'].keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.2.running_mean', 'features.2.running_var', 'features.2.num_batches_tracked', 'features.4.weight', 'features.4.bias', 'features.6.weight', 'features.6.bias', 'features.6.running_mean', 'features.6.running_var', 'features.6.num_batches_tracked', 'features.8.weight', 'features.8.bias', 'features.10.weight', 'features.10.bias', 'features.10.running_mean', 'features.10.running_var', 'features.10.num_batches_tracked', 'features.11.weight', 'features.11.bias', 'features.13.weight', 'features.13.bias', 'features.13.running_mean', 'features.13.running_var', 'features.13.num_batches_tracked', 'features.14.weight', 'features.14.bias', 'features.16.weight', 'features.16.bias', 'features.16.running_mean', 'features.16.running_var', 'features.16.num_batches_tracked', 'classifier.1.weight', 'classifier.1.bias', 'classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'])

In [27]:
list(sbp_equal_net.parameters())

[Parameter containing:
 tensor([-5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5., -5.,
         -5., -5., -5., -5., -5., -5., -5., -5.], device='cuda:0',
        requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[[[-0.1025, -0.4885,  0.4237],
           [-0.0737,  0.5720,  0.1400],
           [ 0.4433,  0.4696,  0.2519]],
 
          [[ 0.5657, -0.2519,  0.2215],
           [-0.1572,  0.09

In [32]:
for epoch in range(0,300):
    SBP_net_train(epoch,sbp_equal_net,optimizer=equal_optimizer,criterion=nn.CrossEntropyLoss())
    SBP_net_test(epoch,sbp_equal_net)

In [33]:
lr_x

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [40]:
for epoch in range(0,300):
    SBP_net_train(epoch,sbp_net,optimizer=optimizer,criterion=nn.CrossEntropyLoss())
    SBP_net_test(epoch,sbp_net)


Epoch: 30
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(43.5962, device='cuda:0'), tensor(43.5865, device='cuda:0'), tensor(43.5885, device='cuda:0'), tensor(43.5875, device='cuda:0'), tensor(43.5872, device='cuda:0')]

Epoch: 31
Saving..
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(41.9232, device='cuda:0'), tensor(41.9329, device='cuda:0'), tensor(41.9390, device='cuda:0'), tensor(41.9393, device='cuda:0'), tensor(41.9408, device='cuda:0')]

Epoch: 32
Saving..
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(40.3647, device='cuda:0'), tensor(40.3841, device='cuda:0'), tensor(40.3931, device='cuda:0'), tensor(40.3912, device='cuda:0'), tensor(40.3956, device='cuda:0')]

Epoch: 33
Saving..
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(38.8237, device='cuda:0'), tensor(38.8300, device='cuda:0'), tensor(38.8299, device='cuda:0'), tensor(38.8293, device='cuda:0'), tensor(38.8302, device='cuda:0')]

Epoch: 34
[0.0, 0.0, 0.0, 0.0, 0.0]
[tensor(37.3455, device='cuda:0'), tensor(37.3460, device='cuda:0'), tensor(37.3422, device='cuda:0'), t

### Resume Training

In [23]:
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/ckpt.pth')
sbp_net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['best_acc']
start_epoch = 20 # set manually from previous training. 

==> Resuming from checkpoint..


In [29]:
x = torch.randn(2,3,32,32)
sbp_net.eval()
sbp_net(x)
print(sbp_net.features[0].sbp.mask)
print(sbp_net.features[0].sbp.multiplicator)

SNR:  tensor(7289.9297)
MASK:  tensor(64.)
SNR:  tensor(21809.4375)
MASK:  tensor(192.)
SNR:  tensor(43668.1445)
MASK:  tensor(384.)
SNR:  tensor(29090.8965)
MASK:  tensor(256.)
SNR:  tensor(29085.9785)
MASK:  tensor(256.)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([[[0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         ...,
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588]],

        [[0.9588, 0.9588, 0.9588,  ..., 0.9588, 0.9588, 0.9588],
         [0.9588,

In [45]:
print(best_acc)

50.39


In [None]:
for epoch in range(20,45):
    SBP_net_train(epoch,sbp_net)
    SBP_net_test(epoch,sbp_net)


Epoch: 20

Epoch: 21

Epoch: 22

Epoch: 23

Epoch: 24
Saving..

Epoch: 25

Epoch: 26

Epoch: 27

Epoch: 28

Epoch: 29

Epoch: 30

Epoch: 31

Epoch: 32

Epoch: 33
Saving..

Epoch: 34

Epoch: 35

Epoch: 36

Epoch: 37

Epoch: 38
 [=>.......................]  Step: 261ms | Tot: 5s400ms | Loss: 7.590 | Acc: 63.594% (1628/256 20/391 

### Get Mask & Prune Network

[SBP_Block(
   (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 SBP_Block(
   (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False),
 SBP_Block(
   (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (rel): ReLU(inplace=True)
 ),
 SBP_Block(
   (conv1): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   (sbp): SBP_layer()
   (bn1): BatchNorm2d(2

In [70]:
def cal_importance_0(net, l_id, num_stop=100):
    layer_weights = net.features[0].sbp.weight.data
    imp_corr_bn = layer_weights.abs().sum(dim=(1,2,3))        
    neuron_order = [np.linspace(0, imp_corr_bn.shape[0]-1, imp_corr_bn.shape[0]), imp_corr_bn]
    return neuron_order

In [1]:
def snr_truncated_log_normal(mu, sigma, a, b):
    alpha = (a - mu)/sigma
    beta = (b - mu)/sigma
    z = phi(beta) - phi(alpha)
    ratio = erfcx((sigma-beta)/(2 ** 0.5))*torch.exp((b-mu)-beta**2/2.0)
    ratio = ratio - erfcx((sigma-alpha)/2 ** 0.5)*torch.exp((a-mu)-alpha**2/2.0)
    denominator = 2*z*erfcx((2.0*sigma-beta)/2 ** 0.5)*torch.exp(2.0*(b-mu)-beta**2/2.0)
    denominator = denominator - 2*z*erfcx((2.0*sigma-alpha)/(2 ** 0.5))*torch.exp(2.0*(a-mu)-alpha**2/2.0)
    denominator = denominator - ratio**2
    ratio = ratio/torch.sqrt(1e-8 + denominator)
    print(denominator)
    return ratio

In [2]:
## get the sbp layer mask
def get_mask(mu,sigma,min_log=-20, max_log=0):
    multiplicator = SBP_utils.mean_truncated_log_normal_reduced(mu, sigma, min_log, max_log)
    print("multiplicator: ",multiplicator)
    snr = snr_truncated_log_normal(mu, sigma, min_log, max_log)
    print("snr: ", snr)
    mask = snr
    
    mask[snr <= 1.0] = 0.0
    mask[snr > 1.0] = 1.
    
    print("mask: ",mask)
    multiplicator = multiplicator * mask
    
    return mask, multiplicator

In [107]:
sbp_net.features[0].sbp.log_sigma.detach()

tensor([-4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441])

In [3]:
mask_list = []
multiplicator_list = []
for i in sbp_net.features:
    if not (isinstance(i,nn.MaxPool2d)):
        #print(i)
        mu = i.sbp.mu.detach()
        sigma = i.sbp.log_sigma.detach()
        mask,multiplicator = get_mask(mu,sigma)
        mask_list.append(mask)
        multiplicator_list.append(multiplicator)

NameError: name 'sbp_net' is not defined

In [118]:
x = torch.randn(2,3,32,32)
sbp_net.eval()
y = sbp_net.forward(x)
print(y,y.size())


tensor([[ -6.5035,  -8.3538,  -6.6634, -10.2934, -17.9784,  -5.0302, -11.5280,
         -10.6385,  -1.8071,  -4.3939,   4.3780, -12.2435, -13.1021, -10.5908,
          -7.5355, -12.7850,   2.3540, -12.5417,  -9.5258, -16.9257,  -8.0328,
         -16.6011,   4.9747,  -3.3861, -12.9907,  -4.4500,  -5.9589, -14.0933,
          -4.2610, -14.5732, -16.0110, -13.5321,  -3.9859,  -2.0124, -13.0010,
          -9.6966,  -9.1011,  -3.9229, -16.7927,   8.4844,  -8.6888, -13.6158,
          -8.8288,  -9.8294, -15.8931,  -3.8235, -12.4232,  -6.1532, -10.6135,
         -11.1450,  -8.6385,  -7.6694, -15.9593,  -3.3235,  -4.4470, -18.4937,
         -20.6754,  -0.6448, -18.1565,  -9.3899, -13.1982,   7.3392,  -3.7231,
          -9.8469, -13.8983,  -7.2742, -11.2102, -11.6250, -19.0181, -17.3108,
           3.8214, -17.8009, -10.2750, -17.4043, -15.8597, -20.3146, -11.4058,
         -14.2033,   1.2470,  -9.8833, -10.9950,  -7.2665,  -4.5578,  -1.9062,
          -7.2311, -17.2258,  -2.7540,  -8.5210,  -3

In [137]:
sbp_net.eval()
out_1 = sbp_net.features[0].conv1(x)
print(out.shape)
out_2 = sbp_net.features[0].sbp(out_1)
mu = sbp_net.features[0].sbp.mu.detach()
sigma = sbp_net.features[0].sbp.log_sigma.detach()
print(sbp_net.features[0].sbp.mu)
print(sbp_net.features[0].sbp.log_sigma)
print(get_mask(mu,sigma))
#print(out_1)

torch.Size([2, 64, 16, 16])
Parameter containing:
tensor([-0.0501, -0.0499, -0.0500, -0.0501, -0.0501, -0.0500, -0.0497, -0.0501,
        -0.0500, -0.0500, -0.0502, -0.0500, -0.0501, -0.0500, -0.0498, -0.0500,
        -0.0501, -0.0509, -0.0499, -0.0500, -0.0500, -0.0500, -0.0501, -0.0500,
        -0.0501, -0.0500, -0.0498, -0.0500, -0.0500, -0.0500, -0.0500, -0.0500,
        -0.0500, -0.0499, -0.0501, -0.0500, -0.0501, -0.0500, -0.0501, -0.0499,
        -0.0500, -0.0500, -0.0497, -0.0501, -0.0500, -0.0501, -0.0495, -0.0500,
        -0.0501, -0.0499, -0.0500, -0.0500, -0.0500, -0.0501, -0.0500, -0.0505,
        -0.0500, -0.0500, -0.0500, -0.0498, -0.0502, -0.0500, -0.0499, -0.0513],
       requires_grad=True)
Parameter containing:
tensor([-4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441, -4.6441,
        -4.6441, -4.

tensor([[[[-4.3465e-01, -4.3465e-01, -1.7600e-01,  ..., -4.3465e-01,
            1.7735e+00, -4.3465e-01],
          [ 1.3551e+00,  5.0437e-01, -4.3465e-01,  ..., -4.3465e-01,
           -3.1233e-01,  1.1398e+00],
          [-4.3465e-01, -4.3465e-01, -4.3465e-01,  ...,  5.5210e-01,
           -4.3465e-01, -4.3465e-01],
          ...,
          [ 2.9487e+00, -4.3465e-01, -4.3465e-01,  ..., -4.3465e-01,
            5.1174e-01, -4.3465e-01],
          [-4.3465e-01, -3.3155e-01, -4.3465e-01,  ..., -4.3465e-01,
           -4.3465e-01, -4.3465e-01],
          [-4.3465e-01,  1.4996e+00,  3.1659e+00,  ..., -4.3465e-01,
            9.4603e-01,  3.1426e+00]],

         [[-5.0862e-01,  1.2093e+00, -5.0862e-01,  ...,  9.4067e-02,
           -5.0862e-01, -5.0862e-01],
          [-5.0862e-01,  3.7695e-01,  4.3916e-01,  ..., -5.0862e-01,
            3.3940e+00,  4.3031e-01],
          [ 2.3318e+00,  1.4081e+00,  4.7294e+00,  ..., -5.0862e-01,
           -5.0862e-01, -5.0862e-01],
          ...,
     

In [71]:
cal_importance_0(sbp_net,l_id=1,num_stop=100)

AttributeError: 'SBP_layer' object has no attribute 'weight'

### PTFLOPS

In [None]:
pt

In [29]:
from ptflops import get_model_complexity_info

with torch.cuda.device(-1):
    flops, params = get_model_complexity_info(sbp_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

SBPConv_AlexNet(
  20.516 M, 100.000% Params, 0.001 GMac, 100.000% MACs, 
  (block1): SBP_ConvBlock(
    0.002 M, 0.010% Params, 0.0 GMac, 8.978% MACs, 
    (conv1): Conv2d_SBP(0.002 M, 0.009% Params, 0.0 GMac, 0.000% MACs, )
    (bn1): BatchNorm2d(0.0 M, 0.001% Params, 0.0 GMac, 5.986% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 2.993% MACs, inplace=True)
  )
  (mp1): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 2.993% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_ConvBlock(
    0.112 M, 0.544% Params, 0.0 GMac, 5.155% MACs, 
    (conv1): Conv2d_SBP(0.111 M, 0.542% Params, 0.0 GMac, 0.000% MACs, )
    (bn1): BatchNorm2d(0.0 M, 0.002% Params, 0.0 GMac, 3.437% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 1.718% MACs, inplace=True)
  )
  (mp2): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 1.718% MACs

In [39]:
with torch.cuda.device(-1):
    #flops, params = get_model_complexity_info(sbp_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    flops, params = get_model_complexity_info(best_net, (3, 32, 32), as_strings=True, print_per_layer_stat=True)

AlexNet(
  20.498 M, 100.000% Params, 0.044 GMac, 100.000% MACs, 
  (features): Sequential(
    2.254 M, 10.996% Params, 0.025 GMac, 58.072% MACs, 
    (0): Conv2d(0.002 M, 0.009% Params, 0.0 GMac, 1.054% MACs, 3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, inplace=True)
    (2): BatchNorm2d(0.0 M, 0.001% Params, 0.0 GMac, 0.075% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.038% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(0.111 M, 0.540% Params, 0.005 GMac, 12.476% MACs, 64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.022% MACs, inplace=True)
    (6): BatchNorm2d(0.0 M, 0.002% Params, 0.0 GMac, 0.043% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.022% MAC

In [None]:
sbp_net.

### Flop weighted importance

In [28]:
### importance for SBP Net

l_imp = {}

for conv_ind in [0, 2, 4, 5, 6]:
    l_imp.update({conv_ind: net.features[conv_ind].bn1.bias.shape[0]})
## do not need to update the classifer indices b/c no blocks 
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [29]:
optimizer = optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [30]:
net = SBP_AlexNet(cfg)
net_ortho = SBP_AlexNet(cfg)

In [31]:
for epoch in range(30):
    SBP_net_train(epoch,net)
    SBP_net_test(epoch,net)
    


Epoch: 0
Saving..

Epoch: 1

Epoch: 2
Saving..

Epoch: 3

Epoch: 4
 [=>.......................]  Step: 279ms | Tot: 4s887ms | Loss: 11.956 | Acc: 0.873% (19/217 17/391 

KeyboardInterrupt: 

In [24]:
for epoch in range(30):
    SBP_net_train_ortho(epoch,net_ortho)
    w_diag(net_ortho)
    SBP_net_test_ortho(epoch,net_ortho)


Epoch: 0

Epoch: 0
angle_cost:  283.252374375


TypeError: net_test() takes 1 positional argument but 2 were given

# Inner product training

In [16]:
alex_net = AlexNet(cfg)
print(alex_net)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [18]:
alex_net.features[2].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)

In [None]:
net_ortho = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_ortho.load_state_dict(net_dict['net'])
best_acc_ortho = net_dict['best_acc']

In [25]:
### importance for SBP Net

l_imp = {}

for conv_ind in [0, 2, 4, 5, 6]:
    l_imp.update({conv_ind: net.features[conv_ind].bn1.bias.shape[0]})
## do not need to update the classifer indices b/c no blocks 
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
### normal AlexNet
l_imp = {}

for conv_ind in [2, 6, 10, 13, 16]:
    l_imp.update({conv_ind: net.features[conv_ind].bias.shape[0]})
    
for lin_ind in [1, 4]:
    l_imp.update({lin_ind: net.classifier[lin_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [18]:
def net_test_ortho(epoch):
    global best_acc_ortho
    net_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_acc_ortho:
        print('Saving..')
        state = {
            'net': net_ortho.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('ortho_checkpoint'):
            os.mkdir('ortho_checkpoint')
        torch.save(state, './ortho_checkpoint/ckpt.pth')
        best_acc_ortho = acc

In [20]:

for conv_ind in [6, 10, 13, 16]:
    print(alex_net.features[conv_ind-2])

Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [19]:
def SBP_net_train_ortho(epoch,net_ortho):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, kl = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].conv1.weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].conv1.bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [2, 4, 5, 6]:
            w_mat = net_ortho.features[conv_ind].conv1.weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind].conv1.bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc + kl #from the sparsity inducer
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [None]:
def net_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[2])*(angle_mat).norm(1) #.norm().pow(2))

        ### Conv_ind != 0 ###
        for conv_ind in [6, 10, 13, 16]:
            w_mat = net_ortho.features[conv_ind-2].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind-2].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        ### lin_ind = 1 ###        
        w_mat = net_ortho.classifier[1].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[1].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[1])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### lin_ind = 4 ###        
        w_mat = net_ortho.classifier[4].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.classifier[4].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))            
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(params.shape[0]).to(device)
        L_angle += (l_imp[4])*(angle_mat).norm(1) #.norm().pow(2))        
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [22]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()

In [15]:
SBP_net_ortho = SBP_AlexNet(cfg).to(device)

In [23]:
optimizer = optim.Adam(SBP_net_ortho.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [17]:
SBP_net_ortho

SBP_AlexNet(
  (block1): SBP_Block(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_Block(
    (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_Block(
    (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_Block(
    (conv1): Conv2d(384, 256, kernel_size=

In [25]:
for epoch in range(2):
    #SBP_net_train(epoch,SBP_net_ortho)
    SBP_net_train_ortho(epoch,SBP_net_ortho)
    #net_test_ortho(epoch)
    #w_diag(net_ortho)

NameError: name 'SBP_net_ortho' is not defined

In [None]:
# PATH = './w_decorr/base_params/wnet_base_2.pth'
# torch.save(net.state_dict(), PATH)

# Importance analysis

### Layer index

In [35]:
alex_net

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): Conv2d(2

In [37]:
SBP_net_ortho

SBP_AlexNet(
  (block1): SBP_Block(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): SBP_Block(
    (conv1): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (mp2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): SBP_Block(
    (conv1): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (sbp): SBP_layer()
    (bn1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (rel): ReLU(inplace=True)
  )
  (block4): SBP_Block(
    (conv1): Conv2d(384, 256, kernel_size=

In [39]:
l_index = 4
layer_id = 'bn'

### Correlated Net

In [None]:
net = AlexNet(cfg).to(device)
net_dict = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(net_dict['net'])
net = net.eval()

In [40]:
#changed for SBP_Aexnet
weight_base = net.features[l_index].bn1.weight.data.clone().detach()
bias_base = net.features[l_index].bn1.bias.data.clone().detach()

In [41]:
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs, kl = net(inputs)
        loss = criterion(outputs, labels) + kl
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [46]:
## changed for SBP_Alexnet

loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in tqdm(range(weight_base.shape[0])): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].bn1.weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bn1.bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs, kl = net(inputs)

        loss = (criterion(outputs, labels)) + kl

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].bn1.weight.data = weight_base.clone().detach()
    net.features[l_index].bn1.bias.data = bias_base.clone().detach()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [44]:
#normal alex net w/ no KL loss
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [None]:
### SBP_Alexnet
### Change bn1 to access other layers in the block.
loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in tqdm(range(weight_base.shape[0])): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].bn1.weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bn1.bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs, kl = net(inputs)

        loss = (criterion(outputs, labels)) + kl

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].bn1.weight.data = weight_base.clone().detach()
    net.features[l_index].bn1.bias.data = bias_base.clone().detach()

In [None]:
### normal AlexNet

loss_mat_corr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    num_stop = 0
    print(n_index)
    running_loss = 0.0

    net.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    net.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)

        loss = (criterion(outputs, labels))

        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_corr[n_index] = running_loss**2
    
    net.features[l_index].weight.data = weight_base.clone().detach()
    net.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_corr, './w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')
loss_mat_corr = torch.load('./w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:

### SBP AlexNet
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in tqdm(range(n_epochs)):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs, kl = net(inputs) + kl
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].bn1.weight.grad)*(net.features[l_index].bn1.weight.data)) + ((net.features[l_index].bn1.bias.grad)*(net.features[l_index].bn1.bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

In [None]:

### Normal AlexNet
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

### Decorrelated net

In [None]:
net_decorr = AlexNet(cfg).to(device)
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_decorr.load_state_dict(net_dict['net'])
net_decorr = net_decorr.eval() 

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
num_stop = 0
loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):        
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_decorr = loss_base_decorr**2

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_mat_decorr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    num_stop = 0
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        net_decorr.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        net_decorr.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
        outputs = net_decorr(inputs)
        
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
            
    loss_mat_decorr[n_index] = running_loss**2
    
    net_decorr.features[l_index].weight.data = weight_base.clone().detach()
    net_decorr.features[l_index].bias.data = bias_base.clone().detach()

In [None]:
# torch.save(loss_mat_decorr, './w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')
loss_mat_decorr = torch.load('./w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).abs().pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

# Graphs

In [None]:
figure(figsize=(20,5))
s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Correlated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/corr/graphs/"+str(l_index)+".png")

In [None]:
figure(figsize=(20,5))
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
plt.plot(s/s.max(), label="Estimated importance")
plt.title("Decorrelated (Taylor FO) for "+str(l_index))
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.xlabel("Neuron index")
plt.ylabel("Normalized importance")
plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
plt.legend()
plt.savefig("./w_decorr/loss_mats/decorr/graphs/"+str(l_index)+".png")

In [None]:
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_decorr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())
ortho_rms = ((loss_diff - s)**2).sum()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_corr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_corr - loss_base_corr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())

base_rms = ((loss_diff - s)**2).sum()

In [None]:
(ortho_rms, base_rms)

In [None]:
# rms_ortho = np.sqrt(np.array(rms_ortho) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))
# rms_base = np.sqrt(np.array(rms_base) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))

In [None]:
plt.figure(figsize=(10,5))
plt.bar(np.linspace(0,30,10)-0.5, rms_ortho, label="Decorrelated network")
plt.bar(np.linspace(0,30,10)+0.5, rms_base, label="Correlated network")
plt.xlabel("Layer ID")
plt.ylabel("RMS")
plt.legend()
plt.savefig("./w_decorr/loss_mats/rms.png")

## Subplots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20,5))

s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
axes[0].plot(s/s.max(), label="Estimated importance")
axes[0].set_title("Decorrelated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
axes[0].set_xlabel("Neuron index")
axes[0].set_ylabel("Normalized importance")
axes[0].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[0].legend()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
axes[1].plot(s/s.max(), label="Estimated importance")
axes[1].set_title("Correlated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
axes[1].set_xlabel("Neuron index")
axes[1].set_ylabel("Normalized importance")
axes[1].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[1].legend()

plt.savefig("./w_decorr/loss_mats/subplots/"+str(l_index)+".png")

# Other metrics

### Net-Slim Train

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### L2 based pruning Train

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_imp_corr.numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
w_imp_decorr = (w_imp_decorr - w_imp_decorr.min())
w_imp_decorr = w_imp_decorr/w_imp_decorr.max()
np.corrcoef(w_imp_decorr.numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

#### Importance plots Netslim Train

In [None]:
figure(figsize=(20,5))

s = scale_corr.cpu().sort()[0].cpu().numpy()
order = scale_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Net-Slim)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = scale_decorr.cpu().sort()[0].cpu().numpy()
order = scale_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Net-Slim)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

#### Importance plots L2 train

In [None]:
figure(figsize=(20,5))
s = w_imp_corr.sort()[0].cpu().numpy()
order = w_imp_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (L2)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = w_imp_decorr.sort()[0].cpu().numpy()
order = w_imp_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (L2)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())