#Food-101 computer vision project
###resnet50
###Model Training (80%train/10%val/10%test/)


In [None]:
import os
import random
import pandas as pd
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader, random_split # Import random_split
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from IPython.display import Image as DisplayImage, display
import gradio as gr

# --- 0. Environment Setup and Data Download/Extraction ---
print("--- 0. Environment Setup and Data Download/Extraction ---")
# Remember to run this line in Colab to mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("Google Drive mounted successfully." if os.path.exists('/content/drive/MyDrive') else "Google Drive mounting failed.")


# Set Google Drive storage path
drive_path = "/content/drive/MyDrive/food101_project"
tar_path = f"{drive_path}/food-101.tar.gz"
extract_path = "/content" # Colab's temporary space for faster extraction

# Check and create the project folder on Google Drive
os.makedirs(drive_path, exist_ok=True)

# Check if Food-101 compressed file is downloaded, if not, download it
if not os.path.exists(tar_path):
    print(f"Downloading Food-101 dataset to {drive_path}...")
    !wget http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz -P {drive_path}
    print("Food-101 dataset downloaded.")
else:
    print("food-101.tar.gz already downloaded, skipping download.")

# Check if Food-101 folder is extracted, if not, extract it
if not os.path.exists(os.path.join(extract_path, "food-101")):
    print(f"Extracting food-101.tar.gz to {extract_path}...")
    !tar -xzf {tar_path} -C {extract_path}
    print("Extraction complete.")
else:
    print("Food-101 folder already extracted, skipping extraction.")

# Confirm data availability
if os.path.exists("/content/food-101/images") and os.path.exists("/content/food-101/meta"):
    print("Food-101 dataset is ready.")
else:
    print("Food-101 dataset preparation failed. Please check download and extraction steps.")


# --- 1. Define Food-101 Food Class Names ---
print("\n--- 1. Define Food-101 Food Class Names ---")
food101_classes = [
    "apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare",
    "beet_salad", "beignets", "bibimbap", "bread_pudding", "breakfast_burrito",
    "bruschetta", "caesar_salad", "cannoli", "caprese_salad", "carrot_cake",
    "ceviche", "cheesecake", "cheese_plate", "chicken_curry", "chicken_quesadilla",
    "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder",
    "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes",
    "deviled_eggs", "donuts", "dumplings", "edamame", "eggs_benedict",
    "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras",
    "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice",
    "frozen_yogurt", "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich",
    "grilled_salmon", "guacamole", "gyoza", "hamburger", "hot_and_sour_soup",
    "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna",
    "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup",
    "mussels", "nachos", "omelette", "onion_rings", "oysters",
    "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck",
    "pho", "pizza", "pork_chop", "poutine", "prime_rib",
    "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto",
    "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits",
    "spaghetti_bolognese", "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake",
    "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare",
    "waffles"
]
class_names = food101_classes # Point class_names to food101_classes


# --- 2. Load Calorie Lookup Table ---
print("\n--- 2. Load Calorie Lookup Table ---")

# Your Excel file path
calorie_excel_path = os.path.join(drive_path, "calorie_lookup_table.xlsx")

# Initialize calorie_dict as an empty dictionary in case loading fails
calorie_dict = {}

if os.path.exists(calorie_excel_path):
    print(f"Loading calorie table from {calorie_excel_path}...")
    try:
        # Load XLSX file using pd.read_excel
        df_calories = pd.read_excel(calorie_excel_path)

        # Confirm DataFrame contains 'food_category' and 'calories_per_serving' columns
        if 'food_category' in df_calories.columns and 'calories_per_serving' in df_calories.columns:
            # Convert DataFrame back to a dictionary for use in predict_image
            # Use 'food_category' as key and 'calories_per_serving' as value
            calorie_dict = pd.Series(df_calories['calories_per_serving'].values, index=df_calories['food_category']).to_dict()
            print("Calorie table loaded successfully.")
        else:
            print(f"Excel file {calorie_excel_path} is missing 'food_category' or 'calories_per_serving' columns.")
            print("Please ensure your Excel file has the correct column names.")
            # If columns are incorrect, use a random calorie table
            calorie_dict = {food: random.randint(150, 600) for food in food101_classes}
            df_calories = pd.DataFrame(list(calorie_dict.items()), columns=['food', 'calories (kcal)'])
            print(f"Generated a random calorie table as a substitute.")
    except Exception as e:
        print(f"Error loading Excel file: {e}")
        print("Will create a random calorie table instead.")
        calorie_dict = {food: random.randint(150, 600) for food in food101_classes}
        df_calories = pd.DataFrame(list(calorie_dict.items()), columns=['food', 'calories (kcal)'])
        print(f"Generated a random calorie table as a substitute.")

