In [1]:
%matplotlib inline

In [2]:
import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from model import *
from swd_optim import *

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

base_path = "/mnt/workspace/stable-weight-decay-regularization"

mode_lr = {
    'SGDS': 0.1, 
    'SGD': 0.1,
    'Adam': 1e-3,
    'AdamW': 1e-3,
    'AdamS': 1e-3
}
mode = ''
depth = 2

In [4]:
print('==> Preparing data..')
# ViT预处理
transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

==> Preparing data..


In [5]:
def optimizers(net, opti_name, lr, weight_decay):
    if opti_name == 'VanillaSGD':
        return optim.SGD(net.parameters(), lr=lr, momentum=0, weight_decay=weight_decay)
    elif opti_name == 'SGD':
        return optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=False)
    elif opti_name == 'SGDS':
        return SGDS(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay, nesterov=False)
    elif opti_name == 'Adam':
        return optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay)
    elif opti_name == 'AMSGrad':
        return optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay,amsgrad=True)
    elif opti_name == 'AdamW':
        return optim.AdamW(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay/lr)
    elif opti_name == 'AdamS':
        return AdamS(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)
    elif opti_name == 'Adai':
        return Adai(net.parameters(), lr=lr, betas=(0.1, 0.99), eps=1e-03, weight_decay=weight_decay)
    elif opti_name == 'AdaiS':
        return AdaiS(net.parameters(), lr=lr, betas=(0.1, 0.99), eps=1e-03, weight_decay=weight_decay)
    else:
        raise 'Unspecified optimizer.'
    

In [6]:
criterion = nn.CrossEntropyLoss(reduction='mean')

In [7]:
def train(net, optimizer, epoch):
    print('Epoch: %d' % (epoch+1))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * targets.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    print("Training Loss: ", train_loss/total)
    print("Training error:", 1-correct/total)
    return 1 - correct/total, train_loss/total

def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print("Test error:", 1-correct/total)
    return 1 - correct/total

In [8]:
def define_models():
    return ViT(n_classes=10, depth=depth)

In [9]:
trainset = torchvision.datasets.CIFAR10(root=f'{base_path}/data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root=f'{base_path}/data', train=False, download=True, transform=transform_test)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:26<00:00, 6320778.87it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [10]:
def optimizer_peformance(model, learning_rate, batch_size, weight_decay, epochs, N, mode):
    net = define_models()
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
        
    train_err = []
    train_loss = []
    test_err = []
    
    opti_name = mode
    optimizer = optimizers(net, opti_name, learning_rate, weight_decay)
    
    lambda_lr = lambda epoch: 0.1 ** (epoch // 80) 
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)
    
    start_time = time.time()
    print("-"*30+" Mode "+ mode+" starts" )
    for epoch in range(epochs):
        train_err_i, train_loss_i = train(net, optimizer, epoch)
        train_err.append(train_err_i)
        train_loss.append(train_loss_i)
        test_err.append(test(net))
        scheduler.step()
        print ("--- %s seconds ---" % (time.time() - start_time))
        
    save_err({mode:train_loss}, {mode:train_err}, {mode:test_err}, model+'_'+mode, learning_rate, batch_size, weight_decay, epochs, N)
    return train_loss, train_err, test_err

In [11]:
def optimizer_performance_comparison(model, batch_size, weight_decay, epochs, N):
    train_loss, train_err, test_err = {}, {}, {}
    train_loss[mode], train_err[mode], test_err[mode] = optimizer_peformance(model, mode_lr[mode], batch_size, weight_decay, epochs, N, mode.split(None, 1)[0])
    return train_loss, train_err, test_err 
    

In [12]:
def plot_figure(model, train_loss, train_err, test_err, batch_size, weight_decay, epochs, N): 
    figure_name = model + '_B'+str(batch_size) + '_N'+ str(N) + '_E' + str(epochs)
    
    plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'

    fig = plt.figure()

    axes = plt.gca()
    for key in test_err:
        break
    axes.set_ylim([0., 0.2])
    axes.set_xlim([0,epochs])
    mode_list = ['SGD', 'SGDS', 'Adam', 'AdamW', 'AdamS'] 
    colors = ['red','blue','green','orange','pink','cyan','brown','yellow','black']
    for idx,mode in enumerate(mode_list):
        plt.plot(np.arange(1,epochs+1), test_err[mode], label=mode, ls='solid', linewidth=2, color=colors[idx])
        
    plt.ylabel('Test Error')
    plt.xlabel('Epochs')
    plt.grid()
    plt.legend()
    plt.show()
    
    fig.savefig('Test_errors_'+figure_name + '.png')
    fig.savefig('Test_errors_'+figure_name+'.pdf', format='pdf', bbox_inches = 'tight')
    
    fig = plt.figure()
    axes = plt.gca()
    axes.set_yscale('log')
    axes.set_ylim([1e-4, 1.])
    axes.set_xlim([0,epochs])
    for idx,mode in enumerate(mode_list):
        plt.plot(np.arange(1,epochs+1), train_loss[mode], label=mode, ls='solid', linewidth=2, color=colors[idx])

    plt.ylabel('Training Loss')
    plt.xlabel('Epochs')
    plt.grid()
    plt.legend()
    plt.show()
    
    fig.savefig('Training_loss_'+figure_name + '.png')
    fig.savefig('Training_loss_'+figure_name+'.pdf', format='pdf', bbox_inches = 'tight')
    
    
    


In [13]:
def save_err(train_loss, train_err, test_err, model, learning_rate, batch_size, weight_decay, epochs, N):
    csvname = model + '_LR'+str(learning_rate) + '_B'+str(batch_size) + '_N'+ str(N) + '_E' + str(epochs)
    csvname = 'Curves_' + csvname

    current_name = csvname +'.csv'
    files_present = glob.glob(current_name)
    if files_present:
        print('WARNING: This file already exists!')
    data_dict = {}
    for mode in test_err:
        data_dict[mode+'_test_err'] = test_err[mode]
        data_dict[mode+'_training_err'] = train_err[mode]
        data_dict[mode+'_training_loss'] = train_loss[mode]
    df = pd.DataFrame(data=data_dict)
    if not files_present:
        df.to_csv(current_name, sep=',', header=True, index=False)
    else:
        print('WARNING: This file already exists!')
        for i in range(1,30):
            files_present = glob.glob(csvname+'_'+str(i)+'.csv')
            if not files_present:
                df.to_csv(csvname+'_'+str(i)+'.csv', sep=',', header=True, index=False)
                return None
    return None
    

In [14]:
batch_size = 128
weight_decay = 5e-4
epochs = 200

#Training data size
N = 50000 

In [15]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

In [16]:
model = 'vit'

In [17]:
train_loss, train_err, test_err = optimizer_performance_comparison(model, batch_size, weight_decay, epochs, N)

------------------------------ Mode SGDS starts
Epoch: 1


KeyboardInterrupt: 