# Snapshot Ensemble, modify

In [1]:
%cd 'drive/MyDrive/snapshot_ensembles'
!pip install wandb

/content/drive/MyDrive/snapshot_ensembles
Collecting wandb
  Downloading wandb-0.12.7-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 14.5 MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.5.0-py2.py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 63.1 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting configparser>=3.8.1
  Downloading configparser-5.2.0-py3-none-any.whl (19 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.24-py3-none-any.whl (180 kB)
[K     |████████████████████████████████| 180 kB 69.6 MB/s 
[?25hCollecting subprocess32>=3.5.3
  Downloading subprocess32-3.5.4.tar.gz (97 kB)
[K     |████████████████████████████████| 97 kB 7.7 MB/s 
Collecting yaspin>=1.0.0
  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)
Collecting pathtools
  Downl

In [2]:
from math import pi
from math import cos
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from resnet import *
# from utils import progress_bar
import wandb
import time
import copy

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(
    root='data/cifar10', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(
    root='data/cifar10', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [4]:
net1 = ResNet18().cuda()

In [5]:
wandb.init(project="Snapshot-cifar10",
           name='resnet_snapshot_100_ext2')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## proposed learning rate schedule

This paper modifies learning rate. In my experiments, I also focus on modifing the learning rate schedule.

In [8]:
def proposed_lr(initial_lr , epochs, stages, epoch, burnin=0.1, func=None, gamma=0.2):
    # proposed learning late function
    # func = None(cosine), 3steplr, 
    num_burnin = epochs * burnin
    epoch_per_cycle = (epochs - num_burnin) // stages
    percent = ((epoch-num_burnin) % epoch_per_cycle) / epoch_per_cycle
    if epoch < num_burnin:
        res = initial_lr
    # elif func is None:
    #     return initial_lr * (cos(pi * percent + 1)) / 2
    elif func == '3steplr':
        if percent < 0.33:
            res = initial_lr 
        elif percent < 0.66:
            res = initial_lr * gamma
        else:
            res = initial_lr * gamma * gamma
    else:
        res = initial_lr * (cos(pi * percent + 1)) / 2
    return res

## Optimizer and scheduler

In [9]:
lr = 0.1
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net1.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, eta_min=1e-4)

## Train and test function

In [10]:
def train_se(epochs, stages, model, criterion, optimizer, 
             train_loader, test_loader, 
             scheduler=None, burnin=0.1,
             path='save_model/'):
    
    # train_errs = []
    # train_loss = []
    snapshots = []
    epochs_per_cycle = epochs // stages
    wandb.watch(model)
    for epoch in range(epochs):
        start = time.time()
        train_err, loss1 = train_epoch(model, criterion, optimizer, train_loader)
        test_err, loss2 = test(model, test_loader, criterion)
        print('Epoch {:03d}/{:03d}, train error: {:.2%} || test error {:.2%}'.format(epoch, epochs, train_err, test_err))
        # train_errs.append(train_err)
        # train_loss.append(loss1)
        
        if scheduler is None: 
            lr_epoch = proposed_lr(lr, epochs, stages, epoch, func='3steplr')
            optimizer.param_groups[0]['lr'] = lr_epoch
        else:
            scheduler.step()

        if (epoch+1) % epochs_per_cycle == 0:
            # torch.save(model.state_dict(), path+'ext_epoch=%d.pt'%epoch)
            snapshots.append(copy.deepcopy(model.state_dict()))
        # Log training..
        wandb.log({'train_loss': loss1, 'val_loss': loss2, 
                   "train_err":train_err, "val_err": test_err, 
                   "lr": optimizer.param_groups[0]["lr"],
                   "epoch_time": time.time()-start})
    return snapshots
    
def train_epoch(model, criterion, optimizer, loader):
    total_correct = 0.
    total_samples = 0.
    loss_sum = 0.
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = criterion(output, target)

        predictions = output.data.max(1, keepdim=True)[1]
        total_correct += predictions.eq(target.data.view_as(predictions)).sum().item()
        total_samples += len(target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_sum += loss.item()

    return 1 - total_correct/total_samples, loss_sum / (batch_idx+1)

def test(model, loader, criterion):
    total_correct = 0.
    total_samples = 0.
    loss_sum = 0.
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.cuda(), target.cuda()

            output = model(data)
            loss = criterion(output, target)
            loss_sum += loss.item()

            predictions = output.data.max(1, keepdim=True)[1]
            total_correct += predictions.eq(target.data.view_as(predictions)).sum().item()
            total_samples += len(target)

    return 1 - total_correct/total_samples, loss_sum / (batch_idx+1)

In [11]:
def test_se(snapshots, use_model_num, test_loader, path='save_model/', ensemble='average'):
    index = len(snapshots) - use_model_num
    snapshots = snapshots[index:]
    model_list = [ResNet18() for _ in snapshots]

    for model, weight in zip(model_list, snapshots):
        model.load_state_dict(weight)
        model.eval()
        if device=='cuda':
            model.cuda()

    total_correct = 0
    total_samples = 0
    loss_sum = 0
    for batch_idx, (data, target) in enumerate(test_loader):
        if device == 'cuda':
            data, target = data.cuda(), target.cuda()

        output_list = [model(data) for model in model_list]
        loss_ = [criterion(output, target).item() for output in output_list]

        # predictions = output.data.max(1, keepdim=True)[1]
        pred_list = [nn.Softmax(dim=1)(output) for output in output_list]
        # todo add more ensemble strategy
        if ensemble == 'average':
            predictions = sum(pred_list) / use_model_num
            
        else:  # ensemble == 'vote'
            pred_label_list = []
            for pred in pred_list:
                pred_m = torch.zeros_like(pred)
                pred_m[torch.arange(len(pred)), pred.argmax(1)] = 1
                pred_label_list.append(pred_m)
            predictions = sum(pred_label_list)

        total_correct += (predictions.argmax(1) == target).type(torch.float).sum().item()
        total_samples += len(target)
        # todo add more ensemble strategy
        loss_sum += sum(loss_) / len(model_list)

    test_loss = loss_sum/ (batch_idx+1)
    test_err = 1-total_correct/total_samples
    print('\nTest set: Average loss: {:.4f}, Error rate: {:.2%}\n'.format(
        test_loss, test_err))

    return test_loss, test_err

## Run

In [12]:
snapshots = train_se(epochs=100, stages=5, model=net1, 
                     criterion=criterion, optimizer=optimizer, 
                     train_loader=trainloader, test_loader=testloader, burnin=0.1)
                    #  scheduler=scheduler)

Epoch 000/100, train error: 72.32% || test error 60.64%
Epoch 001/100, train error: 53.50% || test error 50.22%
Epoch 002/100, train error: 42.35% || test error 39.77%
Epoch 003/100, train error: 33.70% || test error 36.54%
Epoch 004/100, train error: 28.82% || test error 31.95%
Epoch 005/100, train error: 24.12% || test error 26.79%
Epoch 006/100, train error: 21.24% || test error 25.23%
Epoch 007/100, train error: 19.65% || test error 22.49%
Epoch 008/100, train error: 18.25% || test error 23.72%
Epoch 009/100, train error: 17.53% || test error 20.66%
Epoch 010/100, train error: 16.73% || test error 19.89%
Epoch 011/100, train error: 16.14% || test error 21.58%
Epoch 012/100, train error: 15.91% || test error 20.59%
Epoch 013/100, train error: 15.33% || test error 19.13%
Epoch 014/100, train error: 14.73% || test error 18.30%
Epoch 015/100, train error: 14.54% || test error 18.64%
Epoch 016/100, train error: 14.25% || test error 17.84%
Epoch 017/100, train error: 8.52% || test error 

In [17]:
len(snapshots)

5

In [13]:
# average
test_se(snapshots, use_model_num=5, test_loader=testloader, ensemble='average')


Test set: Average loss: 0.2863, Error rate: 7.07%



(0.28628383238613603, 0.07069999999999999)

In [14]:
# average
test_se(snapshots, use_model_num=4, test_loader=testloader, ensemble='average')


Test set: Average loss: 0.2748, Error rate: 6.92%



(0.274834364131093, 0.06920000000000004)

In [15]:
# majority vote
test_se(snapshots, use_model_num=5, test_loader=testloader, ensemble='vote')


Test set: Average loss: 0.2863, Error rate: 7.40%



(0.28628383238613603, 0.07399999999999995)

In [16]:
# majority vote
test_se(snapshots, use_model_num=4, test_loader=testloader, ensemble='vote')


Test set: Average loss: 0.2748, Error rate: 7.41%



(0.274834364131093, 0.07410000000000005)

In [18]:
testnet1 = ResNet18()
testnet1.load_state_dict(snapshots[0])
testnet1.eval()
testnet1.cuda()
test(testnet1, testloader, criterion)

(0.11329999999999996, 0.33208170540630816)