In [None]:
import os
import random
import pandas as pd
from PIL import Image
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.nn as nn
import gradio as gr
from openai import OpenAI

# --- 0. Environment Setup and Folder Mounting ---
from google.colab import drive
print("--- 0. Mounting Google Drive ---")
drive.mount('/content/drive')
drive_path = "/content/drive/MyDrive/food101_project"
image_root = "/content/food-101/images"
os.makedirs(drive_path, exist_ok=True)

# --- 1. Define Food-101 Class Names ---
print("\n--- 1. Define Food-101 Classes ---")
food101_classes = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad", "beignets", "bibimbap", "bread_pudding", "breakfast_burrito",
                   "bruschetta", "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "cheesecake", "cheese_plate", "chicken_curry", "chicken_quesadilla",
                   "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder", "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes",
                   "deviled_eggs", "donuts", "dumplings", "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries", "french_onion_soup",
                   "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt", "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole", "gyoza",
                   "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna", "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons",
                   "miso_soup", "mussels", "nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck", "pho", "pizza", "pork_chop", "poutine",
                   "prime_rib", "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese",
                   "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "waffles"]
class_names = food101_classes

# --- 2. Load Calorie Table ---
print("\n--- 2. Load Calorie Table ---")
calorie_excel_path = os.path.join(drive_path, "calorie_lookup_table.xlsx")
calorie_dict = {}
if os.path.exists(calorie_excel_path):
    df_calories = pd.read_excel(calorie_excel_path)
    if 'food_category' in df_calories.columns and 'calories_per_serving' in df_calories.columns:
        calorie_dict = pd.Series(df_calories['calories_per_serving'].values, index=df_calories['food_category']).to_dict()
    else:
        calorie_dict = {food: random.randint(150, 600) for food in food101_classes}
else:
    calorie_dict = {food: random.randint(150, 600) for food in food101_classes}

# --- 3. Load Trained Model ---
print("\n--- 3. Load Model ---")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, len(food101_classes))
model_save_path = os.path.join(drive_path, "food101_resnet50.pth")
checkpoint = torch.load(model_save_path, map_location=device)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(" Loaded model from checkpoint with 'model_state_dict'.")
elif isinstance(checkpoint, dict) and all(isinstance(k, str) and '.' in k for k in checkpoint.keys()):
    model.load_state_dict(checkpoint)
    print(" Loaded model from plain state_dict.")
else:
    raise ValueError(" Unrecognized model format. Please check your .pth file.")
model.to(device)
model.eval()

# --- 4. Prediction Function ---
print("\n--- 4. Define Prediction Function ---")
def predict_image(image_path, model, class_names, calorie_dict):
    transform = transforms.Compose([
        transforms.Resize((256)),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    if isinstance(image_path, str):
        image = Image.open(image_path).convert("RGB")
    else:
        image = image_path.convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        _, pred = torch.max(output, 1)
    class_idx = pred.item()
    class_name = class_names[class_idx]
    calories = calorie_dict.get(class_name, "Unknown")
    confidence = probabilities[0][class_idx].item() * 100
    return class_name, confidence, calories

# --- 5. OpenAI API Client Setup ---
client = OpenAI(api_key="sk-XXXXXXXXXXXX")

# --- 6. Accumulated Predictions ---
prediction_history = []
total_calories = 0
gpt_response = ""

def gradio_predict_accumulate(image_file, goal):
    global total_calories, gpt_response
    if image_file is None:
        return "\n".join(prediction_history), f"Total intake so far: {total_calories} kcal\nPlease upload an image."

    class_name, confidence, estimated_calories = predict_image(image_file, model, class_names, calorie_dict)
    total_calories += estimated_calories if isinstance(estimated_calories, (int, float)) else 0

    entry = f" **{class_name}** ({confidence:.2f}%) — `{estimated_calories} kcal`\n🎯 Goal: `{goal}`"
    prediction_history.append(entry)

    # Update GPT summary after each prediction
    gpt_response = get_gpt_summary(goal)
    full_output = "\n".join(prediction_history) + "\n\n GPT Suggestion:\n" + gpt_response
    return full_output, f"Total intake so far: {total_calories} kcal"

def get_gpt_summary(goal):
    try:
        items_summary = "\n".join(prediction_history)
        prompt = f"""You are a friendly nutritionist. Based on the following:
{items_summary}

Goal: {goal}\nTotal calories: {total_calories} kcal
Please provide one combined dietary recommendation."""

        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "You are a helpful nutrition coach."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.7,
            max_tokens=300,
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        return f"( GPT error: {str(e)})\n\nPlease try again later or check your quota."

# --- 7. Launch Gradio Interface ---
goal_dropdown = gr.Dropdown(label="Select Your Goal", choices=["fat loss", "muscle gain", "maintain"], value="fat loss")
iface = gr.Interface(
    fn=gradio_predict_accumulate,
    inputs=[gr.Image(type="pil", label="Upload Food Image"), goal_dropdown],
    outputs=[gr.Textbox(label="Cumulative Results"), gr.Textbox(label="Status")],
    title="Food-101 + GPT: Multi-Image Tracker",
    description="Upload multiple food images to accumulate predictions and automatically get smart GPT advice at the bottom.",
    allow_flagging="never",
    css="footer {visibility: hidden}"
)
iface.launch(debug=True, share=True)
