In [None]:
import torch
from torch.utils import data
from torchvision import transforms
from torchvision.models import resnet18
from torchvision.datasets import cifar

import copy
import time


sample_data = cifar.CIFAR10('./', download=True)

means = sample_data.data.mean(axis=(0, 1, 2)) / 255
stds = sample_data.data.std(axis=(0, 1, 2)) / 255

In [None]:
jitter_transform = transforms.RandomApply([
                                            transforms.ColorJitter(
                                            brightness=(0.5, 1.5),
                                            contrast=(0.5, 1.5),
                                            saturation=(0.5, 1.5),
                                            hue=(-0.5, 0.5))],
                                           p=0.5)

train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
                                       transforms.RandomResizedCrop((32, 32),
                                                                    (0.5, 1.0)),
                                       jitter_transform,
                                       transforms.ToTensor(),
                                       transforms.Normalize(means, stds)])

test_transforms = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize(means, stds)])

train_data = cifar.CIFAR10('./', download=True, transform=train_transforms)
test_data = cifar.CIFAR10('./', train=False, download=True,
                             transform=test_transforms)

In [3]:
TRAIN_RATIO = 0.9
num_train_samples = int(len(train_data) * TRAIN_RATIO)
num_valid_samples = len(train_data) - num_train_samples
split = [num_train_samples, num_valid_samples]

train_data, valid_data = data.random_split(train_data, lengths=split)

In [4]:
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transforms = test_transforms

In [5]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 256

train_iterator = data.DataLoader(train_data, BATCH_SIZE, shuffle=True)
valid_iterator = data.DataLoader(valid_data, BATCH_SIZE)
test_iterator = data.DataLoader(test_data, BATCH_SIZE)

In [6]:
dataloaders = {"train": train_iterator, "val": valid_iterator}
dataset_sizes = {"train": len(train_data.indices),
                 'val': len(valid_data.indices)}

In [7]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            if phase == 'train':  # take scheduler step on train acc
                scheduler.step(epoch_acc)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
NUM_EPOCHS = 200
LR = 0.1
LR_DECAY = 0.1

model = resnet18()
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 10)

criterion = torch.nn.CrossEntropyLoss()

# same parameters as the ResNet paper
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.99,
                            weight_decay=1e-4)

# patience is not known from the paper
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=LR_DECAY,
                                                       patience=NUM_EPOCHS / 10,
                                                       verbose=True)

model.to(DEVICE)
criterion.to(DEVICE)

In [None]:
train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)

Epoch 0/199
----------
train Loss: 0.0102 Acc: 0.1635
val Loss: 0.0086 Acc: 0.2032

Epoch 1/199
----------
train Loss: 0.0079 Acc: 0.2370
val Loss: 0.0080 Acc: 0.2288

Epoch 2/199
----------
train Loss: 0.0076 Acc: 0.2760
val Loss: 0.0078 Acc: 0.2722

Epoch 3/199
----------
train Loss: 0.0075 Acc: 0.2827
val Loss: 0.0076 Acc: 0.2950

Epoch 4/199
----------
train Loss: 0.0084 Acc: 0.2024
val Loss: 0.0090 Acc: 0.1508

Epoch 5/199
----------
train Loss: 0.0082 Acc: 0.1780
val Loss: 0.0083 Acc: 0.1940

Epoch 6/199
----------
train Loss: 0.0080 Acc: 0.2067
val Loss: 0.0083 Acc: 0.2022

Epoch 7/199
----------
train Loss: 0.0078 Acc: 0.2306
val Loss: 0.0079 Acc: 0.2454

Epoch 8/199
----------
train Loss: 0.0075 Acc: 0.2517
val Loss: 0.0079 Acc: 0.2308

Epoch 9/199
----------
train Loss: 0.0076 Acc: 0.2531
val Loss: 0.0076 Acc: 0.2636

Epoch 10/199
----------
train Loss: 0.0074 Acc: 0.2663
val Loss: 0.0076 Acc: 0.2528

Epoch 11/199
----------
train Loss: 0.0072 Acc: 0.2890
val Loss: 0.0073 Acc