In [1]:
import os
import copy
import math
import tqdm
import torch
import pickle
import random
import datetime

import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.distributions.beta as beta
import torchvision.transforms.functional as TF

from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
from torch.nn.functional import relu, avg_pool2d

In [2]:
seed = 0

In [3]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
'''Data Preparation Arguments'''
data_prep_from_scratch = True
dataset = 'cifar10'
path_dataset = './Data/CIFAR10/'

'''Data Augmentation Method'''
method = 'vhmixup'

'''Optimization Arguments'''
batch_size = 128
train_epochs = 225

lr = 0.01
momentum = 0.9
weight_decay = 5e-4
# Increase LR from 0.01 to 1.
lr_scheduler_1_gamma = 10.0; milestones_1 = [400]
# Lower LR from 1. to 0.01
lr_scheduler_2_gamma = 0.1; milestones_2 = [32000, 48000, 70000]

# Preprocess CIFAR

In [6]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [7]:
if data_prep_from_scratch and dataset == 'cifar100':
    cifar100_train = unpickle(os.path.join(path_dataset, 'train'))
    cifar100_test = unpickle(os.path.join(path_dataset, 'test'))

    x_tr = torch.from_numpy(cifar100_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_tr = torch.LongTensor(cifar100_train[b'fine_labels'])
    x_te = torch.from_numpy(cifar100_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar100_test[b'fine_labels'])

    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

In [8]:
if data_prep_from_scratch and dataset == 'cifar10':
    x_tr, y_tr = None, None
    for b in range(5):
        cifar10_train = unpickle(os.path.join(path_dataset, 'data_batch_{}'.format(b+1)))
        
        batch_img = torch.from_numpy(cifar10_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
        batch_label = torch.LongTensor(cifar10_train[b'labels'])
        
        if x_tr is None:
            x_tr = batch_img
            y_tr = batch_label
        else:
            x_tr = torch.cat((x_tr, batch_img), dim=0)
            y_tr = torch.cat((y_tr, batch_label), dim=0)
    
    cifar10_test = unpickle(os.path.join(path_dataset, 'test_batch'))
    x_te = torch.from_numpy(cifar10_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar10_test[b'labels'])
    
    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Load Data

In [9]:
def load_datasets(path):
    d = torch.load(path)
    d_tr = (d[0], d[1])
    d_te = (d[2], d[3])
    if dataset == 'cifar100':
        n_outputs = 100
    else:
        n_outputs = 10        
    return d_tr, d_te, n_outputs

In [10]:
d_tr, d_te, n_outputs = load_datasets(os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Dataloader

In [11]:
class CIFAR(torch.utils.data.Dataset):
    def __init__(self, pack, method, train=False):
        self.x = pack[0]
        self.y = pack[1]
        self.img_size = (3,32,32)
        
        self.method = method
        self.train = train
    
    def __len__(self):
        return len(self.x)
    
    def transform(self, img):
        top = torch.randint(0,8,(1,))
        left = torch.randint(0,8,(1,))
        img = TF.crop(img, top=top, left=left, height=self.img_size[1], width=self.img_size[2])
        
        if torch.rand(1) > 0.5:
            img = TF.vflip(img)
            
        return img
    
    def __getitem__(self, item):
        x = self.x[item].float()
        
        if 'bcplus' not in self.method:
            mean_image = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
        else:
            x = x - torch.mean(x)
            mean_image = torch.from_numpy(np.array([0.21921569, 0.21058824, 0.22156863])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
            
        x = x / 255.0 - mean_image
        x = x / std_image
        
        x = x.permute(2,0,1)
        
        if self.train:
            x = TF.pad(x, padding=4)
            x = self.transform(x)
            
        return x, self.y[item]

In [12]:
train_datasets = CIFAR(d_tr, method, train=True)
train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True, drop_last=True)

In [13]:
train_iterator = iter(train_dataloaders)

In [14]:
test_datasets = CIFAR(d_te, method)
test_dataloaders = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=False)

In [15]:
print(len(train_datasets), len(test_datasets))

50000 10000


In [16]:
print(len(train_dataloaders), len(test_dataloaders))

390 79


# Augmentations

In [17]:
def verticalConcat(pair_1, pair_2):
    lambda_vertical_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_vertical = lambda_vertical_beta.sample()
    
    img_1, label_1 = pair_1[0], pair_1[1]
    img_2, label_2 = pair_2[0], pair_2[1]
    
    h, w = img_1.shape[1], img_1.shape[2]
    
    if len(label_1.shape) == 0:
        label_1 = torch.zeros(n_outputs)
        label_1[pair_1[1]] = 1
        
    if len(label_2.shape) == 0:
        label_2 = torch.zeros(n_outputs)
        label_2[pair_2[1]] = 1
    
    vertical_concat = torch.zeros(img_1.shape)
    vertical_label = torch.zeros(n_outputs)

    vertical_concat[:,:int(lambda_vertical*h),:] = img_1[:,:int(lambda_vertical*h),:]
    vertical_concat[:,int(lambda_vertical*h):,:] = img_2[:,int(lambda_vertical*h):,:]
    
    vertical_label = int(lambda_vertical*h) / h * label_1 + (h - int(lambda_vertical*h)) / h * label_2
    
    return vertical_concat, vertical_label

In [18]:
def horizontalConcat(pair_1, pair_2):
    lambda_horizontal_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_horizontal = lambda_horizontal_beta.sample()
    
    img_1, label_1 = pair_1[0], pair_1[1]
    img_2, label_2 = pair_2[0], pair_2[1]
    
    h, w = img_1.shape[1], img_1.shape[2]
    
    if len(label_1.shape) == 0:
        label_1 = torch.zeros(n_outputs)
        label_1[pair_1[1]] = 1
        
    if len(label_2.shape) == 0:
        label_2 = torch.zeros(n_outputs)
        label_2[pair_2[1]] = 1
    
    horizontal_concat = torch.zeros(img_1.shape)
    horizontal_label = torch.zeros(n_outputs)
    
    horizontal_concat[:,:,:int(lambda_horizontal*w)] = img_1[:,:,:int(lambda_horizontal*w)]
    horizontal_concat[:,:,int(lambda_horizontal*w):] = img_2[:,:,int(lambda_horizontal*w):]
    
    horizontal_label = int(lambda_horizontal*w) / w * label_1 + (w - int(lambda_horizontal*w)) / w * label_2
    
    return horizontal_concat, horizontal_label

In [19]:
def VHMixup(pair_1, pair_2):
    lambda_mixup_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_mixup = lambda_mixup_beta.sample()
    
    vertical_concat, vertical_label = verticalConcat(pair_1, pair_2)
    horizontal_concat, horizontal_label = horizontalConcat(pair_1, pair_2)
    
    mixed_img = lambda_mixup * vertical_concat + (1 - lambda_mixup) * horizontal_concat
    mixed_label = lambda_mixup * vertical_label + (1 - lambda_mixup) * horizontal_label
    
    return mixed_img, mixed_label

In [20]:
def VHBCplus(pair_1, pair_2):
    vertical_concat, vertical_label = verticalConcat(pair_1, pair_2)
    horizontal_concat, horizontal_label = horizontalConcat(pair_1, pair_2)
    
    lambda_uni = torch.rand(1)
    lambda_factor = (1 - lambda_uni) / lambda_uni
    
    vertical_std = torch.std(vertical_concat)
    horizontal_std = torch.std(horizontal_concat)
    std_factor = vertical_std / horizontal_std

    p = 1 / (1 + std_factor * lambda_factor)
    
    denom = torch.sqrt(p**2 + (1-p)**2)
    
    bcplus_img = (p * vertical_concat + (1 - p) * horizontal_concat) / denom
    bcplus_label = lambda_uni * vertical_label + (1 - lambda_uni) * horizontal_label
    
    return bcplus_img, bcplus_label

# Model

In [21]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, nf):
        super(ResNet, self).__init__()
        self.in_planes = nf

        self.conv1 = conv3x3(3, nf * 1)
        self.bn1 = nn.BatchNorm2d(nf * 1)
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
        self.linear = nn.Linear(nf * 8 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        bsz = x.size(0)
        out = relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [22]:
def ResNet18(nclasses, nf=20):
    return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf)

# Train

### Log Files

In [26]:
ROOT_DIR = './Results/'
now =  '{}_ResNet_{}_{}'.format(dataset, method, seed)

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
else:
    import shutil
    shutil.rmtree(LOG_DIR)
    os.makedirs(LOG_DIR)
    
MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)

