In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import numpy as np

from models import *
from utils import progress_bar
from imp_baselines import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from ptflops import get_model_complexity_info

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


### Build and load base networks

In [4]:
print('==> Building model..')
net_corr = MobileNet()
net_decorr = MobileNet().to(device)
criterion = nn.CrossEntropyLoss()

net_corr = net_corr.to(device)
net_decorr = net_decorr.to(device)
if device == 'cuda':
    net_corr = torch.nn.DataParallel(net_corr)
    net_decorr = torch.nn.DataParallel(net_decorr)
    cudnn.benchmark = True

==> Building model..


In [5]:
PATH_corr = './w_decorr/base_params/cifar100_net.pth'
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])

PATH_decorr = './w_decorr/base_params/wnet_base.pth'
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net_ortho'])

<All keys matched successfully>

### Accuracy Calculator

In [6]:
def cal_acc(net_test):
    net_test.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_test(inputs)

            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(100 * correct / total)
        
    return 100 * correct / total

### Importance Calculator

In [7]:
def cal_importance(net, l_id):
    bias_base = l_id.bias.data.clone().detach()
    av_corrval = 0

    running_loss = 0.0
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        imp_corr_bn += (((l_id.weight.grad)*(l_id.weight.data)) + ((l_id.bias.grad)*(l_id.bias.data))).abs().pow(2)
        break
    imp_norm = imp_corr_bn
    
    neuron_order = [np.linspace(0, imp_norm.shape[0]-1, imp_norm.shape[0]), imp_norm]
    
    return neuron_order

### Time Calculator

In [8]:
import time

In [9]:
def cal_time(net_acc):
    net_acc.eval()
    testsamp = torch.rand(1,3,32,32).to(device)
    
    for i in range(5):
        net_acc(testsamp)    
    t_end = 0
    t_s = time.time()
    for i in range(25):
        net_acc(testsamp)
        t_end += time.time() - t_s
    
    return (t_end / 25)

In [10]:
t_corr = cal_time(net_corr)
t_decorr = cal_time(net_decorr)

### TFO importance

In [11]:
import pickle

In [None]:
with open("./w_decorr/base_params/tfo_corr.pkl", 'rb') as f:
    imp_order_corr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_corr.parameters(), lr=0, weight_decay=0)
imp_order_corr = np.array([[],[],[]]).transpose()
i = 0

print(i)
nlist = cal_importance(net_corr, net_corr.module.bn1)
imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
i+=1

