# ResNet on CIFAR-10
This notebook is used to experiment with ResNet-50 on CIFAR-10 dataset.

## Setup

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sys

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune
import src.compression.quantization as quant


Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


Found cached dataset imagenet-1k (/workspace/volume/cache/imagenet-1k/default-212aff79ee65f848/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# Get dataset
dataset = data.supported_datasets["CIFAR-10"]

# Get transforms
resnet_cifar10_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set transforms
dataset.set_transforms(resnet_cifar10_transform)

In [5]:
MODEL_DIR = "/workspace/volume/models/"

In [4]:
resnet = torch.load(MODEL_DIR + "resnet_cifar_finetuned.pth")

In [6]:
before_results = eval.get_results(resnet, dataset)
plot.print_results(**before_results)

Test: 100%|██████████| 157/157 [00:18<00:00,  8.39it/s]


Test loss: 0.1748
Test score: 94.8248
Could not calculate FLOPS
Loss: 0.174775
Score: 94.824841
Time per data point: 7.4499 ms
Model Size: 90.04 MB
Number of parameters: 23528522
Number of MACs: 4119896576


## Pruning

In [None]:
pruned_model = prune.magnitude_pruning_structured(resnet, dataset, sparsity=0.5, fineTune=True, iterative_steps=5)

In [None]:
general.finetune(pruned_model, dataset, target=98, max_it=10)

In [None]:
torch.save(resnet.state_dict(), "/workspace/volume/models/resnet_cifar10.pt")

In [None]:
torch.save(pruned_model, "/workspace/volume/models/resnet_cifar_pruned.pt")

In [None]:
after_results = eval.get_results(pruned_model, dataset)

In [None]:
plot.print_before_after_results(before_results, after_results)

## Quantization

In [4]:
#Load basic model
model = torch.load("/workspace/volume/models/resnet_cifar10_full.pt")

In [7]:
before_results = eval.get_results(model, dataset)
plot.print_results(**before_results)

Test:   0%|          | 0/157 [00:00<?, ?it/s]


RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

In [None]:
quant.static_quantization(model, dataset, calibrate=False)

In [None]:
quant.fuse_modules(model)

In [None]:
quant.get_modules_to_fuse(model)

In [None]:
general.validate(model, dataset)