In [2]:
!pip install pytesseract pdf2image pillow pdfplumber pandas





[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import os
import re
import json
import pdfplumber
from pdf2image import convert_from_path
import pytesseract
from typing import Dict, Any, List


# ---- YOUR PDF PATH ----
PDF_PATH = r"C:\AINutriCare\Data\Raw\Reports\REPORT.pdf"
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'


PARAMETERS = [
    "Glucose",
    "Insulin",
    "Creatinine",
    "Urea (BUN)",
    "Sodium",
    "Potassium",
    "Hemoglobin",
    "WBC",
    "Lactate",
    "pH",
    "Age",
    "Gender",
    "Cholestrol",
    "HbA1c",
]


# Updated patterns to match investigationlabreports.pdf layout[file:45]
# Updated patterns to handle OCR typos (e.g., 'l' vs '1', 'g/dl' vs 'g/d1')
PATTERN_MAP = {
    # Glucose: Matches "78 mg/dl"
    "Glucose": r"GLUCOSE[^\n]*?([0-9.]+)\s*mg/[dl1I]+",

    # Insulin: Not present in this report (Expected: None)
    "Insulin": r"(Insulin)\s+([0-9.]+)",

    # Creatinine: Matches "0.86 # mg/dl"
    "Creatinine": r"CREATININE[^\n]*?([0-9.]+)\s*#?\s*mg/[dl1I]+",

    # Urea: Matches "9.81 mg/dl"
    "Urea (BUN)": r"BUN[^\n]*?([0-9.]+)\s*mg/[dl1I]+",

    # Sodium: Matches "139.0 mmol/l" (Handles mmol/1, mmol/I)
    "Sodium": r"SODIUM[^\n]*?([0-9.]+)\s*mmol/[l1I|]+",

    # Potassium: Matches "4.01 mmol/l"
    "Potassium": r"POTASSIUM[^\n]*?([0-9.]+)\s*mmol/[l1I|]+",

    # Hemoglobin: Matches "15.1 g/dl" (Handles g/d1, g/dl)
    "Hemoglobin": r"[Hh]a?emoglobin[^\n]*?([0-9.]+)\s*g/[dl1I]+",

    # WBC: Matches "8800 /cu.mm" or "8800 /cmm"
    "WBC": r"WBC\s+Count.*?([0-9]+)\s*/(?:cu\.mm|cmm)",

    # Lactate: Not present in this report
    "Lactate": r"(Lactate)\s+([0-9.]+)",
    
    # pH: Not present in this report
    "pH": r"\bpH\b\s*([0-9.]+)",

    # Age: Captures number after "Age :"
    "Age": r"Age\s*:\s*([0-9]{1,3})",
    
    # Gender: Captures text after "Sex :" or "Gender :"
    "Gender": r"(?:Sex|Gender)\s*:\s*([A-Za-z]+)",

    # Cholesterol: Not present in this report
    "Cholestrol": r"(Cholesterol)\s+([0-9.]+)",
    
    # HbA1c: Not present in this report
    "HbA1c": r"HbA1c\s*.*?([0-9.]+)",
}



UNIT_MAP = {
    "Glucose": "mg/dL",
    "Insulin": "¬µIU/mL",
    "Creatinine": "mg/dL",
    "Urea (BUN)": "mg/dL",
    "Sodium": "mmol/L",
    "Potassium": "mmol/L",
    "Hemoglobin": "g/dL",
    "WBC": "/cmm",
    "Lactate": "mmol/L",
    "pH": "",
    "Age": "years",
    "Gender": "",
    "Cholestrol": "mg/dL",
    "HbA1c": "%",
}


def ocr_pdf_to_text(pdf_path: str) -> str:
    print(f"[INFO] Extracting text from: {pdf_path}")
    text_pages: List[str] = []

    # 1) Try direct text
    try:
        with pdfplumber.open(pdf_path) as pdf:
            for page in pdf.pages:
                page_text = page.extract_text() or ""
                text_pages.append(page_text)
    except Exception as e:
        print(f"[WARN] pdfplumber failed: {e}")

    full_text = "\n".join(text_pages).strip()

    # 2) Fallback to OCR if too little text
    if len(full_text) < 200:
        print("[INFO] Direct text small; switching to OCR...")
        images = convert_from_path(pdf_path, dpi=300)
        ocr_pages = []
        for i, img in enumerate(images):
            page_text = pytesseract.image_to_string(img)
            ocr_pages.append(page_text)
            print(f"[INFO] OCR page {i+1}/{len(images)} done.")
        full_text = "\n".join(ocr_pages)

    return full_text


def extract_parameter(text: str, parameter: str) -> Dict[str, Any]:
    pattern = PATTERN_MAP.get(parameter)
    result = {
        "name": parameter,
        "value": None,
        "unit": UNIT_MAP.get(parameter, ""),
        "raw_match": "",
    }
    if not pattern:
        return result

    match = re.search(pattern, text, flags=re.IGNORECASE)
    if not match:
        return result

    value_str = match.groups()[-1]
    result["value"] = value_str.strip()
    result["raw_match"] = match.group(0).strip()
    return result

    # Always use the last captured group as the numeric/string value
    value_str = match.groups()[-1]
    result["value"] = value_str.strip()
    result["raw_match"] = match.group(0).strip()
    return result


def extract_all_parameters(text: str) -> Dict[str, Dict[str, Any]]:
    data = {}
    for param in PARAMETERS:
        data[param] = extract_parameter(text, param)
    return data


def save_to_json(data: Dict[str, Dict[str, Any]], out_json: str):
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)
    print(f"[INFO] Saved vitals JSON to: {out_json}")


def run_extraction():
    out_json = "patient_vitals.json"
    full_text = ocr_pdf_to_text(PDF_PATH)

    extracted = extract_all_parameters(full_text)

    print("\n=== Extracted Parameters ===")
    for k, v in extracted.items():
        print(f"{k:12s}: {v.get('value')} {v.get('unit')}")

    save_to_json(extracted, out_json)
    return extracted


# In Jupyter, just run this cell:
if __name__ == "__main__":
    extracted_params = run_extraction()


[INFO] Extracting text from: C:\AINutriCare\Data\Raw\Reports\REPORT.pdf

