# Environment Configuration

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from huggingface_hub import notebook_login
from datasets import load_dataset, load_metric
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

In [None]:

notebook_login()

# Data Collection

In [None]:
dataset_path = "../dataset"

datasets = load_dataset("imagefolder", data_dir=dataset_path)


datasets

In [None]:
print(datasets["train"][0])

datasets["train"][0]['image']

# Preprocessing

Preprocessing images typically comes down to (1) resizing them to a particular size (2) normalizing the color channels (R,G,B) using a mean and standard deviation. These are referred to as image transformations.

In addition, one typically performs what is called data augmentation during training (like random cropping and flipping) to make the model more robust and achieve higher accuracy. Data augmentation is also a great technique to increase the size of the training data.

In [None]:
label2id, id2label = dict(), dict()
for id, label in enumerate(datasets["train"].features["label"].names):
    label2id[label] = id
    id2label[id] = label

id2label

In [None]:
model_id = "microsoft/swin-tiny-patch4-window7-224"
image_processor  = AutoImageProcessor.from_pretrained(model_id)
image_processor

In [None]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

eval_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply trainuing transformation across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_eval(example_batch):
    """Apply evaluation transformations across a batch."""
    example_batch["pixel_values"] = [eval_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
datasets["train"].set_transform(preprocess_train)
datasets["validation"].set_transform(preprocess_eval)
datasets["test"].set_transform(preprocess_eval)

datasets["test"]

# Training

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_id, 
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint,
    trust_remote_code=True
)

In [None]:
accuracy_metric = load_metric("accuracy", trust_remote_code=True)

def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=eval_pred.label_ids)

def data_collator(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

model_name = model_id.split("/")[-1]
batch_size = 32
gradient_accumulation_steps = 4
num_train_epochs = 5
learning_rate = 5e-5
warmup_ratio = 0.1
logging_steps = 10
metric_for_best_model = "accuracy"

training_arguments = TrainingArguments(
    f"../checkpoints/{model_name}",
    num_train_epochs=num_train_epochs,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    warmup_ratio=warmup_ratio,
    #logging_steps=logging_steps,
    load_best_model_at_end=True,
    metric_for_best_model=metric_for_best_model,
    push_to_hub=False,
    remove_unused_columns=False,
)

trainer = Trainer(
    model,
    training_arguments,
    train_dataset=datasets["train"],
    eval_dataset=datasets["validation"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [None]:
train_results = trainer.train()

# Evaluation

In [None]:
metric_names = ["eval_accuracy"]
validation_evalation_results = trainer.evaluate(datasets["validation"])
test_evalation_results = trainer.evaluate(datasets["test"])
evaluatation_results = {
    "validation": {metric_name: round(validation_evalation_results[metric_name] * 100, 2) for metric_name in metric_names},
    "test": {metric_name: round(test_evalation_results[metric_name] * 100, 2) for metric_name in metric_names}
}

evaluatation_results

# Saving

In [None]:
model_path = f"../models/{model_name}"
trainer.save_model(model_path)