In [None]:
!pip install transformers torch pillow requests

from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
from PIL import Image
import requests

In [None]:
# Load the pretrained food image classification model
model_name = "nateraw/food"
extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

In [4]:
# USDA API setup
API_KEY = "G7raQzpCDSdqWXdAbXm6abn12M63WL3Z47ySoPuQ"  # your USDA key
USDA_URL = "https://api.nal.usda.gov/fdc/v1/foods/search"
OPENFOOD_URL = "https://world.openfoodfacts.org/api/v2/search"

In [5]:
# Specified search instead of full name from pretrained model
# normalizes text to be readable by human and API
def clean_query(q):
    """Make model label more USDA-friendly."""
    q = q.lower().replace("_", " ").strip()
    mappings = {
        "spaghetti bolognese": "spaghetti with meat sauce",
        "grilled cheese sandwich": "grilled cheese",
        "chicken curry": "chicken, curry",
        "apple pie": "pie, apple",
        "fried rice": "rice, fried",
        "hamburger": "beef patty",
        "ice cream": "ice cream, vanilla"
    }
    return mappings.get(q, q)

# this function takes in the parsed or recognized food name and searches the USDA FoodData Central API for it
def search_usda(query):
    """Search the USDA FoodData Central database."""
    params = {
        "api_key": API_KEY,
        "query": query,
        "dataType": ["Foundation", "Survey (FNDDS)"],
        "pageSize": 1
    }
    r = requests.get(USDA_URL, params=params)
    data = r.json()
    if "foods" in data and data["foods"]:
        food = data["foods"][0]
        return {
            "source": "USDA",
            "name": food["description"],
            "nutrients": food.get("foodNutrients", [])
        }
    return None

# when the parsed in food from the pretrained model isn't recognized by USDA API, OpenFood API is searched
def search_openfoodfacts(query):
    """Fallback: use Open Food Facts if USDA has no match."""
    r = requests.get(OPENFOOD_URL, params={"search_terms": query, "fields": "product_name,nutriments", "page_size": 1})
    data = r.json()
    products = data.get("products", [])
    if products:
        prod = products[0]
        return {
            "source": "OpenFoodFacts",
            "name": prod.get("product_name", "Unknown product"),
            "nutrients": prod.get("nutriments", {})
        }
    return None

# based on the food name the appropriate nutrients breakdown is retrieved
def get_food_data(pred_class):
    """Unified lookup: USDA → fallback to Open Food Facts."""
    query = clean_query(pred_class)
    print(f"\nSearching for: {query}")

    result = search_usda(query)
    if not result:
        short_query = query.split()[0]
        print(f"No USDA match for '{query}'. Retrying with '{short_query}'...")
        result = search_usda(short_query)

    if not result:
        print("Trying Open Food Facts...")
        result = search_openfoodfacts(query)

    if result:
        print(f"\n {result['source']} match: {result['name']}")
        return result
    else:
        print("No food found in any database.")
        return None

In [None]:
# Upload an image from your computer
from google.colab import files
uploaded = files.upload()

# Use the uploaded file name
image_path = list(uploaded.keys())[0]
image = Image.open(image_path).convert("RGB")

# Predict food label
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
preds = torch.nn.functional.softmax(outputs.logits, dim=-1)

pred_class = model.config.id2label[preds.argmax().item()]
print(f"\nDetected food: {pred_class}")

In [None]:
# Lookup food information from USDA or Open Food Facts
food_info = get_food_data(pred_class)

In [None]:
import pandas as pd

if food_info and food_info["source"] == "USDA":
    nutrients = food_info["nutrients"]
    key_nutrients = ["Energy", "Protein", "Total lipid (fat)", "Carbohydrate, by difference"] # listed nutrients
    rows = [
        {"Nutrient": n["nutrientName"], "Amount": n["value"], "Unit": n["unitName"]} # column names
        for n in nutrients if n["nutrientName"] in key_nutrients
    ]
    if rows:
        display(pd.DataFrame(rows))
    else:
        print("No key nutrients found.")
elif food_info and food_info["source"] == "OpenFoodFacts":
    nutr = food_info["nutrients"]
    df = pd.DataFrame([nutr]).T.reset_index()
    df.columns = ["Nutrient", "Value"]
    display(df)