In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

from utils import *
from models import *
from train import predict_model
from dataset import CIFAR_100_Dataset

In [2]:
seed_all(0)
data_path = './dataset/cifar100'
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.0203, 0.1994, 0.2010])])
train_valid_dataset = datasets.cifar.CIFAR100(data_path, train=True, transform=transform)
test_dataset = datasets.cifar.CIFAR100(data_path, train=False, transform=transform)
train_dataset, valid_dataset, train_label, valid_label= train_test_split(train_valid_dataset.data, train_valid_dataset.targets, test_size=0.2, stratify=train_valid_dataset.targets)
aug_train_dataset = CIFAR_100_Dataset(train_dataset, train_label, shuffle=True, prob=0.2, augment='cutmix', beta=1)
train_dataset = CIFAR_100_Dataset(train_dataset / 255, train_label)
valid_dataset = CIFAR_100_Dataset(valid_dataset / 255, valid_label)
test_dataset = CIFAR_100_Dataset(test_dataset.data / 255, test_dataset.targets)

In [3]:
BATCH_SIZE = 1024
INPUT_CHANNEL = 3
OUTPUT_CHANNEL = 100

origin_params = torch.load('./log/lr_0.001_weight_decay_0.0_aug_None/model_param.pth')
cutmix_params = torch.load('./log/lr_0.001_weight_decay_0.0_aug_cutmix_prob_0.5_beta_1.0/model_param.pth')
cutout_params = torch.load('./log/lr_0.001_weight_decay_0.0_aug_cutout_prob_0.5_beta_1.0/model_param.pth')
mixup_params = torch.load('./log/lr_0.001_weight_decay_0.0_aug_mixup_prob_0.5_beta_1.0/model_param.pth')

In [4]:
aug_train_loader = DataLoader(aug_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

loss_func = F.cross_entropy
resnet18_model = ResNet18(input_channel=INPUT_CHANNEL, output_channel=OUTPUT_CHANNEL)
metrics = Metrics(['accuracy', 'precision', 'recall', 'f1_score'])

In [5]:
print("Origin Model")
resnet18_model.load_state_dict(origin_params)
predict_model(resnet18_model, test_loader, loss_func, 'cpu', metrics=metrics)

Origin Model


(1.9954503675460815,
 tensor([72., 33., 72.,  ..., 37., 42., 70.]),
 {'accuracy': 0.5207,
  'precision': 0.5616343477782396,
  'recall': 0.5207,
  'micro_f1': 0.5207,
  'macro_f1': 0.5259116230695718})

In [6]:
print("CutMix Model")
resnet18_model.load_state_dict(cutmix_params)
predict_model(resnet18_model, test_loader, loss_func, 'cpu', metrics=metrics)

CutMix Model


(1.8469962120056151,
 tensor([ 2., 33., 72.,  ..., 51., 42., 70.]),
 {'accuracy': 0.5416,
  'precision': 0.5816391250267754,
  'recall': 0.5416000000000001,
  'micro_f1': 0.5416,
  'macro_f1': 0.5381288713646828})

In [7]:
print("Cutout Model")
resnet18_model.load_state_dict(cutout_params)
predict_model(resnet18_model, test_loader, loss_func, 'cpu', metrics=metrics)

Cutout Model


(2.126535963821411,
 tensor([49., 33., 55.,  ..., 51., 42., 70.]),
 {'accuracy': 0.5241,
  'precision': 0.5556018649555718,
  'recall': 0.5241,
  'micro_f1': 0.5241,
  'macro_f1': 0.5213198210847135})

In [9]:
print("Mixup Model")
resnet18_model.load_state_dict(mixup_params)
predict_model(resnet18_model, test_loader, loss_func, 'cpu', metrics=metrics)

Mixup Model


(1.962523399734497,
 tensor([68., 33., 30.,  ..., 51., 97., 70.]),
 {'accuracy': 0.5335,
  'precision': 0.5539419456471145,
  'recall': 0.5335,
  'micro_f1': 0.5335,
  'macro_f1': 0.5285037689820197})