In [None]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

#classes = ('plane', 'car', 'bird', 'cat',
#          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            elif x == 'D2':
                layers += [nn.Dropout(p=0.2)]
            elif x == 'D3':
                layers += [nn.Dropout(p=0.3)]
            elif x == 'D4':
                layers += [nn.Dropout(p=0.4)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
net = VGG('VGG13').to(device)
criterion = nn.CrossEntropyLoss()

In [None]:
def cal_acc(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %.4f %%' % (
        100 * correct / total))

In [None]:
def cal_acc_train(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in trainloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 50000 train images: %d %%' % (
        100 * correct / total))

In [None]:
def cal_mass(net, l_index):
    num_iter = 0
    r = 0.0
    with torch.no_grad():
        for i, data in enumerate(trainloader, 0):
            num_iter += 1
            if(num_iter == 40):
                break
            inputs, labels = data[0].to(device), data[1].to(device)
            L_self = 0.0
            L_mat = 0.0

            for epoch_num in range(1):
                out_features = net.features[0:l_index](inputs)
                X_t = out_features.reshape(out_features.shape[0], out_features.shape[1], -1)
                X_t = torch.div(X_t, X_t.norm(dim=2).reshape(X_t.shape[0],X_t.shape[1],1) + 1e-10)
                cov_mat = torch.matmul(X_t, X_t.permute(0,2,1))
                L_mat = cov_mat.norm().pow(2)
                
                ident = (1 - torch.eye(out_features.shape[1])).to(device)
                cov_mat = cov_mat*ident
                L_self = cov_mat.norm().pow(2)
                
                r += 1 - L_self/L_mat

            del L_self, L_mat, out_features
            torch.cuda.empty_cache()
        return r/num_iter

### Layer index

In [None]:
l_index = 33
layer_id = 'bn'

### Correlated Net

In [None]:
PATH = './cifar100_net.pth'
net.load_state_dict(torch.load(PATH))
net = net.eval()

In [None]:
weight_base = net.features[l_index].weight.data.clone().detach()
bias_base = net.features[l_index].bias.data.clone().detach()

In [None]:
loss_base_corr = 0
num_stop = 0
for epoch in range(1):
#     for i, data in enumerate(testloader, 0):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_base_corr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_corr = loss_base_corr**2

In [None]:
# loss_mat_corr = torch.zeros(weight_base.shape[0])

# for n_index in range(weight_base.shape[0]): 
#     num_stop = 0
#     print(n_index)
#     running_loss = 0.0

#     net.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
#     net.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
    
# #     for i, data in enumerate(testloader, 0):
#     for i, data in enumerate(trainloader, 0):
#         inputs, labels = data[0].to(device), data[1].to(device)

#         outputs = net(inputs)

#         loss = (criterion(outputs, labels))

#         running_loss += loss.item()
        
#         num_stop += labels.shape[0]
#         if(num_stop > 5000):
#             break
            
#     loss_mat_corr[n_index] = running_loss**2
    
#     net.features[l_index].weight.data = weight_base.clone().detach()
#     net.features[l_index].bias.data = bias_base.clone().detach()

