# Project description

В данном задании вам предстоит осуществить путешевствие в мир Спрингфилда,
где вы сможете познакомиться со всеми любимыми персонажами Симпсонов.

Основным заданием будет обучить классификатор на основе сверточных сетей,
чтобы научиться отличать всех жителей Спрингфилда.
# Dataset description
Обучающая и тестовая выборка состоят из отрывков из мультсериала Симпсоны.
Каждая картинка представлена в формате jpg c необходимой меткой - названием
персонажа изображенного на ней. Тест был поделен на приватную и публичную
часть в соотношении 95/5

В тренировочном датасете примерно по 1000 картинок на каждый класс,
но они отличаются размером.

Метки классов представлены в виде названий папок, в которых лежат картинки.

# Table of content:
1. [__Data preparation__](#data_preparation)
2. [__Training models__](#training_models)
    * [__Data augmentation__](#data_augmentation)
    * [__Models__](#models)
        * [_AlexNet_](#alexnet)
        * [_VGG19_](#vgg19_bn)
        * [_ResNet152_](#resnet152)

# <a name='data_preparation'>1. Data preparation</a>

In [None]:
# download data from here
# https://www.kaggle.com/c/journey-springfield/data
import os.path
import sys
if 'google' in sys.modules:
    from google.colab import drive
    drive.mount('/content/drive')
    !mkdir Data
    if not os.path.exists('Data/train'):
        !cp drive/My\ Drive/Colab/Stepik/Kaggle/journey-springfield.zip Data
        !unzip -q -n Data/journey-springfield.zip -d Data
        !rm Data/journey-springfield.zip

In [None]:
# load model and train functions
# from my other repositories
# https://github.com/AllexFrolov/MobileNet_v3-PyTorch
if not os.path.isfile('MobileNet_v3.py'):
    !wget -q https://raw.githubusercontent.com/AllexFrolov/MobileNet_v3-PyTorch/master/MobileNet_v3.py
if not os.path.isfile('functions.py'):
    !wget -q https://raw.githubusercontent.com/AllexFrolov/MobileNet_v3-PyTorch/master/functions.py

# -------- for debugging ----------
import functions
import MobileNet_v3
from importlib import reload
functions = reload(functions)
MobileNet_v3 = reload(MobileNet_v3)
# --------------------------------
from functions import train, accuracy

In [None]:
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset = ImageFolder('Data/train/simpsons_dataset')

In [None]:
# look at the image
np.random.seed(42)

fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(8, 8),
                       sharey=True, sharex=True)

for fig_x in ax.flatten():
    random_characters = np.random.choice(len(dataset), 1)[0]
    im, label = dataset[random_characters]
    img_label = " ".join(map(lambda x: x.capitalize(),
                             dataset.classes[label].split('_')))
    im = im.resize((224, 244))
    fig_x.imshow(im)
    if img_label is not None:
        fig_x.set_title(img_label)
    fig_x.grid(False)

In [None]:
# Create custom DataLoader
class MyDataLoader:
    def __init__(self, data, indices: list, batch_size: int, transformer=None, shuffle=False):
        assert type(shuffle) is bool, \
            f'shuffle should be bool type, not {type(shuffle)}'
        assert type(batch_size) is int, \
            f'batch_size should be type int, not {type(batch_size)}'

        self.shuffle = shuffle
        self.batch_size = batch_size
        self.indices = indices
        self.data = data
        self.data_len = len(indices)
        self.len_ = int(np.ceil(self.data_len / batch_size))

        self.transformer = transformer
        if transformer is None:
            self.transformer = transforms.ToTensor()

    def __len__(self):
        return self.len_

    def __getitem__(self, index):
        start_index = index * self.batch_size
        end_index = min(self.data_len, start_index + self.batch_size)
        batch_indices = self.indices[start_index: end_index]
        X_batch = []
        y_batch = []
        for batch_index in batch_indices:
            X, y = self.data[batch_index]
            X = self.transformer(X)
            X_batch.append(X)
            y_batch.append(y)
        if len(X_batch) > 1:
            X_batch = torch.stack(X_batch)
        else:
            X_batch = torch.unsqueeze(X_batch[0], 0)
        return X_batch, torch.Tensor(y_batch)

    def __next__(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
        for n_batch in range(self.len_):
            return self.__getitem__(n_batch)

In [None]:
# split data
train_val_indices, test_indices = train_test_split(np.arange(len(dataset)),
                                                   train_size=0.75)

train_indices, val_indices = train_test_split(train_val_indices,
                                              train_size=0.75)

# <a name='training_models'>2. Training models</a>

## <a name='data_augmentation'>Data augmentation</a>


In [None]:
IM_SIZE = (224, 224)
batch_size = 64

train_transformer = transforms.Compose([transforms.Resize(IM_SIZE),
                                        transforms.RandomRotation(15),
                                        transforms.ColorJitter(0.5, 0.5),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406],
                                                             [0.229, 0.224, 0.225])
                                        ])

val_transformer = transforms.Compose([transforms.Resize(IM_SIZE),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406],
                                                             [0.229, 0.224, 0.225])
                                        ])

# Create data loaders
train_loader = MyDataLoader(dataset, train_indices, batch_size,
                            train_transformer, True)
val_loader = MyDataLoader(dataset, val_indices, batch_size, val_transformer)
test_loader = MyDataLoader(dataset, test_indices, batch_size)

In [None]:
num_classes = len(dataset.classes)

## <a name='models'>Models</a>

In [None]:
models_history = {}

### <a name='alexnet'>AlexNet</a>

In [None]:
# Download pretrained model
alexnet = models.alexnet(pretrained=True)

# freeze parameters
for param in alexnet.parameters():
    param.requires_grad = False

# replace last layer
in_dim = alexnet.classifier[-1].in_features
classifier = nn.Linear(in_dim, num_classes)
alexnet.classifier[-1] = classifier
alexnet = alexnet.to(DEVICE)

In [None]:
lr = 1e-2
WEIGHT_DECAY = 1e-5
optimizer = torch.optim.Adam(alexnet.parameters(),
                             lr=lr,
                             weight_decay=WEIGHT_DECAY)

FACTOR = 0.5
THRESHOLD = 0.01
PATIENCE = 1

loss_func = nn.CrossEntropyLoss().to(DEVICE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'max', FACTOR, PATIENCE, True, THRESHOLD
)

In [None]:
%%time
epoch_count=10
history, best_param = \
        train(alexnet, train_loader, loss_func, optimizer, epoch_count,
              accuracy, val_loader, scheduler)

In [None]:
models_history['alexnet'] = history

### <a name='vgg19_bn'>VGG19</a>

In [None]:
# Download pretrained model
vgg19_bn = models.vgg19_bn(pretrained=True)

# freeze parameters
for param in vgg19_bn.parameters():
    param.requires_grad = False

# replace last layer
in_dim = vgg19_bn.classifier[-1].in_features
classifier = nn.Linear(in_dim, num_classes)
vgg19_bn.classifier[-1] = classifier
vgg19_bn = vgg19_bn.to(DEVICE)

In [None]:
lr = 1e-2
WEIGHT_DECAY = 1e-5
optimizer = torch.optim.Adam(vgg19_bn.parameters(),
                             lr=lr,
                             weight_decay=WEIGHT_DECAY)

FACTOR = 0.5
THRESHOLD = 0.01
PATIENCE = 1

loss_func = nn.CrossEntropyLoss().to(DEVICE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'max', FACTOR, PATIENCE, True, THRESHOLD
)

In [None]:
%%time
epoch_count=10
history, best_param = \
        train(vgg19_bn, train_loader, loss_func, optimizer, epoch_count,
              accuracy, val_loader, scheduler)


In [None]:
models_history['vgg19_bn'] = history