In [27]:
def augmentBatch(pair_1, pair_2, augmentor):
    img_1, label_1 = pair_1[0], pair_1[1]
    img_2, label_2 = pair_2[0], pair_2[1]
    
    augment_batch = torch.zeros(img_1.shape)
    augment_labels = torch.zeros(label_1.shape[0],n_outputs)
    
    for b in range(img_1.shape[0]):
        p1 = img_1[b], label_1[b]
        p2 = img_2[b], label_2[b]
        
        augment_batch[b], augment_labels[b] = augmentor(p1, p2)
        
    return augment_batch, augment_labels

In [28]:
model = ResNet18(n_outputs).to(device)

In [29]:
if method == 'vhmixup':
    augmentor = VHMixup
else:
    augmentor = VHBCplus

### Optimizer and Schedulers

In [30]:
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler_1 = torch.optim.lr_scheduler.MultiStepLR(opt, gamma=lr_scheduler_1_gamma, milestones=milestones_1)
scheduler_2 = torch.optim.lr_scheduler.MultiStepLR(opt, gamma=lr_scheduler_2_gamma, milestones=milestones_2)

In [31]:
criterion = nn.CrossEntropyLoss().to(device)

In [32]:
def softXEnt(output, target):
    logprobs = torch.nn.functional.log_softmax(output, dim = 1)
    return  -(target * logprobs).sum() / output.shape[0]