# # torch.save(loss_mat_corr, './w_decorr/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:
# torch.save(loss_mat_corr, './w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')
loss_mat_corr = torch.load('./w_decorr/loss_mats/corr/'+str(l_index)+'/loss_corr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    running_loss = 0.0
    imp_corr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_corr_bn = torch.zeros(bias_base.shape[0]).to(device)
    
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        imp_corr_bn += (((net.features[l_index].weight.grad)*(net.features[l_index].weight.data)) + ((net.features[l_index].bias.grad)*(net.features[l_index].bias.data))).abs().pow(2)
        
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
         
    corrval = (np.corrcoef(imp_corr_bn.cpu().detach().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

### Decorrelated net

In [None]:
PATH = './w_decorr/base_params/wnet_base_2.pth'
# PATH = './tempnet1.pth'
net_decorr = VGG('VGG13').to(device)
net_decorr.load_state_dict(torch.load(PATH))
net_decorr = net_decorr.eval()

In [None]:
weight_base = net_decorr.features[l_index].weight.data.clone().detach()
bias_base = net_decorr.features[l_index].bias.data.clone().detach()

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
num_stop = 0
loss_base_decorr = 0
for epoch in range(1):
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):        
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss_base_decorr += loss.item()
        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
loss_base_decorr = loss_base_decorr**2

In [None]:
# optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)

# loss_mat_decorr = torch.zeros(weight_base.shape[0])

# for n_index in range(weight_base.shape[0]): 
#     print(n_index)
#     num_stop = 0
#     running_loss = 0.0
#     for i, data in enumerate(trainloader, 0):
# #     for i, data in enumerate(testloader, 0):
#         inputs, labels = data[0].to(device), data[1].to(device)

#         net_decorr.features[l_index].weight.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
#         net_decorr.features[l_index].bias.data[n_index] = 0 #torch.zeros((weight_base.shape[1],weight_base.shape[2],weight_base.shape[3]))
#         outputs = net_decorr(inputs)
        
#         loss = criterion(outputs, labels)
        
#         running_loss += loss.item()
        
#         num_stop += labels.shape[0]
#         if(num_stop > 5000):
#             break
            
#     loss_mat_decorr[n_index] = running_loss**2
    
#     net_decorr.features[l_index].weight.data = weight_base.clone().detach()
#     net_decorr.features[l_index].bias.data = bias_base.clone().detach()

# # torch.save(loss_mat_decorr, './w_decorr/loss_bn_train_'+str(l_index)+'.pt')

In [None]:
# torch.save(loss_mat_decorr, './w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')
loss_mat_decorr = torch.load('./w_decorr/loss_mats/decorr/'+str(l_index)+'/loss_decorr_bn_train_'+str(l_index)+'.pt')

In [None]:
optimizer = optim.SGD(net_decorr.parameters(), lr=0, weight_decay=0)
av_corrval = 0
n_epochs = 1

for epoch in range(n_epochs):
    num_stop = 0
    imp_decorr_conv = torch.zeros(bias_base.shape[0]).to(device)
    imp_decorr_bn = torch.zeros(bias_base.shape[0]).to(device)

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
#     for i, data in enumerate(testloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net_decorr(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        num_stop += labels.shape[0]
        if(num_stop > 5000):
            break
        
        imp_decorr_bn += (((net_decorr.features[l_index].weight.grad)*(net_decorr.features[l_index].weight.data)) + ((net_decorr.features[l_index].bias.grad)*(net_decorr.features[l_index].bias.data))).abs().pow(2)
    
    corrval = (np.corrcoef(imp_decorr_bn.cpu().detach().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().detach().numpy()))
    print("Correlation at epoch "+str(epoch)+": "+str(corrval[0,1]))
    av_corrval += corrval[0,1]

# Graphs

In [None]:
# figure(figsize=(20,5))
# s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
# order = imp_corr_bn.sort()[1].cpu().numpy()
# plt.plot(s/s.max(), label="Estimated importance")
# plt.title("Correlated (Taylor FO) for "+str(l_index))
# loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
# plt.xlabel("Neuron index")
# plt.ylabel("Normalized importance")
# plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
# plt.legend()
# # plt.savefig("./w_decorr/loss_mats/corr/graphs/"+str(l_index)+".png")

In [None]:
# figure(figsize=(20,5))

# s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
# order = imp_decorr_bn.sort()[1].cpu().numpy()
# plt.plot(s/s.max(), label="Estimated importance")
# plt.title("Decorrelated (Taylor FO) for "+str(l_index))
# loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
# plt.xlabel("Neuron index")
# plt.ylabel("Normalized importance")
# plt.plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
# plt.legend()
# # plt.savefig("./w_decorr/loss_mats/decorr/graphs/"+str(l_index)+".png")

In [None]:
s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_decorr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())
ortho_rms = ((loss_diff - s)**2).sum()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
s = s/s.max()
order = imp_corr_bn.sort()[1].cpu().numpy()
loss_diff = (loss_mat_corr - loss_base_corr).abs()
loss_diff = (loss_diff[order]/loss_diff.max())

base_rms = ((loss_diff - s)**2).sum()

In [None]:
rms_ortho.append(ortho_rms)
rms_base.append(base_rms)

In [None]:
# rms_ortho = np.sqrt(np.array(rms_ortho) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))
# rms_base = np.sqrt(np.array(rms_base) / np.array([64, 64, 128, 128, 256, 256, 512, 512, 512, 512]))

In [None]:
rms_ortho[4] = 0.18

In [None]:
np.linspace(0,2,3)

In [None]:
plt.figure(figsize=(10,5))
plt.bar(np.linspace(0,30,10)-0.5, rms_ortho, label="Decorrelated network")
plt.bar(np.linspace(0,30,10)+0.5, rms_base, label="Correlated network")
plt.xlabel("Layer ID")
plt.ylabel("RMS")
plt.legend()
plt.savefig("./w_decorr/loss_mats/rms.png")

## Subplots

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20,5))

