In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch_geometric



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GraphSAGE, GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch
import networkx as nx
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional
import json
from dataclasses import dataclass
from enum import Enum
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Hugging Face transformers for pretrained models
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    T5ForConditionalGeneration, T5Tokenizer,
    GPT2LMHeadModel, GPT2Tokenizer,
    BartForConditionalGeneration, BartTokenizer,
    Trainer, TrainingArguments
)
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import login as hf_login
from transformers import EvalPrediction

# Data structures for our domain
@dataclass
class Food:
    id: str
    name: str
    category: str
    calories: float
    protein: float
    carbs: float
    fats: float
    fiber: float
    vitamins: Dict[str, float]
    minerals: Dict[str, float]
    dosha_effects: Dict[str, str]  # vata, pitta, kapha effects
    rasa: str  # taste
    guna: List[str]  # qualities
    virya: str  # potency
    vipaka: str  # post-digestive effect
    health_tags: List[str]
    contraindications: List[str]

@dataclass
class Patient:
    id: str
    age: int
    gender: str
    weight: float
    height: float
    bmi: float
    lifestyle: str
    prakriti: str  # constitutional type
    health_conditions: List[str]
    allergies: List[str]
    preferred_cuisine: List[str]

@dataclass
class MealPlan:
    patient_id: str
    day: int
    breakfast: List[str]
    lunch: List[str]
    dinner: List[str]
    snacks: List[str]
    restrictions: List[str]
    doctor_notes: str

class NodeType(Enum):
    FOOD = "food"
    PATIENT = "patient"
    DOSHA = "dosha"
    RASA = "rasa"
    GUNA = "guna"
    CONDITION = "condition"
    CATEGORY = "category"

# Knowledge Graph Foundation
class AyurvedaKnowledgeGraph:
    def __init__(self):
        self.graph = nx.Graph()
        self.node_to_idx = {}
        self.idx_to_node = {}
        self.node_types = {}
        self.node_features = {}
        self.food_names_by_category = {}  # Store food names by category

    def add_food_node(self, food: Food):
        """Add a food item and its relationships to the graph"""
        food_id = f"food_{food.id}"
        self.graph.add_node(food_id)
        self.node_types[food_id] = NodeType.FOOD

        # Store food name for recommendation
        if food.category not in self.food_names_by_category:
            self.food_names_by_category[food.category] = []
        self.food_names_by_category[food.category].append(food.name)

        # Store food features
        self.node_features[food_id] = {
            'calories': food.calories,
            'protein': food.protein,
            'carbs': food.carbs,
            'fats': food.fats,
            'fiber': food.fiber,
            'category_embedding': self._encode_category(food.category)
        }

        # Add relationships
        for dosha, effect in food.dosha_effects.items():
            dosha_node = f"dosha_{dosha}"
            self._ensure_node_exists(dosha_node, NodeType.DOSHA)
            self.graph.add_edge(food_id, dosha_node, relation=f"affects_{effect}")

        if food.rasa:
            rasa_node = f"rasa_{food.rasa}"
            self._ensure_node_exists(rasa_node, NodeType.RASA)
            self.graph.add_edge(food_id, rasa_node, relation="has_taste")

        for guna in food.guna:
            guna_node = f"guna_{guna}"
            self._ensure_node_exists(guna_node, NodeType.GUNA)
            self.graph.add_edge(food_id, guna_node, relation="has_quality")

        category_node = f"category_{food.category}"
        self._ensure_node_exists(category_node, NodeType.CATEGORY)
        self.graph.add_edge(food_id, category_node, relation="belongs_to")

        for tag in food.health_tags:
            condition_node = f"condition_{tag}"
            self._ensure_node_exists(condition_node, NodeType.CONDITION)
            self.graph.add_edge(food_id, condition_node, relation="beneficial_for")

    def add_patient_node(self, patient: Patient):
        """Add patient and their characteristics"""
        patient_id = f"patient_{patient.id}"
        self.graph.add_node(patient_id)
        self.node_types[patient_id] = NodeType.PATIENT

        self.node_features[patient_id] = {
            'age': patient.age,
            'bmi': patient.bmi,
            'gender_embedding': self._encode_gender(patient.gender),
            'lifestyle_embedding': self._encode_lifestyle(patient.lifestyle)
        }

        prakriti_node = f"dosha_{patient.prakriti}"
        self._ensure_node_exists(prakriti_node, NodeType.DOSHA)
        self.graph.add_edge(patient_id, prakriti_node, relation="has_prakriti")

        for condition in patient.health_conditions:
            condition_node = f"condition_{condition}"
            self._ensure_node_exists(condition_node, NodeType.CONDITION)
            self.graph.add_edge(patient_id, condition_node, relation="has_condition")

    def _ensure_node_exists(self, node_id: str, node_type: NodeType):
        if node_id not in self.graph:
            self.graph.add_node(node_id)
            self.node_types[node_id] = node_type
            self.node_features[node_id] = {}

    def _encode_category(self, category: str) -> List[float]:
        categories = ['grains', 'vegetables', 'fruits', 'dairy', 'spices', 'legumes', 'nuts', 'oils']
        encoding = [1.0 if category.lower() == cat else 0.0 for cat in categories]
        if sum(encoding) == 0:  # Unknown category
            encoding.append(1.0)
        else:
            encoding.append(0.0)
        return encoding

    def _encode_gender(self, gender: str) -> List[float]:
        return [1.0, 0.0] if gender.lower() == 'male' else [0.0, 1.0]

    def _encode_lifestyle(self, lifestyle: str) -> List[float]:
        lifestyles = ['sedentary', 'moderate', 'active', 'very_active']
        encoding = [1.0 if lifestyle.lower() == ls else 0.0 for ls in lifestyles]
        return encoding

    def to_pytorch_geometric(self) -> Data:
        """Convert NetworkX graph to PyTorch Geometric format"""
        nodes = list(self.graph.nodes())
        self.node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        self.idx_to_node = {idx: node for node, idx in self.node_to_idx.items()}

        edges = list(self.graph.edges())
        edge_index = torch.tensor([[self.node_to_idx[u], self.node_to_idx[v]]
                                  for u, v in edges], dtype=torch.long).t().contiguous()

        node_features = []
        for node in nodes:
            features = []
            node_type = self.node_types.get(node, NodeType.FOOD)

            type_embedding = [0.0] * len(NodeType)
            type_embedding[list(NodeType).index(node_type)] = 1.0
            features.extend(type_embedding)

            if node in self.node_features:
                for key, value in self.node_features[node].items():
                    if isinstance(value, list):
                        features.extend(value)
                    else:
                        features.append(float(value))

            while len(features) < 32:
                features.append(0.0)

            node_features.append(features[:32])

        x = torch.tensor(node_features, dtype=torch.float)
        return Data(x=x, edge_index=edge_index)

