<a href="https://colab.research.google.com/github/Aradhyakapil/Food-label-analyzer/blob/main/food_label_analyzer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install required libraries
!pip install transformers
!pip install peft
!pip install datasets
!pip install pandas
!pip install torch
!pip install easyocr
!pip install streamlit
!pip install opencv-python
!pip install matplotlib
!pip install Pillow
!pip install numpy
!pip install accelerate
!pip install bitsandbytes
!pip install sentencepiece protobuf scipy

# Install Streamlit dependencies for running in Colab
!pip install pyngrok

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    default_data_collator
)
from peft import get_peft_model, LoraConfig, TaskType
from datasets import Dataset
import pandas as pd
import torch
import numpy as np
import os
import gc
import logging
import json
import matplotlib.pyplot as plt
from transformers.trainer_callback import TrainerCallback

# -----------------------------------------
# Setup logging & clear GPU cache
# -----------------------------------------
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    logger.info(
        f"GPU memory before loading: {props.total_memory/1e9:.2f} GB total, "
        f"{torch.cuda.memory_allocated()/1e9:.2f} GB allocated, "
        f"{torch.cuda.memory_reserved()/1e9:.2f} GB reserved"
    )

# -----------------------------------------
# 1) Load CSV data
# -----------------------------------------
PATH = "/content/drive/MyDrive/ingredients.csv"
if not os.path.exists(PATH):
    raise FileNotFoundError(f"{PATH} not found")
df = pd.read_csv(PATH)
df.columns = df.columns.str.strip()
logger.info(f"Loaded {len(df)} rows from CSV")

# -----------------------------------------
# 2) Build prompt/response pairs
# -----------------------------------------
def create_prompt_response(row):
    prompt = (
        f"Ingredient: {row['ingredient_name']}, Category: {row['category']}, "
        f"Glycemic Index: {row['gi_value']}, Carbs per 100g: {row['carbs_per_100g']}, "
        f"Sodium per 100g: {row['sodium_per_100g']}, Health Concerns: Diabetic Risk, Hypertension Risk, Pregnancy Risk\n"
        "Analyze this ingredient for health profiles with potential concerns."
    )
    response = (
        "{\n"
        f'  "risk_level": "{row["general_risk_level"]}",\n'
        f'  "effects": ["{row["general_risk_reason"]}"],\n'
        '  "recommendations": ["Monitor consumption based on your health profile."]\n'
        "}\n\n"
        f"For diabetics - Risk Level: {row['diabetic_risk_level']}\n"
        f"Effects: {row['diabetic_risk_reason']}\n"
        f"For hypertension - Risk Level: {row['hypertension_risk_level']}\n"
        f"Effects: {row['hypertension_risk_reason']}\n"
        f"For pregnancy - Risk Level: {row['pregnancy_risk_level']}\n"
        f"Effects: {row['pregnancy_risk_reason']}"
    )
    return {'input': prompt, 'output': response}

dataset = Dataset.from_list(df.apply(create_prompt_response, axis=1).tolist())
logger.info(f"Prepared {len(dataset)} examples")

# -----------------------------------------
# 3) Load & quantize model
# -----------------------------------------
MODEL_NAME = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
)

# -----------------------------------------
# 4) Apply LoRA + enable grads on inputs
# -----------------------------------------
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "dense"],
    bias="none",
)

model = get_peft_model(model, lora_config)

# Disable kv-cache so gradients can flow
model.config.use_cache = False

# Enable input gradients for 4-bit
model.enable_input_require_grads()

# Re-enable gradient checkpointing (after disabling cache)
model.gradient_checkpointing_enable()

# Tell Trainer which field is the label
model.config.label_names = ["labels"]

# Sanity check
model.print_trainable_parameters()

# -----------------------------------------
# 5) Tokenize
# -----------------------------------------
def preprocess(examples):
    prompts = [
        f"User: {inp}\n\nAssistant: {out}"
        for inp, out in zip(examples["input"], examples["output"])
    ]
    tok = tokenizer(
        prompts,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )
    tok["labels"] = tok["input_ids"].clone()
    return tok

tokenized = dataset.map(
    preprocess,
    batched=True,
    remove_columns=dataset.column_names
)

