# Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning

This repository contains the code for an end-to-end deep learning project to classify 101 food categories from the challenging Food-101 dataset. The project demonstrates a systematic approach to model selection, fine-tuning, and hyperparameter optimization, achieving a final validation accuracy of **85.4%** on the full dataset.

The entire training and evaluation pipeline is built using modern, reproducible practices with PyTorch Lightning.

## 1. Imports

In [None]:
import torch
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import pytorch_lightning as pl
import torch.optim.lr_scheduler as lr_scheduler
import torchvision

from torchmetrics.functional import accuracy
from torchvision import transforms, datasets
from torchinfo import summary
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torch import nn
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split

## 2. Quick inspection of the top model

In [None]:
# Here we inspect the models classifier layer to match the number of classes in Food101

weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
model = torchvision.models.efficientnet_v2_s(weights=weights)
effnet_v2_s_transforms = weights.transforms()

print(model.classifier)

In [None]:
# Inspect the model

summary(model=model,
        input_size=(1, 3, 224, 224),
        col_names=['input_size', 'output_size', 'num_params', 'trainable'],
        col_width=20,
        row_settings=['var_names'])

In [None]:
# This will be the base transforms for training 

effnet_v2_s_transforms = weights.transforms()
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.TrivialAugmentWide(),
    effnet_v2_s_transforms])

train_transforms

## 3. Dataset and Torch lightning Datamodule Classes

In [None]:
from torchvision import datasets
from pathlib import Path
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torchvision import transforms as T
import numpy as np
import torchvision
from torchvision.datasets import Food101
from torch.utils.data import DataLoader, Dataset
from typing import Dict, Tuple, Any
import random


def get_model_components(
    model_name: str, 
    return_classifier: bool = False, 
    augmentation_level: str = "default"
) -> Dict[str, Any]:
    """
    Retrieves pre-trained model components from torchvision.

    This function fetches the appropriate weights and transforms for a given
    model. It supports different levels of training data augmentation.

    Args:
        model_name (str): The name of the model to get components for.
            Supported models include "EfficientNet_V2_S" and "EfficientNet_B2".
        return_classifier (bool, optional): If True, the model's classifier
            head is also returned. Defaults to False.
        augmentation_level (str, optional): The level of data augmentation to use
            for the training set. Can be "default" or "strong". 
            Defaults to "default".

    Returns:
        Dict[str, Any]: A dictionary containing the requested components.
            Always includes 'train_transforms' and 'val_transforms'.
            Includes 'classifier' if return_classifier is True.
            
    Raises:
        ValueError: If model_name or augmentation_level is not supported.
    """
    model_registry = {
        "EfficientNet_V2_S": (
            torchvision.models.efficientnet_v2_s,
            torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
        ),
        "EfficientNet_B2": (
            torchvision.models.efficientnet_b2,
            torchvision.models.EfficientNet_B2_Weights.DEFAULT
        )
    }

    if model_name not in model_registry:
        raise ValueError(f"Model '{model_name}' is not supported. "
                         f"Supported models are: {list(model_registry.keys())}")

    # 1. Look up the model and weights classes
    model_class, weights_class = model_registry[model_name]
    weights = weights_class
    val_transforms = weights.transforms()

    # 2. Create the training transforms based on the desired level
    if augmentation_level == "default":
        train_transforms = T.Compose([
            T.TrivialAugmentWide(),
            val_transforms  # val_transforms includes ToTensor and Normalize
        ])
    elif augmentation_level == "strong":
        # Note: We don't need to add ToTensor() or Normalize() here because
        # they are already included inside the 'val_transforms' pipeline.
        train_transforms = T.Compose([
            T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),
            T.RandomHorizontalFlip(p=0.5),
            T.RandAugment(num_ops=2, magnitude=9),
            # RandomErasing should be applied to a tensor, so we apply it after
            # val_transforms, which handles the PIL -> Tensor conversion.
            val_transforms, 
            T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
        ])
    else:
        raise ValueError(f"Augmentation level '{augmentation_level}' is not supported. "
                         f"Choose from 'default' or 'strong'.")
    
    # 3. Prepare the dictionary to be returned
    components = {
        "train_transforms": train_transforms,
        "val_transforms": val_transforms
    }

    # 4. Optionally, instantiate the model to get the classifier
    if return_classifier:
        model = model_class(weights=weights)
        components["classifier"] = model.classifier

    return components
        
