# VGG-16 on CIFAR-10
This notebook is used to experiment with VGG-16 on CIFAR-10 dataset.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sys

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune
import src.compression.quantization as quant

Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

Load the dataset

In [3]:
# Get dataset
dataset = data.supported_datasets["CIFAR-10"]
# Get transforms
vgg_cifar10_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Set transforms
dataset.set_transforms(vgg_cifar10_transform)

Load the model.

In [4]:
vgg16 = torch.load("/workspace/volume/models/vgg16_cifar10.pt")

Evaluate before compression.

In [5]:
before_results = eval.get_results(vgg16, dataset)
plot.print_results(**before_results)

Test: 100%|██████████| 157/157 [00:26<00:00,  5.97it/s]


Test loss: 0.2099
Test score: 93.9889
Loss: 0.209863
Score: 93.988854
Time per data point: 10.4772 ms
Model Size: 512.33 MB
Number of parameters: 134301514
Number of MACs: 15499459200


## Pruning
This section is used to experiment with pruning.

In [None]:
quant.fuse_modules(vgg16)

In [None]:
device = general.get_device()
vgg16.to(device)
pruned_model = prune.channel_pruning(vgg16, dataset, 0.5, fineTune=True, iterative_steps=10, layers=vgg16.classifier)

In [None]:
torch.save(pruned_model, "/workspace/volume/models/vgg16_cifar10_pruned_50.pt")

In [None]:
pruned_model = torch.load("/workspace/volume/models/vgg16_cifar10_pruned_50.pt")

In [None]:
general.validate(pruned_model, dataset)

## Distillation

Load teacher and student.

In [14]:
teacher = torch.load("/workspace/volume/models/vgg16_cifar10.pt")
student = torch.load("/workspace/volume/models/vgg16_cifar10_pruned_50.pt")

Evaluate performance before distillation.

In [5]:
teacher_results = eval.get_results(teacher, dataset)
plot.print_results(**teacher_results)

Test: 100%|██████████| 157/157 [00:26<00:00,  5.90it/s]


Test loss: 0.2146
Test score: 94.0585
Could not calculate FLOPS
Loss: 0.214559
Score: 94.058519
Time per data point: 10.5894 ms
Model Size: 512.33 MB
Number of parameters: 134301514
Number of MACs: 15499459200


In [6]:
student_results = eval.get_results(student, dataset)
plot.print_results(**student_results)

Test: 100%|██████████| 157/157 [00:16<00:00,  9.62it/s]


Test loss: 0.4197
Test score: 86.4351
Could not calculate FLOPS
Loss: 0.419710
Score: 86.435111
Time per data point: 6.4991 ms
Model Size: 242.15 MB
Number of parameters: 63475626
Number of MACs: 3934750048


Perform distillation

In [7]:
settings = {
    "epochs": 3,
    "distil_technique": distill.soft_target_distillation,
    "distil_loss": F.kl_div,
}

In [None]:
student = distill.perform_distillation(teacher, dataset, student_model=student, settings=settings)

In [28]:
student_results = eval.get_results(student, dataset)
plot.print_before_after_results(teacher_results, student_results)

Test: 100%|██████████| 157/157 [00:15<00:00,  9.81it/s]


Test loss: 0.3126
Test score: 90.0279
Loss: 0.214559 -> 0.312607 (45.70%)
Score: 94.058519 -> 90.027866 (-4.29%)
Time per data point: 10.5894 ms -> 6.3684 ms (-39.86%)
Model Size: 512.33 MB -> 242.15 MB (-52.74%)
Number of parameters: 134301514 -> 63475626 (-52.74%)
Number of MACs: 15499459200 -> 3934750048 (-74.61%)


In [None]:
student_results = eval.get_results(student, dataset)
plot.print_before_after_results(teacher_results, student_results)

In [None]:
general.finetune(pruned_model, dataset, target=99, max_it=10)

In [29]:
torch.save(student, "/workspace/volume/models/vgg16_cifar10_distilled.pt")

## Quantization

Load a new model.

In [None]:
vgg16 = torch.load("/workspace/volume/models/vgg16_cifar10.pt")

In [None]:
quantized_model = quant.static_quantization(vgg16, dataset)

In [None]:
device = 'cpu'
quantized_results = eval.get_results(quantized_model, dataset, device)
plot.print_results(**quantized_results)

In [None]:
plot.print_before_after_results(before_results, quantized_results)

In [None]:
torch.save(quantized_model, "/workspace/volume/models/vgg16_cifar10_dynamic_quantized.pt")