In [None]:
# -*- coding: utf-8 -*-
# Libraries we need
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

In [None]:
'''Dataloader objects'''

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
         #(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

#classes = ('plane', 'car', 'bird', 'cat',
#          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
'''VGG based networks that work well on CIFAR datasets.'''

cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif x == 'D2':
                layers += [nn.Dropout(p=0.2)]
            elif x == 'D3':
                layers += [nn.Dropout(p=0.3)]
            elif x == 'D4':
                layers += [nn.Dropout(p=0.4)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
'''Setting up the GPU'''

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def cal_acc(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

In [None]:
def cal_acc_train(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in trainloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 50000 train images: %d %%' % (
        100 * correct / total))

## Define layer for analysis

In [None]:
l_index = 2
layer_id = 'bn'

# Baseline network analysis

### Define the network

In [None]:
net = VGG('VGG13').to(device)

In [None]:
PATH = './corr/cifar100_net.pth'
net.load_state_dict(torch.load(PATH))

In [None]:
weight_base = net.features[l_index].weight.data.clone().detach()
bias_base = net.features[l_index].bias.data.clone().detach()

### Ground truth calculation for importance in baseline network

In [None]:
loss_base_corr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()

In [None]:
loss_mat = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    running_loss = 0.0

    net.features[l_index].weight.data[n_index] = 0
    net.features[l_index].bias.data[n_index] = 0
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        outputs = net(inputs)

        loss = (criterion(outputs, labels))

        running_loss += loss.item()
    loss_mat[n_index] = running_loss
    
    net.features[l_index].weight.data = weight_base.clone().detach()
    net.features[l_index].bias.data = bias_base.clone().detach()

torch.save(loss_mat, './corr/loss_bn_'+str(l_index)+'.pt')

In [None]:
# loss_mat = torch.load('./corr/loss_bn_'+str(l_index)+'.pt')

### Taylor FO calculation for importance in baseline network

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
n_epochs = 1

for epoch in range(n_epochs):
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))

# Decorrelated network analysis

### Define the network

In [None]:
PATH = './decorr/cifar100_decorrnet.pth'
net_decorr = VGG('VGG13').to(device)
net_decorr.load_state_dict(torch.load(PATH))

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

### Ground truth calculation for importance in decorrelated network

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

loss_mat_decorr = torch.zeros(weight_base.shape[0])

for n_index in range(weight_base.shape[0]): 
    print(n_index)
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        net_decorr.features[l_index].weight.data[n_index] = 0
        net_decorr.features[l_index].bias.data[n_index] = 0 
        outputs = net_decorr(inputs)
        
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
    loss_mat_decorr[n_index] = running_loss
    
    net_decorr.features[l_index].weight.data = weight_base.clone().detach()
    net_decorr.features[l_index].bias.data = bias_base.clone().detach()

torch.save(loss_mat_decorr, './decorr/loss_bn_'+str(l_index)+'.pt')

In [None]:
# loss_mat_decorr = torch.load('./decorr/loss_bn_'+str(l_index)+'.pt')

### Taylor FO calculation for importance in decorrelated network

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
n_epochs = 1

for epoch in range(n_epochs):
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))

# Net-Slim

### train dataset

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### test dataset

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

# L2 based pruning 