class CustomFood101(Dataset):
    """A PyTorch Dataset for Food101 with conditional downloading and subset support.

    This class wraps the torchvision Food101 dataset. It only downloads the data
    if the specified directory doesn't already exist. It can also create a
    reproducible, shuffled subset of the data for faster experimentation.

    Args:
        split (str): The dataset split, either "train" or "test".
        transform (callable, optional): A function/transform to apply to the images.
        data_dir (str, optional): The directory to store the data. Defaults to "data".
        subset_fraction (float, optional): The fraction of the dataset to use.
            Defaults to 1.0 (using the full dataset).
    """

    def __init__(self, split, transform=None, data_dir="data", subset_fraction: float = 0.1):
        # Check if the dataset already exists before setting the download flag.
        dataset_path = os.path.join(data_dir, "food-101")
        should_download = not os.path.isdir(dataset_path)

        # 1. Load the full dataset metadata with the conditional flag
        self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)
        self.classes = self.full_dataset.classes

        # 2. Create a reproducible subset of indices
        if subset_fraction < 1.0:
            num_samples = int(len(self.full_dataset) * subset_fraction)
            all_indices = list(range(len(self.full_dataset)))
            # Shuffle with a fixed seed for reproducibility
            random.Random(42).shuffle(all_indices)
            self.indices = all_indices[:num_samples]
        else:
            self.indices = list(range(len(self.full_dataset)))

    def __len__(self):
        """Returns the total number of samples in the subset."""
        return len(self.indices)

    def __getitem__(self, idx):
        """
        Fetches the sample for the given subset index and applies the transform.
        """
        # Map the subset index to the actual index in the full dataset
        original_idx = self.indices[idx]
        image, label = self.full_dataset[original_idx]
        return image, label

class Food101DataModule(pl.LightningDataModule):
    """A PyTorch Lightning DataModule for the Food101 dataset.

    This module encapsulates all data-related logic, including downloading,
    processing, and creating DataLoaders for the training, validation, and
    test sets. It uses the CustomFood101 dataset internally and allows for
    controlling the fraction of data used in the training and validation splits.

    Args:
        data_dir (str, optional): Root directory for the data. Defaults to "data".
        batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.
        num_workers (int, optional): Number of workers for data loading. Defaults to 2.
        train_transforms (callable, optional): Transformations for the training set.
        val_transforms (callable, optional): Transformations for the validation/test set.
        subset_fraction (float, optional): The fraction of data to use for training
            and validation. Defaults to 1.0.
    """
    def __init__(self, data_dir="data", batch_size=32, num_workers=2,
                 train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.subset_fraction = subset_fraction

        self.classes = []

    def prepare_data(self):
        """Downloads data if needed."""
        CustomFood101(split='train', data_dir=self.data_dir)
        CustomFood101(split='test', data_dir=self.data_dir)

    def setup(self, stage=None):
        """Assigns datasets, passing the subset_fraction."""
        if stage == 'fit' or stage is None:
            self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,
                                               data_dir=self.data_dir, subset_fraction=self.subset_fraction)
            self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,
                                             data_dir=self.data_dir, subset_fraction=self.subset_fraction)
            self.classes = self.train_dataset.classes

        if stage == 'test' or stage is None:
            self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,
                                              data_dir=self.data_dir, subset_fraction=1.0) # Use full test set
            if not self.classes:
                self.classes = self.test_dataset.classes

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)


