# ResNet on CIFAR-10
This notebook is used to experiment with ResNet-50 on CIFAR-10 dataset.

## Setup

In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sys

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune


In [19]:
# Get model
resnet_weights = tv.models.ResNet50_Weights.DEFAULT
resnet = tv.models.resnet50(weights=resnet_weights)

# Modify the last layer to have 10 output classes (CIFAR-10 has 10 classes)
resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)


In [20]:
# Get dataset
dataset = data.supported_datasets["CIFAR-10"]

In [21]:
# Get transforms
resnet_cifar10_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Set transforms
dataset.set_transforms(resnet_cifar10_transform)

In [22]:
general.get_device()

device(type='cuda')

In [24]:
torch.cuda.empty_cache()

In [27]:
eval.get_size(resnet)

90.04

In [25]:
general.test(resnet, dataset)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 9.78 GiB total capacity; 94.86 MiB already allocated; 3.12 MiB free; 104.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [9]:
general.finetune(vgg16, dataset, target=99, max_it=10)

Train:  91%|█████████▏| 715/782 [04:43<00:26,  2.54it/s]

In [1]:
general.test(vgg16, dataset)

NameError: name 'general' is not defined

In [11]:
plot.print_before_after_results(before_results, after_results)