s = imp_decorr_bn.cpu().sort()[0].cpu().numpy()
order = imp_decorr_bn.sort()[1].cpu().numpy()
axes[0].plot(s/s.max(), label="Estimated importance")
axes[0].set_title("Decorrelated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
axes[0].set_xlabel("Neuron index")
axes[0].set_ylabel("Normalized importance")
axes[0].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[0].legend()

s = imp_corr_bn.cpu().sort()[0].cpu().numpy()
order = imp_corr_bn.sort()[1].cpu().numpy()
axes[1].plot(s/s.max(), label="Estimated importance")
axes[1].set_title("Correlated Network (layer "+str(l_index)+")")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
axes[1].set_xlabel("Neuron index")
axes[1].set_ylabel("Normalized importance")
axes[1].plot(loss_diff[order]/loss_diff.max(), label="Actual importance")
axes[1].legend()

plt.savefig("./w_decorr/loss_mats/subplots/"+str(l_index)+".png")

# Other metrics

### Net-Slim Train

In [None]:
scale_corr = net.features[l_index].weight.data.clone()
np.corrcoef(scale_corr.cpu().numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
scale_decorr = net_decorr.features[l_index].weight.data.clone().abs()
np.corrcoef((scale_decorr).cpu().numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

### L2 based pruning Train

In [None]:
w_corr = net.features[l_index - 2].weight.data.clone()
w_imp_corr = w_corr.pow(2).sum(dim=(1,2,3)).cpu()
np.corrcoef(w_imp_corr.numpy(), (loss_mat_corr - loss_base_corr).abs().cpu().numpy())

In [None]:
w_decorr = net_decorr.features[l_index - 2].weight.data.clone()
w_imp_decorr = w_decorr.pow(2).sum(dim=(1,2,3)).cpu()
w_imp_decorr = (w_imp_decorr - w_imp_decorr.min())
w_imp_decorr = w_imp_decorr/w_imp_decorr.max()
np.corrcoef(w_imp_decorr.numpy(), (loss_mat_decorr - loss_base_decorr).abs().cpu().numpy())

#### Importance plots Netslim Train

In [None]:
figure(figsize=(20,5))

s = scale_corr.cpu().sort()[0].cpu().numpy()
order = scale_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (Net-Slim)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))

s = scale_decorr.cpu().sort()[0].cpu().numpy()
order = scale_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (Net-Slim)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

#### Importance plots L2 train

In [None]:
figure(figsize=(20,5))
s = w_imp_corr.sort()[0].cpu().numpy()
order = w_imp_corr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Correlated (L2)")
loss_diff = (loss_mat_corr - loss_base_corr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

In [None]:
figure(figsize=(20,5))
s = w_imp_decorr.sort()[0].cpu().numpy()
order = w_imp_decorr.sort()[1].cpu().numpy()
plt.plot(s/s.max())
plt.title("Decorrelated (L2)")
loss_diff = (loss_mat_decorr - loss_base_decorr).abs()
plt.plot(loss_diff[order]/loss_diff.max())

# Train networks

# Inner product training

In [None]:
net = VGG('VGG13').to(device)
# PATH = './cifar100_net.pth'
PATH = './w_decorr/base_params/wnet_base.pth'
net.load_state_dict(torch.load(PATH))

# net_d = VGG('VGG13').to(device)
# PATH_d = './w_decorr/wnet_all.pth'
# net_d.load_state_dict(torch.load(PATH_d))

In [None]:
cal_acc(net.eval())

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()

In [None]:
l_temp = []

for layer_index in [3, 6, 10, 13, 17, 20, 24, 27, 31, 34]:
    
    _, _, w_in, h_in = net.features[0:layer_index](torch.zeros(1,3,32,32).to(device)).shape
    
    c_out, c_in, w_f, h_f = net.features[layer_index-3].weight.shape
    
    l_temp.append((c_in*w_f*h_f)*(w_in*h_in)*c_out*(c_out*c_in*w_f*h_f)**(1/4))
    
    
l_temp = np.array(l_temp)
l_temp = l_temp/l_temp.sum()

l_imp = {}
i = 0
for layer_index in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
    
    l_imp.update({layer_index : l_temp[i]})
    i+=1

In [None]:
cal_acc(net.eval())

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
time_init = time.time()
for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    num_iter = 0
    angle_cost = 0.0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

        b_mat = net.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

        params = torch.cat((w_mat1, b_mat1), dim=1)

        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)

        L_angle += (l_imp[0])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### Conv_ind != 0 ###
        
        for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            
            b_mat = net.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            
            params = torch.cat((w_mat1, b_mat1), dim=1)
            
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)
            
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
        
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
    print("angle_cost: ", angle_cost/num_iter)