# -----------------------------------------
# 6) Define Metrics Callback
# -----------------------------------------
class MetricsCallback(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.train_perplexities = []
        self.learning_rates = []
        self.epoch_metrics = {}
        self.current_epoch = 0

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None: return
        if "loss" in logs:
            step, loss = state.global_step, logs["loss"]
            self.train_losses.append((step, loss))
            self.train_perplexities.append((step, np.exp(loss)))
        if "learning_rate" in logs:
            self.learning_rates.append((state.global_step, logs["learning_rate"]))

    def on_epoch_begin(self, args, state, control, **kwargs):
        self.current_epoch = state.epoch
        logger.info(f"Epoch {self.current_epoch} start")

    def on_epoch_end(self, args, state, control, **kwargs):
        losses = [
            l for s, l in self.train_losses
            if (self.current_epoch - 1) <= s / state.max_steps * args.num_train_epochs < self.current_epoch
        ]
        if losses:
            avg = float(np.mean(losses))
            self.epoch_metrics[self.current_epoch] = {
                "loss": avg,
                "perplexity": float(np.exp(avg)),
                "step": state.global_step
            }
            logger.info(
                f"Epoch {self.current_epoch} end — loss: {avg:.4f}, "
                f"perplexity: {np.exp(avg):.4f}"
            )

# -----------------------------------------
# 7) Trainer Setup
# -----------------------------------------
output_dir = "/content/drive/MyDrive/fine_tuned_phi2_peft"
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=2,
    logging_dir=f"{output_dir}/logs",
    logging_steps=10,
    fp16=True,
    gradient_checkpointing=True,
    remove_unused_columns=False,
    report_to="tensorboard",
    warmup_steps=100,
    lr_scheduler_type="cosine",
    optim="adamw_torch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    data_collator=default_data_collator,
    callbacks=[MetricsCallback()],
)

# -----------------------------------------
# 8) Train & Save
# -----------------------------------------
logger.info("Starting training...")
train_result = trainer.train()

logger.info("Training complete; saving model & tokenizer")
model.save_pretrained(f"{output_dir}/model")
tokenizer.save_pretrained(f"{output_dir}/tokenizer")


# -----------------------------------------
# (Optional) 9) Visualization & Evaluation Helpers
# -----------------------------------------
def plot_training_curves(metrics_callback):
    plt.figure(figsize=(15, 10))
    steps, losses = zip(*metrics_callback.train_losses) if metrics_callback.train_losses else ([], [])
    plt.subplot(2, 2, 1)
    plt.plot(steps, losses); plt.xlabel('Step'); plt.ylabel('Loss'); plt.title('Training Loss'); plt.grid(True)
    steps, perps = zip(*metrics_callback.train_perplexities) if metrics_callback.train_perplexities else ([], [])
    plt.subplot(2, 2, 2)
    plt.plot(steps, perps); plt.xlabel('Step'); plt.ylabel('Perplexity'); plt.title('Training Perplexity'); plt.grid(True)
    if metrics_callback.learning_rates:
        steps, lrs = zip(*metrics_callback.learning_rates)
        plt.subplot(2, 2, 3)
        plt.plot(steps, lrs); plt.xlabel('Step'); plt.ylabel('LR'); plt.title('Learning Rate'); plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/training_curves.png")
    plt.close()

def plot_epoch_metrics(metrics_callback):
    if not metrics_callback.epoch_metrics:
        return
    epochs = sorted(metrics_callback.epoch_metrics)
    losses = [metrics_callback.epoch_metrics[e]["loss"] for e in epochs]
    perps = [metrics_callback.epoch_metrics[e]["perplexity"] for e in epochs]
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, losses, marker='o'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Loss per Epoch'); plt.grid(True)
    plt.subplot(1, 2, 2)
    plt.plot(epochs, perps, marker='o'); plt.xlabel('Epoch'); plt.ylabel('Perplexity'); plt.title('PPL per Epoch'); plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/epoch_metrics.png")
    plt.close()

def export_epoch_metrics(metrics_callback):
    if not metrics_callback.epoch_metrics:
        return
    rows = []
    for e, m in metrics_callback.epoch_metrics.items():
        rows.append({'epoch': e, 'loss': m['loss'], 'perplexity': m['perplexity'], 'step': m['step']})
    pd.DataFrame(rows).to_csv(f"{output_dir}/epoch_metrics.csv", index=False)
    with open(f"{output_dir}/epoch_metrics.json", "w") as f:
        json.dump(metrics_callback.epoch_metrics, f, indent=2)
    logger.info(f"Exported epoch metrics to CSV and JSON")


In [None]:
%%writefile app.py

import streamlit as st
import os
import re
import json
import time
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import easyocr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import gc
import peft
from peft import PeftModel
from peft import PeftModelForCausalLM, PeftConfig

