In [None]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import os
from utils import progress_bar
from imp_baselines import *

In [None]:
from ptflops import get_model_complexity_info

In [None]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

In [None]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
class VGG_p(nn.Module):
    def __init__(self, vgg_name, cfg):
        super(VGG_p, self).__init__()
        self.features = self._make_layers(cfg)
        self.classifier = nn.Linear(cfg[-2], 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
net_corr = VGG('VGG13').to(device)
net_decorr = VGG('VGG13').to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
PATH_corr = './w_decorr/base_params/cifar100_net.pth'
PATH_decorr = './w_decorr/base_params/wnet_base.pth'

net_corr.load_state_dict(torch.load(PATH_corr))
net_decorr.load_state_dict(torch.load(PATH_decorr))

### Accuracies

In [None]:
def cal_acc(net_acc):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net_acc(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return (100 * correct / total)

### Importance

In [None]:
def cal_importance(net, l_index):
    bias_base = net.features[l_index].bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)

    imp_norm = imp_corr_bn
    
    neuron_order = [np.linspace(0, imp_norm.shape[0]-1, imp_norm.shape[0]), imp_norm]
    
    return neuron_order

### Timer

In [None]:
import time

In [None]:
def cal_time(net_acc):
    testsamp = torch.rand(1,3,32,32).to(device)
    
    for i in range(5):
        net_acc(testsamp)    
    
    t_s = time.time()
    for i in range(5):
        net_acc(testsamp)
        t_end += time.time() - t_s
    
    return (t_end / 5)

In [None]:
t_corr = cal_time(net_corr)
t_decorr = cal_time(net_decorr)

### TFO importance

In [None]:
import pickle

In [None]:
with open("./w_decorr/base_params/tfo_corr.pkl", 'rb') as f:
    imp_order_corr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_corr.parameters(), lr=0, weight_decay=0)
imp_order_corr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_corr, l_index)
    imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_corr.pkl", 'wb') as f:
    pickle.dump(imp_order_corr, f)

In [None]:
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'rb') as f:
    imp_order_decorr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
imp_order_decorr = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_decorr, l_index)
    imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'wb') as f:
    pickle.dump(imp_order_decorr, f)

### Orders and ratios

#### Global

In [None]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

#### Local

In [None]:
# def order_and_ratios(imp_order, prune_ratio):
#     imp_sort = np.argsort(imp_order[:,2])
#     temp_order = imp_order[imp_sort]
    
#     n_prune = int(prune_ratio * imp_order.shape[0])

#     prune_list = temp_order[0:n_prune]

#     imp_order_tfo = {}
#     ratios = []

#     for l_index in [2, 6, 10, 13, 16, 1, 4]:
#         nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
#         imp_order_tfo.update({l_index: nlist})
#         ratios.append(int(nlist.shape[0] * prune_ratio))
#     return imp_order_tfo, ratios

### Pruning

In [None]:
def cfg_p(prune_ratio, orig_size, save_cfg_corr=0, save_cfg=0):
    cfg_list = []

    for i in range(4):
        cfg_list.append(orig_size[2*i] - prune_ratio[2*i])
        cfg_list.append(orig_size[2*i+1] - prune_ratio[2*i+1])
        cfg_list.append('M')

    cfg_list.append(orig_size[8] - prune_ratio[8])
    cfg_list.append(orig_size[9] - prune_ratio[9])
    cfg_list.append('M')
    
    if(save_cfg == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    elif(save_cfg == 2):
        with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)
    
    return cfg_list

