# Train CIFAR with the `policy` module

Let's install the latest version of Poutyne (if it's not already) and import all the needed packages.

In [1]:
%pip install --upgrade poutyne

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18

from poutyne import Model, OptimizerPolicy, one_cycle_phases

Requirement already up-to-date: poutyne in /home/fredy/.venv/base/lib/python3.8/site-packages (1.0.0)
Note: you may need to restart the kernel to use updated packages.


## Training constant
But first, let's set the training constants, the CUDA device used for training if one is present, we set the batch size (i.e. the number of elements to see before updating the model) and the number of epochs (i.e. the number of times we see the full dataset).

In [2]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

batch_size = 1024
epochs = 5

# Load the data

In [3]:
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [4]:
root = "datasets"
train_ds = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
val_ds = datasets.CIFAR10(root, train=False, transform=val_transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8
)

# The model
We'll train a simple `resNet-18` network.
This takes a while without GPU but is pretty quick with GPU.

In [6]:
def get_module():
    model = resnet18(pretrained=False)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 10)
    return model


# Training without the `policies` module

In [7]:
pytorch_network = get_module()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
)

[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.71s [93mloss:[96m 2.115406[93m acc:[96m 22.852000[93m val_loss:[96m 1.870449[93m val_acc:[96m 32.520000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.59s [93mloss:[96m 1.789251[93m acc:[96m 34.850000[93m val_loss:[96m 1.664176[93m val_acc:[96m 39.490000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.58s [93mloss:[96m 1.638004[93m acc:[96m 40.164000[93m val_loss:[96m 1.560369[93m val_acc:[96m 42.310000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.76s [93mloss:[96m 1.558112[93m acc:[96m 43.102000[93m val_loss:[96m 1.494177[93m val_acc:[96m 45.360000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.70s [93mloss:[96m 1.493087[93m ac

# Training with the `policies` module

In [8]:
steps_per_epoch = len(train_dl)
steps_per_epoch

49

In [9]:
pytorch_network = get_module()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

policy = OptimizerPolicy(
    one_cycle_phases(epochs * steps_per_epoch, lr=(0.01, 0.1, 0.008)),
)
history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
    callbacks=[policy],
)

[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.62s [93mloss:[96m 1.848409[93m acc:[96m 33.260000[93m val_loss:[96m 1.755679[93m val_acc:[96m 42.600000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.73s [93mloss:[96m 1.380307[93m acc:[96m 50.424000[93m val_loss:[96m 1.286768[93m val_acc:[96m 54.810000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.61s [93mloss:[96m 1.170797[93m acc:[96m 58.460000[93m val_loss:[96m 1.094710[93m val_acc:[96m 60.010000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.62s [93mloss:[96m 0.970011[93m acc:[96m 65.824000[93m val_loss:[96m 0.938677[93m val_acc:[96m 67.060000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m8.80s [93mloss:[96m 0.809022[93m ac