# Train CIFAR with the `policy` module

Let's import all the needed packages first.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import resnet18

from poutyne.framework import Model
from poutyne.framework import OptimizerPolicy, one_cycle_phases

## Training constant
But first, let's set the training constants, the CUDA device used for training if one is present, we set the batch size (i.e. the number of elements to see before updating the model) and the number of epochs (i.e. the number of times we see the full dataset).

In [2]:
cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

batch_size = 1024
epochs = 5

# Load the data

In [3]:
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [4]:
root = "data"
train_ds = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
val_ds = datasets.CIFAR10(root, train=False, transform=val_transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_dl = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True
)
val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False
)

# The model
We'll train a simple resnet18 network.
This takes a while without GPU but is pretty quick with GPU.

In [6]:
def get_module():
    model = resnet18(pretrained=False)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 10)
    return model


# Training without the `policies` module

In [7]:
pytorch_network = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
)

[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 2.111655[93m acc:[96m 23.260000[93m val_loss:[96m 1.889516[93m val_acc:[96m 32.000000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.794558[93m acc:[96m 34.396000[93m val_loss:[96m 1.656629[93m val_acc:[96m 39.540000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.651485[93m acc:[96m 39.688000[93m val_loss:[96m 1.547896[93m val_acc:[96m 43.050000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.560221[93m acc:[96m 43.178000[93m val_loss:[96m 1.487496[93m val_acc:[96m 45.410000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.498565[93m acc:[96m 45

# Training with the `policies` module

In [8]:
steps_per_epoch = len(train_dl)
steps_per_epoch

49

In [9]:
pytorch_network = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_network.parameters(), lr=0.01)

model = Model(
    pytorch_network,
    optimizer,
    criterion,
    batch_metrics=["acc"],
)
model = model.to(device)

policy = OptimizerPolicy(
    one_cycle_phases(epochs * steps_per_epoch, lr=(0.01, 0.1, 0.008)),
)
history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
    callbacks=[policy],
)

[93mEpoch: [94m1/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.817947[93m acc:[96m 34.648000[93m val_loss:[96m 1.542678[93m val_acc:[96m 46.480000[0m
[93mEpoch: [94m2/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.385061[93m acc:[96m 50.588000[93m val_loss:[96m 1.384743[93m val_acc:[96m 52.750000[0m
[93mEpoch: [94m3/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 1.142981[93m acc:[96m 59.612000[93m val_loss:[96m 1.107310[93m val_acc:[96m 61.340000[0m
[93mEpoch: [94m4/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 0.947667[93m acc:[96m 66.580000[93m val_loss:[96m 0.944277[93m val_acc:[96m 66.610000[0m
[93mEpoch: [94m5/5 [93mStep: [94m49/49 [93m100.00% |[92m█████████████████████████[93m|[32m20s [93mloss:[96m 0.801159[93m acc:[96m 71