# Import Libraries

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams.update({'font.size': 12})
torch.random.manual_seed(1904)

%matplotlib inline
%load_ext autoreload
%autoreload 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

## Initialize SummaryWriter for Tensorboard

In [41]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')

## Import ViTImageProcessor from HuggingFace

In [None]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
processor

# CIFAR100 Data Preparation

## Transform and Format the data

In [18]:
# Transform the data using the ViTImageProcesor
# and format it to be feasable for Hugging Face framework

def transform_function(image):
    return processor(image, return_tensors="pt").pixel_values[0]

class CIFAR100Dataset(torchvision.datasets.CIFAR100):
    def __getitem__(self, index):
        image, label = super().__getitem__(index)
        item = {
            'pixel_values': transform_function(image),
            'labels': label
        }
        return item

## Get and split the data

In [None]:
from  torch.utils.data import random_split

batch_size = 4

train_set = CIFAR100Dataset(
    root='./data',
    train=True,
    download=True,
)

ds_train, ds_valid = random_split(train_set, [0.8, 0.2], generator=torch.Generator().manual_seed(42))

dl_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
    )

ds_test = CIFAR100Dataset(
    root='./data',
    train=False,
    download=True,
)

## Get unique labels

In [None]:
import pickle

labels_path = "data/cifar-100-python/meta"
with open(labels_path, 'rb') as f:
    classes = pickle.load(f)['fine_label_names']
    n_classes = len(classes)

n_classes

In [21]:
# Map labels
id2label = {id: label for id, label in enumerate(classes)}
label2id = {label: id for id, label in id2label.items()}

## Plot random train examples

In [None]:
def imshow(images):
    # Add the images to the tansorboard
    writer.add_image('batch_images', images)
    writer.close()
    
    image_np = images.permute(1, 2, 0).numpy()
    image_np = np.clip(image_np, 0, 1)
    
    plt.figure(figsize=(6, 6))
    plt.imshow(image_np)
    plt.axis('off')
    plt.show()

example_batch = next(iter(dl_train))
imshow(torchvision.utils.make_grid(example_batch['pixel_values']))
print([id2label[label.item()] for label in example_batch['labels']])

# Get the ViTForImageClassification model from HuggingFace

In [None]:
from transformers import ViTForImageClassification

model_name = 'google/vit-base-patch16-224-in21k'
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=n_classes,
    id2label=id2label,
    label2id=label2id
    )

model.to(DEVICE)

# Add model architecture to the tensorboard
writer.add_graph(model, example_batch['pixel_values'].to(DEVICE), use_strict_trace=False)
writer.close()

model

# Training

## Define custom metrics for model evaluation

In [10]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}

In [11]:
from copy import deepcopy
from transformers import TrainerCallback, Trainer, TrainingArguments

class CustomCallback(TrainerCallback):
    """
    Custom Callback that inherits from TrainerCallback.
    Used to make an evaluation on the training data also.
    """
    def __init__(self, trainer: Trainer) -> None:
        super().__init__()
        self._trainer = trainer
    
    def on_epoch_end(self, args, state, control, **kwargs):
        if control.should_evaluate:
            control_copy = deepcopy(control)
            self._trainer.evaluate(eval_dataset=self._trainer.train_dataset, metric_key_prefix="train")
            return control_copy

## Define training arguments

In [12]:
# Training arguments were taken from the internet

train_args = TrainingArguments(
    output_dir="output-models",
    save_total_limit=2,
    save_strategy="epoch",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    logging_dir="logs",
    remove_unused_columns=False,
)

## Run training to fine tune the model

In [14]:
# If you wish at the end of each epoch to view results
# on the training set also, change this flag to True.
# Note: doing so will result in a longer training time.

evaluate_train = False

In [None]:
trainer = Trainer(
    model,
    train_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    tokenizer=processor,
    compute_metrics=compute_metrics
)

if evaluate_train:
    trainer.add_callback(CustomCallback(trainer)) 

trainer.train()

# Model Evaluation

In [None]:
preds = trainer.predict(ds_test)
preds.metrics

## Add Precision-Recall curves to tensorboard

In [None]:
import torch.nn.functional as F

# Calculate probabilities from the predictions output
preds_probs = F.softmax(torch.Tensor(preds.predictions), dim=1)

def add_pr_curve_tensorboard(class_index, test_probs, test_label, global_step=0):
    """
    Takes in a "class_index" from 0 to 99 and plots the corresponding
    precision-recall curve in the tensorboard.
    """
    tensorboard_truth = test_label == class_index
    tensorboard_probs = test_probs[:, class_index]

    writer.add_pr_curve(
        tag=id2label[class_index],
        labels=tensorboard_truth,
        predictions=tensorboard_probs,
        global_step=global_step
        )
    writer.close()

for i in range(len(classes)):
    add_pr_curve_tensorboard(i, preds_probs, preds.label_ids)

## Save metrics

In [None]:
metrics = trainer.evaluate(ds_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

## Save best model and training logs

In [43]:
trainer.save_model()
trainer.save_state()