In [58]:
from resnet import ResNet_  as Net
from mixup import *
from cutout  import Cutout
import numpy as np
import argparse
import os, sys
import time
import datetime
from copy import deepcopy
# Import pytorch dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

def get_num_correct(pred,labels):
    return pred.argmax(dim=1).eq(labels).sum().item()
def init_weights(m):
    if type(m)==nn.Linear or type(m)==nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hpyerparameters

In [28]:
# normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
#                                      std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

# train_transform = transforms.Compose([])
# if flag_augmetation:
#     train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
#     train_transform.transforms.append(transforms.RandomHorizontalFlip())
# train_transform.transforms.append(transforms.ToTensor())
# train_transform.transforms.append(normalize)
# if flag_cutout:
#     train_transform.transforms.append(Cutout(n_holes = n_holes, length = length))

# train_set=torchvision.datasets.CIFAR10(
#     root='./data/cifar10',
#     train=True,
#     download=True,
#     transform=train_transform
# )

# test_set=torchvision.datasets.CIFAR10(
#     root='./data/cifar10',
#     train=False,
#     download=True,
#     transform=transforms.Compose([transforms.ToTensor(), normalize])
# )

# GPU check                
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# if device =='cuda':
#     print("Run on GPU...")
# else:
#     print("Run on CPU...")

# # Model Definition  
# net = Net(18)
# net = net.to(device)

# # Test forward pass
# data = torch.randn(5,3,32,32)
# data = data.to(device)
# # Forward pass "data" through "net" to get output "out" 
# out =  net(data) #Your code here
# # Check output shape
# assert(out.detach().cpu().numpy().shape == (5,10))
# print("Forward pass successful")

