In [None]:
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
import evaluate
import numpy as np
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

In [None]:
# 1. Load dataset
dataset = load_dataset(path="Eramia_dataset/Eramia_classification")

In [None]:
# 2. Pick a ResNet model from Hugging Face
model_name = "microsoft/resnet-50"
processor = AutoImageProcessor.from_pretrained(model_name)

In [None]:
# 3. Define preprocessing (resize + normalization according to ResNet training)
transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=processor.image_mean, std=processor.image_std)
])

def transform_examples(examples):
    examples["pixel_values"] = [transform(img.convert("RGB")) for img in examples["image"]]
    return examples

dataset = dataset.with_transform(transform_examples)

In [None]:
# 4. Load model (adjust label mapping)
labels = dataset["train"].features["label"].names
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    ignore_mismatched_sizes=True,
    id2label={i: label for i, label in enumerate(labels)},
    label2id={label: i for i, label in enumerate(labels)},
)

In [None]:
# 5. Metrics (Accuracy, Precision, Recall, F1)
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels_ids = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_metric.compute(predictions=preds, references=labels_ids)["accuracy"],
        "precision": precision_metric.compute(predictions=preds, references=labels_ids, average="macro")["precision"],
        "recall": recall_metric.compute(predictions=preds, references=labels_ids, average="macro")["recall"],
        "f1": f1_metric.compute(predictions=preds, references=labels_ids, average="macro")["f1"],
        "precision_weighted": precision_metric.compute(predictions=preds, references=labels_ids, average="weighted")["precision"],
        "recall_weighted": recall_metric.compute(predictions=preds, references=labels_ids, average="weighted")["recall"],
        "f1_weighted": f1_metric.compute(predictions=preds, references=labels_ids, average="weighted")["f1"],
    }

In [None]:
# 6. Data collator (batch preparation)
def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
# 7. Training arguments
training_args = TrainingArguments(
    output_dir="Eramia_classification/model_checkpoint",
    remove_unused_columns=False,
    do_eval=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=100,
    per_device_eval_batch_size=100,
    num_train_epochs=15,
    weight_decay=0.01,
    logging_dir="./runs/Eramia_classification/5",
    logging_first_step=True,
    logging_steps=100,
    push_to_hub=False,
    report_to="tensorboard",
    save_total_limit=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    greater_is_better=True,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=4,
    dataloader_persistent_workers=True,
)

In [None]:
# 8. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    processing_class=processor,
)

In [None]:
# 9. Train
trainer.train()