In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt
import train_and_test_model
from train_and_test_model import trainModel,testModel,attention_map
from torchvision.transforms import TrivialAugmentWide,AutoAugment, AutoAugmentPolicy
from ResNet_Classifier import classifierTrain

## Load Dataset

In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

transform_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# transform_aug = transforms.Compose([
#     AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean, std)
# ])

# transform_aug = transforms.Compose([
#     transforms.Resize(256),
#     transforms.RandomResizedCrop(224),
#     #transforms.RandomHorizontalFlip(p=0.5),
#     #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
#     transforms.Normalize(mean, std)
# ])

In [None]:
train_dataset_origin = datasets.ImageFolder(root='./datasets/train', transform=transform)
train_loader_origin = DataLoader(train_dataset_origin, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

test_dataset_origin = datasets.ImageFolder(root='./datasets/test', transform=transform)
test_loader_origin = DataLoader(test_dataset_origin, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

val_dataset_origin = datasets.ImageFolder(root='./datasets/val', transform=transform)
val_loader_origin = DataLoader(val_dataset_origin, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

train_dataset_aug = datasets.ImageFolder(root='./datasets/train', transform=transform_aug)
train_loader_aug = DataLoader(train_dataset_aug, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

val_dataset_aug = datasets.ImageFolder(root='./datasets/val', transform=transform_aug)
val_loader_aug = DataLoader(val_dataset_aug, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

perturbation_l1 = datasets.ImageFolder(root='./datasets_l1', transform=transform)
test_loader_l1 = DataLoader(perturbation_l1, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

perturbation_l2 = datasets.ImageFolder(root='./datasets_l2', transform=transform)
test_loader_l2 = DataLoader(perturbation_l2, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

perturbation_l3 = datasets.ImageFolder(root='./datasets_l3', transform=transform)
test_loader_l3 = DataLoader(perturbation_l3, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
def show_augmented_image(ori, aug, index=0):
    img_ori, label_ori = ori[index]
    img_aug, label_aug = aug[index]
    img_ori = img_ori.permute(1, 2, 0)  # CHW -> HWC
    img_aug = img_aug.permute(1, 2, 0)  # CHW -> HWC
    img_ori = img_ori * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])  # unnormalize
    img_aug = img_aug * torch.tensor([0.229, 0.224, 0.225]) + torch.tensor([0.485, 0.456, 0.406])  # unnormalize
    img_ori = img_ori.clip(0, 1)
    img_aug = img_aug.clip(0, 1)
    plt.figure(figsize=(6,8))
    plt.subplot(1,2,1)
    plt.imshow(img_ori)
    plt.title(f" Original")
    plt.axis('off')

    plt.subplot(1,2,2)
    plt.imshow(img_aug)
    plt.title(f" Augmentation")
    plt.axis('off')
    plt.show()

show_augmented_image(train_dataset_origin, train_dataset_aug, 10)

## Fine-tuning Test

### Summary
- ResNet18 origin_acc: 91.06% aug_acc: 92.22%
- ResNet50 origin_acc: 93.17% aug_acc: 93.78%
- ResNet101 origin_acc: 92.44% aug_acc: 93.33%

### ResNet18 + LongTail Experiment/ Data Augmentation

In [None]:
trainModel(18,train_loader_origin,val_loader_origin,'origin')
testModel(18,test_loader_origin,'origin')

In [None]:
trainModel(18,train_loader_aug,val_loader_aug,'aug')
testModel(18,test_loader_origin,'aug')

### ResNet50 + LongTail Experiment/ Data Augmentation

In [None]:
trainModel(50,train_loader_origin,val_loader_origin,'origin')
testModel(50,test_loader_origin,'origin')

In [None]:
trainModel(50,train_loader_aug,val_loader_aug,'aug')
testModel(50,test_loader_origin,'aug')

In [None]:
attention_map(50,test_loader_origin,'aug')

### ResNet101 + LongTail Experiment/ Data Augmentation

In [None]:
trainModel(101,train_loader_origin,val_loader_origin,'origin')
testModel(101,test_loader_origin,'origin')

In [None]:
trainModel(101,train_loader_aug,val_loader_aug,'aug')
testModel(101,test_loader_origin,'aug')

## SVM Classfier

### Summary
- ResNet18 origin_acc: 94.16%
- ResNet50 origin_acc: 96.11%
- ResNet101 origin_acc: 94.66%

Data perturbation on ResNet50 + SVM
- L1: 93.44%
- L2: 83.33%
- L3: 59.61%

Different Classifier on ResNet50:
- SVM: 96.11%
- KNN: 88.33%
- MLP: 94.77%
- Proto: 81.2%

In [None]:
classifierTrain(18,train_loader_origin,test_loader_origin,'SVM')

In [None]:
classifierTrain(50,train_loader_origin,test_loader_origin,'SVM')

In [None]:
# data pertubation level1
classifierTrain(50,train_loader_origin,test_loader_l1,'SVM')

In [None]:
# data pertubation level2
classifierTrain(50,train_loader_origin,test_loader_l2,'SVM')

In [None]:
# data pertubation level3
classifierTrain(50,train_loader_origin,test_loader_l3,'SVM')

In [None]:
classifierTrain(50,train_loader_origin,test_loader_origin,'KNN')

In [None]:
classifierTrain(50,train_loader_origin,test_loader_origin,'MLP')

In [None]:
classifierTrain(50,train_loader_origin,test_loader_origin,'Proto')

In [None]:
classifierTrain(50,train_loader_aug,test_loader_origin,'SVM')

In [None]:
classifierTrain(101,train_loader_origin,test_loader_origin,'SVM')

In [None]:
classifierTrain(101,train_loader_aug,test_loader_origin,'SVM')