In [77]:
def train_(train_set,test_set,lr, depth, mixup_enbale, alpha, model_checkpoint,epochs):

    torch.manual_seed(1)
    train_loader=torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=False, pin_memory=True,num_workers=2)
    test_loader=torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False, pin_memory=True,num_workers=2)
    network= Net(depth).to(device)
    optimizer = optim.SGD(network.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
    criterion = torch.nn.CrossEntropyLoss().to(device)
#     scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.2)
    
    acc_train=[]
    acc_test=[]
    acc = 0
    best_acc = 0
    for epoch in range(epochs):
        total_loss = 0
        total_correct = 0
        network.train()
        count_in = 0
        for batch in train_loader: #Get batch
            images,labels = batch
            images, labels = images.to(device), labels.to(device)
            
            if mixup_enbale:
                images, targets_a, targets_b, lam = mixup_data(images, labels, alpha)
                images, targets_a, targets_b = map(Variable, (images,
                                                          targets_a, targets_b))
                preds = network(images)
                loss = mixup_criterion(criterion, preds, targets_a, targets_b, lam)

                _, predicted = torch.max(preds.data, 1)
                correct = (lam * predicted.eq(targets_a.data).cpu().sum().float()
                + (1 - lam) * predicted.eq(targets_b.data).cpu().sum().float())
                total_correct += correct
            
            if not mixup_enbale:
                preds=network(images) #pass batch to network
                correct = get_num_correct(preds, labels)
                loss = criterion(preds,labels) #Calculate loss
                total_correct+=correct
            
            optimizer.zero_grad()
            loss.backward() #Calculate gradients
            optimizer.step() #Update weights
            
            
        print("epoch: ", epoch,  "total_correct: ", total_correct.item() )
        print("training accuracy: ", total_correct.item() /len(train_set))
        acc_train.append(deepcopy(float(total_correct)/len(train_set)))

        with torch.no_grad():
            correct_test=0
            for batch_test in test_loader: #Get batch
                
                images_test,labels_test = batch_test
                images_test, labels_test = images_test.to(device), labels_test.to(device)
                preds_test=network(images_test) #pass batch to network
                correct_test += get_num_correct(preds_test, labels_test)
                
            print("testing accuracy: ", correct_test / len(test_set))
            if epoch == epochs - 1:
                print(correct_test / len(test_set))
                acc = correct_test / len(test_set) 
            acc_test.append(deepcopy(float(correct_test)/len(test_set)))
        scheduler.step()
        if best_acc < acc:
            best_acc = acc
            torch.save(network.state_dict(), model_checkpoint)

    return (acc_train,acc_test)

In [78]:
def do_test(flag_augmetation = False, 
            flag_cutout = False, 
            n_holes = 1, 
            length = 16, 
            depth = 18,
            epochs = 100,
            lr = 0.1,
            mixup_enbale = True,
            alpha = 0.1
           ):
    model_checkpoint = "resnet" + str(depth) 
    if flag_augmetation:
        model_checkpoint += '+'
    if flag_cutout:
        model_checkpoint += "cutout"
    model_checkpoint += ".pt"
    
    normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

    train_transform = transforms.Compose([])
    if flag_augmetation:
        train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
        train_transform.transforms.append(transforms.RandomHorizontalFlip())
    train_transform.transforms.append(transforms.ToTensor())
    train_transform.transforms.append(normalize)
    if flag_cutout:
        train_transform.transforms.append(Cutout(n_holes = n_holes, length = length))


    train_set=torchvision.datasets.CIFAR10(
        root='./data/cifar10',
        train=True,
        download=True,
        transform=train_transform)

    test_set=torchvision.datasets.CIFAR10(
        root='./data/cifar10',
        train=False,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    
    acc_train,acc_test = train_(train_set,test_set,lr, depth,mixup_enbale,alpha,  model_checkpoint, epochs = epochs)
    return (acc_train,acc_test)

In [79]:
do_test(flag_augmetation = True, 
            flag_cutout = False, 
            n_holes = 1, 
            length = 16, 
            depth = 18,
            epochs = 100,
            lr = 0.1,
            mixup_enbale = True,
            alpha = 0.1
       )

Files already downloaded and verified
Files already downloaded and verified
epoch:  0 total_correct:  12626.9833984375
training accuracy:  0.25253966796875
testing accuracy:  0.4028
epoch:  1 total_correct:  21658.01953125
training accuracy:  0.433160390625
testing accuracy:  0.5135
epoch:  2 total_correct:  26711.2421875
training accuracy:  0.53422484375
testing accuracy:  0.5787
epoch:  3 total_correct:  30958.033203125
training accuracy:  0.6191606640625
testing accuracy:  0.6586
epoch:  4 total_correct:  33946.84375
training accuracy:  0.678936875
testing accuracy:  0.7042
epoch:  5 total_correct:  36029.87109375
training accuracy:  0.720597421875
testing accuracy:  0.7277
epoch:  6 total_correct:  36715.703125
training accuracy:  0.7343140625
testing accuracy:  0.7622
epoch:  7 total_correct:  37169.296875
training accuracy:  0.7433859375
testing accuracy:  0.7639
epoch:  8 total_correct:  38402.42578125
training accuracy:  0.768048515625
testing accuracy:  0.7822
epoch:  9 total_

([0.25253966796875,
  0.433160390625,
  0.53422484375,
  0.6191606640625,
  0.678936875,
  0.720597421875,
  0.7343140625,
  0.7433859375,
  0.768048515625,
  0.77224765625,
  0.79122046875,
  0.782665,
  0.7869884375,
  0.79330890625,
  0.7971996875,
  0.79143421875,
  0.8075765625,
  0.81745625,
  0.82345359375,
  0.819845078125,
  0.820336484375,
  0.817158984375,
  0.82881046875,
  0.824468046875,
  0.819223046875,
  0.82267203125,
  0.823955,
  0.8253634375,
  0.825641796875,
  0.825870078125,
  0.87665640625,
  0.891415078125,
  0.9072096875,
  0.9130646875,
  0.89807046875,
  0.903165078125,
  0.9064625,
  0.917678515625,
  0.91836125,
  0.91023015625,
  0.901338046875,
  0.895678046875,
  0.910756796875,
  0.898656875,
  0.910734140625,
  0.913415859375,
  0.9131134375,
  0.904106484375,
  0.896805859375,
  0.906594609375,
  0.909820703125,
  0.909519765625,
  0.90398390625,
  0.90583328125,
  0.914863203125,
  0.913410078125,
  0.905073046875,
  0.90004765625,
  0.914468828125

In [55]:
list_acc = []
# for depth in [18,34]:
#     for flag_cutout in [False, True]:
for depth in [18,34,50,101]:
    acc_train,acc_test  =do_test(flag_augmetation = True, 
    flag_cutout = True, 
    n_holes = 1, 
    length = 16, 
    depth = 18,
    epochs = 100,
    lr = 0.1,
    mixup_enbale = False,
    alpha = 0.1)
    list_acc.append(acc_train)
    list_acc.append(acc_test)
        
        list_acc.append(acc_train)
        list_acc.append(acc_test)
print(list_acc)

Files already downloaded and verified
Files already downloaded and verified
epoch:  0 total_correct:  18156
training accuracy:  0.36312
testing accuracy:  0.4859
epoch:  1 total_correct:  28414
training accuracy:  0.56828
testing accuracy:  0.632
0.632
Files already downloaded and verified
Files already downloaded and verified
epoch:  0 total_correct:  14137
training accuracy:  0.28274
testing accuracy:  0.3772
epoch:  1 total_correct:  20721
training accuracy:  0.41442
testing accuracy:  0.4675
0.4675
Files already downloaded and verified
Files already downloaded and verified
epoch:  0 total_correct:  14776
training accuracy:  0.29552
testing accuracy:  0.424
epoch:  1 total_correct:  24451
training accuracy:  0.48902
testing accuracy:  0.5479
0.5479
Files already downloaded and verified
Files already downloaded and verified
epoch:  0 total_correct:  11519
training accuracy:  0.23038
testing accuracy:  0.3247
epoch:  1 total_correct:  17293
training accuracy:  0.34586
testing accuracy

In [57]:
import pandas as pd
list_acc = pd.DataFrame(list_acc)
list_acc.to_csv("acc.csv",index = False)

In [None]:
train_(train_set,test_set,18,100)

epoch:  0 total_correct:  19394
training accuracy:  0.38788
testing accuracy:  0.5088
epoch:  1 total_correct:  29893
training accuracy:  0.59786
testing accuracy:  0.6038
epoch:  2 total_correct:  34235
training accuracy:  0.6847
testing accuracy:  0.6631
epoch:  3 total_correct:  37032
training accuracy:  0.74064
testing accuracy:  0.7209
epoch:  4 total_correct:  39145
training accuracy:  0.7829
testing accuracy:  0.7548
epoch:  5 total_correct:  40576
training accuracy:  0.81152
testing accuracy:  0.7847
epoch:  6 total_correct:  41752
training accuracy:  0.83504
testing accuracy:  0.7948
epoch:  7 total_correct:  42569
training accuracy:  0.85138
testing accuracy:  0.8134
epoch:  8 total_correct:  43238
training accuracy:  0.86476
testing accuracy:  0.8255
epoch:  9 total_correct:  43809
training accuracy:  0.87618
testing accuracy:  0.8327
epoch:  10 total_correct:  44323
training accuracy:  0.88646
testing accuracy:  0.8356
epoch:  11 total_correct:  44845
training accuracy:  0.