In [None]:
# Define configuration for the script
DATA_DIR = "data"
MODEL_NAME = "EfficientNet_V2_S"
BATCH_SIZE = 32

print(f"Running data preparation script for model: {MODEL_NAME}")

# 1. Get model-specific transforms
components = get_model_components(MODEL_NAME)
train_transforms = components["train_transforms"]
val_transforms = components["val_transforms"]

# 2. Instantiate the DataModule
datamodule = Food101DataModule(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    train_transforms=train_transforms,
    val_transforms=val_transforms,
    subset_fraction=0.1  # Use a small subset for quick verification
)

# 3. Trigger download and setup
datamodule.prepare_data()
datamodule.setup(stage='fit')

# 4. (Optional) Verification Step
print("\n--- Verifying Dataloader ---")
# Get one batch from the training dataloader
train_dl = datamodule.train_dataloader()
images, labels = next(iter(train_dl))

print(f"Number of classes: {len(datamodule.classes)}")
print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")
print("--- Verification Complete ---")    

## 4. Model Classes

In [None]:
import torch
import torchvision
import pytorch_lightning as pl
from torch import nn
from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

class EffNetV2_S(pl.LightningModule):
    """A PyTorch Lightning Module for fine-tuning EfficientNetV2-S.

    This module encapsulates the EfficientNetV2-S model and provides a flexible
    fine-tuning strategy. It can be configured for Stage 1 (training only the
    classifier and later feature blocks) or Stage 2 (training the entire model).

    Args:
        lr (float, optional): The learning rate. Defaults to 1e-3.
        weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
        num_classes (int, optional): The number of output classes. Defaults to 101.
        class_names (list, optional): A list of class names for logging. Defaults to None.
        freeze_features (bool, optional): If True, freezes the backbone and unfreezes
            only the later blocks (Stage 1). If False, all features are trainable
            (Stage 2). Defaults to True.
        unfreeze_from_block (int, optional): Which feature block to start unfreezing
            from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
    """
    
    def __init__(
        self,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        num_classes: int = 101,
        class_names: list = None,
        freeze_features: bool = True,         # True = Stage 1, False = Stage 2
        unfreeze_from_block: int = -3          # Only used if freeze_features=True
    ):
        super().__init__()
        self.save_hyperparameters()
        self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]

        # Load pretrained weights
        weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
        self.model = torchvision.models.efficientnet_v2_s(weights=weights)

        # ---- Freezing strategy ----
        if freeze_features:
            # Freeze all first
            for param in self.model.parameters():
                param.requires_grad = False
            # Unfreeze from a specific block (default: last 3 blocks)
            for param in self.model.features[unfreeze_from_block:].parameters():
                param.requires_grad = True
        else:
            # Stage 2: unfreeze everything
            for param in self.model.parameters():
                param.requires_grad = True

        # Classifier head
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True)
        )

        # Loss & metrics
        self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
        self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
        self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
        self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
        self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.train_accuracy(logits, y)
        self.train_f1(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.val_accuracy(logits, y)
        self.val_f1(logits, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_accuracy, prog_bar=True)
        self.log('val_f1', self.val_f1, prog_bar=True)
        self.val_conf_matrix.update(logits, y)

    def on_validation_epoch_end(self):
        cm = self.val_conf_matrix.compute()
        per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)
        print("\n--- Per-Class Validation Accuracy ---")
        for i, acc in enumerate(per_class_acc):
            self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True)
            print(f"{self.class_names[i]:<20}: {acc.item():.4f}")
        print("------------------------------------")
        self.val_conf_matrix.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        self.test_conf_matrix.update(logits, y)

    def on_test_end(self):
        cm = self.test_conf_matrix.compute()
        print("\nGenerating final confusion matrix plot...")
        self.test_conf_matrix.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )
        return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}}
    
