More details [here1](https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/), [here2](https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/).

In [1]:
import torch
from torch import nn
import torchvision
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

In [3]:
BS = 2048

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=BS, shuffle=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
valid_loader = torch.utils.data.DataLoader(
    testset, batch_size=BS, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
model = models.resnet18(pretrained=True).cuda(0)
model.fc = torch.nn.Linear(in_features=512, out_features=10).cuda(0)
model.fc.requires_grad_(True)

Linear(in_features=512, out_features=10, bias=True)

In [12]:
%%time

LR = 0.003
EPOCHS = 10
swa_start = 8
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

cost_function = torch.nn.CrossEntropyLoss()
TRAIN_STEP = len(trainset)/BS
VALID_STEP = len(testset)/BS

TRAIN_LOSS = []
VAL_LOSS = []
VAL_ACC = []

for epoch in range(1,EPOCHS+1):
  print("Epoch", epoch)
  running_loss = 0.0
  model.train()
  for (x, y) in train_loader:
    optimizer.zero_grad()
    x, y = x.cuda(0), y.cuda(0)
    z = model(x)
    loss = cost_function(z, y)
    running_loss+=loss.detach()
    loss.backward()
    optimizer.step()

  train_l = running_loss/TRAIN_STEP
  TRAIN_LOSS.append(train_l.item())
  total, correct = 0, 0 
  print("Training loss:", train_l.item())

  if epoch > swa_start:
    print('Updating swa')
    swa_model.update_parameters(model)
    swa_scheduler.step()
  else:
    running_loss = 0.0
    model.eval()
    for (x,y) in valid_loader:
      x, y = x.cuda(0), y.cuda(0)
      with torch.no_grad():
        z = model(x)
      loss = cost_function(z, y)
      running_loss+=loss.detach()
      _, yhat = torch.max(z,1)
      total += y.size(0)
      correct += (yhat == y).sum().item()
    scheduler.step()
    valid_l = running_loss/VALID_STEP
    valid_a = 100 * correct / total
    print("Valid loss:", valid_l.item())
    VAL_LOSS.append(valid_l.item())
    print('VAL_Accuracy: ', valid_a)
    VAL_ACC.append(valid_a)
  print('')
print("Best accuracy", max(VAL_ACC))    

Epoch 1
Training loss: 0.9224651455879211
Valid loss: 0.9948633909225464
VAL_Accuracy:  70.4

Epoch 2
Training loss: 0.5696503520011902
Valid loss: 0.7134029865264893
VAL_Accuracy:  77.25

Epoch 3
Training loss: 0.4732469618320465
Valid loss: 0.8396474123001099
VAL_Accuracy:  74.68

Epoch 4
Training loss: 0.39830276370048523
Valid loss: 0.662550151348114
VAL_Accuracy:  79.16

Epoch 5
Training loss: 0.35848575830459595
Valid loss: 0.7014987468719482
VAL_Accuracy:  78.68

Epoch 6
Training loss: 0.30550530552864075
Valid loss: 0.6869193911552429
VAL_Accuracy:  79.56

Epoch 7
Training loss: 0.26629510521888733
Valid loss: 0.7214446067810059
VAL_Accuracy:  79.44

Epoch 8
Training loss: 0.23825792968273163
Valid loss: 0.6263319849967957
VAL_Accuracy:  82.36

Epoch 9
Training loss: 0.20299969613552094
Updating swa
Epoch 10
Training loss: 0.24596136808395386
Updating swa
Best accuracy 82.36
CPU times: user 11min 47s, sys: 7min 56s, total: 19min 43s
Wall time: 19min 43s


In [13]:
# Update batch norm statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(train_loader, swa_model, device='cuda:0')  

In [14]:
correct, total = 0,0
for (x,y) in valid_loader:
  x, y = x.cuda(0), y.cuda(0)
  with torch.no_grad():
    z = swa_model(x)
  loss = cost_function(z, y)
  running_loss+=loss.detach()
  _, yhat = torch.max(z,1)
  total += y.size(0)
  correct += (yhat == y).sum().item()
valid_a = 100 * correct / total
print("VAL_Accuracy:", valid_a)

VAL_Accuracy: 84.55