### Training

In [None]:
for epoch in range(train_epochs+1):
    model.train()
    for i, d in enumerate(tqdm.tqdm(train_dataloaders)):
        try:
            d1 = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_dataloaders)
            d1 = next(train_iterator)
        
        try:
            d2 = next(train_iterator)
        except StopIteration:
            train_iterator = iter(train_dataloaders)
            d2 = next(train_iterator)
            
        x1, y1 = d1[0], d1[1]
        x2, y2 = d2[0], d2[1]
        
        x, y = augmentBatch((x1,y1), (x2,y2), augmentor)
        x = x.float().to(device)
        y = y.float().to(device)
        
        opt.zero_grad()
        
        out = model(x)
        
        loss = softXEnt(out, y)
        loss.backward()
        
        opt.step()
        
        summary_writer.add_scalar('Loss', loss.item())
        
        # Scheduler is defined based on total number of iterations
        scheduler_1.step()
        scheduler_2.step()
        
    model.eval()
    total_acc = 0
    for i, d in enumerate(test_dataloaders):
        x, y = d[0], d[1]
        x = x.float().to(device)
        y = y.long().to(device)
        
        with torch.no_grad():
            out_prob = model(x)
        
        pred = torch.argmax(out_prob, dim=1)
        prediction = pred.cpu().numpy()
        truth = y.cpu().numpy()
        acc = np.count_nonzero(prediction == truth)

        total_acc += acc
        
    print('After the accuracy after {} epochs is {}'.format(epoch, total_acc / len(test_datasets)))
    print()

    summary_writer.add_scalar('Eval ACC', total_acc / len(test_datasets))
    
    torch.save(model.state_dict(), MODEL_DIR+'epoch_{}.pth'.format(epoch))

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:24<00:00,  4.63it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 0 epochs is 0.3967



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:24<00:00,  4.62it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 1 epochs is 0.4963



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:30<00:00,  4.33it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 2 epochs is 0.5347



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:33<00:00,  4.17it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 3 epochs is 0.5925



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:39<00:00,  3.91it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 4 epochs is 0.6145



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:45<00:00,  3.70it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 5 epochs is 0.5724



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:47<00:00,  3.64it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 6 epochs is 0.6345



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:38<00:00,  3.95it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 7 epochs is 0.6836



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:45<00:00,  3.68it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 8 epochs is 0.6194



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:44<00:00,  3.74it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 9 epochs is 0.6365



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:44<00:00,  3.74it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 10 epochs is 0.6566



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:48<00:00,  3.58it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 11 epochs is 0.7052



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:40<00:00,  3.89it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 12 epochs is 0.6431



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:39<00:00,  3.91it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 13 epochs is 0.6249



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:36<00:00,  4.05it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 14 epochs is 0.5986



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 15 epochs is 0.7063



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 16 epochs is 0.5696



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:40<00:00,  3.87it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 17 epochs is 0.6623



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 18 epochs is 0.7208



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:41<00:00,  3.85it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 19 epochs is 0.6738



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:41<00:00,  3.83it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 20 epochs is 0.6217



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:42<00:00,  3.79it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 21 epochs is 0.7066



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:42<00:00,  3.79it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 22 epochs is 0.6773



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:44<00:00,  3.73it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 23 epochs is 0.6395



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:43<00:00,  3.78it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 24 epochs is 0.6919



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:41<00:00,  3.84it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 25 epochs is 0.6896



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:46<00:00,  3.67it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 26 epochs is 0.7284



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:43<00:00,  3.76it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 27 epochs is 0.6737



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:41<00:00,  3.85it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 28 epochs is 0.73



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:45<00:00,  3.69it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 29 epochs is 0.6972



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:41<00:00,  3.85it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 30 epochs is 0.6859



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:43<00:00,  3.79it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 31 epochs is 0.6966



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:42<00:00,  3.80it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 32 epochs is 0.7462



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:37<00:00,  3.99it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 33 epochs is 0.5582



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:44<00:00,  3.74it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 34 epochs is 0.6914



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:40<00:00,  3.88it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 35 epochs is 0.7133



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:43<00:00,  3.76it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 36 epochs is 0.7253



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.09it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 37 epochs is 0.6977



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 38 epochs is 0.7565



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 39 epochs is 0.6915



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:38<00:00,  3.97it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 40 epochs is 0.7317



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 41 epochs is 0.676



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 42 epochs is 0.7142



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 43 epochs is 0.6481



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 44 epochs is 0.6953



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 45 epochs is 0.6938



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 46 epochs is 0.7473



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.07it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 47 epochs is 0.7148



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:34<00:00,  4.12it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 48 epochs is 0.7016



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:33<00:00,  4.17it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 49 epochs is 0.6911



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:33<00:00,  4.17it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 50 epochs is 0.6227



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:33<00:00,  4.18it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 51 epochs is 0.7452



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:40<00:00,  3.87it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 52 epochs is 0.7137



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:35<00:00,  4.08it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 53 epochs is 0.6259



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:38<00:00,  3.97it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 54 epochs is 0.7108



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:37<00:00,  4.00it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 55 epochs is 0.6649



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:36<00:00,  4.02it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 56 epochs is 0.6724



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:36<00:00,  4.02it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 57 epochs is 0.6527



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:36<00:00,  4.03it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 58 epochs is 0.7462



100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [01:37<00:00,  3.99it/s]
  0%|                                                                                          | 0/390 [00:00<?, ?it/s]

After the accuracy after 59 epochs is 0.7053



 45%|████████████████████████████████████                                            | 176/390 [00:45<00:54,  3.93it/s]

In [None]:
total_acc / len(test_datasets)

In [None]:
torch.save(model.state_dict(), MODEL_DIR+'final.pth')

# Evaluate

In [None]:
model.eval()
total_acc = 0
for i, d in enumerate(test_dataloaders):
    x, y = d[0], d[1]
    x = x.float().to(device)
    y = y.long().to(device)

    with torch.no_grad():
        out_prob = model(x)

    pred = torch.argmax(out_prob, dim=1)
    prediction = pred.cpu().numpy()
    truth = y.cpu().numpy()
    acc = np.count_nonzero(prediction == truth)

    total_acc += acc

print('After the accuracy after {} epochs is {}'.format(epoch, total_acc / len(test_datasets)))