# ResNet on CIFAR-10
This notebook is used to experiment with ResNet-50 on CIFAR-10 dataset.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
import torch
import sys
import torch.nn.functional as F

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.interfaces.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune
import src.compression.quantization as quant


In [1]:
dataset = data.supported_datasets["CIFAR-10"]

NameError: name 'data' is not defined

In [4]:
resnet = torch.load("/workspace/volume/models/resnet_cifar_finetuned.pth")

In [7]:
before_results = eval.get_results(resnet, dataset)
plot.print_results(**before_results)

Test: 100%|██████████| 157/157 [00:19<00:00,  7.97it/s]


Test loss: 0.1898
Test score: 94.4367
Loss: 0.189813
Score: 94.436704
Time per data point: 7.8383 ms
Model Size: 90.04 MB
Number of parameters: 23528522
Number of MACs: 4119896576


In [8]:
torch.save(resnet, "/workspace/volume/models/resnet_cifar10.pt")

## Pruning

In [7]:
pruned_model = prune.channel_pruning(resnet, dataset, prune.PruningTechnique.L1, sparsity=0.8)

In [8]:
after_results = eval.get_results(pruned_model, dataset)

Test: 100%|██████████| 157/157 [00:12<00:00, 12.27it/s]


Test loss: 18.2151
Test score: 9.9821


In [12]:
plot.print_before_after_results(before_results, after_results)

Loss: 0.191720 -> 18.218641 (9402.73%)
Score: 94.207803 -> 9.982086 (-89.40%)
Time per data point: 13.8665 ms -> 6.6640 ms (-51.94%)
Model Size: 90.04 MB -> 3.78 MB (-95.80%)
Number of parameters: 23528522 -> 956641 (-95.93%)
Number of MACs: 4119896576 -> 291470594 (-92.93%)


In [None]:
torch.save(pruned_model, "/workspace/volume/models/resnet_cifar10_pruned_99.pth")

In [None]:
quantized_model = quant.dynamic_quantization(pruned_model, dataset)

## Quantization

In [None]:
#Load basic model
model = torch.load("/workspace/volume/models/resnet_cifar10_full.pt")

In [None]:
before_results = eval.get_results(model, dataset)
plot.print_results(**before_results)

Dynamic quantization

In [None]:
dynamic_quantized_model = quant.dynamic_quantization(model, dataset)

In [None]:
quant.fuse_modules(model)

In [None]:
quant.get_modules_to_fuse(model)

In [None]:
general.validate(model, dataset)

## Distillation

In [4]:
teacher = torch.load("/workspace/volume/models/resnet_cifar_finetuned.pt")
student = torch.load("/workspace/volume/models/resnet_cifar_pruned_finetuned_v2.pt")

In [7]:
teacher_results = eval.get_results(teacher, dataset)
student_results = eval.get_results(student, dataset)
plot.print_before_after_results(teacher_results, student_results)

Test: 100%|██████████| 157/157 [00:19<00:00,  8.09it/s]


Test loss: 0.1884
Test score: 94.6955


Test: 100%|██████████| 157/157 [00:13<00:00, 11.80it/s]


Test loss: 0.7471
Test score: 74.7711
Loss: 0.188398 -> 0.747112 (296.56%)
Score: 94.695462 -> 74.771099 (-21.04%)
Time per data point: 7.7226 ms -> 5.2971 ms (-31.41%)
Model Size: 90.04 MB -> 3.78 MB (-95.80%)
Number of parameters: 23528522 -> 956641 (-95.93%)
Number of MACs: 4119896576 -> 291470594 (-92.93%)


In [16]:
settings = {
    "temperature": 3,
    "alpha": 0.5,
    "epochs": 10,
    "performance_target": 90,
    "distil_technique": distill.combined_loss_distillation,
    "distil_criterion": F.kl_div,
    "optimizer": torch.optim.Adam(student.parameters(), lr=0.001)
}

In [17]:
distilled_model = distill.perform_distillation(teacher, dataset, student, settings)

Validate: 100%|██████████| 79/79 [00:06<00:00, 11.90it/s]


Test loss: 0.7350
Test score: 75.4153


Distillation Training: 100%|██████████| 11250/11250 [08:01<00:00, 23.35it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.01it/s]


Test loss: 0.8549
Test score: 71.7761


Distillation Training: 100%|██████████| 11250/11250 [08:07<00:00, 23.05it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.67it/s]


