In [None]:
t_corr = 0.0002101660919189453
t_decorr = 0.0002142816114425659

In [None]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import torch.optim as optim
from imp_baselines import *

In [None]:
from ptflops import get_model_complexity_info

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

#classes = ('plane', 'car', 'bird', 'cat',
#          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
#         if(cfg[vgg_name] == 'VGG13_p'):
#             self.classifier = nn.Linear(256, 100)
#         else:
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
class VGG_p(nn.Module):
    def __init__(self, vgg_name, cfg):
        super(VGG_p, self).__init__()
        self.features = self._make_layers(cfg)
        self.classifier = nn.Linear(cfg[-2], 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
net_corr = VGG('VGG13').to(device)
net_decorr = VGG('VGG13').to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
PATH_corr = './cifar100_net.pth'
PATH_decorr = './w_decorr/base_params/wnet_base_2.pth'
# PATH_decorr = './tempnet1.pth'

net_corr.load_state_dict(torch.load(PATH_corr))
net_decorr.load_state_dict(torch.load(PATH_decorr))

In [None]:
def cal_acc(net_acc):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net_acc(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return (100 * correct / total)

In [None]:
def cal_importance(net, l_index):
    bias_base = net.features[l_index].bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)

    imp_norm = imp_corr_bn
    
    neuron_order = [np.linspace(0, imp_norm.shape[0]-1, imp_norm.shape[0]), imp_norm]
    
    return neuron_order

In [None]:
import time

In [None]:
def cal_time(net_acc):
    testsamp = torch.rand(1,3,32,32).to(device)
    
    for i in range(5):
        net_acc(testsamp)    
    
    t_s = time.time()
    for i in range(5):
        net_acc(testsamp)
        t_end += time.time() - t_s
    
    return (t_end / 5)

In [None]:
def cal_mass(net, l_index):
    num_iter = 0
    r = 0.0
    with torch.no_grad():
        for i, data in enumerate(trainloader, 0):
            num_iter += 1
            if(num_iter == 40):
                break
            inputs, labels = data[0].to(device), data[1].to(device)
            L_self = 0.0
            L_mat = 0.0

            for epoch_num in range(1):
                out_features = net.features[0:l_index](inputs)
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
                X_t = torch.div(X_t, X_t.norm(dim=2).reshape(X_t.shape[0],X_t.shape[1],1) + 1e-10)
                cov_mat = torch.matmul(X_t, X_t.permute(0,2,1))
                L_mat = cov_mat.norm().pow(2)
                
                ident = (1 - torch.eye(out_features.shape[1])).to(device)
                cov_mat = cov_mat*ident
                L_self = cov_mat.norm().pow(2)
                
                r += 1 - L_self/L_mat

            del L_self, L_mat, out_features
            torch.cuda.empty_cache()
        return r/num_iter

### Ground importance

In [None]:
# loss_base_corr = 0
# num_stop = 0

# for epoch in range(1):
#     for i, data in enumerate(trainloader, 0):
#         inputs, labels = data[0].to(device), data[1].to(device)
#         outputs = net_corr(inputs)
#         loss = criterion(outputs, labels)
#         loss_base_corr += loss.item()
#         num_stop += labels.shape[0]
#         if(num_stop > 5000):
#             break

In [None]:
# imp_order_ground = {}
# for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
#     loss_mat = torch.load("./w_decorr/loss_corr_bn_train_"+str(l_index)+".pt")
#     imp_order_ground.update({l_index: ((loss_mat - loss_base_corr).abs().sort()[1])})#.sort()[0]})

In [None]:
# loss_base_decorr = 0
# num_stop = 0

# for epoch in range(1):
#     for i, data in enumerate(trainloader, 0):
#         inputs, labels = data[0].to(device), data[1].to(device)
#         outputs = net_decorr(inputs)
#         loss = criterion(outputs, labels)
#         loss_base_decorr += loss.item()
#         num_stop += labels.shape[0]
#         if(num_stop > 5000):
#             break

In [None]:
# imp_order_ground_decorr = {}
# for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
#     loss_mat = torch.load("./w_decorr/loss_bn_train_"+str(l_index)+".pt")
#     imp_order_ground_decorr.update({l_index: ((loss_mat - loss_base_decorr).abs().sort()[1])})#.sort()[0]})

### TFO importance

In [None]:
import pickle

In [None]:
with open("./w_decorr/base_params/tfo_corr.pkl", 'rb') as f:
    imp_order_corr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_corr.parameters(), lr=0, weight_decay=0)
