# Train CIFAR with the `policy` module

Let's install the latest version of Poutyne (if it's not already) and import all the needed packages.

In [1]:
%pip install --upgrade poutyne

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18

from poutyne import Model, OptimizerPolicy, one_cycle_phases

## Training constant
But first, let's set the training constants, the CUDA device used for training if one is present, we set the batch size (i.e. the number of elements to see before updating the model) and the number of epochs (i.e. the number of times we see the full dataset).

In [2]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

batch_size = 1024
epochs = 5

# Load the data

In [3]:
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [4]:
root = "data"
train_ds = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
val_ds = datasets.CIFAR10(root, train=False, transform=val_transform, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [5]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False
)

# The model
We'll train a simple `resNet-18` network.
This takes a while without GPU but is pretty quick with GPU.

In [6]:
def get_module():
    model = resnet18(pretrained=False)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 10)
    return model


# Training without the `policies` module

In [7]:
pytorch_network = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
)





[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.58s [93mloss:[96m 2.115964[93m acc:[96m 22.938000[93m val_loss:[96m 1.869150[93m val_acc:[96m 32.870000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.36s [93mloss:[96m 1.805687[93m acc:[96m 34.560000[93m val_loss:[96m 1.682006[93m val_acc:[96m 39.280000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.38s [93mloss:[96m 1.658098[93m acc:[96m 40.008000[93m val_loss:[96m 1.610190[93m val_acc:[96m 41.670000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.36s [93mloss:[96m 1.563534[93m acc:[96m 43.264000[93m val_loss:[96m 1.502107[93m val_acc:[96m 45.430000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.41s [93mloss:[96m 1.504496[93

# Training with the `policies` module

In [8]:
steps_per_epoch = len(train_dl)
steps_per_epoch

49

In [9]:
pytorch_network = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

policy = OptimizerPolicy(
    one_cycle_phases(epochs * steps_per_epoch, lr=(0.01, 0.1, 0.008)),
)
history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
    callbacks=[policy],
)

[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.50s [93mloss:[96m 1.847516[93m acc:[96m 33.342000[93m val_loss:[96m 1.675540[93m val_acc:[96m 42.400000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.56s [93mloss:[96m 1.400317[93m acc:[96m 49.748000[93m val_loss:[96m 1.323423[93m val_acc:[96m 52.840000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.53s [93mloss:[96m 1.142495[93m acc:[96m 59.492000[93m val_loss:[96m 1.117616[93m val_acc:[96m 61.920000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.53s [93mloss:[96m 0.946500[93m acc:[96m 66.758000[93m val_loss:[96m 0.941687[93m val_acc:[96m 67.420000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.51s [93mloss:[96m 0.782936[93m ac