In [1]:
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from torch import nn
import numpy as np
import torch
import os
import random
from tqdm import tqdm
from IPython import display
from models import *
import torchvision
import torchvision.transforms as transforms

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0，图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转，一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

works_num=4
download=True
batch_size = 128
dataset_path = '/data/'
trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=download, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=works_num,pin_memory=True)
testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=download, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=works_num,pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
num_classes = 10

cuda


In [4]:
def get_accuracy(model, val_loader):
    size = 0
    correct = 0
    with torch.no_grad():
        for data in val_loader:
            x, y = data
            x = x.to(device)
            y = y.to(device)
            outputs = model(x)
            _, y_pred = torch.max(outputs.data, 1)
            size += y.size(0)
            correct += (y_pred == y).sum().item()

    print('Accuracy: %.2f %% ' % (100 * correct / size))

def train(model, optimizer, criterion, train_loader, val_loader, scheduler=None, epochs_n=100):
    model.to(device)
    learning_curve = [np.nan] * epochs_n
    batches_n = len(train_loader)
    losses_list = []
    grads = []
    
    for epoch in tqdm(range(epochs_n), unit='epoch'):
        if scheduler is not None:
            scheduler.step()
        model.train()

        grad = []  # use this to record the loss gradient of each step
        learning_curve[epoch] = 0  # maintain this to plot the training curve
        size = 0

        for data in train_loader:
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            prediction = model(x)
            loss = criterion(prediction, y)
            losses_list.append(loss.cpu().detach() )
            _, y_pred = torch.max(prediction.data, 1)
            learning_curve[epoch] += (y_pred == y).sum().item()
            size += y.size(0)

            loss.backward()
            optimizer.step()
        learning_curve[epoch] /= size
    model.eval()
    get_accuracy(model, val_loader)

    return losses_list, learning_curve

In [5]:
def plot_loss_landscape(iteration, VGG_A_max_curve, VGG_A_min_curve, VGG_A_BN_max_curve, VGG_A_BN_min_curve):
    fig = plt.figure(0)
    plt.style.use("ggplot")
    # plot VGG_A curve
    plt.plot(iteration, VGG_A_max_curve, c="green")
    plt.plot(iteration, VGG_A_min_curve, c="green")
    plt.fill_between(
        iteration,
        VGG_A_max_curve,
        VGG_A_min_curve,
        color="lightgreen",
        label="Standard VGG",
    )

    # plot VGG_A_BatchNorm  curve
    plt.plot(iteration, VGG_A_BN_max_curve, c="firebrick")
    plt.plot(iteration, VGG_A_BN_min_curve, c="firebrick")
    plt.fill_between(
        iteration,
        VGG_A_BN_max_curve,
        VGG_A_BN_min_curve,
        color="lightcoral",
        label="Standard VGG + BatchNorm",
    )

    # configs
    plt.xticks(np.arange(0, iteration[-1], 1000))
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Loss Landscape")
    plt.legend(loc="upper right", fontsize="x-large")
    savepath = "/Loss_landscape_VGG_Cmp_BN.png"
    plt.savefig(savepath, dpi=300)
    plt.close(0)


def plot_acc_curve(iteration, VGG_A_acc, VGG_A_norm_acc):
    fig = plt.figure(0)
    plt.style.use("ggplot")
    plt.plot(iteration, VGG_A_acc, c="green", label="Standard VGG")
    plt.plot(iteration, VGG_A_norm_acc, c="firebrick", label="Standard VGG + BatchNorm")
    # configs
    plt.xticks(range(0, 22))
    plt.xlabel("Epoch")
    plt.ylabel("Train Accuary")
    plt.title("Accuary Curve")
    plt.legend(loc="best", fontsize="x-large")
    savepath = "/Train_Acc_VGG_Cmp_BN.png"
    plt.savefig(savepath, dpi=300)
    plt.close(0)

In [6]:
lrs=[2e-3, 1e-4, 5e-4]
epoch_num = 20

VGG_A_losses = []
VGG_A_BN_losses = []
VGG_A_acc=[]
VGG_A_bn_acc=[]

for lr in lrs:
    model = VGG_A()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    a1, a2=train(model, optimizer, criterion, trainloader, testloader, epochs_n=epoch_num)
    VGG_A_losses.append(a1)
    VGG_A_acc.append(a2)
    
    model = VGG_A_BatchNorm()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    b1, b2=train(model, optimizer, criterion, trainloader, testloader, epochs_n=epoch_num)
    VGG_A_BN_losses.append(b1)
    VGG_A_bn_acc.append(b2)