=== Extracted Parameters ===
Glucose     : 157.07 mg/dL
Insulin     : None ¬µIU/mL
Creatinine  : 0.83 mg/dL
Urea (BUN)  : None mg/dL
Sodium      : 143.00 mmol/L
Potassium   : 4.90 mmol/L
Hemoglobin  : 14.5 g/dL
WBC         : 10570 /cmm
Lactate     : None mmol/L
pH          : 6.0 
Age         : None years
Gender      : Tube 
Cholestrol  : 189.0 mg/dL
HbA1c       : 7.10 %
[INFO] Saved vitals JSON to: patient_vitals.json


In [6]:
import numpy as np
import json
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Layer
import os

# ==========================================
# 1. Configuration & Paths
# ==========================================
# Ensure these paths match your local setup
JSON_INPUT = "patient_vitals.json"
MODEL_PATH = r"C:\AINutriCare\Notebooks\Milestone_2\LSTM\attention_lstm.h5"
SCALER_PATH = r"C:\AINutriCare\Data\Transformed\X_final.npy"

# The 17 Features the Model was trained on (Must be in this exact order)
MODEL_FEATURES = [
    "Heart Rate", "MAP", "Respiratory Rate", "Temperature", 
    "Glucose", "Creatinine", "BUN", "Sodium", "Potassium", "Hemoglobin", "WBC", "Lactate",
    "Fluid Balance", "Vasopressors", "Sedatives", "Antibiotics", "Insulin"
]

# Defaults for values NOT in the PDF (Assumes resting/stable state for missing vitals)
DEFAULTS = {
    'Heart Rate': 75, 'MAP': 90, 'Respiratory Rate': 16, 'Temperature': 98.4,
    'Lactate': 1.0, 'Fluid Balance': 0, 'Vasopressors': 0, 'Sedatives': 0, 
    'Antibiotics': 0, 'Insulin': 0
}

# ==========================================
# 2. Define Custom Layer (Required to load model)
# ==========================================
@tf.keras.utils.register_keras_serializable()
class SimpleAttention(Layer):
    def __init__(self, units=64, **kwargs):
        super(SimpleAttention, self).__init__(**kwargs)
        self.units = units
    def get_config(self):
        config = super(SimpleAttention, self).get_config()
        config.update({"units": self.units})
        return config
    def build(self, input_shape):
        self.W1 = self.add_weight(name='att_w1', shape=(input_shape[-1], self.units), initializer='glorot_uniform')
        self.W2 = self.add_weight(name='att_w2', shape=(self.units, 1), initializer='glorot_uniform')
        self.b1 = self.add_weight(name='att_b1', shape=(self.units,), initializer='zeros')
        super(SimpleAttention, self).build(input_shape)
    def call(self, x):
        h = tf.nn.tanh(tf.matmul(x, self.W1) + self.b1)
        e = tf.squeeze(tf.matmul(h, self.W2), -1)
        alpha = tf.nn.softmax(e)
        context = x * tf.expand_dims(alpha, -1)
        context = tf.reduce_sum(context, axis=1)
        return context, alpha

# ==========================================
# 3. Load Resources
# ==========================================
def load_ai_resources():
    print("Loading AI Model & Scaler...")
    try:
        # 1. Load Model
        if not os.path.exists(MODEL_PATH):
            raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
        model = load_model(MODEL_PATH, custom_objects={'SimpleAttention': SimpleAttention})
        
        # 2. Load Scaler Statistics (Mean/Std from Training)
        if os.path.exists(SCALER_PATH):
            X_ref = np.load(SCALER_PATH)
            X_flat = X_ref.reshape(-1, X_ref.shape[2])
            means = np.mean(X_flat, axis=0)
            stds = np.std(X_flat, axis=0)
            stds[stds == 0] = 1.0 # Prevent divide by zero
        else:
            print("‚ö†Ô∏è Scaler file (X_final.npy) not found. Using raw unscaled values (Results may be inaccurate).")
            means = np.zeros(17)
            stds = np.ones(17)
            
        print("‚úÖ AI Resources Loaded Successfully.")
        return model, means, stds
        
    except Exception as e:
        print(f"‚ùå Error Loading AI Resources: {e}")
        return None, None, None

# ==========================================
# 4. Data Processing (JSON -> Tensor)
# ==========================================
def preprocess_patient_data(json_file, means, stds):
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # helper to safely get float values
    def get_val(key):
        item = data.get(key)
        if isinstance(item, dict):
            val = item.get('value')
        else:
            val = item
            
        if val in [None, "N/A", "Not Found"]: return None
        try:
            # Clean string like "H 141.0" -> 141.0
            clean = str(val).replace('H', '').replace('L', '').strip()
            return float(clean)
        except:
            return None

    # 1. Build Feature Vector (17,)
    vector = []
    extracted_values = {} # For reporting
    
    for feature in MODEL_FEATURES:
        # Map Model Feature Names to JSON Keys
        json_key = feature
        if feature == "Cholesterol": json_key = "Cholestrol" # Handle typo in json if present
        
        val = get_val(json_key)
        
        # Fallback Logic
        if val is None:
            val = DEFAULTS.get(feature, 0)
        
        # Scaling fixes (e.g. WBC 10570 -> 10.57)
        if feature == "WBC" and val > 1000:
            val = val / 1000.0
            
        vector.append(val)
        extracted_values[feature] = val

    # 2. Create Time Series (Repeat static data for 24 hours)
    # Shape: (1, 24, 17)
    patient_matrix = np.tile(vector, (24, 1))
    
    # 3. Normalize
    normalized_matrix = (patient_matrix - means) / stds
    input_tensor = normalized_matrix.reshape(1, 24, 17)
    
    return input_tensor, extracted_values, data