class EffNetb2(pl.LightningModule):
    """A PyTorch Lightning Module for fine-tuning EfficientNet-B2.

    This module encapsulates the EfficientNet-B2 model and provides a flexible
    fine-tuning strategy. It can be configured for Stage 1 (training only the
    classifier and later feature blocks) or Stage 2 (training the entire model).

    Args:
        lr (float, optional): The learning rate. Defaults to 1e-3.
        weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
        num_classes (int, optional): The number of output classes. Defaults to 101.
        class_names (list, optional): A list of class names for logging. Defaults to None.
        freeze_features (bool, optional): If True, freezes the backbone and unfreezes
            only the later blocks (Stage 1). If False, all features are trainable
            (Stage 2). Defaults to True.
        unfreeze_from_block (int, optional): Which feature block to start unfreezing
            from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
    """

    def __init__(
        self,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
        num_classes: int = 101,
        class_names: list = None,
        freeze_features: bool = True,
        unfreeze_from_block: int = -3
    ):
        super().__init__()
        self.save_hyperparameters()
        self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)]

        # Model setup
        weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
        self.model = torchvision.models.efficientnet_b2(weights=weights)
        
        # --- : Flexible Freezing Strategy ---
        if self.hparams.freeze_features:
            # Stage 1: Freeze all first
            for param in self.model.parameters():
                param.requires_grad = False
            # Unfreeze from a specific block (default: last 3 blocks)
            for param in self.model.features[self.hparams.unfreeze_from_block:].parameters():
                param.requires_grad = True
        else:
            # Stage 2: unfreeze everything
            for param in self.model.parameters():
                param.requires_grad = True

        # Classifier head
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.3, inplace=True),
            nn.Linear(in_features=1408, out_features=self.hparams.num_classes)
        )

        # Metrics
        self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
        self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
        self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
        self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
        self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.train_accuracy(logits, y)
        self.train_f1(logits, y)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.val_accuracy(logits, y)
        self.val_f1(logits, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_accuracy, prog_bar=True)
        self.log('val_f1', self.val_f1, prog_bar=True)
        self.val_conf_matrix.update(logits, y)

    def on_validation_epoch_end(self):
        cm = self.val_conf_matrix.compute()

        # Add a small epsilon (1e-6) to the denominator for numerical stability.
        per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)

        print("\n--- Per-Class Validation Accuracy ---")
        for i, acc in enumerate(per_class_acc):
            class_name = self.class_names[i]
            self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True)
            print(f"{class_name:<20}: {acc.item():.4f}")
        print("------------------------------------")

        self.val_conf_matrix.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        self.test_conf_matrix.update(logits, y)

    def on_test_end(self):
        cm = self.test_conf_matrix.compute()
        print("\nGenerating final confusion matrix plot...")
        # Assuming plot_confusion_matrix is defined elsewhere
        # plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names)
        self.test_conf_matrix.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
            },
        }


## 5. Training and plotting the Confusion Matrix

In [None]:
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint
from typing import Optional
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from typing import List

DATA_DIR = "data"
MODEL_NAME = "EfficientNet_V2_S"
BATCH_SIZE = 32
SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing
CHECKPOINT_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt"  # Path to your trained model checkpoint

def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):
    """
    Creates and saves a multi-class confusion matrix plot.

    This function normalizes the confusion matrix to show prediction
    percentages for each class, visualizes it as a heatmap, and saves
    the resulting figure to a file.

    Args:
        cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.
        class_names (List[str]): A list of class names for the labels.
        figsize (tuple, optional): The size of the figure. Defaults to (25, 25).
    """
    # 1. Normalize the confusion matrix to show percentages
    # Add a small epsilon to prevent division by zero
    cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)

    # 2. Create a DataFrame for a beautiful plot with labels
    df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)

    # 3. Create the plot
    plt.figure(figsize=figsize)
    heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes

    # 4. Format the plot
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)

    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Normalized Confusion Matrix')
    plt.tight_layout()

    # 5. Save the figure and show the plot
    plt.savefig('confusion_matrix.png', dpi=300)
    print("Confusion matrix plot saved to confusion_matrix.png")
    plt.show()

