In [None]:
from typing import Any

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from PIL import Image
from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ToTensor
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
    pipeline,
)

from .utils._logger import logger
from .utils._validation import config_args

In [None]:
# Custom class
class FoodDataset:
    """
    A custom dataset class for the food image classification task.
    """

    def __init__(self):
        """
        Initializes the dataset.
        """
        self.dataset = load_dataset(config_args.dataset_path)
        self.image_processor = AutoImageProcessor.from_pretrained(
            config_args.base_model, use_fast=True
        )

        try:
            self.labels = self.dataset["train"].features["label"].names
            label2id, id2label = dict(), dict()

            for i, label in enumerate(self.labels):
                label2id[label] = i
                id2label[i] = label

            self.label2id = label2id
            self.id2label = id2label

        except Exception as e:
            logger.error(f"Error processing labels: {e}")
            return None

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves a single item from the dataset.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            dict: A dictionary containing the image and label.
        """
        try:
            example = self.dataset[idx]
            image = example["image"]
            label = example["label"]

            pixel_values = self.image_processor(image, return_tensors="pt").pixel_values

            return {"pixel_values": pixel_values, "label": label}
        except Exception as e:
            logger.error(f"Error processing item at index {idx}: {e}")
            return None

In [None]:
food = FoodDataset()

In [None]:
# Data loaders
from torchvision.transforms.transforms import Normalize


def augment_images(example):
    """
    Augments images with defined values.

    Args:
        example: Use over dataset"s .with_transform() method.
    """
    try:
        normalize: Normalize = Normalize(
            mean=food.image_processor.image_mean, std=food.image_processor.image_std
        )

        size = (
            food.image_processor.size["shorted_edge"]
            if "shorted_edge" in food.image_processor.size
            else (
                food.image_processor.size["height"],
                food.image_processor.size["width"],
            )
        )

        # Data augmentation
        _transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

        example["pixel_values"] = [
            _transforms(img.convert("RGB")) for img in example["image"]
        ]  # Similar to input_ids in NLP
        del example["image"]

        return example
    except Exception as e:
        logger.error(f"Transform images failed: {e}")
        raise

In [None]:
food_augment = food.dataset.with_transform(augment_images)

In [None]:
# Evaluation metrics
def compute_metrics(eval_pred) -> dict:
    """
    Computes evaluation metrics for the image classification task.

    Args:
        eval_pred: The prediction and reference from the Trainer.

    Returns:
        Dict: A dictionary containing the evaluation metrics.
    """
    try:
        predictions, references = eval_pred
        predictions = np.argmax(predictions, axis=1)

        return evaluate.load("accuracy").compute(
            predictions=predictions, references=references
        )
    except Exception as e:
        logger.error(f"Error computing metrics: {e}")
        return {}

In [None]:
# Model
def load_model(model_name: str) -> AutoModelForImageClassification:
    """
    Loads the pre-trained image classification model.

    Args:
        model_name (str): The name of the pre-trained model.

    Returns:
        AutoModelForImageClassification: The pre-trained model.
    """
    try:
        model = AutoModelForImageClassification.from_pretrained(
            model_name,
            num_labels=len(food.labels),
            id2label=food.id2label,
            label2id=food.label2id,
        ).to("cuda" if torch.cuda.is_available() else "cpu")
        return model
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        raise

In [None]:
def set_training_arguments(output_dir: str) -> TrainingArguments:
    """
    Sets the TrainingArguments object.

    Args:
        output_dir (str): The directory to save the trained model.

    Returns:
        TrainingArguments: The TrainingArguments object.
    """
    try:
        args = TrainingArguments(
            output_dir=output_dir,
            remove_unused_columns=False,
            eval_strategy="epoch",
            save_strategy="epoch",
            learning_rate=5e-5,
            per_device_train_batch_size=16,
            per_device_eval_batch_size=16,
            gradient_accumulation_steps=4,
            # num_train_epochs=2,
            max_steps=20,
            load_best_model_at_end=True,
            metric_for_best_model="accuracy",
            report_to="none",
        )
        return args
    except Exception as e:
        logger.error(f"Error setting training arguments: {e}")
        raise

In [None]:
# Trainer
def create_model(
    model: AutoModelForImageClassification, args: TrainingArguments
) -> Trainer:
    """
    Trains the image classification model.

    Args:
        model (AutoModelForImageClassification): The pre-trained model.
        args (TrainingArguments): The TrainingArguments object.

    Returns:
        Trainer: The trained Trainer object.
    """
    try:
        trainer = Trainer(
            model=model,
            args=args,
            train_dataset=food_augment["train"],
            eval_dataset=food_augment["test"],
            tokenizer=food.image_processor,
            compute_metrics=compute_metrics,
        )
        return trainer
    except Exception as e:
        logger.error(f"Error training model: {e}")
        raise

In [None]:
def train_model(save_name: str):
    """
    Orchestrates the food classification training process.
    """
    try:
        model = load_model(config_args.base_model)
        args = set_training_arguments(config_args.output_path)

        trainer = create_model(model, args)
        trainer.train()

        trainer.save_model(config_args.output_path + "\\" + save_name)
        logger.info(f"Training completed. Model saved to {config_args.output_path}")

    except Exception as e:
        logger.error(f"An error occurred: {e}")

In [None]:
def predict(image_path: str, load_name: str) -> list[dict[str, Any]]:
    """
    Returns image predictions.

    Args:
        image_path: Path of custom image.
        load_name: Name of the saved model.
    """
    pipe = pipeline(
        "image-classification",
        config_args.output_path + "\\" + load_name,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    )
    image = Image.open(image_path)
    return pipe(image)