# Fine-tuning Vision Transformer (ViT) on a Pokémon Dataset

This notebook demonstrates how to fine-tune the `ViT-base-patch16-224` model on a custom Pokémon dataset, to classify Pokemons based on their type.

This Notebook uses `Hugging Face` and `PyTorch`.

**Dataset:** [JJMack/pokemon-classification-gen1-9](https://huggingface.co/datasets/JJMack/pokemon-classification-gen1-9)

**Model:** [ViT-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)


The code in this notebook is inspired by the article: <br>
"[Fine-tuning a Vision Transformer (ViT) Model With a Custom Dataset](https://medium.com/@imabhi1216/fine-tuning-a-vision-transformer-vit-model-with-a-custom-dataset-37840e4e9268)".
All image rights are reserved by Nintendo.

In [None]:
# Get information about the accelerators
import torch
import warnings
warnings.filterwarnings('ignore')
if torch.cuda.is_available():
    print("CUDA is available.")
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")

    for i in range(num_gpus):
        print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Memory Allocated: {torch.cuda.memory_allocated(i)} bytes")
        print(f"Memory Cached: {torch.cuda.memory_reserved(i)} bytes")
else:
    print("CUDA is not available.")

## Load Pokemon dataset

In [None]:
pip install datasets==2.20.0

In [None]:
from datasets import load_dataset
# downloads data from hugging face
pokemon_dataset = load_dataset("JJMack/pokemon-classification-gen1-9")

In [None]:
pokemon_dataset

In [None]:
pokemon_train_dataset = pokemon_dataset['train']
pokemon_validation_dataset = pokemon_dataset['validation']
pokemon_test_dataset = pokemon_dataset['test']

In [None]:
from collections import Counter

type_one_counter = Counter(pokemon_train_dataset['Type 1'])
print(f'Type 1: {len(list(type_one_counter.keys()))}')
print(type_one_counter)

In [None]:
# See Some examples
import matplotlib.pyplot as plt
import numpy as np

def show_samples(ds,rows,cols):
    samples = ds.shuffle().select(np.arange(rows*cols)) # selecting random images
    fig = plt.figure(figsize=(cols*4,rows*4))
    # plotting
    for i in range(rows*cols):
        img = samples[i]['image_data']
        label = samples[i]['label']
        name = samples[i]['name']
        generation = samples[i]['generation']
        type_one = samples[i]['Type 1']
        type_two = samples[i]['Type 2']
        shiny = samples[i]['shiny']
        fig.add_subplot(rows,cols,i+1)
        plt.imshow(img)
        plt.title(f'{name} ({generation}): {type_one}, {type_two}, {shiny}')
        plt.axis('off')


## Inspecting Dataset

In [None]:
from datasets import concatenate_datasets

unique_types = list(type_one_counter.keys())

for pokemon_type in unique_types:
    # Filter the dataset for the current type
    type_subset = pokemon_validation_dataset.filter(lambda x: x['Type 1'] == pokemon_type)
    show_samples(type_subset, rows=1, cols=5)


## Classify Pokemon by types

### Change Labels from Name to Type

In [None]:
from collections import Counter

type_one_labels = list(Counter(pokemon_train_dataset['Type 1']).keys())
type_one_labels.sort()

print(str(len(type_one_labels)) + ": " + str(type_one_labels))



label2id, id2label = dict(), dict()

for i, label in enumerate(type_one_labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

print('label2id')
print(label2id)

print('id2label')
print(id2label)

In [None]:
# Function to map Type 1 string labels to integer IDs
def map_type1_to_id(examples):
    # Ensure 'Type 1' column exists and is not None
    if 'Type 1' in examples and examples['Type 1'] is not None:
        # Use a list comprehension to map each Type 1 string to its ID
        examples["label"] = [label2id[type1] for type1 in examples["Type 1"]]
    else:
        # Handle cases where 'Type 1' might be missing (though unlikely for this dataset)
        examples["label"] = [-1] * len(examples['image_data']) # Assign a placeholder like -1

    return examples

# Apply the mapping to your datasets
pokemon_train_dataset_mapped = pokemon_train_dataset.map(map_type1_to_id, batched=True)
pokemon_validation_dataset_mapped = pokemon_validation_dataset.map(map_type1_to_id, batched=True)
pokemon_test_dataset_mapped = pokemon_test_dataset.map(map_type1_to_id, batched=True)


# Check the updated labels in the dataset (optional)
print("\nExample original labels:", pokemon_train_dataset[0]['Type 1'])
print(label2id[pokemon_train_dataset[0]['Type 1']])

print("\nExample original labels:", pokemon_train_dataset[0]['label'])
print("Example mapped labels:", pokemon_train_dataset_mapped[0]['label'])

pokemon_train_dataset = pokemon_train_dataset_mapped
pokemon_validation_dataset = pokemon_validation_dataset_mapped
pokemon_test_dataset = pokemon_test_dataset_mapped

### Init Transformer

In [None]:
from transformers import ViTImageProcessor

model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
processor

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    ToTensor,
    Resize,
)

# Get configurations from ViT processor
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

# Normalizes the image pixels by subtracting the mean and dividing by the std from the pretrained model configurations
normalize = Normalize(mean=image_mean, std=image_std)

# Compose: Combines a series of image transformations into one pipeline.
train_transforms = Compose(
    [
        RandomResizedCrop(size),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)
val_transforms = Compose(
    [
        Resize(size),
        CenterCrop(size),
        ToTensor(),
        normalize,
    ]
)
test_transforms = Compose(
    [
        Resize(size),
        CenterCrop(size),
        ToTensor(),
        normalize,
    ]
)

In [None]:
def apply_train_transforms(examples):
    examples["pixel_values"] = [train_transforms(image.convert("RGB")) for image in examples["image_data"]]
    return examples


def apply_val_transforms(examples):
    examples["pixel_values"] = [val_transforms(image.convert("RGB")) for image in examples["image_data"]]
    return examples


def apply_test_transforms(examples):
    examples["pixel_values"] = [val_transforms(image.convert("RGB")) for image in examples["image_data"]]
    return examples

pokemon_train_dataset.set_transform(apply_train_transforms)
pokemon_validation_dataset.set_transform(apply_val_transforms)
pokemon_test_dataset.set_transform(apply_test_transforms)

In [None]:
import torch
from torch.utils.data import DataLoader


def collate_fn(examples):
    # Stacks the pixel values of all examples into a single tensor and collects labels into a tensor
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Create a DataLoader for the training dataset, with custom collation and a batch size of 4
train_dl = DataLoader(pokemon_train_dataset, collate_fn=collate_fn, batch_size=4)

In [None]:
batch = next(iter(train_dl))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)

# Output
# pixel_values torch.Size([4, 3, 224, 224])
# labels torch.Size([4])

### Feed Transformer

In [None]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels = len(type_one_labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

In [None]:
from transformers import TrainingArguments, Trainer
import numpy as np

train_args = TrainingArguments(
    output_dir="output-models",
  per_device_train_batch_size=16,
  eval_strategy="steps",
  num_train_epochs=2,
  fp16=True,
  save_steps=10,
  eval_steps=10,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [None]:
trainer = Trainer(
    model,
    train_args,
    train_dataset=pokemon_train_dataset,
    eval_dataset=pokemon_validation_dataset,
    data_collator=collate_fn,
    tokenizer=processor,
)
trainer.train()

In [None]:
outputs = trainer.predict(pokemon_test_dataset)
print(outputs.metrics)

# Output
# {'test_loss': 0.25027137994766235,
# 'test_runtime': 1.3596,
# 'test_samples_per_second': 58.842,
# 'test_steps_per_second': 7.355}

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=type_one_labels)
disp.plot(xticks_rotation=45)

## Load Model

In [None]:
from transformers import ViTForImageClassification, ViTImageProcessor

checkpoint_path = "/content/MyDrive/MyDrive/checkpoint-1650"
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(checkpoint_path)
processor = ViTImageProcessor.from_pretrained(model_name)

## Testing Model

In [None]:
def predict_random_image(dataset, model, transforms, id2label):
    random_index = random.randint(0, len(dataset) - 1)

    image = dataset[random_index]['image_data']
    type_one = dataset[random_index]['Type 1']
    name = dataset[random_index]['name']

    plt.imshow(image)
    plt.axis('off')
    plt.title(f"{random_index} {name}: {type_one}")
    plt.show()

    processed_image = transforms(image.convert("RGB")).unsqueeze(0) # Fügen Sie eine Batch-Dimension hinzu

    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    processed_image = processed_image.to(device)
    model.to(device)

    with torch.no_grad():
        outputs = model(processed_image)
        logits = outputs.logits

    predicted_class_idx = logits.argmax(-1).item()

    predicted_class_name = id2label[str(predicted_class_idx)]

    print(f"Die vorhergesagte Klasse ist: {predicted_class_name}")

# Example usage:
# predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
from PIL import Image
import torch
image = pokemon_dataset['train'][10]['image_data']

plt.imshow(image)
plt.axis('off')
plt.show()
processed_image = val_transforms(image.convert("RGB")).unsqueeze(0) 

model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processed_image = processed_image.to(device)
model.to(device)

with torch.no_grad():
    outputs = model(processed_image)
    logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()

predicted_class_name = id2label[str(predicted_class_idx)]

print(f"Die vorhergesagte Klasse ist: {predicted_class_name}")

In [None]:
from PIL import Image
import torch
image = pokemon_dataset['train'][999]['image_data']

plt.imshow(image)
plt.axis('off')
plt.show()

processed_image = val_transforms(image.convert("RGB")).unsqueeze(0)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processed_image = processed_image.to(device)
model.to(device)

with torch.no_grad():
    outputs = model(processed_image)
    logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()

predicted_class_name = id2label[str(predicted_class_idx)]

print(f"Die vorhergesagte Klasse ist: {predicted_class_name}")

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
predict_random_image(pokemon_dataset['train'], model, val_transforms, id2label)

In [None]:
from PIL import Image
import torch
from matplotlib import pyplot as plt

image_path = '/content/drive/MyDrive/proxy-image.png' 
plt.imshow(image)
plt.axis('off')
plt.show()
image = Image.open(image_path)

processed_image = val_transforms(image.convert("RGB")).unsqueeze(0) 

model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processed_image = processed_image.to(device)
model.to(device)
with torch.no_grad():
    outputs = model(processed_image)
    logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()

predicted_class_name = id2label[str(predicted_class_idx)]

print(f"Die vorhergesagte Klasse ist: {predicted_class_name}")