imp_order_corr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_corr, l_index)
    imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
# with open("./w_decorr/base_params/tfo_corr.pkl", 'wb') as f:
#     pickle.dump(imp_order_corr, f)

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
imp_order_decorr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_decorr, l_index)
    imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
# with open("./w_decorr/base_params/tfo_w_decorr_temp.pkl", 'wb') as f:
#     pickle.dump(imp_order_decorr, f)

In [None]:
with open("./w_decorr/base_params/tfo_w_decorr_temp.pkl", 'rb') as f:
    imp_order_decorr = pickle.load(f)

In [None]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

### stuff

In [None]:
temp_corr = imp_order_corr.copy()
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    temp_corr[temp_corr[:,0]==l_index, 2] = torch.load("./w_decorr/loss_mats/corr/"+str(l_index)+"/loss_corr_bn_train_"+str(l_index)+".pt")

In [None]:
temp_decorr = imp_order_decorr.copy()
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    temp_decorr[temp_decorr[:,0]==l_index, 2] = torch.load("./w_decorr/loss_mats/decorr/"+str(l_index)+"/loss_decorr_bn_train_"+str(l_index)+".pt")

In [None]:
figure(figsize=(20,5))
s = torch.tensor(imp_order_corr[:,2])
order = s.sort()[1]
vals_corr = s.sort()[0]
plt.plot(vals_corr/vals_corr.max())
plt.title("Correlated (Taylor FO)")
loss_corr_curve = torch.tensor(temp_corr[order,2]/temp_corr[:,2].max())
plt.plot(loss_corr_curve)

In [None]:
(vals_corr / vals_corr.max() - loss_corr_curve).norm()

In [None]:
figure(figsize=(20,5))
s = torch.tensor(imp_order_decorr[:,2])
order = s.sort()[1]
vals_decorr = s.sort()[0]
plt.plot(vals_decorr/vals_decorr.max())
plt.title("Decorrelated (Taylor FO)")
loss_decorr_curve = torch.tensor(temp_decorr[order,2]/temp_decorr[:,2].max())
plt.plot(loss_decorr_curve)

In [None]:
(vals_decorr / vals_decorr.max() - loss_decorr_curve).norm()

### Pruning

In [None]:
def cfg_p(prune_ratio, orig_size, save_cfg_corr=0, save_cfg=0):
    cfg_list = []

    for i in range(4):
        cfg_list.append(orig_size[2*i] - prune_ratio[2*i])
        cfg_list.append(orig_size[2*i+1] - prune_ratio[2*i+1])
        cfg_list.append('M')

    cfg_list.append(orig_size[8] - prune_ratio[8])
    cfg_list.append(orig_size[9] - prune_ratio[9])
    cfg_list.append('M')
    
    if(save_cfg == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    elif(save_cfg == 2):
        with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)
    
    return cfg_list

