# main function for decomposition
### Author: Yiming Fang

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR

import torchvision
import torchvision.transforms as transforms
from torchvision import models

import tensorly as tl
import tensorly
from itertools import chain
from tensorly.decomposition import parafac, partial_tucker

import os
import matplotlib.pyplot as plt
import numpy as np
import time

from nets import *
from decomp import *

In [2]:
# load data
def load_mnist():
    print('==> Loading data..')
    transform_train = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
    return trainloader, testloader

def load_cifar10():
    print('==> Loading data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
    
    return trainloader, testloader

# ImageNet is no longer publically available
def load_imagenet():
    print('==> Loading data..')
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    trainset = torchvision.datasets.ImageNet(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    testset = torchvision.datasets.ImageNet(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
    
    return trainloader, testloader

def load_cifar100():
    print('==> Loading data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
    
    return trainloader, testloader

In [3]:
# build model
def build(model, decomp='cp'):
    print('==> Building model..')
    tl.set_backend('pytorch')
    full_net = model
    full_net = full_net.to(device)
    torch.save(full_net, 'models/model')
    if decomp:
        decompose(decomp)
    net = torch.load("models/model").cuda()
    print(net)
    print('==> Done')
    return net
    
# training
def train(epoch, train_acc, model):
    print('\nEpoch: ', epoch)
    model.train()
    criterion = nn.CrossEntropyLoss()
    train_loss = 0
    correct = 0
    total = 0
    print('|', end='')
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % 10 == 0:
            print('=', end='')
    print('|', 'Accuracy:', 100. * correct / total,'% ', correct, '/', total)
    train_acc.append(correct / total)
    return train_acc

# testing
def test(test_acc, model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        print('|', end='')
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx % 10 == 0:
                print('=', end='')
    acc = 100. * correct / total
    print('|', 'Accuracy:', acc, '% ', correct, '/', total)
    test_acc.append(correct / total) 
    return test_acc

# decompose
def decompose(decomp):
    model = torch.load("models/model").cuda()
    model.eval()
    model.cpu()
    for i, key in enumerate(model.features._modules.keys()):
        if i >= len(model.features._modules.keys()) - 2:
            break
        conv_layer = model.features._modules[key]
        if isinstance(conv_layer, torch.nn.modules.conv.Conv2d):
            rank = max(conv_layer.weight.data.numpy().shape) // 10
            if decomp == 'cp':
                model.features._modules[key] = cp_decomposition_conv_layer(conv_layer, rank)
            if decomp == 'tucker': 
                ranks = [int(np.ceil(conv_layer.weight.data.numpy().shape[0] / 3)), 
                         int(np.ceil(conv_layer.weight.data.numpy().shape[1] / 3))]
                model.features._modules[key] = tucker_decomposition_conv_layer(conv_layer, ranks)
            if decomp == 'tt':
                model.features._modules[key] = tt_decomposition_conv_layer(conv_layer, rank)
        torch.save(model, 'models/model')
    return model

# Run functions
def run_train(i, model):
    train_acc = []
    test_acc = []
    for epoch in range(i):
        s = time.time()
        train_acc = train(epoch, train_acc, model)
        test_acc = test(test_acc, model)
        scheduler.step()
        e = time.time()
        print('This epoch took', e - s, 'seconds')
        print('Current learning rate: ', scheduler.get_lr()[0])
    print('Best training accuracy overall: ', max(test_acc))
    return train_acc, test_acc

In [4]:
# main function
def run_all(dataset, decomp=None, iterations=100, rate=0.05): 
    global trainloader, testloader, device, optimizer, scheduler
    
    # choose an appropriate learning rate
    rate = rate
    
    # choose dataset from (MNIST, CIFAR10, ImageNet)
    if dataset == 'mnist':
        trainloader, testloader = load_mnist()
        model = Net()
    if dataset == 'cifar10':
        trainloader, testloader = load_cifar10()
        model = VGG('VGG19')
    if dataset == 'cifar100':
        trainloader, testloader = load_cifar100()
        model = VGG('VGG19')
    
    # check GPU availability
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
  
    # choose decomposition algorithm from (CP, Tucker, TT)
    net = build(model, decomp)
    optimizer = optim.SGD(net.parameters(), lr=rate, momentum=0.9, weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.9)
    train_acc, test_acc = run_train(iterations, net)
    
    if not decomp:
        decomp = 'full'
    
    filename = dataset + '_' + decomp
    torch.save(net, 'models/' + filename)
    np.save('curves/' + filename + '_train', train_acc)
    np.save('curves/' + filename + '_test', test_acc)

In [None]:
%%time
run_all('mnist')

==> Loading data..
==> Building model..
Net(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.25, inplace=False)
    (1): Linear(in_features=9216, out_features=128, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)
==> Done

Epoch:  0
This epoch took 13.191731214523315 seconds
Current learning rate:  0.05

Epoch:  1
This epoch took 12.04975700378418 seconds
Current learning rate:  0.05

Epoch:  2
This epoch took 12.05876111984253 seconds
Current learning rate:  0.05

Epoch:  3
This epoch took 12.11062240600586 seconds
Current learning rate:  0.05

Epoch:  4
This epoch took 12.165474891662598 seconds
Current lea



In [None]:
%%time
run_all('mnist', 'cp', rate=0.01)

In [None]:
%%time
run_all('mnist', 'tucker')

In [None]:
%%time
run_all('mnist', 'tt')

In [None]:
%%time
run_all('cifar10', iterations=200)

In [None]:
%%time
run_all('cifar10', 'tucker', iterations=200)

In [None]:
%%time
run_all('cifar100', 'tt', iterations=200)