 ### Imports:

In [None]:
# standard imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")
# torch
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset
from torch.utils.data import DataLoader
# from torchvision import datasets as torchV_datasets
# I end up using the word datasets as a variable all the time
import torchvision
from torchvision.io import read_image
from torchvision import datasets as torchV_datasets
from torchvision import transforms, utils
from torchvision.utils import make_grid
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import Subset
from torch import optim
# from torchvision import datasets, transforms
from tqdm.auto import tqdm
# sklearn
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score
# others
import os
import shutil
from distutils.dir_util import copy_tree
import glob as glob
# import splitfolders #to split image data into train, val and test sets
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import collections
# seeding
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

### Utility Class:
This class will move anime image folders from the dataset called anime-images-dataset to your local kaggle working directory.  
Future Work:
- [ ] Add anime list based selection

In [None]:
class UtilityWorker():
    """ 
    Utility class:
        params:
            num_anime: define number of anime to classify
        notes:
            - These will be randomly chosen
            - Chosen anime are copied in the kaggle local working directory
    """
    def __init__(self, num_anime, seed=42):
        """
        Randomly choosing anime to copy to local working directory
        """
        self.ANIME_IMAGES = "../input/anime-images-dataset/data/anime_images/"
        self.INPUT = "/kaggle/working/input/"
        # seed needs to be initialized before a call to random since threads can interfere with it
        # https://towardsdatascience.com/random-seeds-and-reproducibility-933da79446e3
        random.seed(seed)
        animelist = random.sample(os.listdir(self.ANIME_IMAGES),k=num_anime)
        # the copying operation may take upto 6-8 mins if you choose all the anime
        print(f"The anime chosen are: \n{animelist}")
        # copying chosen anime from dataset to local working dir, in a folder called input
        self.copy_folder_helper(animelist)             
        print(f"These anime can be found within folders of their own name in: {self.INPUT}")
        
    def copy_folder_helper(self, animelist):
        """
        Helper function that does the actual copying.
        """
        for anime in animelist:
            copy_tree(self.ANIME_IMAGES+anime, self.INPUT+anime)

In [None]:
utility = UtilityWorker(10, seed=42)

### Data Class:
We define a data class that will create train and test split from the images in the kaggle local working directory.  
We don't need a validation split here since I am not bothered to tune hyperparameters right now.  
This class will also count the class instances in the splits.

In [None]:
class Data():
    def __init__(self):
        self.INPUT = "/kaggle/working/input/"
        self.separator = "#"*80
        
    def get_full_dataset(self):
        self.base_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                                   transforms.ToTensor(),])
        self.full_dataset = torchV_datasets.ImageFolder(self.INPUT, transform=self.base_transforms)
        print(f"Full dataset details:")
        print(self.full_dataset)
        print(self.separator)
        return self.full_dataset
    
    def get_train_test_splits(self, dataset, test_split=0.20):
        train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=test_split)
        split_dataset = {}
        split_dataset['train'] = Subset(dataset, train_idx)
        split_dataset['val'] = Subset(dataset, val_idx)
        train_dataset = split_dataset['train']
        test_dataset = split_dataset['val']
        print(f"Size of the train dataset:{len(train_dataset)}")
        print(f"Size of the test dataset:{len(test_dataset)}")
        print(self.separator)
        return train_dataset, test_dataset
    
    def count_labels_in_dataset(self, dataset, dataset_name="optional"):
        labels = []
        if dataset_name != "optional":
            print(f"Counting labels in {dataset_name} dataset....")
        else:
            print(f"Counting labels....")
        for _, label in tqdm(dataset):
            labels.append(label)
        print(f"Label counts are:\n{collections.Counter(labels)}")
        print(self.separator)

In [None]:
data_obj = Data()
full_dataset = data_obj.get_full_dataset()
train_dataset, test_dataset = data_obj.get_train_test_splits(full_dataset)
data_obj.count_labels_in_dataset(train_dataset, "train")
data_obj.count_labels_in_dataset(test_dataset, "test")

### Plotting Images:
The following two functions were borrowed from this nice [tutorial](https://www.kaggle.com/code/shtrausslearning/pytorch-multiclass-image-classification#2-|-GET-THE-TRAINING-DATA).
Future Work:
- [ ] Write our own image display function. Maybe one that's easier to read for me.
- [ ] Include this in a class of its own??


In [None]:
# These functions have been borrowed as they are from:
# https://www.kaggle.com/code/shtrausslearning/pytorch-multiclass-image-classification#2-|-GET-THE-TRAINING-DATA
def plot_img(img,y=None,color=True):
    npimg = img.numpy()
    npimg_T = np.transpose(npimg,(1,2,0))
    plt.imshow(npimg_T)
    plt.title('Image samples from each of the 10 classes')
    plt.axis('on')

def plot_tensor(tensor,random_id=False,class_id=None):
    
    if(random_id is True):
        rnd_inds = np.random.randint(0,len(tensor),100) 
        X_show = [tensor[i][0] for i in rnd_inds]
        target = [tensor[i][1] for i in rnd_inds]
    else:
        
        if(class_id is None):
            X_show = []
            # cycle through all classes
            for j in range(0,10):
                ii=-1
                for i in range(0,1000):
                    if(tensor[i][1] is j):
                        ii+=1
                        if(ii>19):
                            break
                        else:
                            X_show.append(tensor[i][0])
                        
        if(class_id is not None):
            
            print(f'Showing samples from {len(tensor)} tensors:')
            
            X_show = []
            ii=-1
            for i in range(0,1000):
                if(tensor[i][1] is class_id):
                    ii+=1
                    if(ii>19):
                        break
                    else:
                        X_show.append(tensor[i][0])
            
            
    X_grid = utils.make_grid(X_show,nrow=20,padding=1)
    plt.figure(figsize=(30,10))
    plot_img(X_grid,y=None,color=True)

In [None]:
plot_tensor(train_dataset)

In [None]:
plot_tensor(train_dataset,class_id=4)

### Dataloaders:
Creating the train and test dataloaders in preparation for training.

In [None]:
train_transforms = transforms.Compose([transforms.RandomRotation(15),
                                      transforms.Resize((232, 232)),
                                      transforms.RandomVerticalFlip(p=0.5),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.ToTensor(),])
test_transforms = transforms.Compose([transforms.Resize((232, 232)),
                                     transforms.ToTensor(),])
train_dataset.transforms = train_transforms
test_dataset.transforms = test_transforms

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) 
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
x,y = next(iter(train_dataloader))
print(f"Train batch data shapes:\n{x.shape,y.shape}")
x,y = next(iter(test_dataloader))
print(f"Test batch data shapes:\n{x.shape,y.shape}")