In [None]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    
    if(net_type==1):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=1)
    elif(net_type==2):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=2)
    else:
        cfg = cfg_p(prune_ratio, orig_size)        
    
    net_pruned = VGG_p('VGG13_p', cfg).to(device)
    bn = [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]
        
    for l in range(len(bn)):
        if(l == 0):
            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[bn[l]][n_c:])
            net_pruned.features[bn[l]-2].weight.data = net.features[bn[l]-2].weight[order_c].data.detach().clone()
            net_pruned.features[bn[l]-2].bias.data = net.features[bn[l]-2].bias[order_c].data.detach().clone()

            net_pruned.features[bn[l]].weight.data = net.features[bn[l]].weight[order_c].data.detach().clone()
            net_pruned.features[bn[l]].bias.data = net.features[bn[l]].bias[order_c].data.detach().clone()
            net_pruned.features[bn[l]].running_var.data = net.features[bn[l]].running_var[order_c].detach().clone()
            net_pruned.features[bn[l]].running_mean.data = net.features[bn[l]].running_mean[order_c].detach().clone()    
            continue
        
        n_p = prune_ratio[l-1]        
        n_c = prune_ratio[l]

        order_p = np.sort(imp_order[bn[l-1]][n_p:])
        order_c = np.sort(imp_order[bn[l]][n_c:])
        
        net_pruned.features[bn[l]-2].weight.data = net.features[bn[l]-2].weight[order_c][:,order_p].detach().clone()
        net_pruned.features[bn[l]-2].bias.data = net.features[bn[l]-2].bias[order_c].detach().clone()

        net_pruned.features[bn[l]].weight.data = net.features[bn[l]].weight[order_c].detach().clone()
        net_pruned.features[bn[l]].bias.data = net.features[bn[l]].bias[order_c].detach().clone()    
        net_pruned.features[bn[l]].running_var.data = net.features[bn[l]].running_var[order_c].detach().clone()
        net_pruned.features[bn[l]].running_mean.data = net.features[bn[l]].running_mean[order_c].detach().clone()    
    
    n_33 = prune_ratio[-1]
    order_33 = np.sort(imp_order[33][n_33:])

    net_pruned.classifier.weight.data = net.classifier.weight[:,order_33].detach().clone()
    net_pruned.classifier.bias.data = net.classifier.bias.detach().clone()
    
    return net_pruned

## Retraining

In [None]:
prune_iter = 1

### Computational importance

In [None]:
c_imp = []

for layer_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    
    _, _, w_in, h_in = net_corr.features[0:layer_index](torch.zeros(1,3,32,32).to(device)).shape
    
    c_out, c_in, w_f, h_f = net_corr.features[layer_index-3].weight.shape
    
    c_imp.append((c_in*w_f*h_f)*(w_in*h_in)*c_out*(c_out*(c_in*w_f*h_f)))
    
c_imp = np.array(c_imp)
c_imp = c_imp/c_imp.sum()

In [None]:
c_imp

### Correlated network pruning

In [None]:
orig_size = np.array([net_corr.features[0].weight.shape[0], net_corr.features[3].weight.shape[0], net_corr.features[7].weight.shape[0], net_corr.features[10].weight.shape[0], net_corr.features[14].weight.shape[0], net_corr.features[17].weight.shape[0], net_corr.features[21].weight.shape[0], net_corr.features[24].weight.shape[0], net_corr.features[28].weight.shape[0], net_corr.features[31].weight.shape[0]])

#### Pruning order

In [None]:
imp_order_temp = np.copy(imp_order_corr)
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    imp_order_temp[imp_order_corr[:,0] == l_index,2] += (c_imp[i])**(6)
    i += 1

In [None]:
(c_imp)**6

In [None]:
# order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.4)
order_corr, prune_ratio = order_and_ratios(imp_order_temp, 0.1)

In [None]:
prune_ratio, orig_size

#### Define pruned network

In [None]:
# net_corr.load_state_dict(torch.load(PATH_corr))
# net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

In [None]:
print(cal_acc(net_corr.eval()), cal_acc(net_p.eval()))

In [None]:
t = 0
for i in range(100):
    t += (cal_time(net_p))
print(1 - t/(100*t_corr))

#### Load correlated pruned network

In [None]:
# PATH = './w_decorr/pruned_nets/corr/nets/net_p_iter'+str(prune_iter)+'.pth'
PATH = './w_decorr/pruned_nets/net_temp1.pth'
net_p.load_state_dict(torch.load(PATH))

