<a href="https://colab.research.google.com/github/arielfikru/nekoclassification_trainer/blob/main/NekoClassification_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ## Login HF
#@markdown For faster Execution, Set your HF_TOKEN in Colab Secrets, Make sure Notebook access is On
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#@title ## Install Dependency
#@markdown You may need restart the Runtime after Run this

# !pip install -q huggingface_hub==0.22.0
!pip install -q datasets==2.19.0
# !pip install -q transformers==4.40.1
!pip install -q accelerate==0.29.3
!sudo apt -qq install git-lfs
!git config --global credential.helper store
exit()

In [None]:
#@title ## Config

# Model
#@markdown ## Model
#@markdown You can use Any Model that Supported by [AutoModelforImageClassification](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForImageClassification)
model_checkpoint = "microsoft/beit-base-patch16-224-pt22k-ft22k" #@param {type:"string"}
batch_size = 16 #@param {type:"number"}

# Dataset
#@markdown ## Dataset
#@markdown Download dataset file from repo_id
main_dir = "/content/" #@param {type:"string"}
dataset_to_download = "experiment.zip" #@param {type:"string"}
dataset_to_load = f"{main_dir}{dataset_to_download}"

repo_id = "NekoFi/class_experiment" #@param {type:"string"}
repo_type = "dataset" #@param ["dataset", "model"]

# Training
#@markdown ## Training
project_name = "protrait_classification" #@param {type:"string"}
learning_rate = 5e-5 #@param {type:"number"}
epochs = 4 #@param {type:"number"}

log_metric_at_end = False #@param {type:"boolean"}
push_to_hub = True #@param {type:"boolean"}

In [None]:
#@title ## Loading Dataset

from huggingface_hub import hf_hub_download
from datasets import load_dataset
from datasets import load_metric

hf_hub_download(repo_id=repo_id, filename=dataset_to_download, repo_type=repo_type, local_dir=main_dir)
dataset = load_dataset("imagefolder", data_files=dataset_to_load)
metric = load_metric("accuracy")

labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

# id2label[1]

In [None]:
#@title ## Preprocessing Dataset

from transformers import AutoImageProcessor

image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

# split up training into training + validation
splits = dataset["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

# train_ds[0]


In [None]:
#@title ## Training

from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import json
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True,
)

model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    project_name,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=push_to_hub,
)

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    labels = eval_pred.label_ids

    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    conf_matrix = confusion_matrix(labels, predictions)

    metrics = {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": conf_matrix.tolist()
    }

    # Save metrics to JSON
    with open('metrics.json', 'w') as f:
        json.dump(metrics, f, indent=4)

    # Plot and save confusion matrix
    plt.figure(figsize=(10, 7))
    sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=id2label.values(), yticklabels=id2label.values())
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1-Score: {f1}")

    return metrics

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

train_results = trainer.train()
trainer.save_model()

if log_metrics_at_end:
  trainer.log_metrics("train", train_results.metrics)

trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

if push_to_hub:
  trainer.push_to_hub()


In [None]:
# @title ## Inference Demo

from PIL import Image
import requests
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
import os
from tqdm import tqdm
from huggingface_hub import whoami
from google.colab import userdata

use_project_repo = True  # @param {type:"boolean"}
custom_repo_id = "Write your Repo ID here if use_project_repo False"  # @param {type:"string"}
input_path = "/content/test"  # @param {type:"string"}
verbose = False

def get_username_from_hf_token():
    hf_token = userdata.get('HF_TOKEN')

    if not hf_token:
        raise ValueError("HF_TOKEN not found in environment variables")

    user_info = whoami(token=hf_token)
    username = user_info['name']
    return username

def load_image(image_path):
    if image_path.startswith('http'):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    return image

def classify_image(image, model, processor):
    encoding = processor(image.convert("RGB"), return_tensors="pt")
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]

def rename_image(image_path, class_label):
    original_filename = os.path.basename(image_path)
    name, ext = os.path.splitext(original_filename)
    new_filename = f"{class_label}_{name}{ext}"
    new_path = os.path.join(os.path.dirname(image_path), new_filename)
    os.rename(image_path, new_path)
    # print(f"File renamed to: {new_path}")

def process_image_file(image_path, model, processor):
    image = load_image(image_path)
    class_label = classify_image(image, model, processor)
    # print(f"Predicted class for {image_path}: {class_label}")
    rename_image(image_path, class_label)

def count_files(image_path):
    file_count = 0
    if os.path.isdir(image_path):
        for _, _, files in os.walk(image_path):
            file_count += len(files)
    else:
        file_count = 1
    return file_count

def main(image_path, inference_repo_name):
    repo_name = inference_repo_name
    processor = AutoImageProcessor.from_pretrained(repo_name)
    model = AutoModelForImageClassification.from_pretrained(repo_name)

    file_count = count_files(image_path)
    progress_bar = tqdm(total=file_count, desc="Processing files", unit="file")

    if os.path.isdir(image_path):
        for root, _, files in os.walk(image_path):
            for file in files:
                file_path = os.path.join(root, file)
                process_image_file(file_path, model, processor)
                progress_bar.update(1)
    else:
        process_image_file(image_path, model, processor)
        progress_bar.update(1)

    progress_bar.close()

username = get_username_from_hf_token()

if use_project_repo:
    inference_repo_name = f"{username}/{project_name}"
else:
    inference_repo_name = custom_repo_id

image_path = input_path
main(image_path, inference_repo_name)
