In [8]:
import torch
import torchvision
from torchvision import transforms, models
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, ConcatDataset, random_split, Subset
from copy import deepcopy
from timm.data.loader import create_loader
from matplotlib import pyplot as plt
import numpy as np
from timm.data.mixup import Mixup
import warnings
warnings.filterwarnings("ignore")

In [9]:
image_transform_origin = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])
test_data = torchvision.datasets.CIFAR100(root='data', transform=image_transform_origin, train=False, download=False)
test_size = len(test_data)

In [10]:
test_loader = DataLoader(dataset=test_data, batch_size=128, num_workers=10)
device = 'cuda:0'
loss_function = nn.CrossEntropyLoss().to(device)
model_no_aug = torch.load('models_save/model.pt').to(device)
model_cutout = torch.load('models_save/model_cutout.pt').to(device)
model_mixup = torch.load('models_save/model_mixup.pt').to(device)
model_cutmix = torch.load('models_save/model_cutmix.pt').to(device)


In [11]:
test_loss = .0
test_acc = .0
for idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            model_no_aug.eval()
            
            outputs = model_no_aug(inputs)
            loss = loss_function(outputs, targets)
            _, predicts = torch.max(outputs.data, 1)
            correct_counts = targets.eq(predicts.data.view_as(targets))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            test_loss += loss.item() * inputs.size(0)
            test_acc += acc.item() * inputs.size(0)
test_loss /= test_size
test_acc /= test_size
print('No augument')
print('test loss:%-7.2f test accuracy:%-7.2f' % (test_loss, test_acc))

No augument
test loss:4.51    test accuracy:0.37   


In [12]:
test_loss = .0
test_acc = .0
for idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            model_cutout.eval()
            
            outputs = model_cutout(inputs)
            loss = loss_function(outputs, targets)
            _, predicts = torch.max(outputs.data, 1)
            correct_counts = targets.eq(predicts.data.view_as(targets))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            test_loss += loss.item() * inputs.size(0)
            test_acc += acc.item() * inputs.size(0)
test_loss /= test_size
test_acc /= test_size
print('Cutout')
print('test loss:%-7.2f test accuracy:%-7.2f' % (test_loss, test_acc))

Cutout
test loss:3.92    test accuracy:0.40   


In [13]:
test_loss = .0
test_acc = .0
for idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            model_mixup.eval()
            
            outputs = model_mixup(inputs)
            loss = loss_function(outputs, targets)
            _, predicts = torch.max(outputs.data, 1)
            correct_counts = targets.eq(predicts.data.view_as(targets))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            test_loss += loss.item() * inputs.size(0)
            test_acc += acc.item() * inputs.size(0)
test_loss /= test_size
test_acc /= test_size
print('Mixup')
print('test loss:%-7.2f test accuracy:%-7.2f' % (test_loss, test_acc))

Mixup
test loss:3.38    test accuracy:0.35   


In [14]:
test_loss = .0
test_acc = .0
for idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            model_cutmix.eval()
            
            outputs = model_cutmix(inputs)
            loss = loss_function(outputs, targets)
            _, predicts = torch.max(outputs.data, 1)
            correct_counts = targets.eq(predicts.data.view_as(targets))
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
            
            test_loss += loss.item() * inputs.size(0)
            test_acc += acc.item() * inputs.size(0)
test_loss /= test_size
test_acc /= test_size
print('Cutmix')
print('test loss:%-7.2f test accuracy:%-7.2f' % (test_loss, test_acc))

Cutmix
test loss:2.87    test accuracy:0.37   