Test loss: 0.7612
Test score: 74.5253


Distillation Training: 100%|██████████| 11250/11250 [08:03<00:00, 23.27it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.47it/s]


Test loss: 0.6546
Test score: 78.8172


Distillation Training: 100%|██████████| 11250/11250 [08:03<00:00, 23.27it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.72it/s]


Test loss: 0.6978
Test score: 78.9359


Distillation Training: 100%|██████████| 11250/11250 [08:01<00:00, 23.38it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.76it/s]


Test loss: 0.5791
Test score: 81.1907


Distillation Training: 100%|██████████| 11250/11250 [08:08<00:00, 23.01it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.34it/s]


Test loss: 0.5523
Test score: 81.3884


Distillation Training: 100%|██████████| 11250/11250 [08:05<00:00, 23.16it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.67it/s]


Test loss: 0.5114
Test score: 82.8718


Distillation Training: 100%|██████████| 11250/11250 [08:21<00:00, 22.45it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 11.98it/s]


Test loss: 0.5315
Test score: 82.2983


Distillation Training: 100%|██████████| 11250/11250 [13:21<00:00, 14.03it/s] 
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.37it/s]


Test loss: 0.4707
Test score: 84.4937


Distillation Training: 100%|██████████| 11250/11250 [08:37<00:00, 21.75it/s]  
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.45it/s]


Test loss: 0.4760
Test score: 84.3552


Distillation Training: 100%|██████████| 11250/11250 [12:46<00:00, 14.67it/s]  
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.19it/s]


Test loss: 0.5317
Test score: 84.5926


Distillation Training: 100%|██████████| 11250/11250 [08:10<00:00, 22.93it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.53it/s]


Test loss: 0.5044
Test score: 84.0981


Distillation Training: 100%|██████████| 11250/11250 [08:09<00:00, 22.98it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.44it/s]


Test loss: 0.4698
Test score: 85.1859


Distillation Training: 100%|██████████| 11250/11250 [08:05<00:00, 23.19it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.68it/s]


Test loss: 0.4717
Test score: 84.8497


Distillation Training: 100%|██████████| 11250/11250 [07:56<00:00, 23.61it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 13.04it/s]


Test loss: 0.4618
Test score: 85.0870


Distillation Training: 100%|██████████| 11250/11250 [08:06<00:00, 23.14it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.37it/s]


Test loss: 0.4274
Test score: 86.2342


Distillation Training: 100%|██████████| 11250/11250 [08:06<00:00, 23.12it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.53it/s]


Test loss: 0.3914
Test score: 86.9462


Distillation Training: 100%|██████████| 11250/11250 [08:05<00:00, 23.18it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.59it/s]


Test loss: 0.4378
Test score: 85.6013


Distillation Training: 100%|██████████| 11250/11250 [08:13<00:00, 22.81it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.35it/s]


Test loss: 0.3990
Test score: 86.9066


Distillation Training: 100%|██████████| 11250/11250 [08:04<00:00, 23.23it/s]
Validate: 100%|██████████| 79/79 [00:06<00:00, 12.49it/s]

Test loss: 0.4220
Test score: 86.6495





In [18]:
distilled_results = eval.get_results(distilled_model, dataset)
plot.print_before_after_results(student_results, distilled_results)

Test: 100%|██████████| 157/157 [00:12<00:00, 12.53it/s]


Test loss: 0.4179
Test score: 86.5744
Loss: 0.747112 -> 0.417887 (-44.07%)
Score: 74.771099 -> 86.574443 (15.79%)
Time per data point: 5.2971 ms -> 4.9891 ms (-5.81%)
Model Size: 3.78 MB -> 3.78 MB (-0.00%)
Number of parameters: 956641 -> 956641 (-0.00%)
Number of MACs: 291470594 -> 291470594 (-0.00%)


In [19]:
plot.print_before_after_results(teacher_results, distilled_results)

Loss: 0.188398 -> 0.417887 (121.81%)
Score: 94.695462 -> 86.574443 (-8.58%)
Time per data point: 7.7226 ms -> 4.9891 ms (-35.40%)
Model Size: 90.04 MB -> 3.78 MB (-95.80%)
Number of parameters: 23528522 -> 956641 (-95.93%)
Number of MACs: 4119896576 -> 291470594 (-92.93%)


In [None]:
torch.save(distilled_model, "/workspace/volume/models/resnet_cifar_distilled.pt")