# ==========================================
# 5. Diet Logic (Post-Prediction)
# ==========================================
def generate_clinical_report(risk_score, vitals, raw_json):
    print(f"\n{'='*60}")
    print(f" AI CLINICAL ANALYSIS REPORT")
    print(f"{'='*60}")
    
    # --- 1. Risk Interpretation ---
    if risk_score > 0.60:
        status = "HIGH RISK (Critical)"
        action = "Immediate Metabolic Intervention"
    elif risk_score > 0.30:
        status = "MODERATE RISK"
        action = "Dietary Management & Monitoring"
    else:
        status = "STABLE"
        action = "Routine Maintenance"
        
    print(f"\n[1] MODEL PREDICTION")
    print(f"    Mortality/ICU Risk: {risk_score:.2%}")
    print(f"    Clinical Status:    {status}")
    print(f"    Recommended Action: {action}")

    # --- 2. Key Drivers ---
    print(f"\n[2] BIOMARKER ANALYSIS")
    
    # Check Diabetes
    gluc = vitals['Glucose']
    hba1c_val = raw_json.get('HbA1c', {}).get('value', 'N/A')
    print(f"    - Glucose: {gluc} mg/dL", end="")
    if gluc > 140: print(" (HIGH - Driver for Risk)")
    else: print(" (Normal)")
    
    print(f"    - HbA1c:   {hba1c_val} %", end="")
    try:
        if float(str(hba1c_val).replace('H','')) > 6.5: print(" (DIABETIC RANGE)")
        else: print("")
    except: print("")

    # Check Renal
    creat = vitals['Creatinine']
    print(f"    - Creatinine: {creat} mg/dL", end="")
    if creat > 1.2: print(" (RENAL STRESS)")
    else: print(" (Normal)")

    # --- 3. Diet Plan ---
    print(f"\n[3] AI-GENERATED NUTRITION PLAN")
    
    if gluc > 126 or (hba1c_val != 'N/A' and float(str(hba1c_val).replace('H','')) > 6.5):
        print("    Protocol: DIABETIC / LOW-GLYCEMIC INDEX")
        print("    - Carbohydrates: Restricted to 40% of total calories.")
        print("    - Focus: Complex carbs only (Fiber > 30g/day).")
        print("    - Avoid: Fruit juices, white bread, processed sugars.")
    elif risk_score > 0.5:
        print("    Protocol: CRITICAL CARE SUPPORT (High Protein)")
        print("    - Focus: Preventing muscle wasting (Catabolism).")
    else:
        print("    Protocol: STANDARD BALANCED DIET")
        print("    - Maintain current nutritional intake.")

    print("-" * 60)

def generate_clinical_report(risk_score, vitals, raw_json):
    """
    Analyzes prediction & vitals to create a JSON for the LLM.
    """
    print(f"\nProcessing AI Clinical Analysis...")

    # --- 1. Initialize Structure for LLM ---
    llm_context = {
        "patient_metrics": {
            "mortality_risk": float(risk_score),  # JSON needs native float, not numpy
            "glucose": float(vitals['Glucose']),
            "creatinine": float(vitals['Creatinine'])
        },
        "conditions": [],
        "avoid": [],
        "recommend": [],
        "summary": ""
    }

    # --- 2. Risk Interpretation ---
    if risk_score > 0.60:
        llm_context['conditions'].append("Critical Stability Risk")
        llm_context['summary'] = "Patient is at HIGH RISK. Immediate metabolic intervention required."
    elif risk_score > 0.30:
        llm_context['conditions'].append("Moderate Clinical Risk")
        llm_context['summary'] = "Patient requires dietary management and monitoring."
    else:
        llm_context['summary'] = "Patient is stable. Routine maintenance diet recommended."

    # --- 3. Biomarker Analysis (Logic -> Rules) ---
    
    # Check Diabetes / Glucose
    glucose = vitals['Glucose']
    hba1c_val = raw_json.get('HbA1c', {}).get('value', 'N/A')
    
    is_diabetic = False
    if glucose > 126:
        is_diabetic = True
    # Handle H141 type strings if present in raw json
    if hba1c_val != 'N/A':
        try:
            val = float(str(hba1c_val).replace('H','').replace('L',''))
            if val > 6.5: is_diabetic = True
        except: pass

    if is_diabetic:
        llm_context['conditions'].append("Diabetes (Type 2 / Hyperglycemia)")
        llm_context['avoid'].extend(["Fruit juices", "White bread", "Processed sugars", "High-GI foods"])
        llm_context['recommend'].extend(["Complex carbohydrates", "High fiber foods (>30g/day)", "Leafy greens"])

    # Check Renal (Kidneys)
    creatinine = vitals['Creatinine']
    if creatinine > 1.2:
        llm_context['conditions'].append("Renal Stress / Kidney Strain")
        llm_context['avoid'].extend(["High sodium foods", "Excessive red meat", "Processed deli meats"])
        llm_context['recommend'].extend(["Low-potassium vegetables", "Cauliflower", "Berries"])

    # Check Hypertension (using MAP as proxy if BP not split)
    # MAP > 100 often correlates with high BP
    if vitals['MAP'] > 100: 
        llm_context['conditions'].append("Hypertension Risk")
        llm_context['avoid'].append("Salt/Sodium")
        llm_context['recommend'].append("DASH diet principles")

    # If no specific conditions found, add general healthy advice
    if not llm_context['conditions']:
        llm_context['conditions'].append("General Health Maintenance")
        llm_context['recommend'].append("Balanced diet with lean proteins and vegetables")

    return llm_context

# ==========================================
# 6. Main Execution Loop
# ==========================================
if __name__ == "__main__":
    # 1. Load Model
    model, means, stds = load_ai_resources()
    
    if model:
        # 2. Process Data
        if os.path.exists(JSON_INPUT):
            input_tensor, vitals_dict, raw_data = preprocess_patient_data(JSON_INPUT, means, stds)
            
            # 3. Predict
            print("Running LSTM Prediction...")
            prediction = model.predict(input_tensor, verbose=0)[0][0]
            
            # 4. Generate & Save Report
            ai_output = generate_clinical_report(prediction, vitals_dict, raw_data)
            
            # Output Filename
            OUTPUT_FILE = "clinical_output.json"
            
            with open(OUTPUT_FILE, 'w') as f:
                json.dump(ai_output, f, indent=4)
                
            print(f"‚úÖ Success! Analysis saved to: {OUTPUT_FILE}")
            print(json.dumps(ai_output, indent=2)) # Print preview
            
        else:
            print(f"‚ùå Error: {JSON_INPUT} not found. Run the extraction step first.")

  if not hasattr(np, "object"):


Loading AI Model & Scaler...





‚úÖ AI Resources Loaded Successfully.
Running LSTM Prediction...

Processing AI Clinical Analysis...
‚úÖ Success! Analysis saved to: clinical_output.json
{
  "patient_metrics": {
    "mortality_risk": 0.40944451093673706,
    "glucose": 157.07,
    "creatinine": 0.83
  },
  "conditions": [
    "Moderate Clinical Risk",
    "Diabetes (Type 2 / Hyperglycemia)"
  ],
  "avoid": [
    "Fruit juices",
    "White bread",
    "Processed sugars",
    "High-GI foods"
  ],
  "recommend": [
    "Complex carbohydrates",
    "High fiber foods (>30g/day)",
    "Leafy greens"
  ],
  "summary": "Patient requires dietary management and monitoring."
}