# Set page configuration
st.set_page_config(
    page_title="AR Food Label Health Analyzer",
    page_icon="🍞",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Title and introduction
st.title("AR Food Label Health Analyzer")
st.markdown("""
This app analyzes food label ingredients based on your health profile to assess potential risks.
Upload an image of a food label to get started. The AR overlay will highlight ingredients based on risk levels.
""")

# ----------------------- Configuration -----------------------
# Sidebar for profile selection
st.sidebar.header("User Profile")
USER_PROFILE = st.sidebar.selectbox(
    "Select your health profile:",
    ["general", "diabetic", "hypertension", "celiac", "keto", "pregnancy"]
)

# API keys input (for future expansion)
st.sidebar.header("API Configuration")
GOOGLE_API_KEY = st.sidebar.text_input("Google API Key", value="", type="password")
GOOGLE_CX = st.sidebar.text_input("Google Custom Search Engine ID", value="")

# Model configuration
st.sidebar.header("Model Configuration")
MODEL_PATH = st.sidebar.text_input("Fine-tuned Model Path", value="/content/drive/MyDrive/fine_tuned_phi2_peft")
MODEL_TEMPERATURE = st.sidebar.slider("Model temperature", min_value=0.1, max_value=1.0, value=0.3, step=0.1)
MAX_TOKEN_LENGTH = 512

# Add debug mode toggle
DEBUG_MODE = st.sidebar.checkbox("Enable Debug Mode", value=False)

# AR Configuration
st.sidebar.header("AR Overlay Settings")
SHOW_AR_OVERLAY = st.sidebar.checkbox("Enable AR Overlay", value=True)
OVERLAY_OPACITY = st.sidebar.slider("Overlay Opacity", min_value=0.1, max_value=1.0, value=0.5, step=0.1)
HIGHLIGHT_MODE = st.sidebar.selectbox(
    "Highlight Method",
    ["Box Highlight", "Text Highlight", "Connected Labels"]
)

# Device detection
DEVICE = 0 if torch.cuda.is_available() else -1
st.sidebar.write(f"Using device: {'GPU' if DEVICE == 0 else 'CPU'}")

# ----------------------- Health Profile Definitions -----------------------
HEALTH_PROFILES = {
    "diabetic": {
        "concerns": ["blood sugar impact", "glycemic index", "carbohydrate content", "sugar content"],
        "avoid": ["sugars", "high fructose corn syrup", "white flour", "dextrose"],
        "monitor": ["carbohydrates", "starches", "maltodextrin"]
    },
    "hypertension": {
        "concerns": ["sodium content", "blood pressure impact", "vasoactive compounds"],
        "avoid": ["sodium", "salt", "MSG", "sodium phosphate", "sodium benzoate"],
        "monitor": ["potassium sorbate", "preservatives", "nitrates", "nitrites"]
    },
    "celiac": {
        "concerns": ["gluten content", "cross-contamination", "grain derivatives"],
        "avoid": ["wheat", "barley", "rye", "malt", "brewer's yeast", "triticale"],
        "monitor": ["oats", "modified food starch", "dextrin", "maltodextrin"]
    },
    "keto": {
        "concerns": ["carbohydrate content", "sugar alcohols", "net carbs"],
        "avoid": ["sugars", "starches", "flours", "corn syrup"],
        "monitor": ["sugar alcohols", "fiber", "artificial sweeteners"]
    },
    "general": {
        "concerns": ["additives", "preservatives", "artificial colors", "ultra-processed ingredients"],
        "avoid": [],
        "monitor": ["artificial colors", "artificial flavors", "preservatives", "high fructose corn syrup"]
    },
    "pregnancy": {
        "concerns": ["fetal development", "toxins and contaminants", "nutrient absorption", "blood pressure impact", "blood sugar impact"],
        "avoid": ["alcohol", "caffeine", "nitrates", "nitrites", "licorice root", "saccharin", "cyclamate"],
        "monitor": ["sodium", "sugar", "artificial sweeteners", "preservatives", "food colorings", "MSG"]
    }
}

# ----------------------- Model Loading -----------------------
@st.cache_resource
def load_model(model_path):
    try:
        # Define the base model name
        base_model_name = "microsoft/phi-2"

        # Load the tokenizer
        tokenizer_path = os.path.join(model_path, "tokenizer")
        if os.path.exists(tokenizer_path):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
            if DEBUG_MODE:
                st.sidebar.write("✅ Tokenizer loaded successfully")
        else:
            tokenizer = AutoTokenizer.from_pretrained(base_model_name)
            if DEBUG_MODE:
                st.sidebar.write("⚠️ Using base model tokenizer")

        # Ensure pad token is set
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # For direct inference without pipeline
        peft_config = PeftConfig.from_pretrained(os.path.join(model_path, "model"))
        if DEBUG_MODE:
            st.sidebar.write(f"✅ PEFT config loaded: {peft_config.base_model_name_or_path}")

        # Load the base model
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Load the PEFT adapter
        adapter_path = os.path.join(model_path, "model")
        if not os.path.exists(adapter_path):
            raise FileNotFoundError(f"PEFT adapter not found at {adapter_path}")

        model = PeftModel.from_pretrained(base_model, adapter_path)
        model.eval()  # Set to evaluation mode

        if DEBUG_MODE:
            st.sidebar.write("✅ Model loaded successfully")

        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        if DEBUG_MODE:
            st.sidebar.write(f"❌ Model loading error: {str(e)}")
        return None, None

# ----------------------- Functions -----------------------
def generate_with_model(model, tokenizer, prompt, max_new_tokens=150, temperature=0.3):
    """Generate text directly with the model instead of using pipeline."""
    try:
        messages = [{"role": "user", "content": prompt}]
        # Format input using chat template
        input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        # Tokenize
        inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=0.9,
                do_sample=True
            )

        # Decode and extract assistant's response
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract just the assistant's response
        if "ASSISTANT:" in generated_text:
            assistant_response = generated_text.split("ASSISTANT:")[-1].strip()
        else:
            assistant_response = generated_text

        if DEBUG_MODE:
            st.sidebar.write(f"✅ Generated response length: {len(assistant_response)}")

        return assistant_response
    except Exception as e:
        if DEBUG_MODE:
            st.sidebar.write(f"❌ Generation error: {str(e)}")
        return "Error generating response"

