# Train CIFAR with the `policy` module

In [None]:
import torch

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Load the data

In [None]:
from torchvision import transforms

_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(_mean, _std),
])

In [None]:
import torchvision.datasets as datasets

root = "data"
train_ds = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
val_ds = datasets.CIFAR10(root, train=False, transform=val_transform, download=True)

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 1024

train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
)
val_dl = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
)

# The model
We'll train a simple resnet18 network.
This takes a while without GPU but is pretty quick with GPU.

In [None]:
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim


def get_module():
    model = resnet18(pretrained=False)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 10)
    return model

In [None]:
epochs = 5

# Training without the `policies` module

In [None]:
from pytoune.framework import Model

In [None]:
pytorch_module = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_module.parameters(), lr=0.01)

model = Model(
    pytorch_module,
    optimizer,
    criterion,
    metrics=["acc"],
)
model = model.to(device)

history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
)

# Training with the `policies` module

In [None]:
steps_per_epoch = len(train_dl)
steps_per_epoch

In [None]:
from pytoune.framework import OptimizerPolicy, one_cycle_phases


pytorch_module = get_module().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(pytorch_module.parameters(), lr=0.01)

model = Model(
    pytorch_module,
    optimizer,
    criterion,
    metrics=["acc"],
)
model = model.to(device)

policy = OptimizerPolicy(
    one_cycle_phases(epochs * steps_per_epoch, lr=(0.01, 0.1, 0.008)),
)
history = model.fit_generator(
    train_dl,
    val_dl,
    epochs=epochs,
    callbacks=[policy],
)