# Train CIFAR with the `policy` module

Let's install the latest version of Poutyne (if it's not already) and import all the needed packages.

In [1]:
%pip install --upgrade poutyne

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18

from poutyne import Model, OptimizerPolicy, one_cycle_phases

## Training constant
But first, let's set the training constants, the CUDA device used for training if one is present, we set the batch size (i.e. the number of elements to see before updating the model) and the number of epochs (i.e. the number of times we see the full dataset).

In [2]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

batch_size = 1024
epochs = 5

# Load the data

In [3]:
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [4]:
root = "datasets"
train_ds = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
val_ds = datasets.CIFAR10(root, train=False, transform=val_transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8
)

# The model
We'll train a simple `resNet-18` network.
This takes a while without GPU but is pretty quick with GPU.

In [6]:
def get_module():
    model = resnet18(pretrained=False)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 10)
    return model


# Training without the `policies` module

In [7]:
pytorch_network = get_module()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
    device=device,
)

history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
)

[35mEpoch: [36m1/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.36s [35mloss:[94m 2.117263[35m acc:[94m 23.430000[35m val_loss:[94m 1.857028[35m val_acc:[94m 33.690000[0m
[35mEpoch: [36m2/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.35s [35mloss:[94m 1.797844[35m acc:[94m 34.568000[35m val_loss:[94m 1.663795[35m val_acc:[94m 39.050000[0m
[35mEpoch: [36m3/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.42s [35mloss:[94m 1.643785[35m acc:[94m 40.474000[35m val_loss:[94m 1.546482[35m val_acc:[94m 44.070000[0m
[35mEpoch: [36m4/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.38s [35mloss:[94m 1.550369[35m acc:[94m 43.554000[35m val_loss:[94m 1.469080[35m val_acc:[94m 46.690000[0m
[35mEpoch: [36m5/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.39s [35mloss:[94m 1.481794[35m acc:[94m 46.146000[35m val_loss:[94m 1.425816[35m val_acc:[94m 48.110000[0m


# Training with the `policies` module

In [8]:
steps_per_epoch = len(train_dl)
steps_per_epoch

49

In [9]:
pytorch_network = get_module()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
    device=device,
)

policy = OptimizerPolicy(
    one_cycle_phases(epochs * steps_per_epoch, lr=(0.01, 0.1, 0.008)),
)
history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
    callbacks=[policy],
)

[35mEpoch: [36m1/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.49s [35mloss:[94m 1.864204[35m acc:[94m 32.568000[35m val_loss:[94m 1.752385[35m val_acc:[94m 41.630000[0m
[35mEpoch: [36m2/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.48s [35mloss:[94m 1.403389[35m acc:[94m 49.632000[35m val_loss:[94m 1.403605[35m val_acc:[94m 49.820000[0m
[35mEpoch: [36m3/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.64s [35mloss:[94m 1.172077[35m acc:[94m 58.390000[35m val_loss:[94m 1.093589[35m val_acc:[94m 61.370000[0m
[35mEpoch: [36m4/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.55s [35mloss:[94m 0.983761[35m acc:[94m 65.396000[35m val_loss:[94m 0.971778[35m val_acc:[94m 65.950000[0m
[35mEpoch: [36m5/5 [35mTrain steps: [36m49 [35mVal steps: [36m10 [32m8.53s [35mloss:[94m 0.826081[35m acc:[94m 70.952000[35m val_loss:[94m 0.898928[35m val_acc:[94m 68.810000[0m