def create_risk_assessment_prompt(ingredient, user_profile):
    """Create a prompt for the model to analyze an ingredient."""
    profile_data = HEALTH_PROFILES.get(user_profile.lower(), HEALTH_PROFILES["general"])

    # Formatting the prompt for the fine-tuned model
    prompt = f"""
    Ingredient: {ingredient}, Category: FOOD_INGREDIENT,
    Glycemic Index: UNKNOWN, Carbs per 100g: UNKNOWN,
    Sodium per 100g: UNKNOWN, Health Concerns: Diabetic Risk, Hypertension Risk, Pregnancy Risk

    Analyze this ingredient for {user_profile} health profile with concerns: {", ".join(profile_data['concerns'])}
    Items to avoid: {", ".join(profile_data['avoid'])}
    Items to monitor: {", ".join(profile_data['monitor'])}

    Provide a risk assessment with risk_level (SAFE, LOW, MODERATE, HIGH, or UNKNOWN),
    effects, and recommendations.
    """
    return prompt.strip()

def parse_model_response(response_text):
    """Parse the model's response to extract risk assessment information."""
    try:
        if DEBUG_MODE:
            st.sidebar.write(f"Parsing response: {response_text[:100]}...")

        # Try to extract risk level
        risk_patterns = [
            r'(?:risk_level|risk level|Risk Level|Risk):\s*(SAFE|LOW|MODERATE|HIGH|UNKNOWN)',
            r'(?:risk_level|risk level|Risk Level|Risk).*?(SAFE|LOW|MODERATE|HIGH|UNKNOWN)',
            r'(SAFE|LOW|MODERATE|HIGH|UNKNOWN)\s*(?:risk|Risk)',
        ]

        risk_level = "UNKNOWN"
        for pattern in risk_patterns:
            risk_match = re.search(pattern, response_text, re.IGNORECASE)
            if risk_match:
                risk_level = risk_match.group(1).upper()
                break

        # Fallback risk assessment based on keywords if no explicit risk level found
        if risk_level == "UNKNOWN":
            if re.search(r'(?:high risk|very dangerous|avoid|harmful|toxic)', response_text, re.IGNORECASE):
                risk_level = "HIGH"
            elif re.search(r'(?:moderate risk|some concern|caution|careful)', response_text, re.IGNORECASE):
                risk_level = "MODERATE"
            elif re.search(r'(?:low risk|minimal concern|generally safe)', response_text, re.IGNORECASE):
                risk_level = "LOW"
            elif re.search(r'(?:safe|no concern|healthy|beneficial)', response_text, re.IGNORECASE):
                risk_level = "SAFE"

        # Extract effects and recommendations using more flexible patterns
        effects_patterns = [
            r'(?:Effects|effects):\s*(.+?)(?:\.|\n|$|For|Recommend)',
            r'(?:Impact|impact|effects of).*?(?:include|are|is)?\s*(.+?)(?:\.|\n|$)',
        ]

        effects = []
        for pattern in effects_patterns:
            effects_match = re.search(pattern, response_text, re.IGNORECASE | re.DOTALL)
            if effects_match:
                effect_text = effects_match.group(1).strip()
                effects = [effect_text]
                break

        if not effects:
            effects = ["Effects not specified in response"]

        # Extract recommendations
        rec_patterns = [
            r'(?:Recommendations|recommendations):\s*(.+?)(?:\.|\n|$)',
            r'(?:recommend|advised|should).*?(.+?)(?:\.|\n|$)',
        ]

        recommendations = []
        for pattern in rec_patterns:
            rec_match = re.search(pattern, response_text, re.IGNORECASE | re.DOTALL)
            if rec_match:
                rec_text = rec_match.group(1).strip()
                recommendations = [rec_text]
                break

        if not recommendations:
            recommendations = ["Consult with healthcare provider"]

        if DEBUG_MODE:
            st.sidebar.write(f"Parsed risk level: {risk_level}")

        return {
            "risk_level": risk_level,
            "effects": effects,
            "recommendations": recommendations
        }
    except Exception as e:
        if DEBUG_MODE:
            st.sidebar.write(f"❌ Parsing error: {str(e)}")
        return {
            "risk_level": "UNKNOWN",
            "effects": ["Unable to determine effects"],
            "recommendations": ["Consult with healthcare provider"]
        }

