# Dimensionality Reduction of VGG16
In this tutorial we will present how to create a reduced version of VGG16 using the techniques described in the article ''A Dimensionality Reduction Approach for Convolutional Neural Networks'', Meneghetti L., Demo N., Rozza G., https://arxiv.org/abs/2110.09163 (2021)

In [1]:
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pandas as pd

## Loading of the model
First of all we need to load the model we want to use (in this case VGG16) starting from a checkpoint file, i.e. a file containing the status of the model after a training process with a chosen dataset. Here we will use the CIFAR10 dataset, but we will also show how to generalize everythong using a custom dataset.

It is important to highlight that the models of VGG-nets implemented in PyTorch (https://pytorch.org/hub/pytorch_vision_vgg/), e.g. 
````
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained=True),
```
are models pre-trained on the ImageNet dataset, that consists of images of dimensions 224x224. Therefore, in order to use datasets like the CIFAR10, composed of images 32x32, we need to change the architecture of VGG-nets, as was done in the file 'vgg.py'. 

In [11]:
import sys
sys.path.insert(0, '/scratch/lmeneghe/Smithers')
from smithers.ml.vgg import VGG

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    

pretrained = 'checkpoint_vgg16_cifar10_60epochs.pth.tar'
model = VGG(None, classifier='cifar', num_classes=10, init_weights=False, pretrain_weights=pretrained)

FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint_vgg16_cifar10_60epochs.pth.tar'

## Loading of the dataset
### CIFAR10 Dataset
As stated before, we use the CIFAR10 dataset (already implemented in PyTorch) to test our technique. It is a computer-vision dataset used for object recognition. It consists of 60000 32 × 32 colour images divided in 10 non-overlapping classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.

See https://www.cs.toronto.edu/~kriz/cifar.html for more details on this dataset and on how to download it.

In [None]:
#load CIFAR10 dataset for training and testing
batch_size = 8 #this can be changed
data_path = 'datasets/' 
# transform functions: take in input a PIL image and apply this
# transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = datasets.CIFAR10(root=data_path + 'CIFAR10/',
                                 train=True,
                                 download=True,
                                 transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_dataset = datasets.CIFAR10(root=data_path + 'CIFAR10/',
                                train=False,
                                download=True,
                                transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)
train_labels = torch.tensor(train_loader.dataset.targets)
targets = list(train_labels)

### Custom dataset
If we want to use a custom dataset, we need firstly to construct it, following for example the tutorial on the construction of a custom dataset for the problem of Image Recognition. Hence, the previuous cell will be substitute with the following one.

In [None]:
from torch.utils.data.sampler import SubsetRandomSampler
from collections import OrderedDict

# load custom dataset for training and testing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data = pd.read_csv('dataset_imagerec/dataframe.csv')
data_path = 'dataset_imagerec/'
# SPLIT OF THE DATASET
batch_size = 128
validation_split = .2
shuffle_dataset = True
random_seed = 42

dataset_size = len(data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
print('train data', len(train_indices))
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
resize_dim = [32, 32]

dataset_imagerec = Imagerec_Dataset(data, data_path, resize_dim, transform)
train_dataset = dataset_imagerec.getdata(train_indices)
train_loader = torch.utils.data.DataLoader(dataset_imagerec,
                                           batch_size=batch_size,
                                           sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset_imagerec,
                                          batch_size=batch_size,
                                          sampler=valid_sampler)

classes = ('class_1', 'class_2', 'class_3', 'class_4')
n_class = len(classes)
targets = list(dataset_imagerec.targets[train_indices])
train_labels = torch.tensor(targets)

## Reduction of VGG16
We now perform the reduction of VGG16 using the module NetAdapter. In this case we use 5 as cut off index and 50 as dimension of the reduced space. For the reduced method and the input-output mapping there are two different choices: 'POD' and 'AS' for the first one, and 'PCE' or 'FNN' for the latter.

In [None]:
from smithers.ml.netadapter import NetAdapter

cutoff_idx = 5 
red_dim = 50 
red_method = 'POD' 
inout_method = 'FNN'
netadapter = NetAdapter(cutoff_idx, red_dim, red_method, inout_method)
red_model = netadapter(model)

In [None]:
rednet_storage = torch.zeros(3)
rednet_flops = torch.zeros(3)

rednet_storage[0], rednet_storage[1], rednet_storage[2] = [
    Total_param(rednet.premodel),
    Total_param(rednet.POD_model),
    Total_param(rednet.ANN)]

rednet_flops[0], rednet_flops[1], rednet_flops[2] = [
    Total_flops(rednet.premodel, device),
    Total_flops(rednet.POD_model, device),
    Total_flops(rednet.ANN, device)]
print(
      'acc: {:.2f} Pre nnz = {:.2f}, POD_model nnz={:.2f}, ANN nnz={:.4f}'.format(
      (100 * correct / total), rednet_storage[0], rednet_storage[1],
      rednet_storage[2]))
print(
      'flops:  Pre = {:.2f}, POD_model = {:.2f}, ANN ={:.2f}'.format(
       rednet_flops[0], rednet_flops[1], rednet_flops[2]))

optimizer = torch.optim.Adam([{
            'params': rednet.premodel.parameters(),
            'lr': 1e-4
            }, {
            'params': rednet.POD_model.parameters(),
            'lr': 1e-5
            }, {
            'params': rednet.ANN.parameters(),
            'lr': 1e-5
            }])

train_loss = []
test_loss = []
train_loss.append(compute_loss(rednet, device, train_loader))
test_loss.append(compute_loss(rednet, device, test_loader))

        
epochs = 10
filename = './cifar10_VGG16_RedNet'+\
            '_epoch_%d_cutID_%d.pth'%(epochs_ann,cutoff_idx)

if os.path.isfile(filename):
    [rednet_pretrained, train_loss,test_loss] = torch.load(filename)
    rednet.load_state_dict(rednet_pretrained)
    print('rednet trained {} epoches is loaded'.format(epochs))
else:
    start_time_REDNET = time()
    train_loss = []
    test_loss = []
    train_loss.append(compute_loss(rednet, device, train_loader))
    test_loss.append(compute_loss(rednet, device, test_loader))
    for epoch in range(1, epochs + 1):
        print('EPOCH {}'.format(epoch))
        train_loss.append(
                train_kd(rednet,
                model,
                device,
                train_loader,
                optimizer,
                train_max_batch,
                alpha=0.1,
        test_loss.append(compute_loss(rednet, device, test_loader))
        torch.save([rednet.state_dict(), train_loss, test_loss], filename)
        end_time_REDNET = time()
        time_REDNET = end_time_REDNET - start_time_REDNET
        print('Time needed to compute and train the reduced net', time_REDNET)


rednet_storage[0, 0], rednet_storage[0, 1], rednet_storage[0, 2] = [
           Total_param(rednet.premodel),
           Total_param(rednet.POD_model),
           Total_param(rednet.ANN)
        ]

rednet_flops[0, 0], rednet_flops[0, 1], rednet_flops[0, 2] = [
            Total_flops(rednet.premodel, device),
            Total_flops(rednet.POD_model, device),
            Total_flops(rednet.ANN, device)
        ]
print(
      'acc: {:.2f} Pre nnz = {:.2f}, POD_model nnz={:.2f}, ANN nnz={:.4f}'.format(
                  test_loss[-1], rednet_storage[0, 0], rednet_storage[0, 1],
                  rednet_storage[0, 2]))

print(
      'flops:  Pre = {:.2f}, POD_model = {:.2f}, ANN ={:.2f}'.format(
                  rednet_flops[0, 0], rednet_flops[0, 1], rednet_flops[0, 2]))