#### Retraining

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
for epoch in range(5):
    running_loss = 0.0
    num_iter = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        outputs = net_p(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
            
        running_loss += loss.item()
        
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    print(cal_acc(net_p.eval()))
print('Finished Training')

#### Save correlated pruned network

In [None]:
prune_iter

In [None]:
PATH = './w_decorr/pruned_nets/corr/nets/net_p_iter'+str(prune_iter)+'.pth'
# PATH = './w_decorr/pruned_nets/net_temp1.pth'
torch.save(net_p.state_dict(), PATH)

### Decorrelated network pruning

In [None]:
orig_size = np.array([net_decorr.features[0].weight.shape[0], net_decorr.features[3].weight.shape[0], net_decorr.features[7].weight.shape[0], net_decorr.features[10].weight.shape[0], net_decorr.features[14].weight.shape[0], net_decorr.features[17].weight.shape[0], net_decorr.features[21].weight.shape[0], net_decorr.features[24].weight.shape[0], net_decorr.features[28].weight.shape[0], net_decorr.features[31].weight.shape[0]])

#### Pruning order

In [None]:
imp_order_temp = np.copy(imp_order_decorr)
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    imp_order_temp[imp_order_decorr[:,0] == l_index,2] += (c_imp[i])**(6)
    i += 1

In [None]:
(c_imp)**6

In [None]:
# order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, 0.4)
order_decorr, prune_ratio = order_and_ratios(imp_order_temp, 0.1)

In [None]:
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_decorr.load_state_dict(torch.load(PATH_decorr))
net_p = pruner(net_decorr, order_decorr, prune_ratio, orig_size, net_type=2)

In [None]:
cal_acc(net_decorr.eval()), cal_acc(net_p.eval()) #, cal_acc(net_p1.eval())

In [None]:
t = 0
for i in range(5):
    t += (cal_time(net_p))
print(1 - t/(5*t_decorr))

#### Load decorrelated pruned network

In [None]:
# PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter'+str(prune_iter)+'.pth'
# # PATH = './w_decorr/pruned_nets/net_temp.pth'
# net_p.load_state_dict(torch.load(PATH))

#### Computational Importance

In [None]:
l_imp_p = []

l_inds = [0, 3, 6, 10, 13, 17, 20, 24, 27, 31, 34]
out = torch.zeros(1,3,32,32).to(device)
for i in range(len(l_inds)-1):
    time_init = time.time()
    out = net_p.features[l_inds[i]:l_inds[i+1]](out)
    l_imp_p.append(time.time() - time_init)
    
l_imp_p = np.array(l_imp_p)
l_imp_p = l_imp_p/l_imp_p.sum()

In [None]:
l_impd = {}
i = 0
for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
    l_impd.update({conv_ind: l_imp_p[i]})
    i+=1

In [None]:
# l_impd[0] = 0 #l_impd[31]

#### Load decorrelated pruned network

In [None]:
PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter'+str(prune_iter)+'.pth'
# PATH = './w_decorr/pruned_nets/net_temp.pth'
net_p.load_state_dict(torch.load(PATH))

#### Retraining

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
for epoch in range(3):  # loop over the dataset multiple times
    running_loss = 0.0
    num_iter = 0
    angle_cost = 0.0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = net_p(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_p.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

        b_mat = net_p.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

        params = torch.cat((w_mat1, b_mat1), dim=1)

        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)

#         L_angle += (l_impd[0])*(angle_mat).norm(1) #.norm().pow(2))
        L_angle += (angle_mat).norm(1) #.norm().pow(2))        
        
        ### Conv_ind != 0 ###
        for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net_p.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            
            b_mat = net_p.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            
            params = torch.cat((w_mat1, b_mat1), dim=1)
            
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)
            