def run_training_session(
    model_name: str = "EfficientNet_V2_S",
    batch_size: int = 32,
    data_dir: str = 'data',
    subset_fraction: float = 1.0,
    checkpoint_path: str = "checkpoints/",
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    freeze_features: bool = True,
    early_stopping_patience: int = 5,
    max_epochs: int = 100,
    accelerator: str = 'auto',
    resume_from_checkpoint: Optional[str] = None
) -> Trainer:
    """
    Sets up and runs a complete training session for a specified model.

    This function handles the entire pipeline: data preparation, model
    instantiation, logger and callback setup, and trainer execution.

    Args:
        model_name (str): The name of the model architecture to train.
        batch_size (int): The number of samples per batch.
        data_dir (str): The root directory for the dataset.
        subset_fraction (float): The fraction of the dataset to use for training.
        checkpoint_path (str): Directory to save model checkpoints.
        lr (float): The learning rate for the optimizer.
        weight_decay (float): The weight decay for the optimizer.
        freeze_features (bool): Flag to control the fine-tuning strategy
            (e.g., for two-stage training).
        early_stopping_patience (int): Number of epochs with no improvement
            after which training will be stopped.
        max_epochs (int): The maximum number of epochs to train for.
        accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').
        resume_from_checkpoint (Optional[str]): Path to a checkpoint file to
            resume training from. Defaults to None.

    Returns:
        Trainer: The PyTorch Lightning Trainer object after fitting is complete.
    """
    # A registry to map model names to their actual classes
    model_class_registry = {
        "EfficientNet_V2_S": EffNetV2_S,
        "EfficientNet_B2": EffNetb2,
    }
    if model_name not in model_class_registry:
        raise ValueError(f"Model '{model_name}' is not a recognized class.")

    # Get model-specific transforms
    components = get_model_components(model_name)
    train_transforms = components["train_transforms"]
    val_transforms = components["val_transforms"]

    # Set up the DataModule
    food_datamodule = Food101DataModule(
        data_dir=data_dir,
        batch_size=batch_size,
        train_transforms=train_transforms,
        val_transforms=val_transforms,
        subset_fraction=subset_fraction
    )
    food_datamodule.prepare_data()
    food_datamodule.setup()

    # Instantiate the model dynamically
    model_class = model_class_registry[model_name]
    model = model_class(
        num_classes=len(food_datamodule.classes),
        class_names=food_datamodule.classes,
        lr=lr,
        weight_decay=weight_decay,
        freeze_features=freeze_features
    )

    # Set up logger and callbacks
    logger = CSVLogger(save_dir="logs/", name=model_name)
    
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        patience=early_stopping_patience,
        mode="min"
    )
    best_model_checkpoint = ModelCheckpoint(
        dirpath=checkpoint_path,
        filename="best-model-{epoch:02d}-{val_acc:.4f}",
        save_top_k=1,
        monitor="val_acc",
        mode="max"
    )
    
    callbacks = [early_stop_callback, best_model_checkpoint]

    # Instantiate the Trainer
    trainer = Trainer(
        max_epochs=max_epochs,
        accelerator=accelerator,
        callbacks=callbacks,
        logger=logger,
    )

    # Start training
    trainer.fit(
        model,
        datamodule=food_datamodule,
        ckpt_path=resume_from_checkpoint 
    )
    
    return trainer


In [None]:
# --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---
config = {
    "model_name": "EfficientNet_V2_S",
    "batch_size": 32,
    "lr": 1e-4,
    "epochs": 50,
    "subset_fraction": 1.0,  # Use 1.0 for the full dataset
    "freeze_features": True,
    "early_stopping_patience": 10
}