In [10]:
!pip install google-genai

Collecting google-genai
  Downloading google_genai-1.58.0-py3-none-any.whl.metadata (53 kB)
Collecting tenacity<9.2.0,>=8.2.3 (from google-genai)
  Downloading tenacity-9.1.2-py3-none-any.whl.metadata (1.2 kB)
Collecting websockets<15.1.0,>=13.0.0 (from google-genai)
  Using cached websockets-15.0.1-cp313-cp313-win_amd64.whl.metadata (7.0 kB)
Collecting distro<2,>=1.7.0 (from google-genai)
  Using cached distro-1.9.0-py3-none-any.whl.metadata (6.8 kB)
Collecting sniffio (from google-genai)
  Using cached sniffio-1.3.1-py3-none-any.whl.metadata (3.9 kB)
Downloading google_genai-1.58.0-py3-none-any.whl (718 kB)
   ---------------------------------------- 0.0/718.4 kB ? eta -:--:--
   ----------------------------- ---------- 524.3/718.4 kB 3.0 MB/s eta 0:00:01
   ---------------------------------------- 718.4/718.4 kB 3.0 MB/s  0:00:00
Using cached distro-1.9.0-py3-none-any.whl (20 kB)
Downloading tenacity-9.1.2-py3-none-any.whl (28 kB)
Using cached websockets-15.0.1-cp313-cp313-win_amd64


[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


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
AIzaSyD_lIMzf-pDJhSOZoEo8o5PtZUik6Yit6c

In [16]:
import json
import pandas as pd

TEST_CSV = "Diet_AI_Model_Test_Cases.csv"
AGG_OUTPUT = "diet_plans_by_testcase.json"

def run_tests_and_collect():
    df_cases = pd.read_csv(TEST_CSV)
    food_df = load_and_tag_data(FOOD_KB_FILE)

    results = []
    all_plans = {}  # <- aggregated per-test-case

    for _, row in df_cases.iterrows():
        tc_id = row["Test Case ID"]

        clinical_insights = build_clinical_from_row(row)
        candidates = get_smart_candidates(food_df, clinical_insights)
        patient = {"name": tc_id, "age": 50}

        plan = generate_structured_plan(patient, clinical_insights, candidates)
        if "error" in plan:
            results.append((tc_id, "FAIL", f"Generator error: {plan['error']}"))
            continue

        # store only the foods (you can store full plan if you prefer)
        all_plans[tc_id] = {
            "clinical_scenario": row["Clinical Scenario"],
            "input_labs": row["Input Labs"],
            "doctor_notes": row["Doctor Notes"],
            "must_not_include": row["Must NOT Include"],
            "day_plan": plan.get("day_plan", {}),
            "total_nutrition": plan.get("total_nutrition", {}),
        }

        ok, reason = check_forbidden(plan, row["Must NOT Include"])
        if not ok:
            results.append((tc_id, "FAIL", reason))
        else:
            results.append((tc_id, "PASS", row["Expected Outcome"]))

    # print summary
    for tc_id, status, msg in results:
        print(f"{tc_id}: {status} - {msg}")

    # save aggregated JSON
    with open(AGG_OUTPUT, "w") as f:
        json.dump(all_plans, f, indent=2)
    print(f"\nSaved per-test-case diet plans to {AGG_OUTPUT}")


if __name__ == "__main__":
    run_tests_and_collect()



ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
ü§ñ AI is structuring the JSON plan...
TC-01: PASS - Glucose reduces, risk lowered
TC-02: PASS - Creatinine stabilizes
TC-03: PASS - Sodium decreases
TC-04: PASS - Overall risk reduced
TC-05: PASS - BP improves
TC-06: PASS - Hydration maintained
TC-07: PASS - Stable health
TC-08: PASS - Safe diet generated
TC-09: PASS - Consistent logic
TC-10: PASS - Pipeline stable

Saved per-test-case diet plans to diet_plans_by_testcase.json


In [15]:
import json
import os
import pandas as pd

# ==========================================
# 1. CONFIGURATION (GOOGLE GEMINI)
# ==========================================
from google import genai
from google.genai import types

GEMINI_API_KEY = "AIzaSyD_lIMzf-pDJhSOZoEo8o5PtZUik6Yit6c"  # or set env var GEMINI_API_KEY
MODEL_ID = "gemini-2.5-flash"  # choose your Gemini model

# File Paths
CLINICAL_INPUT = "clinical_output.json"
FOOD_KB_FILE = "diet_kb.json"

client = genai.Client(api_key=GEMINI_API_KEY)


# ==========================================
# 2. AUTO-TAGGING & FILTERING SYSTEM
# ==========================================
def load_and_tag_data(filepath):
    """
    Loads diet_kb.json and adds medical tags based on macros.
    """
    if not os.path.exists(filepath):
        return pd.DataFrame()  # Return empty if missing

    with open(filepath, "r") as f:
        data = json.load(f)

    df = pd.DataFrame(data)

    # --- AUTO-TAGGING LOGIC ---
    def get_tags(row):
        tags = []
        # Diabetic Friendly: Low Carb (<30g) OR Low Sugar (ingredients check)
        if row["Carbohydrate (g)"] < 30 and "sugar" not in str(row["ingredients"]).lower():
            tags.append("diabetic_friendly")
            tags.append("low_sugar")

        # High Protein: > 10g
        if row["Protein (g)"] > 10:
            tags.append("high_protein")

        # Low Fat: < 8g
        if row["Total Fat (g)"] < 8:
            tags.append("low_fat")

        # Renal Safe (Simplified): Moderate Protein (5‚Äì15g)
        if 5 < row["Protein (g)"] < 15:
            tags.append("renal_safe")

        return tags

    df["medical_tags"] = df.apply(get_tags, axis=1)
    return df


def get_smart_candidates(df, clinical_insights):
    """
    Filters the tagged dataframe to find the best candidates for the patient.
    """
    if df.empty:
        return {}

    # 1. Parse Constraints
    conditions = " ".join(clinical_insights.get("conditions", [])).lower()
    avoid = " ".join(clinical_insights.get("avoid", [])).lower()

    # 2. Apply Filters
    candidates = df.copy()

    # DIABETES FILTER
    if "diabetes" in conditions or "sugar" in avoid:
        candidates = candidates[candidates["medical_tags"].apply(lambda x: "diabetic_friendly" in x)]

    # RENAL FILTER
    if "renal" in conditions or "kidney" in conditions:
        candidates = candidates[candidates["medical_tags"].apply(lambda x: "renal_safe" in x)]

    # 3. Categorize
    breakfast_keywords = "idli|dosa|upma|poha|paratha|oats|porridge"
    breakfast_df = candidates[candidates["name"].str.contains(breakfast_keywords, case=False, na=False)]
    if len(breakfast_df) < 2:
        breakfast_df = candidates.sample(n=min(5, len(candidates)))

    # Lunch/Dinner: high calorie mains
    mains_df = candidates[candidates["Energy (kcal)"] > 150]

    # Snacks: low calorie
    snacks_df = candidates[candidates["Energy (kcal)"] < 150]

    def serialize(sub_df, count=5):
        return sub_df.sample(n=min(count, len(sub_df))).to_dict(orient="records")

    return {
        "breakfast": serialize(breakfast_df),
        "lunch": serialize(mains_df),
        "dinner": serialize(mains_df),
        "snacks": serialize(snacks_df),
    }


# ==========================================
# 3. LLM GENERATOR (Structure Enforcer ‚Äì GEMINI)
# ==========================================
def generate_structured_plan(patient_profile, clinical_data, food_candidates):
    summary = clinical_data.get("summary", "Healthy Diet")
    options_preview = json.dumps(food_candidates, indent=2)

    prompt = f"""
    You are an AI Clinical Dietitian.

    PATIENT: {patient_profile['name']}
    CONDITION: {summary}

    TASK:
    Select items from the PROVIDED CANDIDATE LIST below to create a 1-day meal plan.
    You must output the result in STRICT JSON format matching the user's required schema.

    CANDIDATE FOODS (Pick from these):
    {options_preview}

    REQUIRED OUTPUT FORMAT (JSON):
    {{
      "day_plan": {{
          "breakfast": [
              {{ "item": "Name", "calories": 100, "protein": 5, "fat": 2, "carbs": 10, "tags": ["tag1", "tag2"] }}
          ],
          "lunch": [],
          "dinner": [],
          "snacks": []
      }},
      "total_nutrition": {{
          "calories": 0,
          "protein": 0,
          "carbs": 0,
          "fat": 0
      }},
      "medical_reasoning": "Brief explanation..."
    }}

    RULES:
    1. Use the EXACT nutritional values from the candidate list. Do not invent numbers.
    2. Include the 'medical_tags' provided in the candidate list.
    3. Calculate the 'total_nutrition' sum correctly.
    4. Output JSON ONLY. No text before or after.
    """

    # JSON schema ‚Äì fully specified for Gemini
    schema = {
        "type": "object",
        "properties": {
            "day_plan": {
                "type": "object",
                "properties": {
                    "breakfast": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": [
                                "item",
                                "calories",
                                "protein",
                                "fat",
                                "carbs",
                                "tags",
                            ],
                        },
                    },
                    "lunch": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": [
                                "item",
                                "calories",
                                "protein",
                                "fat",
                                "carbs",
                                "tags",
                            ],
                        },
                    },
                    "dinner": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": [
                                "item",
                                "calories",
                                "protein",
                                "fat",
                                "carbs",
                                "tags",
                            ],
                        },
                    },
                    "snacks": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": [
                                "item",
                                "calories",
                                "protein",
                                "fat",
                                "carbs",
                                "tags",
                            ],
                        },
                    },
                },
                "required": ["breakfast", "lunch", "dinner", "snacks"],
            },
            "total_nutrition": {
                "type": "object",
                "properties": {
                    "calories": {"type": "number"},
                    "protein": {"type": "number"},
                    "fat": {"type": "number"},
                    "carbs": {"type": "number"},
                },
                "required": ["calories", "protein", "fat", "carbs"],
            },
            "medical_reasoning": {"type": "string"},
        },
        "required": ["day_plan", "total_nutrition", "medical_reasoning"],
    }

    try:
        print("ü§ñ AI is structuring the JSON plan...")
        resp = client.models.generate_content(
            model=MODEL_ID,
            contents=prompt,
            config=types.GenerateContentConfig(
                response_mime_type="application/json",
                response_schema=schema,
                temperature=0.1,
            ),
        )

        # In JSON mode, resp.parsed is already a Python dict
        if hasattr(resp, "parsed") and resp.parsed is not None:
            return resp.parsed

        raw_text = resp.text.strip()
        return json.loads(raw_text)

    except Exception as e:
        raw_text = ""
        try:
            raw_text = resp.text  # may exist if request reached the model
        except:
            pass
        return {"error": str(e), "raw_output": raw_text}