def analyze_ingredient_risk(model, tokenizer, ingredient, user_profile):
    """Analyze the risk of an ingredient using the fine-tuned model."""
    try:
        # Generate the risk analysis prompt
        prompt = create_risk_assessment_prompt(ingredient, user_profile)

        if DEBUG_MODE:
            st.sidebar.write(f"Analyzing ingredient: {ingredient}")

        # First try: specific prompt for fine-tuned model
        generated_text = generate_with_model(
            model,
            tokenizer,
            prompt,
            max_new_tokens=150,
            temperature=MODEL_TEMPERATURE
        )

        # Parse the response
        analysis = parse_model_response(generated_text)

        # If still unknown, try a more general prompt
        if analysis["risk_level"] == "UNKNOWN":
            fallback_prompt = f"Is {ingredient} safe or risky for someone with {user_profile} health concerns? Respond with a risk level (SAFE, LOW, MODERATE, HIGH) and explain why."
            generated_text = generate_with_model(
                model,
                tokenizer,
                fallback_prompt,
                max_new_tokens=150,
                temperature=MODEL_TEMPERATURE
            )
            analysis = parse_model_response(generated_text)

        # Add hardcoded rules for common ingredients as fallback
        if analysis["risk_level"] == "UNKNOWN":
            # Add common ingredient risk mappings
            common_ingredients = {
                "wheat flour": {"diabetic": "MODERATE", "celiac": "HIGH", "general": "LOW"},
                "sugar": {"diabetic": "HIGH", "keto": "HIGH", "general": "MODERATE"},
                "salt": {"hypertension": "HIGH", "general": "LOW"},
                "wheat": {"celiac": "HIGH", "general": "LOW"},
                "yeast": {"general": "SAFE"},
                "vegetable oil": {"general": "LOW", "keto": "MODERATE"},
                "palm oil": {"general": "MODERATE"},
                "preservative": {"general": "MODERATE", "pregnancy": "MODERATE"},
                "emulsifier": {"general": "LOW"},
            }

            # Check if ingredient is in our common list or contains any common ingredients
            for common_ing, risk_map in common_ingredients.items():
                if common_ing in ingredient.lower():
                    analysis["risk_level"] = risk_map.get(user_profile.lower(), risk_map.get("general", "UNKNOWN"))
                    analysis["effects"] = [f"Common {common_ing} effects for {user_profile} profile"]
                    analysis["recommendations"] = [f"Standard recommendations for {common_ing}"]
                    break

        return analysis["risk_level"], analysis
    except Exception as e:
        if DEBUG_MODE:
            st.sidebar.write(f"❌ Analysis error: {str(e)}")
        return "UNKNOWN", {"effects": ["Error analyzing ingredient"], "recommendations": ["Consult healthcare provider"]}

def evaluate_overall_product_risk(assessments, user_profile):
    """Calculate overall product risk based on ingredient assessments."""
    risk_weights = {
        "HIGH": 4,
        "MODERATE": 2,
        "LOW": 1,
        "SAFE": 0,
        "UNKNOWN": 1
    }

    if not assessments:
        return "UNKNOWN", ["Unable to assess without ingredients"]

    total_weight = 0
    total_ingredients = len(assessments)

    # Simple weighting - each ingredient contributes equally
    for ingredient, (risk, _) in assessments.items():
        total_weight += risk_weights[risk]

    # Normalize by number of ingredients
    if total_ingredients > 0:
        normalized = total_weight / (total_ingredients * risk_weights["HIGH"])
    else:
        return "UNKNOWN", ["No ingredients to assess"]

    # Use different thresholds based on health profile
    thresholds = {
        "diabetic": [0.4, 0.6, 0.8],
        "hypertension": [0.3, 0.5, 0.7],
        "celiac": [0.2, 0.4, 0.6],
        "keto": [0.5, 0.7, 0.85],
        "general": [0.25, 0.5, 0.75],
        "pregnancy": [0.3, 0.5, 0.7]
    }

    profile = user_profile.lower()
    thr = thresholds.get(profile, thresholds["general"])

    if normalized >= thr[2]:
        overall = "HIGH"
        recs = ["This product poses significant risks based on your health profile. Consider alternatives."]
    elif normalized >= thr[1]:
        overall = "MODERATE"
        recs = ["This product contains some concerning ingredients. Consume in moderation and monitor your response."]
    elif normalized >= thr[0]:
        overall = "LOW"
        recs = ["This product appears relatively safe but contains ingredients to monitor."]
    else:
        overall = "SAFE"
        recs = ["This product appears safe based on your health profile."]

    return overall, recs