### train dataset

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_corr.pow(2).sum(dim=(1,2,3)).cpu().numpy(), (loss_mat - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_decorr.pow(2).sum(dim=(1,2,3)).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### test dataset

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_corr.pow(2).sum(dim=(1,2,3)).cpu().numpy(), (loss_mat - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_decorr.pow(2).sum(dim=(1,2,3)).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

# Training

### Weight decorrelated training

In [None]:
# net1 = VGG('VGG13').to(device)
PATH = './cifar100_net.pth'
net.load_state_dict(torch.load(PATH))

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    num_iter = 0
    angle_cost = 0.0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        L_angle = 0
        L_diag = 0
        for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            angle_mat = torch.matmul(w_mat1, torch.t(w_mat1))
            L_diag += (w_mat1.pow(2).sum(dim=1)).norm().pow(2) #angle_mat.diag().norm().pow(2) # 
            L_angle += angle_mat.norm().pow(2)
        
        Lc = criterion(outputs, labels)
        loss = (1e-3)*(L_angle - L_diag) + Lc
        
        loss.backward()
        optimizer.step()
        # print statistics
        running_loss += loss.item()
        angle_cost += (L_angle - L_diag).item()
    
    print("angle_cost: ", angle_cost/num_iter)
    print("diag_mass_ratio: ", L_diag.detach().cpu().numpy()/(L_angle.detach().cpu().numpy()))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    running_loss = 0.0
    cal_acc(net)
    
print('Finished Training')

### Inner product training

In [None]:
for epoch in range(2):  
    running_loss = 0.0
    cov_loss = 0
    num_iter = 0
    av_cov_mass = 0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        
        for epoch_num in range(1):            
            L_cov = 0
            for l_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
                out_features = net.features[0:l_index](inputs)
                act_mat = (out_features).permute(0,2,3,1).reshape(-1,out_features.shape[1]) #(0,2,3,1)
                cov_mat = torch.matmul(torch.t(act_mat), act_mat) 

                L_cov += ((cov_mat).norm().pow(2) - cov_mat.diag().norm().pow(2))/2

            Lc = criterion(outputs, labels)
            loss = Lc + (1e-11)*L_cov

            loss.backward()
            optimizer.step()
            
        # print statistics
        running_loss += loss.item()
        cov_loss += L_cov
        av_cov_mass += ((cov_mat).diag().norm().cpu().detach().numpy() / cov_mat.norm().cpu().detach().numpy())
        
        del L_cov, act_mat, cov_mat, out_features
        torch.cuda.empty_cache()
        
    print("Covariance loss: " + str(cov_loss/num_iter))
    print(av_cov_mass/num_iter)
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net)
print('Finished Training')

### Outer product training

In [None]:
PATH = './cifar100_net.pth'
net = VGG('VGG13').to(device)
net.load_state_dict(torch.load(PATH))

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
def m_ratio(net):
    for l in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
        t_mass, d_mass = cal_mass(net, l)
        mass_ratio = d_mass/t_mass
        print("Diagonal mass ratio for layer " + str(l) + ": " + str(mass_ratio))

In [None]:
def cal_mass(net, l_index):
    
    other_loss = 0.0
    self_loss = 0.0
    num_iter = 0
    
    with torch.no_grad():
        for i, data in enumerate(trainloader, 0):
            num_iter += 1
            if(num_iter == 40):
                break
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)

            L_self = 0.0
            L_mat = 0.0

            for epoch_num in range(1):            
                out_features = net.features[0:l_index](inputs)
                act_mat = (out_features).reshape(out_features.shape[0], out_features.shape[1], -1)

                mul_mat = (torch.matmul(act_mat,act_mat.permute(0,2,1)))

                L_self += (mul_mat.diagonal(dim1=1,dim2=2).norm(dim=1).pow(2)).sum()
                L_mat += (mul_mat.norm(dim=(1,2)).pow(2)).sum()

            other_loss += L_mat
            self_loss += L_self

            del L_self, L_mat, act_mat, out_features
            torch.cuda.empty_cache()

        return other_loss, self_loss

In [None]:
reg = 1e-3
# reg = ((mat_loss/num_iter).log()/np.log(10)).floor()

for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    mat_loss = 0.0
    num_iter = 0

    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        
        L_mat = 0.0
        
#         for epoch_num in range(1):            
        for l_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]: 
            out_features = net.features[0:l_index](inputs)
            act_mat = (out_features).reshape(out_features.shape[0], out_features.shape[1], -1)

            mul_mat = (torch.matmul(act_mat,act_mat.permute(0,2,1)))

            L_mat += (mul_mat.norm(dim=(1,2)).pow(2) - mul_mat.diagonal(dim1=1,dim2=2).norm(dim=1).pow(2)).sum()
#             L_mat += (mul_mat.pow(2).sum(dim=(1,2)) - mul_mat.pow(2).diagonal(dim1=1,dim2=2).sum(dim=1)).sum()

        Lc = criterion(outputs, labels)
        loss = Lc + reg*(L_mat)

        loss.backward()
        optimizer.step()
            
        # print statistics
        running_loss += loss.item()
        mat_loss += L_mat

        del L_mat, act_mat, out_features
        torch.cuda.empty_cache()
        
    print("Regularization loss: " + str(mat_loss/num_iter))
    
    m_ratio(net)
    
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    cal_acc(net)
    
#     reg = 1e-9 #10**(-((mat_loss.detach().cpu()/num_iter).log()/np.log(10)).floor())
    print("reg: ", reg)
print('Finished Training')

In [None]:
# # import os
# # os.mkdir("decorr")
PATH = './decorr/cifar100_decorrnet_8.pth'
torch.save(net.state_dict(), PATH)