# ==========================================
# 4. ROUNDING HELPER
# ==========================================
def round_plan(plan, ndigits=0):
    for meal in ["breakfast", "lunch", "dinner", "snacks"]:
        for dish in plan.get("day_plan", {}).get(meal, []):
            for key in ["calories", "protein", "fat", "carbs"]:
                val = dish.get(key)
                if isinstance(val, (int, float)):
                    dish[key] = round(val, ndigits)

    tn = plan.get("total_nutrition", {})
    for key in ["calories", "protein", "fat", "carbs"]:
        val = tn.get(key)
        if isinstance(val, (int, float)):
            tn[key] = round(val, ndigits)
    return plan


# ==========================================
# 5. MAIN EXECUTION
# ==========================================
if __name__ == "__main__":

    print("Loading Food DB...")
    df = load_and_tag_data(FOOD_KB_FILE)

    # Clinical insights
    if os.path.exists(CLINICAL_INPUT):
        with open(CLINICAL_INPUT, "r") as f:
            clinical_insights = json.load(f)
    else:
        clinical_insights = {
            "conditions": ["Diabetes"],
            "summary": "Patient requires low-sugar, low-carb diet.",
        }

    candidates = get_smart_candidates(df, clinical_insights)

    patient = {"name": "Rajesh Kumar", "age": 45}
    final_json = generate_structured_plan(patient, clinical_insights, candidates)
    final_json = round_plan(final_json, ndigits=0)

    if "error" not in final_json:
        print("\n‚úÖ GENERATED JSON:")
        print(json.dumps(final_json, indent=2))

        with open("final_structured_diet.json", "w") as f:
            json.dump(final_json, f, indent=2)
            print("\nSaved to 'final_structured_diet.json'")
    else:
        print("‚ùå Error:", final_json["error"])


