In [None]:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import numpy as np
import os
from utils import progress_bar

In [None]:
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(45),
     transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
     ])

trainset = torchvision.datasets.CIFAR100(root='./../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./../data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

In [None]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
    'VGG11': [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 100)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.ReLU(inplace=True),
                           nn.BatchNorm2d(x)]
                in_channels = x
        #layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def test():
    net = VGG('VGG11')
    x = torch.randn(2,3,32,32)
    y = net(x)
    print(y.size())

In [None]:
def w_diag():
    
    ### conv_ind == 0 ###
    w_mat = net_ortho.features[0].weight
    w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
    b_mat = net.features[0].bias
    b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
    params = torch.cat((w_mat1, b_mat1), dim=1)
    angle_mat = torch.matmul(torch.t(params), params)
    L_diag = (angle_mat.diag().norm(1))
    L_angle = (angle_mat.norm(1))
    print(L_diag.cpu()/L_angle.cpu())

    ### conv_ind != 0 ###
    for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
        w_mat = net_ortho.features[conv_ind].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net.features[conv_ind].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(params, torch.t(params))
        L_diag = (angle_mat.diag().norm(1))
        L_angle = (angle_mat.norm(1))
        print(L_diag.cpu()/L_angle.cpu())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
net = VGG('VGG13').to(device)
criterion = nn.CrossEntropyLoss()

### Train Baseline

In [None]:
# Training
def net_train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def net_test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

In [None]:
# best_acc = 0

In [None]:
for epoch in range(5):
    net_train(epoch)
    net_test(epoch)

In [None]:
net_dict = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(net_dict['net'])
best_acc = net_dict['best_acc']

### Inner product training

In [None]:
net_ortho = AlexNet(cfg).to(device)
net_dict = torch.load('./checkpoint/ckpt.pth')
net_ortho.load_state_dict(net_dict['net'])
best_acc_ortho = net_dict['best_acc']

In [None]:
l_imp = {}

for conv_ind in [0, 3, 7, 10, 14, 17, 21, 24, 28, 31]:
    l_imp.update({conv_ind: net_ortho.features[conv_ind].bias.shape[0]})
    
normalizer = 0
for key, val in l_imp.items():
    normalizer += val
for key, val in l_imp.items():
    l_imp[key] = val / normalizer

In [None]:
def net_test_ortho(epoch):
    global best_acc_ortho
    net_ortho.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net_ortho(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    print(acc)
    if acc > best_acc_ortho:
        print('Saving..')
        state = {
            'net': net_ortho.state_dict(),
            'best_acc': acc
        }
        if not os.path.isdir('ortho_checkpoint'):
            os.mkdir('ortho_checkpoint')
        torch.save(state, './ortho_checkpoint/ckpt.pth')
        best_acc_ortho = acc

In [None]:
def net_train_ortho(epoch):
    print('\nEpoch: %d' % epoch)
    net_ortho.train()
    correct = 0
    total = 0
    running_loss = 0.0
    angle_cost = 0.0
            
    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net_ortho(inputs)
        L_angle = 0
        
        ### Conv_ind == 0 ###
        w_mat = net_ortho.features[0].weight
        w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
        b_mat = net_ortho.features[0].bias
        b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
        params = torch.cat((w_mat1, b_mat1), dim=1)
        angle_mat = torch.matmul(torch.t(params), params) - torch.eye(params.shape[1]).to(device)
        L_angle += (l_imp[0])*(angle_mat).norm(1) #.norm().pow(2))
        
        ### Conv_ind != 0 ###
        for conv_ind in [3, 7, 10, 14, 17, 21, 24, 28, 31]:
            w_mat = net_ortho.features[conv_ind].weight
            w_mat1 = (w_mat.reshape(w_mat.shape[0],-1))
            b_mat = net_ortho.features[conv_ind].bias
            b_mat1 = (b_mat.reshape(b_mat.shape[0],-1))
            params = torch.cat((w_mat1, b_mat1), dim=1)
            angle_mat = torch.matmul(params, torch.t(params)) - torch.eye(w_mat.shape[0]).to(device)
            L_angle += (l_imp[conv_ind])*(angle_mat).norm(1) #.norm().pow(2))
    
        Lc = criterion(outputs, labels)
        loss = (1e-1)*(L_angle) + Lc
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        angle_cost += (L_angle).item()
    
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            % (running_loss/(batch_idx+1), 100.*correct/total, correct, total))
    
    print("angle_cost: ", angle_cost/total)

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()

In [None]:
optimizer = optim.Adam(net_ortho.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [None]:
for epoch in range(5):
    net_train_ortho(epoch)
    net_test_ortho(epoch)
    w_diag(net_ortho)

In [None]:
net_dict = torch.load('./ortho_checkpoint/ckpt.pth')
net_ortho.load_state_dict(net_dict['net'])
best_acc_ortho = net_dict['best_acc']