#             L_angle += (l_impd[conv_ind])*(angle_mat).norm(1)
            L_angle += (angle_mat).norm(1)
            
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
    print("angle_cost: ", angle_cost/num_iter)
    print("diag_mass_ratio: ", (num_iter*(64+128+256+1024)*2)/(L_angle.detach().cpu().numpy()))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    running_loss = 0.0
    print(cal_acc(net_p.eval()))
    
print('Finished Training')

#### Save decorrelated pruned network

In [None]:
prune_iter 

In [None]:
PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter'+str(prune_iter)+'.pth'
# PATH = './w_decorr/pruned_nets/net_temp.pth'
torch.save(net_p.state_dict(), PATH)

#### Evaluate orthogonality of filters in pruned network

In [None]:
conv_ind = 0
w_mat = net_p.features[conv_ind].weight
w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

b_mat = net_p.features[conv_ind].bias
b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

params = torch.cat((w_mat1, b_mat1), dim=1)

angle_mat = torch.matmul(torch.t(params), params)

L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))

print(L_diag.cpu()/L_angle.cpu())
    
for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
    w_mat = net_p.features[conv_ind].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

    b_mat = net_p.features[conv_ind].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

    params = torch.cat((w_mat1, b_mat1), dim=1)

    angle_mat = torch.matmul(params, torch.t(params))

    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    
    print(L_diag.cpu()/L_angle.cpu())

### Subsequent pruning

#### Importance

In [None]:
# optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
# imp_order_p = np.array([[],[],[]]).transpose()
# i = 0
# for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
#     print(l_index)
#     nlist = cal_importance(net_p, l_index)
#     imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
#     i+=1
    
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     pickle.dump(imp_order_tfo_p, f)

In [None]:
optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
imp_order_p = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_p, l_index)
    imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p, f)
    
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     pickle.dump(imp_order_p, f)

In [None]:
### De-Correlated network loading ###
# with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'rb') as f:
#     imp_order_tfo_p = pickle.load(f)

### Correlated network loading ###
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     imp_order_tfo_p = pickle.load(f)

#### Computational importance

In [None]:
def t_imp(net):
    t = []
    l_inds = [0, 3, 6, 10, 13, 17, 20, 24, 27, 31, 34]
    out = torch.zeros(1,3,32,32).to(device)
    for i in range(len(l_inds)-1):
        time_init = time.time()
        out = net.features[l_inds[i]:l_inds[i+1]](out)
        t.append(time.time() - time_init)
    return np.array(t)

In [None]:
c_imp = t_imp(net_p)

In [None]:
for i in range(9):
    c_imp += t_imp(net_p)

In [None]:
c_imp = c_imp / 10

#### Pruned network pruning

In [None]:
orig_size = np.array([net_p.features[0].weight.shape[0], net_p.features[3].weight.shape[0], net_p.features[7].weight.shape[0], net_p.features[10].weight.shape[0], net_p.features[14].weight.shape[0], net_p.features[17].weight.shape[0], net_p.features[21].weight.shape[0], net_p.features[24].weight.shape[0], net_p.features[28].weight.shape[0], net_p.features[31].weight.shape[0]])

#### Pruning order

In [None]:
imp_order_temp = np.copy(imp_order_p)
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    imp_order_temp[imp_order_p[:,0] == l_index,2] += (1 - 100*c_imp[i])**(500)
    i += 1

In [None]:
(1-c_imp*100)**500

In [None]:
imp_order_p[:,2].mean()

In [None]:
# order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.4)
order_p, prune_ratio = order_and_ratios(imp_order_temp, 0.1)

In [None]:
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=2)

In [None]:
print(cal_acc(net_p1.eval()), cal_acc(net_p.eval()))

In [None]:
t = 0
for i in range(5):
    t += (cal_time(net_p1))
    
t_decorr = 0

for i in range(5):
    t_decorr += (cal_time(net_decorr))

print(1 - t/(t_decorr))

#### Prune the pruned network again

In [None]:
prune_iter = 2

In [None]:
net_p = net_p1