for l_index in range(13):
    print(i)
    nlist = cal_importance(net_corr, net_corr.module.layers[l_index].bn1)
    imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
    print(i)
    nlist = cal_importance(net_corr, net_corr.module.layers[l_index].bn2)
    imp_order_corr = np.concatenate((imp_order_corr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_corr.pkl", 'wb') as f:
    pickle.dump(imp_order_corr, f)

In [12]:
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'rb') as f:
    imp_order_decorr = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
imp_order_decorr = np.array([[],[],[]]).transpose()
i = 0

print(i)
nlist = cal_importance(net_decorr, net_decorr.module.bn1)
imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
i+=1

for l_index in range(13):
    print(i)
    nlist = cal_importance(net_decorr, net_decorr.module.layers[l_index].bn1)
    imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
    print(i)
    nlist = cal_importance(net_decorr, net_decorr.module.layers[l_index].bn2)
    imp_order_decorr = np.concatenate((imp_order_decorr,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/base_params/tfo_w_decorr.pkl", 'wb') as f:
    pickle.dump(imp_order_decorr, f)

### Order and ratios for pruning

In [13]:
def order_and_ratios(imp_order, prune_ratio):
    imp_sort = np.argsort(imp_order[:,2])
    temp_order = imp_order[imp_sort]

    n_prune = int(prune_ratio * imp_order.shape[0])

    prune_list = temp_order[0:n_prune]

    imp_order_tfo = {}
    ratios = []

    for l_index in range(27):
        nlist = temp_order[(temp_order[:,0] == l_index), 1].astype(int)
        imp_order_tfo.update({l_index: nlist})
        nlist = np.sort(prune_list[(prune_list[:,0] == l_index), 1].astype(int))
        ratios.append(nlist.shape[0])
    return imp_order_tfo, ratios

In [14]:
def cal_size(net_size):
    orig_size = [net_size.module.bn1.bias.shape[0]]

    for l_index in range(13):
        orig_size.append(net_size.module.layers[l_index].bn1.bias.shape[0])
        orig_size.append(net_size.module.layers[l_index].bn2.bias.shape[0])

    orig_size = np.array(orig_size)
    
    return orig_size

### Pruned MobileNet Class

In [15]:
def cfg_p(prune_ratio, orig_size, save_cfg=0):
    cfg_list = []

    for i in range(0,27,2):
        cfg_list.append(int(orig_size[i] - prune_ratio[i]))
    
    for i in [2, 4, 6, 12]:
        cfg_list[i] = (cfg_list[i], 2)
    
    if(save_cfg == 1):
        with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)

    elif(save_cfg == 2):
        with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(prune_iter)+".pkl", 'wb') as f:
            pickle.dump(cfg_list, f)
    
    return cfg_list

In [16]:
class Block(nn.Module):
    '''Depthwise conv + Pointwise conv'''
    def __init__(self, in_planes, out_planes, stride=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out


class MobileNet_p(nn.Module):

    def __init__(self, in_first, cfg, num_classes=100):
        super(MobileNet_p, self).__init__()
        self.cfg = cfg
        self.conv1 = nn.Conv2d(3, in_first, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_first)
        self.layers = self._make_layers(in_planes=in_first)
        self.linear = nn.Linear(cfg[-1], num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for x in self.cfg:
            out_planes = x if isinstance(x, int) else x[0]
            stride = 1 if isinstance(x, int) else x[1]
            layers.append(Block(in_planes, out_planes, stride))
            in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

### Pruner

In [38]:
def pruner(net, imp_order, prune_ratio, orig_size, net_type=0):
    
    cfg = cfg_p(prune_ratio, orig_size, save_cfg=net_type)
    
    net_pruned = torch.nn.DataParallel(MobileNet_p(cfg[0], cfg[1:]))
    n_c = 0
    order_c = []
    order_p = []

    for l in range(0, 27, 2):
        if(l == 0):
            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[l][n_c:])
            net_pruned.module.conv1.weight.data = net.module.conv1.weight[order_c].data.detach().clone()
            
            net_pruned.module.bn1.weight.data = net.module.bn1.weight[order_c].data.detach().clone()
            net_pruned.module.bn1.bias.data = net.module.bn1.bias[order_c].data.detach().clone()
            net_pruned.module.bn1.running_var.data = net.module.bn1.running_var[order_c].data.detach().clone()
            net_pruned.module.bn1.running_mean.data = net.module.bn1.running_mean[order_c].data.detach().clone()
            
            continue
        
        else:
            ind = int((l+1)/2) - 1
            # n_p = prune_ratio[ind]
            # order_c = np.sort(imp_order[ind][n_p:])
            net_pruned.module.layers[ind].conv1.weight.data = net.module.layers[ind].conv1.weight[order_c].data.detach().clone()
            
            net_pruned.module.layers[ind].bn1.weight.data = net.module.layers[ind].bn1.weight[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn1.bias.data = net.module.layers[ind].bn1.bias[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn1.running_var.data = net.module.layers[ind].bn1.running_var[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn1.running_mean.data = net.module.layers[ind].bn1.running_mean[order_c].data.detach().clone()
            order_p = order_c.copy()

            n_c = prune_ratio[l]
            order_c = np.sort(imp_order[l][n_c:])

            net_pruned.module.layers[ind].conv2.weight.data = net.module.layers[ind].conv2.weight[order_c][:,order_p].data.detach().clone()
            
            net_pruned.module.layers[ind].bn2.weight.data = net.module.layers[ind].bn2.weight[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn2.bias.data = net.module.layers[ind].bn2.bias[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn2.running_var.data = net.module.layers[ind].bn2.running_var[order_c].data.detach().clone()
            net_pruned.module.layers[ind].bn2.running_mean.data = net.module.layers[ind].bn2.running_mean[order_c].data.detach().clone()
    
    n_linear = prune_ratio[-1]
    order_linear = np.sort(imp_order[26][n_linear:])

    net_pruned.module.linear.weight.data = net.module.linear.weight[:,order_linear].detach().clone()
    net_pruned.module.linear.bias.data = net.module.linear.bias.detach().clone()
    
    return net_pruned

## Retraining

In [18]:
prune_iter = 1

### Correlated network pruning

In [None]:
orig_size = cal_size(net_corr)

In [None]:
order_corr, prune_ratio = order_and_ratios(imp_order_corr, 0.2)
np.array(prune_ratio), orig_size

In [None]:
net_dict = torch.load(PATH_corr)
net_corr.load_state_dict(net_dict['net'])
net_p = pruner(net_corr, order_corr, prune_ratio, orig_size, net_type=1)

### Accs

In [None]:
cal_acc(net_corr.eval()), cal_acc(net_p.eval())

In [None]:
t_corr = cal_time(net_corr)
t_p = cal_time(net_p)

In [None]:
100*(1 - (t_p / t_corr))

#### Retraining

In [None]:
# Training
def net_p_train(epoch):
    print('\nEpoch: %d' % epoch)
    net_p.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_p_test(epoch):
    global best_p_acc
    global prune_iter
    net_p.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_p_acc:
        print('Saving..')
        state = {
            'net_p': net_p.state_dict(),
            'best_p_acc': acc
        }
        if not os.path.isdir('net_p_checkpoint'):
            os.mkdir('net_p_checkpoint')
        torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')
        best_p_acc = acc

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p.parameters(), lr=0.000001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
best_p_acc = 0

In [None]:
for epoch in range(1):
    net_p_train(epoch)
    net_p_test(epoch)

#### Load correlated pruned network

In [None]:
prune_iter

In [None]:
net_dict = torch.load('./net_p_checkpoint/ckpt1.pth')
net_p.load_state_dict(net_dict['net_p'])
best_p_acc = net_dict['best_p_acc']

## Decorrelated network pruning

In [19]:
orig_size = cal_size(net_decorr)

In [20]:
order_decorr, prune_ratio = order_and_ratios(imp_order_decorr, 0.1)
np.array(prune_ratio), orig_size

(array([ 15,  13,   8,   7,  10,   5,   8,   3,  17,  11,  61,  24, 101,
         64, 144, 102, 103,  77,  48,  22,  26,  14,  26,  21,  91,  68,
          5]),
 array([  32,   32,   64,   64,  128,  128,  128,  128,  256,  256,  256,
         256,  512,  512,  512,  512,  512,  512,  512,  512,  512,  512,
         512,  512, 1024, 1024, 1024]))

#### Define pruned network

In [21]:
net_dict = torch.load(PATH_decorr)
net_decorr.load_state_dict(net_dict['net_ortho'])
net_p_ortho = pruner(net_decorr, order_decorr, prune_ratio, orig_size, net_type=2)

In [22]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

43.93
44.42


(43.93, 44.42)

In [23]:
t_decorr = cal_time(net_decorr)
t_p_ortho = cal_time(net_p_ortho)

In [24]:
100*(1 - (t_p_ortho / t_corr))

1.6359625594480565

#### Retraining

In [25]:
# Training
def net_p_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_p_ortho.train()
    running_loss = 0
    correct = 0
    total = 0
    angle_cost = 0.0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, labels = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net_p_ortho(inputs)

        L_angle = 0

        ### Conv_ind == 0 ###
        w_mat = net_p_ortho.module.conv1.weight
        params = w_mat.reshape(w_mat.shape[0],-1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += l_imp['conv1'] * (angle_mat).norm(1) #.norm().pow(2))
        
        ### Conv_ind != 0 ###        
        for lnum in range(13):
            w_mat = net_p_ortho.module.layers[lnum].conv1.weight
            params = (w_mat.reshape(w_mat.shape[0],-1))
            angle_mat = torch.matmul(params.t(), params) - torch.eye(params.shape[1]).to(device)
            L_angle += l_imp[lnum] * (angle_mat).norm(1)

            w_mat = net_p_ortho.module.layers[lnum].conv2.weight
            params = (w_mat.reshape(w_mat.shape[0],-1))
            angle_mat = torch.matmul(params.t(), params) - torch.eye(params.shape[1]).to(device)
            L_angle += l_imp[lnum] * (angle_mat).norm(1)
                
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        angle_cost += (L_angle).item()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))

    print("angle_cost: ", angle_cost/batch_idx+1)

In [26]:
def net_p_test_ortho(epoch):
    global best_p_ortho_acc
    global prune_iter
    net_p_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_p_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_p_ortho_acc:
        print('Saving..')
        state = {
            'net_p_ortho': net_p_ortho.state_dict(),
            'best_p_ortho_acc': acc
        }
        if not os.path.isdir('ortho_p_checkpoint'):
            os.mkdir('ortho_p_checkpoint')
        torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
        best_p_ortho_acc = acc

#### Computational importance

In [27]:
l_imp = {}

l_imp.update({'conv1': net_p_ortho.module.conv1.weight.shape[0]})

for conv_ind in range(13):
    l_imp.update({conv_ind: net_p_ortho.module.layers[conv_ind].conv1.weight.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [28]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net_p_ortho.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
# best_p_ortho_acc = 0

In [None]:
for epoch in range(1):
    net_p_train_ortho(epoch)
    net_p_test_ortho(epoch)

#### Save decorrelated pruned network

In [29]:
prune_iter

1

In [None]:
net_dict = torch.load('./ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')
net_p_ortho.load_state_dict(net_dict['net_p_ortho'])

#### Evaluate orthogonality of filters in pruned network

In [30]:
w_mat = net_p_ortho.module.conv1.weight
params = (w_mat.reshape(w_mat.shape[0],-1))
angle_mat = torch.matmul(torch.t(params), params) # - torch.eye(params.shape[1]).to(device)
L_diag = (angle_mat.diag().norm(1))
L_angle = (angle_mat.norm(1))
print(L_diag.cpu()/L_angle.cpu())

for lnum in range(13):
    w_mat = net_p_ortho.module.layers[lnum].conv1.weight
    params = (w_mat.reshape(w_mat.shape[0],-1))
    angle_mat = torch.matmul(params.t(), params)# - torch.eye(params.shape[0]).to(device)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(lnum,"-a: ", L_diag.cpu()/L_angle.cpu())

    w_mat = net_p_ortho.module.layers[lnum].conv2.weight
    params = (w_mat.reshape(w_mat.shape[0],-1))
    angle_mat = torch.matmul(params.t(), params)# - torch.eye(params.shape[0]).to(device)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(lnum,"-b: ", L_diag.cpu()/L_angle.cpu())

tensor(0.1255, grad_fn=<DivBackward0>)
0 -a:  tensor(0.2985, grad_fn=<DivBackward0>)
0 -b:  tensor(0.3180, grad_fn=<DivBackward0>)
1 -a:  tensor(0.4291, grad_fn=<DivBackward0>)
1 -b:  tensor(0.1325, grad_fn=<DivBackward0>)
2 -a:  tensor(0.4088, grad_fn=<DivBackward0>)
2 -b:  tensor(0.0655, grad_fn=<DivBackward0>)
3 -a:  tensor(0.2347, grad_fn=<DivBackward0>)
3 -b:  tensor(0.0719, grad_fn=<DivBackward0>)
4 -a:  tensor(0.3798, grad_fn=<DivBackward0>)
4 -b:  tensor(0.0404, grad_fn=<DivBackward0>)
5 -a:  tensor(0.2231, grad_fn=<DivBackward0>)
5 -b:  tensor(0.0571, grad_fn=<DivBackward0>)
6 -a:  tensor(0.4774, grad_fn=<DivBackward0>)
6 -b:  tensor(0.0313, grad_fn=<DivBackward0>)
7 -a:  tensor(0.4358, grad_fn=<DivBackward0>)
7 -b:  tensor(0.0336, grad_fn=<DivBackward0>)
8 -a:  tensor(0.4617, grad_fn=<DivBackward0>)
8 -b:  tensor(0.0288, grad_fn=<DivBackward0>)
9 -a:  tensor(0.5335, grad_fn=<DivBackward0>)
9 -b:  tensor(0.0249, grad_fn=<DivBackward0>)
10 -a:  tensor(0.6149, grad_fn=<DivBackwa

### Subsequent pruning

#### Importance

In [None]:
''' Correlated network '''
# with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'rb') as f:
#     imp_order_p = pickle.load(f)

In [None]:
optimizer = optim.SGD(net_p.parameters(), lr=0, weight_decay=0)
imp_order_p = np.array([[],[],[]]).transpose()
i = 0

print(i)
nlist = cal_importance(net_p, net_p.module.bn1)
imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
i+=1

for l_index in range(13):
    print(i)
    nlist = cal_importance(net_p, net_p.module.layers[l_index].bn1)
    imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
    print(i)
    nlist = cal_importance(net_p, net_p.module.layers[l_index].bn2)
    imp_order_p = np.concatenate((imp_order_p,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/pruned_nets/corr/tfo_order/tfo_corr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p, f)

In [32]:
''' De-Correlated network '''
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'rb') as f:
    imp_order_p_ortho = pickle.load(f)

In [31]:
optimizer = optim.SGD(net_p_ortho.parameters(), lr=0, weight_decay=0)
imp_order_p_ortho = np.array([[],[],[]]).transpose()
i = 0

print(i)
nlist = cal_importance(net_p_ortho, net_p_ortho.module.bn1)
imp_order_p_ortho = np.concatenate((imp_order_p_ortho,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
i+=1

for l_index in range(13):
    print(i)
    nlist = cal_importance(net_p_ortho, net_p_ortho.module.layers[l_index].bn1)
    imp_order_p_ortho = np.concatenate((imp_order_p_ortho,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
    print(i)
    nlist = cal_importance(net_p_ortho, net_p_ortho.module.layers[l_index].bn2)
    imp_order_p_ortho = np.concatenate((imp_order_p_ortho,np.array([np.repeat([i],nlist[1].shape[0]).tolist(), nlist[0].tolist(), nlist[1].detach().cpu().numpy().tolist()]).transpose()), 0)
    i+=1
    
with open("./w_decorr/pruned_nets/decorr/tfo_order/tfo_w_decorr_p"+str(prune_iter)+".pkl", 'wb') as f:
    pickle.dump(imp_order_p_ortho, f)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19


KeyboardInterrupt: 

#### Pruned network pruning

In [None]:
# ''' Correlated network '''
# orig_size = cal_size(net_p)

In [33]:
''' De-Correlated network '''
orig_size = cal_size(net_p_ortho)

#### Pruning order

In [None]:
# ''' Correlated network '''
# order_p, prune_ratio = order_and_ratios(imp_order_p, 0.1)
# np.array(prune_ratio), orig_size

In [34]:
''' De-Correlated network '''
order_p, prune_ratio = order_and_ratios(imp_order_p_ortho, 0.1)
np.array(prune_ratio), orig_size

(array([ 15,  13,   8,   7,  11,   5,   8,   4,  17,  10,  59,  24,  98,
         62, 143, 105, 103,  78,  48,  27,  26,  12,  28,  21,  88,  69,
          5]),
 array([  17,   17,   56,   56,  118,  118,  120,  120,  239,  239,  195,
         195,  411,  411,  368,  368,  409,  409,  464,  464,  486,  486,
         486,  486,  933,  933, 1019]))

#### Define pruned network

In [35]:
prune_iter = 2

In [None]:
# ''' Correlated network pruning '''
net_p1 = pruner(net_p, order_p, prune_ratio, orig_size, net_type=1)

print("Accs:", cal_acc(net_p1.eval()), cal_acc(net_p.eval()))
print("Time:", cal_time(net_p1), t_corr)

In [36]:
''' De-Correlated network pruning '''
net_p1_ortho = pruner(net_p_ortho, order_p, prune_ratio, orig_size, net_type=2)

print("Accs:", cal_acc(net_p1_ortho.eval()), cal_acc(net_p_ortho.eval()))
print("Time:", cal_time(net_p1_ortho), t_decorr)

RuntimeError: CUDA error: device-side assert triggered

#### Save pruned network

In [None]:
# ''' Correlated network saving '''
net_p = net_p1

print('Saving..')
state = {
    'net_p': net_p.state_dict(),
    'best_p_acc': cal_acc(net_p.eval())
}
if not os.path.isdir('net_p_checkpoint'):
    os.mkdir('net_p_checkpoint')
torch.save(state, './net_p_checkpoint/ckpt'+str(prune_iter)+'.pth')

In [None]:
''' De-Correlated network saving '''
net_p_ortho = net_p1_ortho

print('Saving..')
state = {
    'net_p_ortho': net_p_ortho.state_dict(),
    'best_p_ortho_acc': cal_acc(net_p_ortho.eval())
}
if not os.path.isdir('ortho_p_checkpoint'):
    os.mkdir('ortho_p_checkpoint')
torch.save(state, './ortho_p_checkpoint/ortho_ckpt'+str(prune_iter)+'.pth')

### Load saved network

In [None]:
# ''' Correlated network loading '''
# with open("./w_decorr/pruned_nets/corr/cfgs/net_p_corr_iter"+str(1)+".pkl", 'rb') as f:
#     cfg_p1 = pickle.load(f)
    
# net_p = torch.nn.DataParallel(MobileNet_p(cfg_p1[0], cfg_p1[1:]))
# PATH = './net_p_checkpoint/ckpt'+str(1)+'.pth'
# net_p.load_state_dict(torch.load(PATH)['net_p'])

In [None]:
# cal_acc(net_p.eval()), cal_acc(net_corr.eval())

In [None]:
''' De-Correlated network loading '''
with open("./w_decorr/pruned_nets/decorr/cfgs/net_p_decorr_iter"+str(1)+".pkl", 'rb') as f:
    cfg_p1 = pickle.load(f)

net_p_ortho = torch.nn.DataParallel(MobileNet_p(cfg_p1[0], cfg_p1[1:]))
PATH = './ortho_p_checkpoint/ortho_ckpt'+str(1)+'.pth'
net_p_ortho.load_state_dict(torch.load(PATH)['net_p_ortho'])    

In [None]:
cal_acc(net_p_ortho.eval()), cal_acc(net_decorr.eval())

### FLOPS calculator

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_p_ortho, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
    
# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_p, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))    

In [None]:
with torch.cuda.device(0):
    flops, params = get_model_complexity_info(net_decorr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
    print('{:<30}  {:<8}'.format('Computational complexity: ', flops))

# with torch.cuda.device(0):
#     flops, params = get_model_complexity_info(net_corr, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
#     print('{:<30}  {:<8}'.format('Computational complexity: ', flops))