# VGG-16 on CIFAR-10
This notebook is used to experiment with VGG-16 on CIFAR-10 dataset.

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import sys

sys.path.append('../../')
sys.path.append('../../src/')

import src.general as general
import src.dataset_models as data
import src.metrics as metrics
import src.evaluation as eval
import src.plot as plot
import src.compression.distillation as distill
import src.compression.pruning as prune


In [None]:
# Get model
vgg_weights = tv.models.VGG16_Weights.DEFAULT
vgg16 = tv.models.vgg16(weights=vgg_weights)

# Modify the last layer to have 10 output classes (CIFAR-10 has 10 classes)
vgg16.classifier[6] = nn.Linear(4096, 10)


In [None]:
# Get dataset
dataset = data.supported_datasets["CIFAR-10"]

In [None]:
# Get transforms
vgg_cifar10_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(224, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Set transforms
dataset.set_transforms(vgg_cifar10_transform)

In [None]:
general.get_device()

In [None]:
general.test(vgg16, dataset)

In [None]:
general.train(vgg16, dataset)