else:
    print(f"Calorie lookup table not found at {calorie_excel_path}.")
    print("Will create a random calorie table instead.")
    # Fallback: If file not found, create a random calorie dictionary
    calorie_dict = {food: random.randint(150, 600) for food in food101_classes}
    df_calories = pd.DataFrame(list(calorie_dict.items()), columns=['food', 'calories (kcal)'])
    # Save this randomly generated calorie table to CSV for future use
    df_calories.to_csv(os.path.join(drive_path, "food101_random_calorie_table.csv"), index=False)
    print(f"Generated and saved a random calorie table.")


# --- 3. Custom Food101 Dataset Class ---
print("\n--- 3. Custom Food101 Dataset Class ---")
class Food101Dataset(Dataset):
    def __init__(self, image_paths_list, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = image_paths_list # Use the directly passed list of image paths

        # Use externally defined food101_classes to ensure consistency of class indices
        self.classes = food101_classes
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.class_to_idx[img_path.split('/')[0]] # Parse class from path (e.g., "apple_pie/12345" -> "apple_pie")
        img_full_path = os.path.join(self.root_dir, img_path + ".jpg")

        image = Image.open(img_full_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, label
print("Food101Dataset class defined.")


# --- 4. Data Loader Setup (80% Train, 10% Validation, 10% Test) ---
print("\n--- 4. Data Loader Setup (80% Train, 10% Validation, 10% Test) ---")
# Set paths for Food-101 images and metadata
image_root = "/content/food-101/images"
meta_root = "/content/food-101/meta"

# **Define image preprocessing steps (add stronger data augmentation for training set)**
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # Random crop and resize, simulate different sizes/positions
    transforms.RandomHorizontalFlip(), # Random horizontal flip
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), # Random color jitter
    transforms.ToTensor(),         # Convert PIL image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize image
])