# -------------- OCR and AR OVERLAY FUNCTIONS ------------------
def extract_text_from_image(image):
    """Extract text from an image using EasyOCR."""
    reader = easyocr.Reader(['en'], gpu=(DEVICE == 0))
    if isinstance(image, str):  # If image is a path
        result = reader.readtext(image)
    else:  # If image is an array/PIL Image
        result = reader.readtext(np.array(image))
    return result  # Return full OCR result with bounding boxes

def parse_ingredients(ocr_result):
    """Parse ingredients list from OCR results."""
    # Extract full text first
    full_text = " ".join([detection[1] for detection in ocr_result])

    # Find the ingredients section
    pattern = re.compile(r'(?i)ingredients?:?\s*(.*)', re.IGNORECASE | re.DOTALL)
    match = pattern.search(full_text)
    ingredients_str = match.group(1).strip() if match else full_text.strip()

    # Cut off at certain keywords
    cutoffs = [r'(?i)\ballergen', r'(?i)\bnutrition', r'(?i)\bwarning', r'(?i)\bstorage', r'(?i)\bmanufactured', r'(?i)\bmay contain']
    for c in cutoffs:
        c_re = re.compile(c)
        c_match = c_re.search(ingredients_str)
        if c_match:
            ingredients_str = ingredients_str[:c_match.start()]
            break

    ingredients_str = ingredients_str.strip()

    # Process ingredients
    results = []

    # Split by common separators
    separators = [',', '.', ';', '•']
    for sep in separators:
        if sep in ingredients_str:
            ingredients = [i.strip() for i in ingredients_str.split(sep) if i.strip()]
            if len(ingredients) > 1:
                results = ingredients
                break

    # If no separators found, try to use the full string
    if not results and ingredients_str:
        results = [ingredients_str]

    # Clean up ingredients (remove numbers at start, etc.)
    cleaned_results = []
    for ingredient in results:
        # Remove numbering
        cleaned = re.sub(r'^\d+[\.\)]\s*', '', ingredient).strip()
        if cleaned:
            cleaned_results.append(cleaned)

    return cleaned_results