Loading Food DB...
ü§ñ AI is structuring the JSON plan...

‚úÖ GENERATED JSON:
{
  "day_plan": {
    "breakfast": [
      {
        "item": "Khandvi",
        "calories": 55.0,
        "protein": 13.0,
        "fat": 17.0,
        "carbs": 30.0,
        "tags": [
          "diabetic_friendly",
          "low_sugar",
          "high_protein",
          "renal_safe"
        ]
      }
    ],
    "lunch": [
      {
        "item": "Galho",
        "calories": 155.0,
        "protein": 13.0,
        "fat": 11.0,
        "carbs": 22.0,
        "tags": [
          "diabetic_friendly",
          "low_sugar",
          "high_protein",
          "renal_safe"
        ]
      }
    ],
    "dinner": [
      {
        "item": "Kolim Jawla",
        "calories": 215.0,
        "protein": 11.0,
        "fat": 29.0,
        "carbs": 12.0,
        "tags": [
          "diabetic_friendly",
          "low_sugar",
          "high_protein",
          "renal_safe"
        ]
      }
    ],
    "snacks": [
    

In [None]:
# main.py
import os
import io
import json
from typing import Dict, Any

from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

import pdfplumber
from pdf2image import convert_from_path
import pytesseract
import re

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Layer

import pandas as pd
from google import genai
from google.genai import types


# =========================
# 0. CONFIG
# =========================
# Paths (update to match your setup)
MODEL_PATH = r"C:\AINutriCare\Notebooks\Milestone_2\LSTM\attention_lstm.h5"
SCALER_PATH = r"C:\AINutriCare\Data\Transformed\X_final.npy"
FOOD_KB_FILE = "diet_kb.json"

GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "AIzaSyD_lIMzf-pDJhSOZoEo8o5PtZUik6Yit6c")
GEMINI_MODEL_ID = "gemini-2.5-flash"

# Tesseract path
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"

client = genai.Client(api_key=GEMINI_API_KEY)

# ========= FastAPI app =========
app = FastAPI(title="AI-NutriCare API", version="0.1.0")

# CORS for React dev (adjust origin as needed)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# =========================
# 1. OCR + PARAMETER EXTRACTION
# =========================
PARAMETERS = [
    "Glucose", "Insulin", "Creatinine", "Urea (BUN)",
    "Sodium", "Potassium", "Hemoglobin", "WBC",
    "Lactate", "pH", "Age", "Gender", "Cholestrol", "HbA1c",
]

PATTERN_MAP = {
    "Glucose": r"GLUCOSE[^\n]*?([0-9.]+)\s*mg/[dl1I]+",
    "Insulin": r"(Insulin)\s+([0-9.]+)",
    "Creatinine": r"CREATININE[^\n]*?([0-9.]+)\s*#?\s*mg/[dl1I]+",
    "Urea (BUN)": r"BUN[^\n]*?([0-9.]+)\s*mg/[dl1I]+",
    "Sodium": r"SODIUM[^\n]*?([0-9.]+)\s*mmol/[l1I|]+",
    "Potassium": r"POTASSIUM[^\n]*?([0-9.]+)\s*mmol/[l1I|]+",
    "Hemoglobin": r"[Hh]a?emoglobin[^\n]*?([0-9.]+)\s*g/[dl1I]+",
    "WBC": r"WBC\s+Count.*?([0-9]+)\s*/(?:cu\.mm|cmm)",
    "Lactate": r"(Lactate)\s+([0-9.]+)",
    "pH": r"\bpH\b\s*([0-9.]+)",
    "Age": r"Age\s*:\s*([0-9]{1,3})",
    "Gender": r"(?:Sex|Gender)\s*:\s*([A-Za-z]+)",
    "Cholestrol": r"(Cholesterol)\s+([0-9.]+)",
    "HbA1c": r"HbA1c\s*.*?([0-9.]+)",
}

UNIT_MAP = {
    "Glucose": "mg/dL",
    "Insulin": "¬µIU/mL",
    "Creatinine": "mg/dL",
    "Urea (BUN)": "mg/dL",
    "Sodium": "mmol/L",
    "Potassium": "mmol/L",
    "Hemoglobin": "g/dL",
    "WBC": "/cmm",
    "Lactate": "mmol/L",
    "pH": "",
    "Age": "years",
    "Gender": "",
    "Cholestrol": "mg/dL",
    "HbA1c": "%",
}


def ocr_pdf_bytes_to_text(pdf_bytes: bytes) -> str:
    # Use pdfplumber on bytes
    text_pages = []
    try:
        with pdfplumber.open(io.BytesIO(pdf_bytes)) as pdf:
            for page in pdf.pages:
                page_text = page.extract_text() or ""
                text_pages.append(page_text)
    except Exception:
        pass

    full_text = "\n".join(text_pages).strip()

    # If direct text is too small, fallback to OCR
    if len(full_text) < 200:
        images = convert_from_path(io.BytesIO(pdf_bytes), dpi=300)
        ocr_pages = []
        for img in images:
            page_text = pytesseract.image_to_string(img)
            ocr_pages.append(page_text)
        full_text = "\n".join(ocr_pages)

    return full_text


def extract_parameter(text: str, parameter: str) -> Dict[str, Any]:
    pattern = PATTERN_MAP.get(parameter)
    result = {
        "name": parameter,
        "value": None,
        "unit": UNIT_MAP.get(parameter, ""),
        "raw_match": "",
    }
    if not pattern:
        return result

    match = re.search(pattern, text, flags=re.IGNORECASE)
    if not match:
        return result

    value_str = match.groups()[-1]
    result["value"] = value_str.strip()
    result["raw_match"] = match.group(0).strip()
    return result


def extract_all_parameters(text: str) -> Dict[str, Dict[str, Any]]:
    data = {}
    for param in PARAMETERS:
        data[param] = extract_parameter(text, param)
    return data


# =========================
# 2. LSTM MODEL + CLINICAL JSON
# =========================
MODEL_FEATURES = [
    "Heart Rate", "MAP", "Respiratory Rate", "Temperature",
    "Glucose", "Creatinine", "BUN", "Sodium", "Potassium",
    "Hemoglobin", "WBC", "Lactate",
    "Fluid Balance", "Vasopressors", "Sedatives", "Antibiotics", "Insulin",
]

DEFAULTS = {
    "Heart Rate": 75,
    "MAP": 90,
    "Respiratory Rate": 16,
    "Temperature": 98.4,
    "Lactate": 1.0,
    "Fluid Balance": 0,
    "Vasopressors": 0,
    "Sedatives": 0,
    "Antibiotics": 0,
    "Insulin": 0,
}


@tf.keras.utils.register_keras_serializable()
class SimpleAttention(Layer):
    def __init__(self, units=64, **kwargs):
        super(SimpleAttention, self).__init__(**kwargs)
        self.units = units

    def get_config(self):
        config = super(SimpleAttention, self).get_config()
        config.update({"units": self.units})
        return config

    def build(self, input_shape):
        self.W1 = self.add_weight(
            name="att_w1",
            shape=(input_shape[-1], self.units),
            initializer="glorot_uniform",
        )
        self.W2 = self.add_weight(
            name="att_w2",
            shape=(self.units, 1),
            initializer="glorot_uniform",
        )
        self.b1 = self.add_weight(
            name="att_b1",
            shape=(self.units,),
            initializer="zeros",
        )
        super(SimpleAttention, self).build(input_shape)

    def call(self, x):
        h = tf.nn.tanh(tf.matmul(x, self.W1) + self.b1)
        e = tf.squeeze(tf.matmul(h, self.W2), -1)
        alpha = tf.nn.softmax(e)
        context = x * tf.expand_dims(alpha, -1)
        context = tf.reduce_sum(context, axis=1)
        return context, alpha


def load_ai_resources():
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
    model = load_model(MODEL_PATH, custom_objects={"SimpleAttention": SimpleAttention})

    if os.path.exists(SCALER_PATH):
        X_ref = np.load(SCALER_PATH)
        X_flat = X_ref.reshape(-1, X_ref.shape[2])
        means = np.mean(X_flat, axis=0)
        stds = np.std(X_flat, axis=0)
        stds[stds == 0] = 1.0
    else:
        means = np.zeros(len(MODEL_FEATURES))
        stds = np.ones(len(MODEL_FEATURES))
    return model, means, stds


MODEL, MEANS, STDS = load_ai_resources()


def preprocess_patient_data(extracted_params: Dict[str, Any]):
    # extracted_params is output of extract_all_parameters
    def get_val_from_params(key: str):
        item = extracted_params.get(key, {})
        val = item.get("value")
        if val in [None, "N/A", "Not Found"]:
            return None
        try:
            clean = str(val).replace("H", "").replace("L", "").strip()
            return float(clean)
        except:
            return None

    vector = []
    vitals_for_report = {}

    for feature in MODEL_FEATURES:
        json_key = feature
        if feature == "BUN":
            json_key = "Urea (BUN)"
        if feature == "Cholesterol":
            json_key = "Cholestrol"

        val = get_val_from_params(json_key)
        if val is None:
            val = DEFAULTS.get(feature, 0)

        if feature == "WBC" and val > 1000:
            val = val / 1000.0

        vector.append(val)
        vitals_for_report[feature] = val

    patient_matrix = np.tile(vector, (24, 1))
    normalized_matrix = (patient_matrix - MEANS) / STDS
    input_tensor = normalized_matrix.reshape(1, 24, len(MODEL_FEATURES))

    return input_tensor, vitals_for_report


def build_clinical_json(risk_score: float, vitals: Dict[str, float], raw_params: Dict[str, Any]):
    llm_context = {
        "patient_metrics": {
            "mortality_risk": float(risk_score),
            "glucose": float(vitals.get("Glucose", 0.0)),
            "creatinine": float(vitals.get("Creatinine", 0.0)),
        },
        "conditions": [],
        "avoid": [],
        "recommend": [],
        "summary": "",
    }

    if risk_score > 0.60:
        llm_context["conditions"].append("Critical Stability Risk")
        llm_context["summary"] = "Patient is at HIGH RISK. Immediate metabolic intervention required."
    elif risk_score > 0.30:
        llm_context["conditions"].append("Moderate Clinical Risk")
        llm_context["summary"] = "Patient requires dietary management and monitoring."
    else:
        llm_context["summary"] = "Patient is stable. Routine maintenance diet recommended."

    glucose = vitals.get("Glucose", 0.0)
    hba1c_val = raw_params.get("HbA1c", {}).get("value", "N/A")

    is_diabetic = False
    if glucose > 126:
        is_diabetic = True
    if hba1c_val != "N/A":
        try:
            val = float(str(hba1c_val).replace("H", "").replace("L", ""))
            if val > 6.5:
                is_diabetic = True
        except:
            pass

    if is_diabetic:
        llm_context["conditions"].append("Diabetes (Type 2 / Hyperglycemia)")
        llm_context["avoid"].extend(["Fruit juices", "White bread", "Processed sugars", "High-GI foods"])
        llm_context["recommend"].extend(["Complex carbohydrates", "High fiber foods (>30g/day)", "Leafy greens"])

    creat = vitals.get("Creatinine", 0.0)
    if creat > 1.2:
        llm_context["conditions"].append("Renal Stress / Kidney Strain")
        llm_context["avoid"].extend(["High sodium foods", "Excessive red meat", "Processed deli meats"])
        llm_context["recommend"].extend(["Low-potassium vegetables", "Cauliflower", "Berries"])

    if vitals.get("MAP", 90) > 100:
        llm_context["conditions"].append("Hypertension Risk")
        llm_context["avoid"].append("Salt/Sodium")
        llm_context["recommend"].append("DASH diet principles")

    if not llm_context["conditions"]:
        llm_context["conditions"].append("General Health Maintenance")
        llm_context["recommend"].append("Balanced diet with lean proteins and vegetables")

    return llm_context


# =========================
# 3. FOOD KB + GEMINI DIET PLAN
# =========================
def load_and_tag_data(filepath):
    if not os.path.exists(filepath):
        return pd.DataFrame()

    with open(filepath, "r") as f:
        data = json.load(f)
    df = pd.DataFrame(data)

    def get_tags(row):
        tags = []
        if row["Carbohydrate (g)"] < 30 and "sugar" not in str(row["ingredients"]).lower():
            tags.append("diabetic_friendly")
            tags.append("low_sugar")
        if row["Protein (g)"] > 10:
            tags.append("high_protein")
        if row["Total Fat (g)"] < 8:
            tags.append("low_fat")
        if 5 < row["Protein (g)"] < 15:
            tags.append("renal_safe")
        return tags

    df["medical_tags"] = df.apply(get_tags, axis=1)
    return df


def get_smart_candidates(df, clinical_insights):
    if df.empty:
        return {}

    conditions = " ".join(clinical_insights.get("conditions", [])).lower()
    avoid = " ".join(clinical_insights.get("avoid", [])).lower()

    candidates = df.copy()

    if "diabetes" in conditions or "sugar" in avoid:
        candidates = candidates[candidates["medical_tags"].apply(lambda x: "diabetic_friendly" in x)]

    if "renal" in conditions or "kidney" in conditions:
        candidates = candidates[candidates["medical_tags"].apply(lambda x: "renal_safe" in x)]

    breakfast_keywords = "idli|dosa|upma|poha|paratha|oats|porridge"
    breakfast_df = candidates[candidates["name"].str.contains(breakfast_keywords, case=False, na=False)]
    if len(breakfast_df) < 2:
        breakfast_df = candidates.sample(n=min(5, len(candidates)))

    mains_df = candidates[candidates["Energy (kcal)"] > 150]
    snacks_df = candidates[candidates["Energy (kcal)"] < 150]

    def serialize(sub_df, count=5):
        return sub_df.sample(n=min(count, len(sub_df))).to_dict(orient="records")

    return {
        "breakfast": serialize(breakfast_df),
        "lunch": serialize(mains_df),
        "dinner": serialize(mains_df),
        "snacks": serialize(snacks_df),
    }


def generate_structured_plan(patient_profile, clinical_data, food_candidates):
    summary = clinical_data.get("summary", "Healthy Diet")
    options_preview = json.dumps(food_candidates, indent=2)

    prompt = f"""
    You are an AI Clinical Dietitian.

    PATIENT: {patient_profile['name']}
    CONDITION: {summary}

    TASK:
    Select items from the PROVIDED CANDIDATE LIST below to create a 1-day meal plan.
    You must output the result in STRICT JSON format matching the user's required schema.

    CANDIDATE FOODS (Pick from these):
    {options_preview}

    REQUIRED OUTPUT FORMAT (JSON):
    {{
      "day_plan": {{
          "breakfast": [
              {{ "item": "Name", "calories": 100, "protein": 5, "fat": 2, "carbs": 10, "tags": ["tag1", "tag2"] }}
          ],
          "lunch": [],
          "dinner": [],
          "snacks": []
      }},
      "total_nutrition": {{
          "calories": 0,
          "protein": 0,
          "carbs": 0,
          "fat": 0
      }},
      "medical_reasoning": "Brief explanation..."
    }}

    RULES:
    1. Use the EXACT nutritional values from the candidate list. Do not invent numbers.
    2. Include the 'medical_tags' provided in the candidate list.
    3. Calculate the 'total_nutrition' sum correctly.
    4. Output JSON ONLY. No text before or after.
    """

    schema = {
        "type": "object",
        "properties": {
            "day_plan": {
                "type": "object",
                "properties": {
                    "breakfast": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": ["item", "calories", "protein", "fat", "carbs", "tags"],
                        },
                    },
                    "lunch": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": ["item", "calories", "protein", "fat", "carbs", "tags"],
                        },
                    },
                    "dinner": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": ["item", "calories", "protein", "fat", "carbs", "tags"],
                        },
                    },
                    "snacks": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "item": {"type": "string"},
                                "calories": {"type": "number"},
                                "protein": {"type": "number"},
                                "fat": {"type": "number"},
                                "carbs": {"type": "number"},
                                "tags": {
                                    "type": "array",
                                    "items": {"type": "string"},
                                },
                            },
                            "required": ["item", "calories", "protein", "fat", "carbs", "tags"],
                        },
                    },
                },
                "required": ["breakfast", "lunch", "dinner", "snacks"],
            },
            "total_nutrition": {
                "type": "object",
                "properties": {
                    "calories": {"type": "number"},
                    "protein": {"type": "number"},
                    "fat": {"type": "number"},
                    "carbs": {"type": "number"},
                },
                "required": ["calories", "protein", "fat", "carbs"],
            },
            "medical_reasoning": {"type": "string"},
        },
        "required": ["day_plan", "total_nutrition", "medical_reasoning"],
    }

    resp = client.models.generate_content(
        model=GEMINI_MODEL_ID,
        contents=prompt,
        config=types.GenerateContentConfig(
            response_mime_type="application/json",
            response_schema=schema,
            temperature=0.1,
        ),
    )

    if hasattr(resp, "parsed") and resp.parsed is not None:
        return resp.parsed

    return json.loads(resp.text.strip())