# **Define preprocessing steps for validation and test sets (only essential transformations)**
val_test_transform = transforms.Compose([
    transforms.Resize((256)), # Typically resize larger first, then center crop
    transforms.CenterCrop((224, 224)), # Center crop to model input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Load all image paths (from train.txt and test.txt)
all_image_paths = []
with open(os.path.join(meta_root, "train.txt"), 'r') as f:
    all_image_paths.extend(f.read().splitlines())
with open(os.path.join(meta_root, "test.txt"), 'r') as f:
    all_image_paths.extend(f.read().splitlines())

# Create a full dataset instance containing all images (using a generic transform for now)
# Different transforms will be applied after random_split
full_food101_dataset = Food101Dataset(image_paths_list=all_image_paths,
                                     root_dir=image_root, transform=None) # Temporarily set to None here

# Split into training, validation, and test sets (80% train, 10% val, 10% test)
total_size = len(full_food101_dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size # Ensure sum equals total_size

# Set random seed to ensure reproducibility of splits
g = torch.Generator().manual_seed(42)

# random_split returns Subset objects
train_subset_raw, val_subset_raw, test_subset_raw = random_split(
    full_food101_dataset,
    [train_size, val_size, test_size],
    generator=g
)

# Since random_split returns Subsets, we need a wrapper to apply different transforms
# This DatasetWithTransform class will use the indices provided by Subset to reload images from the original dataset and apply the specified transformation
class DatasetWithTransform(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        # Inherit class_to_idx and classes from the original dataset
        # subset.dataset points to full_food101_dataset before random_split
        self.class_to_idx = subset.dataset.class_to_idx
        self.classes = subset.dataset.classes
        self.root_dir = subset.dataset.root_dir

    def __getitem__(self, idx):
        # Get the original index in the full dataset via the subset's internal index
        original_idx_in_full_dataset = self.subset.indices[idx]
        img_path_relative = self.subset.dataset.image_paths[original_idx_in_full_dataset]
        label = self.subset.dataset.class_to_idx[img_path_relative.split('/')[0]]
        img_full_path = os.path.join(self.root_dir, img_path_relative + ".jpg")

        image = Image.open(img_full_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.subset)

# Create new DatasetWithTransform instances for the split subsets, applying different transforms
train_dataset = DatasetWithTransform(train_subset_raw, transform=train_transform)
val_dataset = DatasetWithTransform(val_subset_raw, transform=val_test_transform)
test_dataset = DatasetWithTransform(test_subset_raw, transform=val_test_transform)


# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Total data loaded: {total_size} samples")
print(f"Training data after split: {len(train_dataset)} samples (with enhanced data augmentation)")
print(f"Validation data: {len(val_dataset)} samples (with standard preprocessing)")
print(f"Test data: {len(test_dataset)} samples (with standard preprocessing)")


# --- 5. Model Initialization and Training (Including Checkpoint and GPU Usage Confirmation) ---
print("\n--- 5. Model Initialization and Training ---")
# Set training device (GPU preferred, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device for training: {device}")
if device.type == 'cuda':
    print(f"CUDA is available! GPU Name: {torch.cuda.get_device_name(0)}")
    print("--- Initial GPU Status ---")
    !nvidia-smi # First check of GPU status
    print("--------------------------")
else:
    print("CUDA is NOT available. Training will use CPU. Please check your Colab runtime type (Runtime -> Change runtime type -> GPU).")


# Initialize ResNet model, using pre-trained weights
# ----------------------------------------------------
# **IMPORTANT: Choose your model here (uncomment the line you want to use)**

# Use ResNet-18 (Accuracy was not good, started stagnating after 10th epoch)
# model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
# model_name = "food101_resnet18.pth"

# Or, use ResNet-50 (Higher accuracy, ultimately decided to use this model)
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model_name = "food101_resnet50.pth" # Corresponding best model filename (without epoch and accuracy)
# ----------------------------------------------------

# Modify the final fully connected layer to match the 101 classes of Food-101
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(food101_classes))

# Move the model to the specified device (GPU/CPU)
model = model.to(device)

# Define loss function (CrossEntropy for classification) and optimizer (Adam commonly used in deep learning)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # You can adjust the learning rate (lr)

# Add learning rate scheduler
# If validation loss doesn't decrease for 3 consecutive epochs, the learning rate will be reduced to 0.1 times its current value.
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)


# Set model save path
model_save_path = os.path.join(drive_path, model_name)

# Check if a pre-trained model exists, load it if it does
if os.path.exists(model_save_path):
    try:
        # Load the entire checkpoint including model, optimizer, and scheduler states
        checkpoint = torch.load(model_save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        # Also load best_val_accuracy and best_val_loss for continuation
        best_val_accuracy = checkpoint.get('best_val_accuracy', 0.0)
        best_val_loss = checkpoint.get('best_val_loss', float('inf'))
        # Determine start_epoch for continuation
        start_epoch = checkpoint.get('epoch', 0) + 1
        print(f"Pre-trained model state loaded from {model_save_path}. Resuming training from Epoch {start_epoch}.")
    except Exception as e:
        print(f"Error loading model {model_save_path}: {e}. Training model from scratch.")
        start_epoch = 0 # Start from scratch if loading fails
        best_val_accuracy = 0.0 # Reset if starting from scratch
        best_val_loss = float('inf') # Reset if starting from scratch
else:
    print(f"Model file not found: {model_save_path}. Training model from scratch.")
    start_epoch = 0 # Start from scratch
    best_val_accuracy = 0.0 # Reset if starting from scratch
    best_val_loss = float('inf') # Reset if starting from scratch

# Initialize best validation accuracy for saving the best model (if not loaded from checkpoint)
# best_val_accuracy is now initialized directly from checkpoint or default 0.0

# This variable will track the path of the best model file saved
best_model_ever_saved_path = model_save_path # Initialize to the main save path


# --- Model Training (Including Validation Process and More Detailed Saving Logic) ---
print(f"--- Starting Model Training (or Resuming Training) ---")
epochs_to_train_now = 15 # Number of epochs to train for now, you can adjust this

for epoch in range(start_epoch, epochs_to_train_now):
    if device.type == 'cuda':
        print(f"\n--- GPU Usage (start of Epoch {epoch+1}) ---")
        !nvidia-smi
        print("-------------------------------------------\n")

    # --- Training Phase ---
    model.train() # Set model to training mode
    running_loss = 0.0
    progress_bar_train = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs_to_train_now} [Train]")

    for images, labels in progress_bar_train:
        images, labels = images.to(device), labels.to(device) # Move data to GPU

        optimizer.zero_grad() # Zero gradients
        outputs = model(images) # Forward pass
        loss = criterion(outputs, labels) # Calculate loss
        loss.backward() # Backward pass
        optimizer.step() # Update weights

        running_loss += loss.item()
        avg_loss = running_loss / (progress_bar_train.n + 1) # Calculate average loss

        progress_bar_train.set_postfix(loss=avg_loss) # Display current loss on the right of the progress bar

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{epochs_to_train_now}] Training complete - Average Training Loss: {avg_train_loss:.4f}")

    # --- Validation Phase ---
    model.eval() # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs_to_train_now} [Validation]")

    with torch.no_grad(): # Disable gradient calculation
        for images, labels in progress_bar_val:
            images, labels = images.to(device), labels.to(device) # Move data to GPU
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            progress_bar_val.set_postfix(val_loss=val_loss / (progress_bar_val.n + 1), acc=(100 * correct / total))

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{epochs_to_train_now}] Validation complete - Average Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # --- Learning Rate Scheduler Update ---
    scheduler.step(avg_val_loss) # Update learning rate based on validation loss

    # --- Model Saving Logic (Checkpoint) ---
    # 1. Save the current epoch's model to a file with version info
    # Filename includes epoch, validation accuracy, and validation loss
    current_epoch_model_filename = f"food101_resnet50_epoch_{epoch+1}_acc_{val_accuracy:.2f}_loss_{avg_val_loss:.4f}.pth"
    current_epoch_model_filepath = os.path.join(drive_path, current_epoch_model_filename)
    # Save full checkpoint state
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_accuracy': best_val_accuracy,
        'best_val_loss': best_val_loss,
    }, current_epoch_model_filepath)
    print(f"Epoch {epoch+1} model snapshot saved to: {current_epoch_model_filepath}")


    # 2. Determine whether to save as the best model (overwrite fixed-name file)
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        # Save full checkpoint state for the "best" model too
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_accuracy': best_val_accuracy,
            'best_val_loss': avg_val_loss, # Update best_val_loss here as well
        }, model_save_path) # Overwrite the main food101_resnet50.pth
        best_model_ever_saved_path = model_save_path # Update the actual path of the best model
        print(f"Validation accuracy improved! Best model saved to: {model_save_path} (Accuracy: {best_val_accuracy:.2f}%)")
    else:
        print(f"Keeping existing best model (Accuracy: {best_val_accuracy:.2f}%)")

