In [1]:
import os
import time
import copy
import argparse
import numpy as np
import torch
from fvcore.nn import FlopCountAnalysis
import torch.nn as nn
from torchvision.utils import save_image
from utils import get_loops, get_dataset, get_network, get_eval_pool, evaluate_synset, get_daparam, match_loss, get_time, TensorDataset, epoch, DiffAugment, ParamDiffAug
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
def train(model, train_loader, test_loader, optimizer, scheduler, criterion, epochs):
    model = model.to(device)
    criterion = criterion.to(device)
    model.train()
    
    for epoch in range(epochs):
        loss_avg, acc_avg, num_exp = 0, 0, 0
        for i_batch, datum in enumerate(train_loader):
            img = datum[0].float().to(device)
            lab = datum[1].long().to(device)
            n_b = lab.shape[0]
            output = model(img)
            loss = criterion(output, lab)
            acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
            loss_avg += loss.item()*n_b
            acc_avg += acc
            num_exp += n_b

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            del img, lab
                
        loss_avg /= num_exp
        acc_avg /= num_exp
        scheduler.step()
        
    loss_avg, acc_avg, num_exp = 0, 0, 0   
    model.eval()
    for i_batch, datum in enumerate(test_loader):
        img = datum[0].float().to(device)
        lab = datum[1].long().to(device)
        n_b = lab.shape[0]
        output = model(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))
        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b
        
        del img, lab

    loss_avg /= num_exp
    acc_avg /= num_exp
      
    return loss_avg, acc_avg

CIFAR10_dataset = 'CIFAR10'
CIFAR10_data_path = './CIFAR10data'
CIFAR10_channel, CIFAR10_im_size, CIFAR10_num_classes, CIFAR10_class_names, CIFAR10_mean, CIFAR10_std, CIFAR10_dst_train, CIFAR10_dst_test, CIFAR10_testloader = get_dataset(CIFAR10_dataset, CIFAR10_data_path)
CIFAR10_trainloader = torch.utils.data.DataLoader(CIFAR10_dst_train, batch_size=8, shuffle=True, num_workers=0)
criterion = nn.CrossEntropyLoss()
epochs = 20
model_set = ['LeNet', 'AlexNet', 'VGG11', 'ConvNetD4']
iteration_set = ['first', 'second', 'third']

for model_architecture in model_set:
    print('========================================')
    accs_each_model = []
    training_times_each_model = []
    for itr in iteration_set:
        print('Start training on the '+itr+' '+model_architecture+' architecture on the original CIFAR10 dataset')
        model = get_network(model_architecture, CIFAR10_channel, CIFAR10_num_classes, CIFAR10_im_size).to(device) # get a random model
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0)
        start = time.time()
        _, acc_avg = train(model, CIFAR10_trainloader , CIFAR10_testloader , optimizer, scheduler, criterion, epochs)
        elapsed_time = time.time()-start
        training_times_each_model.append(elapsed_time)
        accs_each_model.append(acc_avg*100)
        print('Training takes {} seconds'.format(training_times_each_model[-1]))
        print('Test accuracy is {}%'.format(accs_each_model[-1]))
    print('----------------------------------------')
    print('Average training time is {} seconds'.format(sum(training_times_each_model)/len(training_times_each_model)))
    print('Average test accuracy is {}%'.format(sum(accs_each_model)/len(accs_each_model)))


Start training on the first LeNet architecture on the original CIFAR10 dataset
Training takes 577.0727677345276 seconds
Test accuracy is 64.53%
Start training on the second LeNet architecture on the original CIFAR10 dataset
Training takes 552.0183479006176 seconds
Test accuracy is 62.23%
Start training on the third LeNet architecture on the original CIFAR10 dataset
Training takes 581.0977817340001 seconds
Test accuracy is 65.77%
----------------------------------------
Average training time is 570.0629657897151 seconds
Average test accuracy is 64.17666666666666%
Start training on the first AlexNet architecture on the original CIFAR10 dataset
Training takes 759.9593343734741 seconds
Test accuracy is 81.8%
Start training on the second AlexNet architecture on the original CIFAR10 dataset
Training takes 751.0998243099123 seconds
Test accuracy is 80.10000000000001%
Start training on the third AlexNet architecture on the original CIFAR10 dataset
Training takes 767.0002123732212 seconds
Test 