def create_ar_overlay(img, ocr_result, ingredient_assessments):
    """Create AR overlay on the image highlighting ingredients based on risk levels."""
    # Create a copy of the image for overlay
    img_pil = Image.fromarray(np.array(img))
    overlay = Image.new('RGBA', img_pil.size, (0, 0, 0, 0))
    draw = ImageDraw.Draw(overlay)

    # Risk color mapping
    risk_colors = {
        "HIGH": (255, 0, 0, int(255 * OVERLAY_OPACITY)),        # Red
        "MODERATE": (255, 165, 0, int(255 * OVERLAY_OPACITY)),  # Orange
        "LOW": (255, 255, 0, int(255 * OVERLAY_OPACITY)),       # Yellow
        "SAFE": (0, 255, 0, int(255 * OVERLAY_OPACITY)),        # Green
        "UNKNOWN": (128, 128, 128, int(255 * OVERLAY_OPACITY))  # Gray
    }

    # Font setup for labels
    try:
        font = ImageFont.truetype("arial.ttf", 15)
    except:
        font = ImageFont.load_default()

    # Match detected text with assessed ingredients
    for detection in ocr_result:
        bbox, text, _ = detection
        text_lower = text.lower()

        # Check if this text box contains an ingredient
        matched_ingredient = None
        for ingredient, (risk_level, details) in ingredient_assessments.items():
            # Improved matching algorithm
            ingredient_lower = ingredient.lower()
            # Check if ingredient is a substring of text or vice versa,
            # or if there's substantial overlap (more than 70% match)
            if (ingredient_lower in text_lower or
                text_lower in ingredient_lower or
                (len(set(ingredient_lower.split()) & set(text_lower.split())) /
                 max(len(set(ingredient_lower.split())), len(set(text_lower.split()))) > 0.7)):
                matched_ingredient = ingredient
                risk = risk_level
                break

        if matched_ingredient:
            # Convert points to rectangle
            box_points = np.array(bbox).astype(np.int32)
            x0, y0 = min(box_points[:, 0]), min(box_points[:, 1])
            x1, y1 = max(box_points[:, 0]), max(box_points[:, 1])

            # Apply different highlight methods based on settings
            if HIGHLIGHT_MODE == "Box Highlight":
                # Draw filled rectangle with transparency
                draw.rectangle([x0, y0, x1, y1], fill=risk_colors[risk])

            elif HIGHLIGHT_MODE == "Text Highlight":
                # Draw outline
                draw.rectangle([x0, y0, x1, y1], outline=risk_colors[risk][:3] + (255,), width=2)

                # Add risk label
                label = f"{risk}"
                draw.text((x0, y0-20), label, fill=risk_colors[risk][:3] + (255,), font=font)

            elif HIGHLIGHT_MODE == "Connected Labels":
                # Draw connecting line
                margin = 50
                label_x = x1 + 10
                label_y = (y0 + y1) // 2

                # Draw line
                draw.line([(x1, label_y), (label_x + margin, label_y)],
                          fill=risk_colors[risk][:3] + (255,), width=2)

                # Draw label box
                label_width = 80
                label_height = 40
                draw.rectangle(
                    [label_x + margin, label_y - label_height//2,
                     label_x + margin + label_width, label_y + label_height//2],
                    fill=risk_colors[risk])

                # Add text
                draw.text((label_x + margin + 5, label_y - 7), f"{risk}",
                          fill=(255, 255, 255, 255), font=font)

    # Combine original image with overlay
    result = Image.alpha_composite(img_pil.convert('RGBA'), overlay)
    return result

def generate_risk_summary(ingredient_assessments, overall_risk, recommendations):
    """Generate a pandas DataFrame summarizing risk assessments."""
    summary = pd.DataFrame(columns=["Ingredient", "Risk Level", "Effects", "Recommendations"])

    for ingredient, (risk_level, details) in ingredient_assessments.items():
        effects = "; ".join(details.get("effects", []))
        recs = "; ".join(details.get("recommendations", []))

        summary = pd.concat([summary, pd.DataFrame({
            "Ingredient": [ingredient],
            "Risk Level": [risk_level],
            "Effects": [effects],
            "Recommendations": [recs]
        })], ignore_index=True)

    return summary


# ----------------------- Main Application Logic -----------------------
def main():
    # Load the model and tokenizer
    model, tokenizer = load_model(MODEL_PATH)

    if model is None or tokenizer is None:
        st.error("Failed to load the fine-tuned model. Please check the model path and try again.")
        return

    tabs = st.tabs(["Upload & Analyze", "Results Dashboard", "About"])

    with tabs[0]:
        st.header("Upload Food Label Image")
        uploaded_file = st.file_uploader("Choose an image of a food label...", type=["jpg", "jpeg", "png"])

        if uploaded_file is not None:
            # Process the uploaded image
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded Image", use_column_width=True)

            with st.spinner("Processing image..."):
                # Get OCR results
                ocr_result = extract_text_from_image(image)

                # Extract ingredients list
                ingredients = parse_ingredients(ocr_result)

                # Display extracted ingredients
                st.subheader("Extracted Ingredients")
                for i, ingredient in enumerate(ingredients):
                    st.write(f"{i+1}. {ingredient}")

                # Initialize ingredient assessments
                ingredient_assessments = {}
                progress_bar = st.progress(0)

                for i, ingredient in enumerate(ingredients):
                    # Get risk assessment
                    risk_level, details = analyze_ingredient_risk(model, tokenizer, ingredient, USER_PROFILE)
                    ingredient_assessments[ingredient] = (risk_level, details)

                    # Update progress
                    progress_bar.progress((i + 1) / len(ingredients))
                    time.sleep(0.1)  # Small delay to prevent rate limiting

                # Evaluate overall product risk
                overall_risk, recommendations = evaluate_overall_product_risk(ingredient_assessments, USER_PROFILE)

                # Store results in session state for the dashboard
                st.session_state['ingredient_assessments'] = ingredient_assessments
                st.session_state['overall_risk'] = overall_risk
                st.session_state['recommendations'] = recommendations
                st.session_state['original_image'] = image
                st.session_state['ocr_result'] = ocr_result

                # Create AR overlay
                if SHOW_AR_OVERLAY:
                    ar_image = create_ar_overlay(image, ocr_result, ingredient_assessments)
                    st.session_state['ar_image'] = ar_image

                # Show success message
                st.success("Analysis complete! Check the Results Dashboard for details.")

    with tabs[1]:
        st.header("Analysis Results")

        if 'ingredient_assessments' in st.session_state:
            # Display overall risk rating
            overall_risk = st.session_state['overall_risk']
            recommendations = st.session_state['recommendations']

            # Display risk with appropriate styling
            risk_colors = {
                "HIGH": "#FF0000",
                "MODERATE": "#FFA500",
                "LOW": "#FFFF00",
                "SAFE": "#00FF00",
                "UNKNOWN": "#808080"
            }

            st.markdown(f"""
            <div style="padding: 20px; border-radius: 10px; background-color: {risk_colors.get(overall_risk, '#808080')};">
                <h2 style="color: white; text-align: center;">Overall Risk: {overall_risk}</h2>
            </div>
            """, unsafe_allow_html=True)

            # Display recommendations
            st.subheader("Recommendations")
            for rec in recommendations:
                st.markdown(f"- {rec}")

            # Display AR overlay
            if 'ar_image' in st.session_state:
                st.subheader("AR Health Analysis Overlay")
                st.image(st.session_state['ar_image'], caption="AR Overlay", use_column_width=True)

            # Display detailed ingredient analysis
            st.subheader("Detailed Ingredient Analysis")
            summary_df = generate_risk_summary(st.session_state['ingredient_assessments'], overall_risk, recommendations)
            st.dataframe(summary_df)

            # Add visualization
            st.subheader("Risk Distribution")
            risk_counts = {
                "HIGH": 0,
                "MODERATE": 0,
                "LOW": 0,
                "SAFE": 0,
                "UNKNOWN": 0
            }

            for _, (risk, _) in st.session_state['ingredient_assessments'].items():
                risk_counts[risk] += 1

            # Create pie chart
            fig, ax = plt.subplots()
            labels = [f"{k} ({v})" for k, v in risk_counts.items() if v > 0]
            sizes = [v for v in risk_counts.values() if v > 0]
            colors = [risk_colors[k] for k, v in risk_counts.items() if v > 0]

            if sizes:  # Only create pie chart if there are values
                wedges, texts = ax.pie(sizes, colors=colors, startangle=90, wedgeprops={'alpha': 0.7})
                ax.legend(wedges, labels, title="Risk Levels", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
                st.pyplot(fig)
            else:
                st.write("No risk data available to display.")

            # Export functionality
            st.download_button(
                label="Download Analysis as CSV",
                data=summary_df.to_csv(index=False),
                file_name="food_label_analysis.csv",
                mime="text/csv",
            )
        else:
            st.info("No analysis results yet. Please upload an image in the Upload & Analyze tab.")

    with tabs[2]:
        st.header("About AR Food Label Health Analyzer")
        st.markdown("""
        ### How It Works

        This application uses computer vision and AI to analyze food labels and provide health recommendations based on your specific health profile:

        1. **Upload**: Take a photo of any food product ingredient list
        2. **OCR Processing**: The system extracts text from the image
        3. **AI Analysis**: Each ingredient is evaluated against health criteria using our fine-tuned AI model
        4. **AR Overlay**: Visual indicators show risk levels directly on the image
        5. **Custom Recommendations**: Get personalized advice for your health profile

        ### Health Profiles

        The analyzer supports multiple health profiles:
        - **General**: Focus on additives and ultra-processed ingredients
        - **Diabetic**: Analyze glycemic impact and carbohydrate content
        - **Hypertension**: Evaluate sodium content and blood pressure impact
        - **Celiac**: Identify gluten risks and cross-contamination concerns
        - **Keto**: Assess carbohydrate content for ketogenic diet compatibility
        - **Pregnancy**: Identify ingredients with potential risks during pregnancy

        ### Privacy & Data

        Your images and analysis results are not stored after you close the session. All processing occurs locally within the application.
        """)

if __name__ == "__main__":
    main()


In [None]:
!ngrok config add-authtoken 2wRR3D8n9Xvbexgw5PUnN1tUOiG_2ndhhe77DvezhBHL2Pish

In [None]:
from pyngrok import ngrok
import os
import threading

# Kill previous tunnels
ngrok.kill()

# Set port for Streamlit
port = 8501

# Start Streamlit in a background thread
def run_streamlit():
    os.system(f"streamlit run app.py --server.port {port}")

thread = threading.Thread(target=run_streamlit)
thread.start()

# Wait a bit for Streamlit to initialize
import time
time.sleep(5)

# Create the public tunnel
public_url = ngrok.connect(port)
print(f"✅ Streamlit app URL: {public_url}")