# Graph Neural Network for Food Embeddings
class GraphNeuralNetwork(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index, batch=None):
        h1 = F.relu(self.conv1(x, edge_index))
        h1 = self.dropout(h1)
        h2 = F.relu(self.conv2(h1, edge_index))
        h2 = self.dropout(h2)
        h3 = self.conv3(h2, edge_index)
        return h3

# Pretrained Transformer Models with Better Generation
class T5MealPlanner(nn.Module):
    """Using T5 for text-to-text meal planning"""

    def __init__(self, model_name: str = "t5-small", graph_embedding_dim: int = 256):
        super().__init__()
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)

        # Add special tokens for Ayurveda concepts
        special_tokens = [
            "<patient>", "</patient>", "<day>", "</day>",
            "<breakfast>", "</breakfast>", "<lunch>", "</lunch>",
            "<dinner>", "</dinner>", "<snacks>", "</snacks>",
            "<vata>", "<pitta>", "<kapha>",
            "<diabetes>", "<hypertension>", "<obesity>", "<digestion>"
        ]

        self.tokenizer.add_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))

        # Graph embeddings integration
        self.graph_encoder = GraphNeuralNetwork(32, 128, graph_embedding_dim)
        self.graph_projection = nn.Linear(graph_embedding_dim, self.model.config.d_model)

    def format_patient_input(self, patient: Patient, day: int) -> str:
        """Convert patient data to structured text input"""
        input_text = "generate meal plan: "
        input_text += f"patient age {patient.age} gender {patient.gender} "
        input_text += f"bmi {patient.bmi:.1f} lifestyle {patient.lifestyle} "
        input_text += f"prakriti {patient.prakriti} "

        if patient.health_conditions:
            conditions = " ".join(patient.health_conditions)
            input_text += f"conditions {conditions} "

        if patient.allergies:
            input_text += f"allergies {' '.join(patient.allergies)} "

        input_text += f"day {day}"
        return input_text

    def format_meal_plan_output(self, meal_plan: MealPlan) -> str:
        """Convert meal plan to structured text output"""
        output_text = ""
        if meal_plan.breakfast:
            output_text += "breakfast: " + ", ".join(meal_plan.breakfast) + " "
        if meal_plan.lunch:
            output_text += "lunch: " + ", ".join(meal_plan.lunch) + " "
        if meal_plan.dinner:
            output_text += "dinner: " + ", ".join(meal_plan.dinner) + " "
        if meal_plan.snacks:
            output_text += "snacks: " + ", ".join(meal_plan.snacks)
        return output_text.strip()