In [None]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    
    if(net_type==1):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=1)
    elif(net_type==2):
        cfg = cfg_p(prune_ratio, orig_size, save_cfg=2)
    else:
        cfg = cfg_p(prune_ratio, orig_size)        
    
    net_pruned = VGG_p('VGG13_p', cfg).to(device)
    bn = [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]
        
    for l in range(len(bn)):
        if(l == 0):
            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[bn[l]][n_c:])
            net_pruned.features[bn[l]-2].weight.data = net.features[bn[l]-2].weight[order_c].data.detach().clone()
            net_pruned.features[bn[l]-2].bias.data = net.features[bn[l]-2].bias[order_c].data.detach().clone()

            net_pruned.features[bn[l]].weight.data = net.features[bn[l]].weight[order_c].data.detach().clone()
            net_pruned.features[bn[l]].bias.data = net.features[bn[l]].bias[order_c].data.detach().clone()
            net_pruned.features[bn[l]].running_var.data = net.features[bn[l]].running_var[order_c].detach().clone()
            net_pruned.features[bn[l]].running_mean.data = net.features[bn[l]].running_mean[order_c].detach().clone()    
            continue
        
        n_p = prune_ratio[l-1]        
        n_c = prune_ratio[l]

        order_p = np.sort(imp_order[bn[l-1]][n_p:])
        order_c = np.sort(imp_order[bn[l]][n_c:])
        
        net_pruned.features[bn[l]-2].weight.data = net.features[bn[l]-2].weight[order_c][:,order_p].detach().clone()
        net_pruned.features[bn[l]-2].bias.data = net.features[bn[l]-2].bias[order_c].detach().clone()

        net_pruned.features[bn[l]].weight.data = net.features[bn[l]].weight[order_c].detach().clone()
        net_pruned.features[bn[l]].bias.data = net.features[bn[l]].bias[order_c].detach().clone()    
        net_pruned.features[bn[l]].running_var.data = net.features[bn[l]].running_var[order_c].detach().clone()
        net_pruned.features[bn[l]].running_mean.data = net.features[bn[l]].running_mean[order_c].detach().clone()    
    
    n_33 = prune_ratio[-1]
    order_33 = np.sort(imp_order[33][n_33:])

    net_pruned.classifier.weight.data = net.classifier.weight[:,order_33].detach().clone()
    net_pruned.classifier.bias.data = net.classifier.bias.detach().clone()
    
    return net_pruned

## Retraining

In [None]:
prune_iter = 1

## Correlated network pruning

In [None]:
orig_size = []
for i in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    orig_size.append(net_corr.features[i].bias.shape[0])
orig_size = np.array(orig_size)

In [None]:
order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.3)
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])
net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

In [None]:
cal_acc(net_p.eval()), cal_acc(net_corr.eval())

#### Retraining

In [None]:
# Training
def net_p_train(epoch):
    print('\nEpoch: %d' % epoch)
    net_p.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_p_test(epoch):
    global best_p_acc
    global prune_iter
    net_p.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_p_acc:
        print('Saving..')
        state = {
            'net_p': net_p.state_dict(),
            'best_p_acc': acc
        }
        if not os.path.isdir('net_p_checkpoint'):
            os.mkdir('net_p_checkpoint')
        torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
#         torch.save(state, './net_p_checkpoint/temp'+str(prune_iter)+'.pth')
        best_p_acc = acc

#### Retraining

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
# best_p_acc = 0

In [None]:
for epoch in range(10):
    net_p_train(epoch)
    net_p_test(epoch)

#### Load correlated pruned network

In [None]:
prune_iter

In [None]:
net_dict = torch.load('./net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
net_p.load_state_dict(net_dict['net_p'])
best_p_acc = net_dict['best_p_acc']

In [None]:
net_p_test(0)

## Decorrelated network pruning

In [None]:
orig_size = []
for i in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    orig_size.append(net_decorr.features[i].bias.shape[0])
orig_size = np.array(orig_size)

In [None]:
order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, 0.3)
prune_ratio, orig_size

#### Define pruned network

In [None]:
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net'])
net_p_ortho = pruner(net_decorr, order_decorr, prune_ratio, orig_size, net_type=2)

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

In [None]:
def net_p_test_ortho(epoch):
    global best_p_ortho_acc
    global prune_iter
    net_p_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_p_ortho_acc:
        print('Saving..')
        state = {
            'net_p_ortho': net_p_ortho.state_dict(),
            'best_p_ortho_acc': acc
        }
        if not os.path.isdir('ortho_p_checkpoint'):
            os.mkdir('ortho_p_checkpoint')
        torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
        best_p_ortho_acc = acc

In [None]:
def net_p_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_p_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_p_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_p_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_p_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[0])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### Conv_ind != 0 ###
        for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net_p_ortho.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_p_ortho.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

#### Retraining

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p_ortho.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
l_imp = {}