# --- 2. PRINT CONFIGURATION AND START TRAINING ---
print("--- Starting Training Session ---")
for key, value in config.items():
    print(f"  {key}: {value}")
print("---------------------------------")

run_training_session(
    model_name=config["model_name"],
    batch_size=config["batch_size"],
    lr=config["lr"],
    max_epochs=config["epochs"],
    subset_fraction=config["subset_fraction"],
    freeze_features=config["freeze_features"],
    early_stopping_patience=config["early_stopping_patience"]
)

print("\n--- Training Session Complete ---")

print("\n--- Starting Evaluation on Test Set ---")

print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")

# Step 1: Set up the DataModule for the test set
components = get_model_components(MODEL_NAME)
val_transforms = components["val_transforms"]

datamodule = Food101DataModule(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    val_transforms=val_transforms
)
# This prepares the test dataloader specifically
datamodule.setup(stage='test')

# Step 2: Load the trained model from the checkpoint file
model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)
model.class_names = datamodule.classes
model.eval() # Set the model to evaluation mode

# Step 3: Create a Trainer and run the test
trainer = pl.Trainer(accelerator='auto')

# This call will run the test_step and automatically trigger the 
# on_test_end hook in your model, which generates the plot.
trainer.test(model, datamodule=datamodule)

print("\nEvaluation complete. The confusion matrix plot has been saved.")

## 6. Local Gradio Demo

In [None]:
FOOD101_CLASSES = [
    'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 
    'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 
    'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 
    'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 
    'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 
    'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 
    'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 
    'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 
    'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 
    'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 
    'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 
    'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 
    'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 
    'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 
    'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 
    'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 
    'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 
    'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 
    'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 
    'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'
]

In [None]:
import gradio as gr
import torch
from gradio.themes.base import Base
from torchvision.datasets import Food101

# --- 1. Configuration ---
MODEL_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt" 
MODEL_NAME = "EfficientNet_V2_S"

theme = gr.themes.Soft(
    primary_hue="orange",
    secondary_hue="blue",
).set(

    body_background_fill="#f2f2f2"
)

# --- 2. Load Model and Assets ---
print("Loading model and assets...")
model = EffNetV2_S.load_from_checkpoint(MODEL_PATH)
model.eval()

components = get_model_components(MODEL_NAME)
transforms = components["val_transforms"]
class_names = FOOD101_CLASSES 

print("Model and assets loaded successfully.")

# --- 3. Prediction Function ---
def predict(image):
    """
    Takes a PIL image, preprocesses it, and returns the model's top 3 predictions.
    """
    # 1. Preprocess the image and add a batch dimension
    input_tensor = transforms(image).unsqueeze(0)
    
    # 2. Move the input tensor to the same device as the model
    # This ensures both the model and the data are on the GPU.
    device = next(model.parameters()).device
    input_tensor = input_tensor.to(device)
    
    # 3. Make a prediction
    with torch.no_grad():
        output = model(input_tensor)
        
    # 4. Post-process the output
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
    
    return confidences
    

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload a Food Image"),
    outputs=gr.Label(num_top_classes=3, label="Top Predictions"),
    theme=theme,
    
    # UI Enhancements
    title="🍔 Food-101 Image Classifier 🍟",
    description=(
        "What's on your plate? Upload an image or try one of the examples below to classify it. "
        "This demo uses an EfficientNetV2-S model fine-tuned on the Food-101 dataset."
    ),
    article=(
        "<p style='text-align: center;'>A project by Daniel Kiani. "
        "<a href='https://github.com/Deathshot78/Food101-Classification' target='_blank'>Check out the code on GitHub!</a></p>"
    ),
    examples=[
        ["assets/ramen.jpg"],
        ["assets/pizza.jpg"],
        ["assets/oysters.jpg"],
        ["assets/onion_rings.jpg"]
    ]
)

In [None]:
# Launch the Gradio app locally
demo.launch()