In [49]:
import torch
import numpy as np
import torch.nn as nn
import math
from torch.optim.optimizer import Optimizer

In [76]:
version_higher = ( torch.__version__ >= "1.5.0" )

class NestAdam(Optimizer):
    def __init__(self, params, lr=1e-03, betas=(0.9, 0.999), eps=1e-16,
                 weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False,
                 rectify=True, degenerated_to_sgd=True):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        
        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
            for param in params:
                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
                    param['buffer'] = [[None, None, None] for _ in range(10)]

        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)])
        super(NestAdam, self).__init__(params, defaults)

        self.degenerated_to_sgd = degenerated_to_sgd
        self.weight_decouple = weight_decouple
        self.rectify = rectify
        self.fixed_decay = fixed_decay

    def __setstate__(self, state):
        super(NestAdam, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError(
                        'NestAdam does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']

                state = self.state[p]

                beta1, beta2 = group['betas']

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                        if version_higher else torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
                            if version_higher else torch.zeros_like(p.data)

                # get current state variable
                exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']

                state['step'] += 1
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']

                # Approximate Nesterov's Accelerated Gradient
                v = exp_avg.mul(beta1).add(grad, alpha=1 - beta1)
                grad.mul_(1 - beta1).add_(v, alpha=beta1)
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                # exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                # grad.mul_(1 - beta1).add_(exp_avg, alpha=beta1)

                # Update first and second moment running average
                grad_residual = grad - exp_avg
                exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

                if amsgrad:
                    max_exp_avg_var = state['max_exp_avg_var']
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_var, exp_avg_var, out=max_exp_avg_var)

                    # Use the max. for normalizing running avg. of gradient
                    denom = (max_exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                else:
                    denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])

                # perform weight decay, check if decoupled weight decay
                if self.weight_decouple:
                    if not self.fixed_decay:
                        p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
                    else:
                        p.data.mul_(1.0 - group['weight_decay'])
                else:
                    if group['weight_decay'] != 0:
                        grad.add_(p.data, alpha=group['weight_decay'])

                # update
                if not self.rectify:
                    # Default update
                    step_size = group['lr'] / bias_correction1
                    p.data.addcdiv_( exp_avg, denom, value=-step_size)

                else:  # Rectified update, forked from RAdam
                    buffered = group['buffer'][int(state['step'] % 10)]
                    if state['step'] == buffered[0]:
                        N_sma, step_size = buffered[1], buffered[2]
                    else:
                        buffered[0] = state['step']
                        beta2_t = beta2 ** state['step']
                        N_sma_max = 2 / (1 - beta2) - 1
                        N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                        buffered[1] = N_sma

                        # more conservative since it's an approximated value
                        if N_sma >= 5:
                            step_size = math.sqrt(
                                (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                        N_sma_max - 2)) / (1 - beta1 ** state['step'])
                        elif self.degenerated_to_sgd:
                            step_size = 1.0 / (1 - beta1 ** state['step'])
                        else:
                            step_size = -1
                        buffered[2] = step_size

                    if N_sma >= 5:
                        denom = exp_avg_var.sqrt().add_(group['eps'])
                        p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
                    elif step_size > 0:
                        p.data.add_( exp_avg, alpha=-step_size * group['lr'])

        return loss

In [51]:
import torchvision
import torchvision.transforms as transforms

In [52]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([transforms.Resize(256),
                                      transforms.RandomResizedCrop(224), 
                                      transforms.RandomHorizontalFlip(), 
                                      transforms.ToTensor(),
                                      normalize])
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     normalize])

In [53]:
batch_size = 100
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [54]:
import time

def timeSince(since):
    now = time.time()
    s = now - since
    m = s // 60
    s -= m * 60
    return '%dm %ds' % (m, s)

In [55]:
import torchvision.models as models

In [56]:
def evaluateImageModel(model, data, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for image, labels in data:
            image, labels = image.to(device), labels.to(device)
            output = model(image)
            pred = torch.argmax(output, dim=1)
            correct += (pred == labels).sum().item()
            total += len(labels)
            loss = criterion(output, labels)
            total_loss += loss.item()

    return total_loss / len(data), 1 - correct / total

def trainImageModel(model, train_data, val_data, n_epochs, criterion, optimizer, device, path):
    model.train()
    train_error_list = []
    val_error_list = []
    min_error = None
    step = 0
    print_every = len(train_data)
    start = time.time()
    for epoch in range(n_epochs):
        running_loss = 0.0
        running_correct = 0
        running_total = 0
        for image, labels in train_data:
            optimizer.zero_grad()
            step += 1
            image, labels = image.to(device), labels.to(device)
            output = model(image)
            pred = torch.argmax(output, dim=1)
            running_correct += (pred == labels).sum().item()
            running_total += len(labels)
            loss = criterion(output, labels)
            running_loss += loss.item()
            loss.backward()
            optimizer.step()
            if step % print_every == 0:
                val_loss, val_error = evaluateImageModel(model, val_data, criterion, device)
                print(('%d/%d (%s) train loss: %.3f, train error: %.2f%%, val loss: %.3f, val error: %.2f%%') %
                      (epoch + 1, n_epochs, timeSince(start), running_loss / len(train_data), 
                       100 *(1 - running_correct / running_total), val_loss, 100 * val_error))
                train_error_list.append(1 - running_correct / running_total)
                val_error_list.append(val_error)
                if min_error is None or min_error > val_error:
                    if min_error is None:
                        print(('Validation error rate in first epoch: %.2f%%') % (100 * val_error))
                    else:
                        print(('Validation error rate is decreasing: %.2f%% --> %.2f%%') % 
                              (100 * min_error, 100 * val_error))
                    min_error = val_error
                    print('Saving model...')
                    torch.save(model, path)
                
                model.train()
                running_loss = 0.0
                running_correct = 0
                running_total = 0
    
    return train_error_list, val_error_list

In [77]:
model = models.resnet34()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
n_epochs = 50
nadam_Res_train_list = None
nadam_Res_test_list = None
optimizer = NestAdam(model.parameters())
model.to(device)
path = 'nadamRes.pt'
nadam_Res_train_list, nadam_Res_test_list = trainImageModel(model, trainloader, testloader,
                                                                    n_epochs, criterion, optimizer,
                                                                    device, path)
torch.save(nadam_Res_train_list, 'nadamres_train.pt')
torch.save(nadam_Res_test_list, 'nadamres_test.pt')

1/50 (2m 49s) train loss: 2.058, train error: 70.78%, val loss: 1.634, val error: 60.43%
Validation error rate in first epoch: 60.43%
Saving model...
2/50 (5m 39s) train loss: 1.685, train error: 61.83%, val loss: 1.499, val error: 55.53%
Validation error rate is decreasing: 60.43% --> 55.53%
Saving model...
3/50 (8m 28s) train loss: 1.452, train error: 52.28%, val loss: 1.388, val error: 50.15%
Validation error rate is decreasing: 55.53% --> 50.15%
Saving model...
4/50 (11m 17s) train loss: 1.296, train error: 46.50%, val loss: 1.027, val error: 35.97%
Validation error rate is decreasing: 50.15% --> 35.97%
Saving model...
5/50 (14m 7s) train loss: 1.169, train error: 41.72%, val loss: 0.985, val error: 33.17%
Validation error rate is decreasing: 35.97% --> 33.17%
Saving model...
6/50 (16m 55s) train loss: 1.075, train error: 37.96%, val loss: 0.851, val error: 28.69%
Validation error rate is decreasing: 33.17% --> 28.69%
Saving model...
7/50 (19m 44s) train loss: 1.009, train error: 3

KeyboardInterrupt: ignored