for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
    l_imp.update({conv_ind: net_p_ortho.features[conv_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
# best_p_ortho_acc = 0

In [None]:
for epoch in range(5):
    net_p_train_ortho(epoch)
    net_p_test_ortho(epoch)

#### Load decorrelated pruned network

In [None]:
prune_iter 

In [None]:
net_dict = torch.load('./ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
net_p_ortho.load_state_dict(net_dict['net_p_ortho'])
best_p_ortho_acc = net_dict['best_p_ortho_acc'] 

In [None]:
net_p_test_ortho(0)

#### Evaluate orthogonality of filters in pruned network

In [None]:
### conv_ind == 0 ###
conv_ind = 0
w_mat = net_p_ortho.features[conv_ind].weight
w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
b_mat = net_p_ortho.features[conv_ind].bias
b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
params = torch.cat((w_mat1, b_mat1), dim=1)
angle_mat = torch.matmul(torch.t(params), params)
L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))
print(L_diag.cpu()/L_angle.cpu())
    
### conv_ind != 0 ###
for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
    w_mat = net_p_ortho.features[conv_ind].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net_p_ortho.features[conv_ind].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(params, torch.t(params))
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

### Subsequent pruning

#### Importance

In [None]:
# # ''' Correlated network '''
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'rb') as f:
#     imp_order_p = pickle.load(f)

In [None]:
# optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
# imp_order_p = np.array([[],[],[]]).transpose()
# i = 0
# for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
#     print(l_index)
#     nlist = cal_importance(net_p, l_index)
#     imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
#     i+=1
    
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
#     pickle.dump(imp_order_tfo_p, f)

In [None]:
''' De-Correlated network '''
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p_ortho"+str(prune_iter)+".pkl", 'rb') as f:
    imp_order_p_ortho = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p_ortho.parameters(), lr=0, weight_decay=0)
imp_order_p_ortho = np.array([[],[],[]]).transpose()
i = 0
for l_index in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    print(l_index)
    nlist = cal_importance(net_p_ortho, l_index)
    imp_order_p_ortho = np.concatenate((imp_order_p_ortho,np.array([np.repeat([l_index],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
        
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p_ortho"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p_ortho, f)

#### Pruned network pruning

In [None]:
# ''' Correlated network '''
# orig_size = []
# for i in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
#     orig_size.append(net_p.features[i].bias.shape[0])
# orig_size = np.array(orig_size)

In [None]:
''' Decorrelated network '''
orig_size = []
for i in [2, 5, 9, 12, 16, 19, 23, 26, 30, 33]:
    orig_size.append(net_p_ortho.features[i].bias.shape[0])
orig_size = np.array(orig_size)

#### Pruning order

In [None]:
# ''' Correlated network '''
# order_p, prune_ratio = order_and_ratios(imp_order_p, 0.3)
# prune_ratio, orig_size

In [None]:
''' De-Correlated network '''
order_p, prune_ratio = order_and_ratios(imp_order_p_ortho, 0.05)
prune_ratio, orig_size

#### Define pruned network

In [None]:
prune_iter = 12

In [None]:
# ''' Correlated network pruning '''
# net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=1)

# print("Accs:", cal_acc(net_p1.eval()), cal_acc(net_p.eval()))

In [None]:
''' De-Correlated network pruning '''
net_p1_ortho = pruner(net_p_ortho, order_p, prune_ratio, orig_size, net_type=2)

print("Accs:", cal_acc(net_p1_ortho.eval()), cal_acc(net_p_ortho.eval()))

#### Save pruned network

In [None]:
''' Correlated network saving '''
net_p = net_p1

print('Saving..')
state = {
    'net_p': net_p.state_dict(),
    'best_p_acc': cal_acc(net_p.eval())
}
if not os.path.isdir('net_p_checkpoint'):
    os.mkdir('net_p_checkpoint')
torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')

In [None]:
# ''' De-Correlated network saving '''
# net_p_ortho = net_p1_ortho

# print('Saving..')
# state = {
#     'net_p_ortho': net_p_ortho.state_dict(),
#     'best_p_ortho_acc': cal_acc(net_p_ortho.eval())
# }
# if not os.path.isdir('ortho_p_checkpoint'):
#     os.mkdir('ortho_p_checkpoint')
# torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')

### Load pruned network

In [None]:
''' Correlated network loading '''
with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'rb') as f:
    cfg_p1 = pickle.load(f)
    
net_p = VGG_p('VGG13_p', cfg_p1).to(device)
PATH = './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth'
net_p.load_state_dict(torch.load(PATH)['net_p'])

In [None]:
# ''' De-Correlated network loading '''
# with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'rb') as f:
#     cfg_p1 = pickle.load(f)

# net_p_ortho = VGG_p('VGG13_p', cfg_p1).to(device)
# PATH = './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth'
# net_p_ortho.load_state_dict(torch.load(PATH)['net_p_ortho'])    

### FLOPS calculator

In [None]:
# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_p_ortho, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
    
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))    

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_decorr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_corr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))