print(f"Model training complete. Final best model state should be located at: {best_model_ever_saved_path if best_model_ever_saved_path else model_save_path}")


# --- 6. Define Prediction Function ---
print("\n--- 6. Define Prediction Function ---")
def predict_image(image_path, model, class_names, calorie_dict):
    """
    Predicts the food category in an image and estimates calories.

    Args:
        image_path (str): Path to the image file.
        model (torch.nn.Module): The trained PyTorch model.
        class_names (list): List of food class names.
        calorie_dict (dict): Dictionary mapping food names to calorie estimates.

    Returns:
        tuple: (Predicted class name, Confidence percentage, Estimated calorie value)
    """
    # Use the same standard transformations as validation/test sets for prediction, no random augmentation
    transform_predict = transforms.Compose([
        transforms.Resize((256)), # Typically resize larger first
        transforms.CenterCrop((224, 224)), # Then center crop to model input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    try:
        image = Image.open(image_path).convert("RGB")
    except FileNotFoundError:
        print(f"Error: Image file not found: {image_path}")
        return None, None, None
    except Exception as e:
        print(f"Error: Could not open image file: {e}")
        return None, None, None

    input_tensor = transform_predict(image).unsqueeze(0).to(device)

    model.eval() # Set model to evaluation mode
    with torch.no_grad(): # Disable gradient calculation
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1) # Calculate class probabilities
        _, pred = torch.max(output, 1) # Get index of the highest probability class
        class_idx = pred.item()

    class_name = class_names[class_idx]
    # Get calories from the loaded calorie_dict
    calories = calorie_dict.get(class_name, "Unknown") # Get calories, "Unknown" if not found
    confidence = probabilities[0][class_idx].item() * 100 # Calculate confidence percentage

    return class_name, confidence, calories

print("Prediction function predict_image defined.")


# --- 7. Final Test Set Evaluation (in Colab Notebook) ---
print("\n--- 7. Final Test Set Evaluation ---")

# Load the best model (ensure it's the one that performed best on the validation set)
if os.path.exists(model_save_path):
    try:
        # Load the full checkpoint
        checkpoint = torch.load(model_save_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # Optionally, load optimizer and scheduler states if needed for future training/debugging
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # if 'scheduler_state_dict' in checkpoint:
        #     scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        print(f"Best model loaded: {model_save_path}, for final testing.")
    except Exception as e:
        print(f"Error loading best model {model_save_path}: {e}. Will use current in-memory model for testing.")
else:
    print(f"Best model not found: {model_save_path}. Please ensure the model was trained and saved successfully. Will use current in-memory model.")

model.eval() # Set to evaluation mode
test_correct = 0
test_total = 0
test_loss = 0.0
progress_bar_test = tqdm(test_loader, desc=f"Final Test Set Evaluation")

with torch.no_grad():
    for images, labels in progress_bar_test:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()
        progress_bar_test.set_postfix(test_loss=test_loss / (progress_bar_test.n + 1), test_acc=(100 * test_correct / test_total))

avg_test_loss = test_loss / len(test_loader)
test_accuracy = 100 * test_correct / test_total
print(f"\nFinal Test Set Evaluation Results:")
print(f"Average Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
print("Reminder: This test set is randomly split from all Food-101 images, results are not directly comparable to official benchmarks.")


# --- 8. Create Demo UI (using Gradio) ---
print("\n--- 8. Create Demo UI (using Gradio) ---")

# Randomly select a few images from test_dataset as examples
# Note: test_dataset.subset.indices contains indices from the original full_food101_dataset
#       test_dataset.subset.dataset.image_paths contains the original relative paths list
num_examples = 5 # Number of examples to display
example_paths = []
if len(test_dataset) > 0:
    # Randomly select indices
    random_indices = random.sample(range(len(test_dataset)), min(num_examples, len(test_dataset)))
    for idx in random_indices:
        # Get the original index from test_dataset
        original_idx = test_dataset.subset.indices[idx]
        # Get the relative path from the original full_food101_dataset (e.g., "apple_pie/12345")
        relative_path = test_dataset.subset.dataset.image_paths[original_idx]
        # Construct the full image file path
        full_image_path = os.path.join(image_root, relative_path + ".jpg")
        example_paths.append(full_image_path)
else:
    print("Warning: Test dataset is empty, cannot generate example image paths.")


# Wrap predict_image into a function required by Gradio interface
def gradio_predict_wrapper(image_file):
    if image_file is None:
        return "Please upload an image."

    image_path = image_file.name

    predicted_class, confidence, estimated_calories = predict_image(image_path, model, class_names, calorie_dict)

    if predicted_class:
        result_str = (
            f"Predicted Class: **{predicted_class}**\n"
            f"Confidence: {confidence:.2f}%\n"
            f"Estimated Calories: approx. **{estimated_calories}** kcal"
        )
        return result_str
    else:
        return "Prediction failed, please check the image file."

# Set up Gradio interface
iface = gr.Interface(
    fn=gradio_predict_wrapper,
    inputs=gr.File(type="filepath", label="Upload Food Image"),
    outputs=gr.Markdown(""),
    title="Food-101 Food Recognition and Calorie Estimation Demo",
    description="Recognize your uploaded food images with an AI model and provide calorie estimation.",
    examples=example_paths, # Use dynamically generated test set image examples
    allow_flagging="never",
    css="footer {visibility: hidden}"
)

# Launch Gradio interface
print("\nLaunching Gradio interface...")
print("Please click on the Public URL or Local URL below to access:")
iface.launch(debug=True, share=True)

print("\nAll code execution complete.")