In [None]:
### De-Correlated network saving ###
PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter'+str(prune_iter)+'.pth'
torch.save(net_p1.state_dict(), PATH)

### Correlated network saving ###
# PATH = './w_decorr/pruned_nets/corr/nets/net_p_iter'+str(prune_iter)+'.pth'
# torch.save(net_p1.state_dict(), PATH)

### Load saved network

In [None]:
# cfg_p1 = [1, 1, 'M', 1, 1, 'M', 1, 1, 'M', 1, 1, 'M', 1, 1, 'M']
# cfg_p1 = []
# for layer_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
#     cfg_p.append(net_p.features[layer_index-1].weight.shape[0])

In [None]:
with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter3.pkl", 'rb') as f:
    cfg_p1 = pickle.load(f)
    
# for i in [0, 1, 3, 4, 6, 7, 9, 10, 12, 13]:
#     cfg_p1[i] += 1

In [None]:
cfg_p1 = [32, 64, 'M', 91, 102, 'M', 127, 132, 'M', 114, 112, 'M', 132, 213, 'M']

In [None]:
net_p = VGG_p('VGG13_p', cfg_p1).to(device)

In [None]:
# PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter'+str(5)+'.pth'
PATH = './w_decorr/pruned_nets/decorr/nets/net_p_iter3.pth'
net_p.load_state_dict(torch.load(PATH))

In [None]:
cal_acc(net_p)

In [None]:
cal_acc(net_p.eval()), cal_acc(net_decorr.eval())

### FLOPS calculator

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_decorr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

In [None]:
accs_corr = []
for prunemuch in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
    order_corr, prune_ratio = order_and_ratios(imp_order_corr, prunemuch)
    net_corr.load_state_dict(torch.load(PATH_corr))
    net_p = pruner(net_corr, order_corr, prune_ratio, orig_size)
    accs_corr.append(cal_acc(net_p.eval()))

In [None]:
accs_decorr = []
for prunemuch in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]:
    order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, prunemuch)
    net_decorr.load_state_dict(torch.load(PATH_decorr))
    net_p = pruner(net_decorr, order_decorr, prune_ratio, orig_size)
    accs_decorr.append(cal_acc(net_p.eval()))

In [None]:
plt.plot(accs_corr)
plt.plot(accs_decorr)

In [None]:
for i in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    print((net_p.features[i-1].weight - net_corr.features[i-1].weight).norm())

In [None]:
plt.figure(figsize=(20,10))
plt.scatter(imp_order_p[:,2], imp_order_p[:,0])
plt.xlim(-0.000002, 0.0002)

In [None]:
size = 1 - np.array([37.9, 30.9, 23.2, 13.2, 11.2, 6.5, 5.5, 4.4, 3.9]) / 37.9

In [None]:
accs = np.array([60.55, 60.8, 60.85, 61.14, 61.12, 61.16, 60.97, 60.61, 60.41])
speedups = np.array([0, 9, 15, 22, 31, 38, 45, 48, 52])

In [None]:
accs_base = np.array([60.52, 60.59, 60.72, 60.83, 60.74, 61.56, 60.97, 60.61, 60.41])
speedups_base = np.array([0, 8, 18, 25, 32, 36, 42, 47, 52])

In [None]:
plt.figure(figsize=(14,6))
plt.plot(size, accs, label="Pruned network's accuracy")
plt.xlabel("Compression ratio")
plt.ylabel("Accuracy")
plt.hlines(xmin=0,xmax=0.9,y=60.55, color='r',label="Baseline accuracy")
plt.legend()
# plt.savefig("ortho_prune.png")

In [None]:
plt.figure(figsize=(14,6))
plt.plot(size, speedups, label="Pruned network's speedup")
plt.xlabel("Compression ratio")
plt.ylabel("% Inference time")
plt.hlines(xmin=0,xmax=0.9,y=55, color='r', label="Maximum possible speedup")
plt.legend()
# plt.savefig("ortho_time.png")