# Dataset class for training
class AyurvedaMealPlanDataset(Dataset):
    def __init__(self, patients: List[Patient], meal_plans: List[MealPlan],
                 tokenizer, max_length: int = 512, model_type: str = "t5"):
        self.patients = {p.id: p for p in patients}
        self.meal_plans = meal_plans
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.model_type = model_type

    def __len__(self):
        return len(self.meal_plans)

    def __getitem__(self, idx):
        meal_plan = self.meal_plans[idx]
        patient = self.patients.get(meal_plan.patient_id)

        if not patient:
            # Create a dummy patient if not found
            patient = Patient(
                id=meal_plan.patient_id,
                age=30, gender="unknown", weight=70, height=170, bmi=24,
                lifestyle="moderate", prakriti="vata",
                health_conditions=[], allergies=[], preferred_cuisine=[]
            )

        # Simplified format for better training
        input_text = f"generate meal plan: age {patient.age} {patient.gender} "
        input_text += f"bmi {patient.bmi:.1f} {patient.prakriti} day {meal_plan.day}"

        # Simplified output format
        target_text = ""
        if meal_plan.breakfast:
            target_text += f"breakfast: {', '.join(meal_plan.breakfast)} "
        if meal_plan.lunch:
            target_text += f"lunch: {', '.join(meal_plan.lunch)} "
        if meal_plan.dinner:
            target_text += f"dinner: {', '.join(meal_plan.dinner)} "
        if meal_plan.snacks:
            target_text += f"snacks: {', '.join(meal_plan.snacks)}"

        # Tokenize
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        targets = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Replace padding token id with -100 for loss calculation
        targets['input_ids'][targets['input_ids'] == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

# Data loading utilities
def _split_list(val: Optional[str]) -> List[str]:
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    if isinstance(val, list):
        return val
    s = str(val).strip()
    if not s:
        return []
    # try JSON array
    try:
        parsed = json.loads(s)
        if isinstance(parsed, list):
            return [str(x).strip() for x in parsed if str(x).strip()]
    except Exception:
        pass
    # fallback: split by | or ,
    sep = '|' if '|' in s else ','
    return [p.strip() for p in s.split(sep) if p.strip()]

def _parse_float(val, default: float = 0.0) -> float:
    try:
        if val is None or (isinstance(val, float) and np.isnan(val)):
            return default
        return float(val)
    except Exception:
        return default

def _parse_dict(val) -> Dict[str, str]:
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return {}
    if isinstance(val, dict):
        return {str(k): str(v) for k, v in val.items()}
    s = str(val).strip()
    if not s:
        return {}
    # try json
    try:
        parsed = json.loads(s)
        if isinstance(parsed, dict):
            return {str(k): str(v) for k, v in parsed.items()}
    except Exception:
        pass
    # fallback: key:value pairs
    pairs = s.split('|') if '|' in s else s.split(',')
    out = {}
    for p in pairs:
        if ':' in p:
            k, v = p.split(':', 1)
            out[k.strip()] = v.strip()
    return out

def load_foods_csv(path: str) -> List[Food]:
    df = pd.read_csv(path)
    foods = []
    for idx, row in df.iterrows():
        foods.append(
            Food(
                id=str(row.get('id', row.get('food_id', idx))),
                name=str(row.get('name', row.get('food_name', 'Unknown'))),
                category=str(row.get('category', 'unknown')),
                calories=_parse_float(row.get('calories', 0)),
                protein=_parse_float(row.get('protein', 0)),
                carbs=_parse_float(row.get('carbs', row.get('carbohydrates', 0))),
                fats=_parse_float(row.get('fats', row.get('fat', 0))),
                fiber=_parse_float(row.get('fiber', 0)),
                vitamins=_parse_dict(row.get('vitamins', {})),
                minerals=_parse_dict(row.get('minerals', {})),
                dosha_effects=_parse_dict(row.get('dosha_effects', {})),
                rasa=str(row.get('rasa', 'sweet')),
                guna=_split_list(row.get('guna', row.get('qualities', ''))),
                virya=str(row.get('virya', 'neutral')),
                vipaka=str(row.get('vipaka', 'sweet')),
                health_tags=_split_list(row.get('health_tags', row.get('tags', ''))),
                contraindications=_split_list(row.get('contraindications', '')),
            )
        )
    return foods

def load_patients_csv(path: str) -> List[Patient]:
    df = pd.read_csv(path)
    patients = []
    for idx, row in df.iterrows():
        height = _parse_float(row.get('height', 170))
        weight = _parse_float(row.get('weight', 70))
        bmi = _parse_float(row.get('bmi', 0))
        if not bmi and height and weight:
            try:
                h_m = height / 100.0 if height > 3 else height
                bmi = weight / (h_m * h_m) if h_m else 24.0
            except Exception:
                bmi = 24.0
        patients.append(
            Patient(
                id=str(row.get('id', row.get('patient_id', idx))),
                age=int(_parse_float(row.get('age', 30))),
                gender=str(row.get('gender', 'unknown')),
                weight=weight,
                height=height,
                bmi=bmi,
                lifestyle=str(row.get('lifestyle', 'moderate')),
                prakriti=str(row.get('prakriti', row.get('constitution', 'vata'))),
                health_conditions=_split_list(row.get('health_conditions', row.get('conditions', ''))),
                allergies=_split_list(row.get('allergies', '')),
                preferred_cuisine=_split_list(row.get('preferred_cuisine', row.get('cuisine', ''))),
            )
        )
    return patients

def load_doctor_plans_csv(path: str) -> List[MealPlan]:
    df = pd.read_csv(path)
    plans = []
    for idx, row in df.iterrows():
        plans.append(
            MealPlan(
                patient_id=str(row.get('patient_id', row.get('id', idx))),
                day=int(_parse_float(row.get('day', 1))),
                breakfast=_split_list(row.get('breakfast', '')),
                lunch=_split_list(row.get('lunch', '')),
                dinner=_split_list(row.get('dinner', '')),
                snacks=_split_list(row.get('snacks', '')),
                restrictions=_split_list(row.get('restrictions', '')),
                doctor_notes=str(row.get('doctor_notes', row.get('notes', ''))),
            )
        )
    return plans

def huggingface_login(token: Optional[str] = None):
    """Login to Hugging Face using a token"""
    tok = token or os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN')
    if tok:
        try:
            hf_login(token=tok)
            print('✓ Logged in to Hugging Face Hub')
        except Exception as e:
            print(f"⚠ Hugging Face login failed: {e}")
    else:
        print('ℹ No Hugging Face token provided; proceeding without login')

# Define compute_metrics function for evaluation
def compute_metrics(p: EvalPrediction) -> Dict[str, float]:
    """Compute metrics for evaluation"""
    # Simple evaluation - could be enhanced with ROUGE, BLEU, etc.
    # For now, just return loss
    return {"eval_loss": p.metrics["eval_loss"]}

# Main Hybrid Neural Engine
class HybridNeuralEngine:
    def __init__(self, model_type: str = "t5", model_name: str = None):
        self.model_type = model_type
        self.knowledge_graph = AyurvedaKnowledgeGraph()

        # Use smaller models for faster inference
        if model_type == "t5":
            model_name = model_name or "t5-small"
            self.planner = T5MealPlanner(model_name)
            self.tokenizer = self.planner.tokenizer
        else:
            raise ValueError(f"Currently only T5 is fully implemented. Got: {model_type}")

        # Determine device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.planner.model.to(self.device)
        print(f"Using device: {self.device}")


    def build_knowledge_graph(self, foods: List[Food], patients: List[Patient]):
        """Build the knowledge graph from data"""
        print(f"Building knowledge graph with {len(foods)} foods and {len(patients)} patients...")
        for food in foods:
            self.knowledge_graph.add_food_node(food)
        for patient in patients:
            self.knowledge_graph.add_patient_node(patient)
        graph_data = self.knowledge_graph.to_pytorch_geometric()
        return graph_data.to(self.device) # Move graph data to device


    def get_food_recommendations(self, patient: Patient) -> Dict[str, List[str]]:
        """Get food recommendations based on patient profile and knowledge graph"""
        recommendations = {
            'breakfast': [],
            'lunch': [],
            'dinner': [],
            'snacks': []
        }

        # Use knowledge graph to get suitable foods
        breakfast_categories = ['grains', 'fruits', 'dairy', 'nuts']
        lunch_dinner_categories = ['grains', 'vegetables', 'legumes', 'dairy']
        snack_categories = ['fruits', 'nuts', 'dairy']

        # Get foods from each category
        for cat in breakfast_categories:
            if cat in self.knowledge_graph.food_names_by_category:
                foods = self.knowledge_graph.food_names_by_category[cat]
                recommendations['breakfast'].extend(foods[:2])  # Take first 2 from each category

        for cat in lunch_dinner_categories:
            if cat in self.knowledge_graph.food_names_by_category:
                foods = self.knowledge_graph.food_names_by_category[cat]
                recommendations['lunch'].extend(foods[:2])
                recommendations['dinner'].extend(foods[2:4] if len(foods) > 2 else foods[:2])

        for cat in snack_categories:
            if cat in self.knowledge_graph.food_names_by_category:
                foods = self.knowledge_graph.food_names_by_category[cat]
                recommendations['snacks'].extend(foods[:1])

        # Limit recommendations
        recommendations['breakfast'] = recommendations['breakfast'][:4]
        recommendations['lunch'] = recommendations['lunch'][:5]
        recommendations['dinner'] = recommendations['dinner'][:5]
        recommendations['snacks'] = recommendations['snacks'][:2]

        return recommendations

    def fine_tune(self, train_dataset: AyurvedaMealPlanDataset,
                  val_dataset: AyurvedaMealPlanDataset = None,
                  output_dir: str = "./ayurveda_meal_planner",
                  num_epochs: int = 3,
                  batch_size: int = 4,
                  learning_rate: float = 5e-5):
        """Fine-tune the pretrained model"""

        print(f"Starting fine-tuning for {num_epochs} epochs...")

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            warmup_steps=50,
            logging_dir=f"{output_dir}/logs",
            logging_steps=10,
            save_steps=100,
            eval_steps=100 if val_dataset else None,
            eval_strategy="steps" if val_dataset else "no", # Changed evaluation_strategy to eval_strategy
            save_total_limit=2,
            load_best_model_at_end=True if val_dataset else False,
            metric_for_best_model="eval_loss" if val_dataset else None,
            greater_is_better=False,
            fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
            dataloader_pin_memory=False,
            report_to="none",  # Disable wandb/tensorboard
            prediction_loss_only=False, # Ensure predictions are returned for metric calculation
        )

        trainer = Trainer(
            model=self.planner.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=self.tokenizer,
            compute_metrics=compute_metrics if val_dataset else None, # Add compute_metrics
        )

        # Start training
        trainer.train()

        # Save the fine-tuned model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)

        print(f"✓ Model fine-tuned and saved to {output_dir}")

    def generate_meal_plan(self, patient: Patient, day: int,
                          graph_data: Data = None,
                          max_length: int = 256,
                          temperature: float = 0.9,
                          use_knowledge_graph: bool = True) -> str:
        """Generate meal plan for a patient"""

        # Ensure graph_data is on the correct device
        if graph_data is not None and graph_data.x.device != self.device:
            graph_data = graph_data.to(self.device)

        if use_knowledge_graph and not self.knowledge_graph.food_names_by_category:
            # If no foods in graph, create some sample recommendations
            print("⚠ No foods in knowledge graph, using default recommendations")
            return self._generate_default_plan(patient, day)

        # Format input
        input_text = self.planner.format_patient_input(patient, day)

        # Tokenize and move to device
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(self.device)


        # Generate
        with torch.no_grad():
            outputs = self.planner.model.generate(
                **inputs,
                max_length=max_length,
                min_length=20,
                temperature=temperature,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                num_beams=3,
                early_stopping=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        # If generation fails, use knowledge graph recommendations
        if not self._has_valid_content(generated_text) and use_knowledge_graph:
            recommendations = self.get_food_recommendations(patient)
            generated_text = self._format_recommendations(recommendations)

        return generated_text

    def _generate_default_plan(self, patient: Patient, day: int) -> str:
        """Generate a default meal plan based on patient profile"""
        # Basic meal suggestions based on prakriti
        vata_foods = {
            'breakfast': ['oatmeal', 'warm milk', 'almonds', 'dates'],
            'lunch': ['rice', 'moong dal', 'ghee', 'cooked vegetables'],
            'dinner': ['khichdi', 'soup', 'bread', 'cooked spinach'],
            'snacks': ['banana', 'soaked almonds']
        }

        pitta_foods = {
            'breakfast': ['coconut water', 'sweet fruits', 'milk', 'cereal'],
            'lunch': ['basmati rice', 'green vegetables', 'cucumber', 'yogurt'],
            'dinner': ['quinoa', 'salad', 'sweet potato', 'green beans'],
            'snacks': ['watermelon', 'coconut']
        }

        kapha_foods = {
            'breakfast': ['honey water', 'light breakfast', 'berries', 'green tea'],
            'lunch': ['millet', 'bitter vegetables', 'spices', 'legumes'],
            'dinner': ['barley soup', 'steamed vegetables', 'ginger tea'],
            'snacks': ['apple', 'pear']
        }

        # Select based on prakriti
        prakriti = patient.prakriti.lower()
        if 'vata' in prakriti:
            foods = vata_foods
        elif 'pitta' in prakriti:
            foods = pitta_foods
        elif 'kapha' in prakriti:
            foods = kapha_foods
        else:
            # Mix of all
            foods = {
                'breakfast': vata_foods['breakfast'][:2] + pitta_foods['breakfast'][:2],
                'lunch': vata_foods['lunch'][:2] + pitta_foods['lunch'][:2],
                'dinner': kapha_foods['dinner'][:2] + vata_foods['dinner'][:2],
                'snacks': pitta_foods['snacks']
            }

        return self._format_recommendations(foods)

    def _format_recommendations(self, recommendations: Dict[str, List[str]]) -> str:
        """Format recommendations into text"""
        text = ""
        if recommendations.get('breakfast'):
            text += f"breakfast: {', '.join(recommendations['breakfast'])} "
        if recommendations.get('lunch'):
            text += f"lunch: {', '.join(recommendations['lunch'])} "
        if recommendations.get('dinner'):
            text += f"dinner: {', '.join(recommendations['dinner'])} "
        if recommendations.get('snacks'):
            text += f"snacks: {', '.join(recommendations['snacks'])}"
        return text.strip()

    def _has_valid_content(self, text: str) -> bool:
        """Check if generated text has valid meal content"""
        return any(meal in text.lower() for meal in ['breakfast', 'lunch', 'dinner', 'snacks'])

    def parse_generated_plan(self, generated_text: str) -> Dict[str, List[str]]:
        """Parse the generated text into structured meal plan"""
        plan = {
            'breakfast': [],
            'lunch': [],
            'dinner': [],
            'snacks': []
        }

        text = generated_text.lower()

        # Try different parsing strategies

        # Strategy 1: Look for meal keywords followed by colon
        for meal_type in ['breakfast', 'lunch', 'dinner', 'snacks']:
            if f'{meal_type}:' in text:
                # Find the section after the meal type
                start = text.index(f'{meal_type}:') + len(f'{meal_type}:')
                # Find the next meal type or end of string
                end = len(text)
                for next_meal in ['breakfast', 'lunch', 'dinner', 'snacks']:
                    if next_meal != meal_type and f'{next_meal}:' in text[start:]:
                        next_pos = text.index(f'{next_meal}:', start)
                        if next_pos < end:
                            end = next_pos

                # Extract and clean the items
                section = text[start:end].strip()
                # Remove any tags
                section = section.replace('</', ' ').replace('<', ' ')
                # Split by comma or common separators
                items = [item.strip() for item in section.split(',')]
                # Clean items
                cleaned_items = []
                for item in items:
                    # Remove extra spaces and special characters
                    item = ' '.join(item.split())
                    # Remove trailing periods or special chars
                    item = item.rstrip('.,;')
                    if item and len(item) > 1 and not item.startswith('/'):
                        cleaned_items.append(item)

                plan[meal_type] = cleaned_items[:5]  # Limit to 5 items per meal

        # Strategy 2: If no meals found with colon, look for tags
        if not any(plan.values()):
            if '<breakfast>' in text and '</breakfast>' in text:
                section = text.split('<breakfast>')[1].split('</breakfast>')[0]
                plan['breakfast'] = [item.strip() for item in section.split(',') if item.strip()][:5]

            if '<lunch>' in text and '</lunch>' in text:
                section = text.split('<lunch>')[1].split('</lunch>')[0]
                plan['lunch'] = [item.strip() for item in section.split(',') if item.strip()][:5]

            if '<dinner>' in text and '</dinner>' in text:
                section = text.split('<dinner>')[1].split('</dinner>')[0]
                plan['dinner'] = [item.strip() for item in section.split(',') if item.strip()][:5]

            if '<snacks>' in text and '</snacks>' in text:
                section = text.split('<snacks>')[1].split('</snacks>')[0]
                plan['snacks'] = [item.strip() for item in section.split(',') if item.strip()][:5]

        # If still no meals found, use default plan
        if not any(plan.values()):
            plan = {
                'breakfast': ['oatmeal', 'fruits', 'milk'],
                'lunch': ['rice', 'dal', 'vegetables', 'yogurt'],
                'dinner': ['chapati', 'vegetables', 'soup'],
                'snacks': ['nuts', 'fruits']
            }

        return plan

# Main execution function
def main():
    """Main function to run the Ayurveda meal planning system"""

    # Configuration
    HF_TOKEN = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN')
    FOODS_CSV = "/content/drive/MyDrive/AyruAhaar-datasets/foods.csv"
    PATIENTS_CSV = "/content/drive/MyDrive/AyruAhaar-datasets/patients.csv"
    PLANS_CSV = "/content/drive/MyDrive/AyruAhaar-datasets/doctor_plans.csv"

    MODEL_TYPE = "t5"  # Currently only T5 is fully implemented
    MODEL_NAME = "t5-small"  # Using smaller model for faster inference
    DO_TRAIN = True  # Set to True to fine-tune
    OUTPUT_DIR = "./ayurveda_meal_planner"
    EPOCHS = 2  # Reduced for faster training
    BATCH_SIZE = 1  # Smaller batch size
    LEARNING_RATE = 3e-4  # Higher learning rate for small dataset

    print("=" * 60)
    print("🌿 Ayurveda Meal Planning System")
    print("=" * 60)

    # Login to Hugging Face if token available
    if HF_TOKEN:
        huggingface_login(HF_TOKEN)

    # Check if files exist
    print("\n📁 Checking data files...")
    for file_path, file_name in [(FOODS_CSV, "foods"), (PATIENTS_CSV, "patients"), (PLANS_CSV, "doctor_plans")]:
        if Path(file_path).exists():
            print(f"  ✓ {file_name} CSV found")
        else:
            print(f"  ✗ {file_name} CSV not found at {file_path}")
            print("\n⚠ Creating sample data for demonstration...")
            # Create sample data if files don't exist
            create_sample_data(FOODS_CSV, PATIENTS_CSV, PLANS_CSV)

    # Load data
    print("\n📊 Loading data...")
    try:
        foods = load_foods_csv(FOODS_CSV)
        patients = load_patients_csv(PATIENTS_CSV)
        plans = load_doctor_plans_csv(PLANS_CSV)
        print(f"  ✓ Loaded {len(foods)} foods")
        print(f"  ✓ Loaded {len(patients)} patients")
        print(f"  ✓ Loaded {len(plans)} meal plans")
    except Exception as e:
        print(f"  ✗ Error loading data: {e}")
        print("\n⚠ Creating sample data for demonstration...")
        create_sample_data(FOODS_CSV, PATIENTS_CSV, PLANS_CSV)
        foods = load_foods_csv(FOODS_CSV)
        patients = load_patients_csv(PATIENTS_CSV)
        plans = load_doctor_plans_csv(PLANS_CSV)

    # Initialize engine
    print("\n🤖 Initializing AI engine...")
    engine = HybridNeuralEngine(model_type=MODEL_TYPE, model_name=MODEL_NAME)
    print(f"  ✓ Model loaded: {MODEL_NAME}")

    # Build knowledge graph
    print("\n🔗 Building knowledge graph...")
    graph_data = engine.build_knowledge_graph(foods, patients)
    print(f"  ✓ Graph built with {graph_data.x.shape[0]} nodes")

    # Prepare dataset
    if DO_TRAIN and len(plans) > 0:
        print("\n📚 Preparing training dataset...")
        dataset = AyurvedaMealPlanDataset(patients, plans, engine.tokenizer, model_type=MODEL_TYPE)

        # Split into train and validation
        train_size = int(0.9 * len(dataset))
        val_size = len(dataset) - train_size

        if train_size > 0:
            train_dataset, val_dataset = torch.utils.data.random_split(
                dataset, [train_size, val_size]
            )

            print(f"  ✓ Train set: {len(train_dataset)} samples")
            print(f"  ✓ Validation set: {len(val_dataset)} samples")

            # Fine-tune the model
            print("\n🎯 Fine-tuning model...")
            engine.fine_tune(
                train_dataset,
                val_dataset if val_size > 0 else None,
                output_dir=OUTPUT_DIR,
                num_epochs=EPOCHS,
                batch_size=BATCH_SIZE,
                learning_rate=LEARNING_RATE,
            )

    # Generate meal plans
    print("\n🍽️ Generating meal plans...")
    print("-" * 60)

    # Generate for first 5 patients or all if less than 5
    num_patients = min(5, len(patients))

    for i in range(num_patients):
        patient = patients[i]
        print(f"\n👤 Patient {patient.id}:")
        print(f"   Age: {patient.age}, Gender: {patient.gender}")
        print(f"   BMI: {patient.bmi:.1f}, Prakriti: {patient.prakriti}")
        if patient.health_conditions:
            print(f"   Conditions: {', '.join(patient.health_conditions)}")

        # Generate plan
        generated_text = engine.generate_meal_plan(patient, day=1, graph_data=graph_data)
        parsed_plan = engine.parse_generated_plan(generated_text)

        print(f"\n   📅 Day 1 Meal Plan:")
        print(f"   🌅 Breakfast: {', '.join(parsed_plan['breakfast']) if parsed_plan['breakfast'] else 'None'}")
        print(f"   ☀️  Lunch: {', '.join(parsed_plan['lunch']) if parsed_plan['lunch'] else 'None'}")
        print(f"   🌙 Dinner: {', '.join(parsed_plan['dinner']) if parsed_plan['dinner'] else 'None'}")
        print(f"   🍎 Snacks: {', '.join(parsed_plan['snacks']) if parsed_plan['snacks'] else 'None'}")

    print("\n" + "=" * 60)
    print("✅ Meal plan generation complete!")
    print("=" * 60)

def create_sample_data(foods_path: str, patients_path: str, plans_path: str):
    """Create sample CSV files if they don't exist"""

    # Create directories if needed
    for path in [foods_path, patients_path, plans_path]:
        Path(path).parent.mkdir(parents=True, exist_ok=True)

    # Sample foods data
    foods_data = {
        'id': ['F001', 'F002', 'F003', 'F004', 'F005'],
        'name': ['Rice', 'Moong Dal', 'Ghee', 'Turmeric Milk', 'Almonds'],
        'category': ['grains', 'legumes', 'oils', 'dairy', 'nuts'],
        'calories': [130, 347, 900, 80, 579],
        'protein': [2.7, 24, 0, 3.5, 21],
        'carbs': [28, 63, 0, 12, 22],
        'fats': [0.3, 1.2, 100, 3, 50],
        'fiber': [0.4, 16, 0, 0, 12],
        'rasa': ['sweet', 'sweet', 'sweet', 'sweet', 'sweet'],
        'virya': ['cooling', 'cooling', 'cooling', 'heating', 'heating'],
        'vipaka': ['sweet', 'sweet', 'sweet', 'sweet', 'sweet'],
        'dosha_effects': ['vata:-,pitta:-,kapha:+', 'vata:-,pitta:-,kapha:-',
                         'vata:-,pitta:-,kapha:+', 'vata:-,pitta:-,kapha:-',
                         'vata:-,pitta:+,kapha:+'],
        'health_tags': ['digestion', 'protein|digestion', 'immunity', 'sleep|immunity', 'brain|heart'],
        'guna': ['light|soft', 'light|dry', 'heavy|oily', 'light|oily', 'heavy|oily']
    }

    # Sample patients data
    patients_data = {
        'id': ['P001', 'P002', 'P003'],
        'age': [35, 28, 45],
        'gender': ['male', 'female', 'male'],
        'weight': [70, 60, 80],
        'height': [175, 165, 180],
        'bmi': [22.9, 22.0, 24.7],
        'lifestyle': ['moderate', 'active', 'sedentary'],
        'prakriti': ['vata', 'pitta', 'kapha'],
        'health_conditions': ['', 'acidity', 'diabetes|obesity'],
        'allergies': ['', 'nuts', ''],
        'preferred_cuisine': ['indian', 'indian', 'indian']
    }

    # Sample meal plans data
    plans_data = {
        'patient_id': ['P001', 'P001', 'P002', 'P002', 'P003', 'P003'],
        'day': [1, 2, 1, 2, 1, 2],
        'breakfast': ['Rice porridge|Almonds|Milk', 'Oatmeal|Dates|Ghee',
                     'Fruit salad|Yogurt', 'Cereal|Coconut water',
                     'Green tea|Apple', 'Herbal tea|Berries'],
        'lunch': ['Rice|Moong dal|Ghee|Vegetables', 'Khichdi|Yogurt|Salad',
                 'Quinoa|Green vegetables|Cucumber', 'Rice|Dal|Steamed vegetables',
                 'Millet|Bitter gourd|Spiced buttermilk', 'Barley|Mixed vegetables|Green salad'],
        'dinner': ['Soup|Bread|Cooked vegetables', 'Rice|Dal|Spinach',
                  'Sweet potato|Green beans|Salad', 'Pasta|Vegetables|Soup',
                  'Vegetable soup|Brown rice', 'Clear soup|Steamed vegetables'],
        'snacks': ['Banana|Soaked almonds', 'Dates|Warm milk',


                  'Watermelon|Coconut water', 'Apple|Herbal tea',
                  'Pear|Green tea', 'Cucumber|Carrot sticks'],
        'restrictions': ['', '', 'no nuts', 'no nuts', 'low sugar', 'low sugar'],
        'doctor_notes': ['Increase warm foods', 'Focus on grounding foods',
                        'Cooling foods recommended', 'Avoid heating foods',
                        'Light foods, avoid heavy meals', 'Increase metabolism']
    }

    # Save CSVs
    pd.DataFrame(foods_data).to_csv(foods_path, index=False)
    pd.DataFrame(patients_data).to_csv(patients_path, index=False)
    pd.DataFrame(plans_data).to_csv(plans_path, index=False)

    print(f"  ✓ Created sample data files")

if __name__ == "__main__":
    main()

🌿 Ayurveda Meal Planning System

📁 Checking data files...
  ✓ foods CSV found
  ✓ patients CSV found
  ✓ doctor_plans CSV found

📊 Loading data...
  ✓ Loaded 20 foods
  ✓ Loaded 5 patients
  ✓ Loaded 10 meal plans

🤖 Initializing AI engine...
Using device: cuda
  ✓ Model loaded: t5-small

🔗 Building knowledge graph...
Building knowledge graph with 20 foods and 5 patients...
  ✓ Graph built with 44 nodes

📚 Preparing training dataset...
  ✓ Train set: 9 samples
  ✓ Validation set: 1 samples

🎯 Fine-tuning model...
Starting fine-tuning for 2 epochs...


Step,Training Loss,Validation Loss


✓ Model fine-tuned and saved to ./ayurveda_meal_planner

🍽️ Generating meal plans...
------------------------------------------------------------

👤 Patient 1:
   Age: 25, Gender: Female
   BMI: 24.2, Prakriti: Vata

   📅 Day 1 Meal Plan:
   🌅 Breakfast: oatmeal, fruits, milk
   ☀️  Lunch: rice, dal, vegetables, yogurt
   🌙 Dinner: chapati, vegetables, soup
   🍎 Snacks: nuts, fruits

👤 Patient 2:
   Age: 35, Gender: Male
   BMI: 24.2, Prakriti: Pitta
   Conditions: ['hypertension']

   📅 Day 1 Meal Plan:
   🌅 Breakfast: oatmeal, fruits, milk
   ☀️  Lunch: rice, dal, vegetables, yogurt
   🌙 Dinner: chapati, vegetables, soup
   🍎 Snacks: nuts, fruits

👤 Patient 3:
   Age: 45, Gender: Female
   BMI: 24.2, Prakriti: Kapha
   Conditions: ['diabetes']

   📅 Day 1 Meal Plan:
   🌅 Breakfast: oatmeal, fruits, milk
   ☀️  Lunch: rice, dal, vegetables, yogurt
   🌙 Dinner: chapati, vegetables, soup
   🍎 Snacks: nuts, fruits

👤 Patient 4:
   Age: 30, Gender: Male
   BMI: 24.2, Prakriti: Vata-Pitta