### Training:
I will use a pretrained resnet50.  
Future Work: 
- [ ] I think the dataset could use some manual / automated cleaning and the results will probably be terrible, but this is meant to be a learning experience, so the results can be improved later.  
- [ ] Manual cleaning would be a pain and automated cleaning would require a data pipeline since I mean to do the scraping periodically.

In [None]:
# list of models
# https://pytorch.org/vision/stable/models.html
model = resnet50(weights="IMAGENET1K_V2")
# we need to preprocess our batch in (B, C, H, W) form according to way ImageNet was pre-processed
# we can use the preprocess defined below on a batch now as a function: preprocess(batch)
preprocess = ResNet50_Weights.IMAGENET1K_V2.transforms()
# we need to fix the final out_features / number of classes in the fully connected layer at the end of the resnet
print(f"ImageNet Defaults: {model.fc}")
fc_inputs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(fc_inputs, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, 10), #10 is num_classes
    nn.LogSoftmax(dim=1) # For using NLLLoss()
)
model = model.to(device)
print(f"After update: {model.fc}")

In [None]:
def eval_run(model=model, valid_loader=test_dataloader, device=device):
    with torch.no_grad():
        model.eval()
        labels = torch.tensor([])
        predictions = torch.tensor([])
        model.eval()
        valid_loss = 0.0
        for inputs, labels in tqdm(valid_loader, desc="Valid"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs) #log probs
            predicted_class = outputs.argmax(dim=1)
            loss = criterion(outputs, labels)
            valid_loss += loss.item() #explore when we multiply this by input size, probably when reduction = sum?
            valid_accuracy = 100*accuracy_score(labels.detach().cpu(), predicted_class.detach().cpu())
        return valid_loss, valid_accuracy

In [None]:
criterion = nn.NLLLoss(reduction="mean")
optimizer = optim.Adam(model.parameters())
epochs = 25

In [None]:
def train(epochs=epochs, model=model, train_dataloader=train_dataloader, device=device):
    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    for epoch in tqdm(range(epochs), desc="Epochs"):
        model.train()
        train_loss, train_accuracy = 0.0, 0.0
        valid_loss, valid_accuracy = 0.0, 0.0
        for inputs, labels in tqdm(train_dataloader, desc="Train"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs) #log probs
            predicted_class = outputs.argmax(dim=1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() #explore when we multiply this by input size, probably when reduction = sum?
            train_accuracy = 100*accuracy_score(labels.detach().cpu(), predicted_class.detach().cpu())
        print(f"Epoch {epoch}, Training Loss: {train_loss}, Training Accuracy: {train_accuracy}")
        valid_loss, valid_accuracy = eval_run()
        print(f"Epoch {epoch}, Test Loss: {valid_loss}, Test Accuracy: {valid_accuracy}")
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_accuracy)
    return train_losses, train_accuracies, valid_losses, valid_accuracies

In [None]:
train_losses, train_accuracies, valid_losses, valid_accuracies = train()

In [None]:
save_dir = "/kaggle/working/models/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
torch.save(model.state_dict(), save_dir+"trained_resnet50_25epochs.pt")

### Plotting curves:

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(15, 7.5)
fig.suptitle("Train-Test Losses and Accuracies")
train_losses, train_accuracies, valid_losses, valid_accuracies
x = [e for e in range(epochs)]
ax1.set_title("Loss vs. Epochs")
ax1.plot(x, train_losses, label="Train Loss")
ax1.plot(x, valid_losses, label="Valid Loss")
ax1.legend()
ax1.set_xticks(x)
ax2.set_title("Accuracy vs. Epochs")
ax2.plot(x, train_accuracies, label="Train Accuracy")
ax2.plot(x, valid_accuracies, label="Valid Accuracy")
ax2.set_xticks(x)
ax2.legend()

There seems to be some overfitting for sure based on the loss curves.  
But despite the dataset problems and the lack of a learning rate scheduler and hyperparameter search, we still get around 70% accuracy on the validation set, which is cool.

### TO DO LATER:
    - Make everything object oriented
    - Train a large resnet from scratch
    - Try different models
    - Plot curves while training itself so we can monitor easily and avoid useless runs
    - Other applications of this dataset.

### References:
    - https://www.kaggle.com/code/leifuer/intro-to-pytorch-loading-image-data
    - https://discuss.pytorch.org/t/how-to-split-dataset-into-test-and-validation-sets/33987/5
    - https://www.kaggle.com/code/shtrausslearning/pytorch-multiclass-image-classification#2-|-GET-THE-TRAINING-DATA
    - https://learnopencv.com/image-classification-using-transfer-learning-in-pytorch/