def round_plan(plan: Dict[str, Any], ndigits: int = 0) -> Dict[str, Any]:
    for meal in ["breakfast", "lunch", "dinner", "snacks"]:
        for dish in plan.get("day_plan", {}).get(meal, []):
            for key in ["calories", "protein", "fat", "carbs"]:
                val = dish.get(key)
                if isinstance(val, (int, float)):
                    dish[key] = round(val, ndigits)

    tn = plan.get("total_nutrition", {})
    for key in ["calories", "protein", "fat", "carbs"]:
        val = tn.get(key)
        if isinstance(val, (int, float)):
            tn[key] = round(val, ndigits)
    return plan


# =========================
# 4. FastAPI endpoint
# =========================
@app.post("/plan-diet")
async def plan_diet(report: UploadFile = File(...)):
    if report.content_type not in ["application/pdf"]:
        raise HTTPException(status_code=400, detail="Only PDF files are supported")

    pdf_bytes = await report.read()

    # 1) OCR + parameter extraction
    text = ocr_pdf_bytes_to_text(pdf_bytes)
    extracted_params = extract_all_parameters(text)

    # 2) LSTM risk prediction
    input_tensor, vitals = preprocess_patient_data(extracted_params)
    prediction = MODEL.predict(input_tensor, verbose=0)[0][0]
    clinical_json = build_clinical_json(float(prediction), vitals, extracted_params)

    # 3) Diet plan
    food_df = load_and_tag_data(FOOD_KB_FILE)
    candidates = get_smart_candidates(food_df, clinical_json)
    patient = {"name": "From PDF", "age": extracted_params.get("Age", {}).get("value", None) or 45}
    diet_plan = generate_structured_plan(patient, clinical_json, candidates)
    diet_plan = round_plan(diet_plan, ndigits=0)

    # Combine clinical + diet if you like
    result = {
        "clinical": clinical_json,
        "diet": diet_plan,
    }
    return result