VGG_A_losses = np.array(VGG_A_losses)
VGG_A_BN_losses = np.array(VGG_A_BN_losses)
VGG_A_acc=np.array(VGG_A_acc)
VGG_A_bn_acc=np.array(VGG_A_bn_acc)

iteration = []
VGG_A_min_curve = []
VGG_A_max_curve = []
VGG_A_BN_min_curve = []
VGG_A_BN_max_curve = []

VGG_A_min = VGG_A_losses.min(axis=0).astype(float)
VGG_A_max = VGG_A_losses.max(axis=0).astype(float)
VGG_A_BN_min = VGG_A_BN_losses.min(axis=0).astype(float)
VGG_A_BN_max = VGG_A_BN_losses.max(axis=0).astype(float)
for i in range(len(VGG_A_min)):
    if i%30 == 0:
        VGG_A_min_curve.append(VGG_A_min[i])
        VGG_A_max_curve.append(VGG_A_max[i])
        VGG_A_BN_min_curve.append(VGG_A_BN_min[i])
        VGG_A_BN_max_curve.append(VGG_A_BN_max[i])
        iteration.append(i)

plot_acc_curve(range(1,21),VGG_A_acc[0],VGG_A_bn_acc[0])

plot_loss_landscape(iteration,VGG_A_max_curve,
                    VGG_A_min_curve,VGG_A_BN_max_curve,
                    VGG_A_BN_min_curve)

100%|██████████| 20/20 [03:48<00:00, 11.40s/epoch]


Accuracy: 75.98 % 


100%|██████████| 20/20 [02:42<00:00,  8.12s/epoch]


Accuracy: 86.40 % 


100%|██████████| 20/20 [02:46<00:00,  8.33s/epoch]


Accuracy: 81.71 % 


100%|██████████| 20/20 [02:36<00:00,  7.85s/epoch]


Accuracy: 81.77 % 


100%|██████████| 20/20 [02:34<00:00,  7.74s/epoch]


Accuracy: 83.90 % 


100%|██████████| 20/20 [02:23<00:00,  7.16s/epoch]


Accuracy: 87.14 % 


In [8]:
def plot_loss_landscape(iteration, VGG_A_max_curve, VGG_A_min_curve, VGG_A_BN_max_curve, VGG_A_BN_min_curve):
    fig = plt.figure(0)
    plt.style.use("ggplot")
    # plot VGG_A curve
    plt.plot(iteration, VGG_A_max_curve, c="green")
    plt.plot(iteration, VGG_A_min_curve, c="green")
    plt.fill_between(
        iteration,
        VGG_A_max_curve,
        VGG_A_min_curve,
        color="lightgreen",
        label="Standard VGG",
    )

    # plot VGG_A_BatchNorm  curve
    plt.plot(iteration, VGG_A_BN_max_curve, c="firebrick")
    plt.plot(iteration, VGG_A_BN_min_curve, c="firebrick")
    plt.fill_between(
        iteration,
        VGG_A_BN_max_curve,
        VGG_A_BN_min_curve,
        color="lightcoral",
        label="Standard VGG + BatchNorm",
    )

    # configs
    plt.xticks(np.arange(0, iteration[-1], 1000))
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.title("Loss Landscape")
    plt.legend(loc="upper right", fontsize="x-large")
    savepath = "Loss_landscape_VGG_Cmp_BN.png"
    plt.savefig(savepath, dpi=300)
    plt.close(0)


def plot_acc_curve(iteration, VGG_A_acc, VGG_A_norm_acc):
    fig = plt.figure(0)
    plt.style.use("ggplot")
    plt.plot(iteration, VGG_A_acc, c="green", label="Standard VGG")
    plt.plot(iteration, VGG_A_norm_acc, c="firebrick", label="Standard VGG + BatchNorm")
    # configs
    plt.xticks(range(0, 22))
    plt.xlabel("Epoch")
    plt.ylabel("Train Accuary")
    plt.title("Accuary Curve")
    plt.legend(loc="best", fontsize="x-large")
    savepath = "Train_Acc_VGG_Cmp_BN.png"
    plt.savefig(savepath, dpi=300)
    plt.close(0)

plot_acc_curve(range(1,21),VGG_A_acc[0],VGG_A_bn_acc[0])

plot_loss_landscape(iteration,VGG_A_max_curve,
                    VGG_A_min_curve,VGG_A_BN_max_curve,
                    VGG_A_BN_min_curve)