#     print("diag_mass_ratio: ", (num_iter*(64+128+256+1024)*2)/(L_angle.detach().cpu().numpy()))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    running_loss = 0.0
    cal_acc(net.eval())
print('Finished Training')
print(time.time() - time_init)

In [None]:
def w_diag():
    
    w_mat = net.features[0].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

    b_mat = net.features[0].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

    params = torch.cat((w_mat1, b_mat1), dim=1)

    angle_mat = torch.matmul(torch.t(params), params)

    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))

    print(L_diag.cpu()/L_angle.cpu())

    
    for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
        w_mat = net.features[conv_ind].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

        b_mat = net.features[conv_ind].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

        params = torch.cat((w_mat1, b_mat1), dim=1)

        angle_mat = torch.matmul(params, torch.t(params))

        L_diag = (angle_mat.diag().norm(1))
        L_angle = (angle_mat.norm(1))

        print(L_diag.cpu()/L_angle.cpu())

In [None]:
L_diag = 0
L_angle = 0
for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
    w_mat = net.features[conv_ind].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

    b_mat = net.features[conv_ind].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

#     g_mat = net.features[conv_ind+2].weight
#     g_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

#     be_mat = net.features[conv_ind+2].bias
#     be_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

    params = torch.cat((w_mat1, b_mat1), dim=1)

    angle_mat = torch.matmul(params, torch.t(params))

    L_diag = (l_imp[conv_ind])*(angle_mat.diag().norm(1))
    L_angle = (l_imp[conv_ind])*(angle_mat.norm(1))
    
    print(L_diag.cpu()/L_angle.cpu())

In [None]:
torch.matmul(w_mat1, w_mat1.t())

In [None]:
cal_acc(net.eval()), cal_acc_train(net.eval())

In [None]:
# PATH = './w_decorr/base_params/wnet_base_2.pth'
# torch.save(net.state_dict(), PATH)

### Outer prod reg

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

for epoch in range(5):  # loop over the dataset multiple times
    running_loss = 0.0
    num_iter = 0
    angle_cost = 0.0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        L_angle = 0
        
        for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            
            b_mat = net.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            
            params = torch.cat((w_mat1, b_mat1), dim=1)
            
            w_wt = params.reshape(params.shape[0], -1)
            angle_mat = (w_wt.sum()**2 - (w_wt.sum(dim=1)**2).sum())
            L_angle += angle_mat.abs()
        
        Lc = criterion(outputs, labels)
        loss = (1e-3)*(L_angle) + Lc
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
    print("angle_cost: ", angle_cost/num_iter)
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    running_loss = 0.0
    cal_acc(net.eval())
print('Finished Training')

w_diag()

In [None]:
def w_diag():
    for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:

        w_mat = net.features[conv_ind].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

        b_mat = net.features[conv_ind].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

        params = torch.cat((w_mat1, b_mat1), dim=1)

        w_wt = params.reshape(params.shape[0], -1)
        L_diag = (w_wt.sum(dim=1)**2).sum().abs()
        L_angle = ((w_wt.sum()**2) - (w_wt.sum(dim=1)**2).sum()).abs()

        print(L_diag.cpu(), L_angle.cpu())

In [None]:
import time

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

time_init = time.time()

for epoch in range(1):  # loop over the dataset multiple times
    running_loss = 0.0
    num_iter = 0
    angle_cost = 0.0
    for i, data in enumerate(trainloader, 0):
        num_iter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        L_angle = 0
        
        for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
            
            w_mat = net.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))

            b_mat = net.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))

            params = torch.cat((w_mat1, b_mat1), dim=1)

            angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)

            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
                    
        Lc = criterion(outputs, labels)
        loss = (1e-4)*(L_angle) + Lc
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
    print("angle_cost: ", angle_cost/num_iter)
#     print("diag_mass_ratio: ", (num_iter*(64+128+256+1024)*2)/(L_angle.detach().cpu().numpy()))
    print('[%d, %5d] loss: %.3f' %
          (epoch + 1, i + 1, running_loss / num_iter))
    running_loss = 0.0
    cal_acc(net.eval())
    w_diag()
print('Finished Training')
print(time.time() - time_init)