# Settings

In [None]:
DATASET_NAME = "Vampyrian/all-images"
OUTPUT_MODEL_NAME = "Vampyrian/all-images-model"

In [None]:
CHECKPOINT = 'google/vit-base-patch16-224-in21k'
# CHECKPOINT = 'google/vit-large-patch32-224-in21k'

# Login to huggin face

In [None]:
from dotenv import load_dotenv
import os

load_dotenv()

hf_token = os.getenv('HF_TOKEN')

In [None]:
from huggingface_hub import login
login(token=hf_token)

# Loading dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME)

In [None]:
dataset

# Check if all image is correct

In [None]:
from PIL import Image, ImageFile, UnidentifiedImageError

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

def is_valid_image(example):
    try:
        # Check if the input is already an image object
        if isinstance(example["image"], Image.Image):
            img = example["image"]  # It's already an image object
        else:
            # Otherwise, open the image from the path
            with Image.open(example["image"]) as img:
                img.verify()  # Verify it is a valid image

        return True  # If all goes well, return True (valid image)
    except Exception as e:
        # Catch exceptions such as invalid image formats
        return False

In [None]:
len(dataset["train"])

In [None]:
dataset["train"] = dataset["train"].filter(lambda row : is_valid_image(row))

In [None]:
train_test_split = dataset["train"].train_test_split(test_size=0.1)

In [None]:
train_test_split

In [None]:
train_test_split["train"][0]

In [None]:
labels = train_test_split["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
id2label[str(2)]

In [None]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained(CHECKPOINT, use_fast=True)

In [None]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [None]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [None]:
train_test_split = train_test_split.with_transform(transforms)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [None]:
import evaluate
accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    CHECKPOINT,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_MODEL_NAME,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_test_split["train"],
    eval_dataset=train_test_split["test"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

# Test on my own image

In [None]:
OUTPUT_MODEL_NAME

In [None]:
from transformers import pipeline

classifier = pipeline("image-classification", model=OUTPUT_MODEL_NAME)

In [None]:
from PIL import Image
import requests
from io import BytesIO

In [None]:
image_url = "https://kainoteka-public.s3.eu-central-1.amazonaws.com/products/b43526e7-0ad8-4b2b-a643-90f9216d7986/89ab8846-4eb6-434f-a18b-a3e11a69b4fb-md.webp"

response = requests.get(image_url)
response.raise_for_status()  # Ensure the request was successful

# Open the image with PIL
image = Image.open(BytesIO(response.content))


In [None]:
classifier(image)