# VGG-16 on CIFAR-10
This notebook is used to experiment with VGG-16 on CIFAR-10 dataset.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sys

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune
import src.compression.quantization as quant

Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

Load the dataset

In [3]:
# Get dataset
dataset = data.supported_datasets["CIFAR-10"]
# Get transforms
vgg_cifar10_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Set transforms
dataset.set_transforms(vgg_cifar10_transform)

Load the model.

In [4]:
vgg16 = torch.load("/workspace/volume/models/vgg16_cifar10.pt")

Evaluate before compression.

In [6]:
before_results = eval.get_results(vgg16, dataset)
plot.print_results(**before_results)

Test: 100%|██████████| 157/157 [00:25<00:00,  6.12it/s]


Test loss: 0.2119
Test score: 93.9889
Could not calculate FLOPS
Loss: 0.211898
Score: 93.988854
Time per data point: 10.2065 ms
Model Size: 512.33 MB
Number of parameters: 134301514
Number of MACs: 15499459200


## Pruning
This section is used to experiment with pruning.

In [10]:
quant.fuse_modules(vgg16)

AttributeError: 'Conv2d' object has no attribute 'split'

In [None]:
device = general.get_device()
vgg16.to(device)
pruned_model = prune.magnitude_pruning_structured(vgg16, dataset, 0.5, fineTune=False, iterative_steps=10, layers=vgg16.classifier)

In [None]:
after_results = eval.get_results(pruned_model, dataset)
plot.print_before_after_results(before_results, after_results)

In [None]:
general.finetune(pruned_model, dataset, target=99, max_it=10)

In [None]:
torch.save(pruned_model, "/workspace/volume/models/vgg16_cifar10_pruned_10.pt")

In [None]:
old_pruned_model = torch.load("/workspace/volume/models/vgg16_cifar10_pruned_v.pt")

In [None]:
old_results = eval.get_results(old_pruned_model, dataset)
plot.print_results(**old_results)

In [None]:
general.train(old_pruned_model, dataset)

In [None]:
torch.save(pruned_model, "/workspace/volume/models/vgg16_cifar10_pruned_50.pt")

## Quantization

In [None]:
# Load a new model
vgg16 = torch.load("/workspace/volume/models/vgg16_cifar10.pt")

In [None]:
quantized_model = quant.static_quantization(vgg16, dataset)

In [None]:
example_inputs = general.get_example_inputs(dataset.train_loader)

In [None]:
vgg16.eval()
fused_model = quant.fuse_modules(vgg16)
fused_model(example_inputs)

In [None]:
device = 'cpu'
quantized_results = eval.get_results(quantized_model, dataset, device)
plot.print_results(**quantized_results)

In [None]:
plot.print_before_after_results(before_results, quantized_results)

In [None]:
torch.save(quantized_model, "/workspace/volume/models/vgg16_cifar10_dynamic_quantized.pt")