<a href="https://www.kaggle.com/code/nithyasrikumaravelu/recipe-genieee?scriptVersionId=233091175" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
pip install pandas numpy torch transformers datasets accelerate sentencepiece

Note: you may need to restart the kernel to use updated packages.


*DATA PIPELINE* - Randomly generated recipe ingredients and steps, input - output pairs, saves(training dataset)

In [3]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
import os
import json
import random

def create_recipe_dataset(sample_size=10000, seed=42):
    """
    Create a compatible dataset for the flax-community/t5-recipe-generation model.
    Uses sample data formatted to match the expected input/output structure.
    """
    print("Creating dataset compatible with flax-community/t5-recipe-generation...")
    
    # Set seeds for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    
    # Load the tokenizer to understand the special tokens used by the model
    tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-recipe-generation", use_fast=True)
    
    # Sample ingredients for synthetic data
    common_ingredients = [
        "chicken", "beef", "pork", "salmon", "tuna", "shrimp", "lamb",
        "pasta", "rice", "potatoes", "bread", "flour", "oats", "quinoa",
        "onion", "garlic", "tomatoes", "bell peppers", "carrots", "broccoli", 
        "spinach", "lettuce", "mushrooms", "zucchini", "eggplant",
        "butter", "olive oil", "vegetable oil", "coconut oil", 
        "salt", "pepper", "oregano", "basil", "thyme", "rosemary", "cumin",
        "milk", "cream", "cheese", "yogurt", "eggs", "mayonnaise",
        "sugar", "brown sugar", "honey", "maple syrup",
        "beans", "lentils", "chickpeas", "tofu", "tempeh",
        "chocolate", "vanilla extract", "cinnamon", "nutmeg"
    ]
    
    # Create directory for generated data
    os.makedirs("synthetic_recipe_data", exist_ok=True)
    
    # Generate synthetic recipes
    recipes = []
    
    print(f"Generating {sample_size} synthetic recipes...")
    for i in range(sample_size):
        # Generate random recipe properties
        num_ingredients = random.randint(4, 10)
        ingredients = random.sample(common_ingredients, num_ingredients)
        
        # Create recipe title
        main_ingredient = random.choice(ingredients)
        cooking_methods = ["Roasted", "Grilled", "Baked", "Fried", "Steamed", "Sautéed", "Slow-cooked"]
        dish_types = ["Casserole", "Soup", "Stew", "Salad", "Pasta", "Curry", "Stir-fry", "Sandwich"]
        title = f"{random.choice(cooking_methods)} {main_ingredient.capitalize()} {random.choice(dish_types)}"
        
        # Create recipe directions
        directions = [
            f"Prepare all the ingredients",
            f"Heat {random.choice(['oven', 'pan', 'pot', 'skillet'])} to {random.randint(300, 450)} degrees",
            f"Mix {ingredients[0]} and {ingredients[1]} together",
            f"Add {', '.join(ingredients[2:4])} and cook for {random.randint(5, 30)} minutes",
            f"Stir in {', '.join(ingredients[4:])}",
            f"Season with salt and pepper to taste",
            f"Serve hot"
        ]
        
        # Create formatted ingredient list for model input
        input_ingredients = ", ".join(ingredients)
        
        # Format the output text with special tokens
        output_text = (
            f"title: {title} <section> "
            f"ingredients: {' <sep> '.join(ingredients)} <section> "
            f"directions: {' <sep> '.join(directions)}"
        )
        
        recipes.append({
            "input_text": f"items: {input_ingredients}",
            "output_text": output_text
        })
    
    # Save raw synthetic data
    with open("synthetic_recipe_data/synthetic_recipes.json", 'w') as f:
        json.dump(recipes, f)
    
    # Create dataset splits
    df = pd.DataFrame(recipes)
    
    # Split into train/validation/test (80/10/10)
    df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
    train_size = int(0.8 * len(df))
    val_size = int(0.1 * len(df))
    
    train_df = df[:train_size]
    val_df = df[train_size:train_size+val_size]
    test_df = df[train_size+val_size:]
    
    # Convert to HuggingFace datasets
    dataset_dict = DatasetDict({
        'train': Dataset.from_pandas(train_df),
        'validation': Dataset.from_pandas(val_df),
        'test': Dataset.from_pandas(test_df)
    })
    
    # Save dataset
    dataset_dict.save_to_disk("recipe_dataset")
    print(f"Dataset created with {len(df)} recipes")
    print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
    
    # Display a few examples
    print("\nExample recipes:")
    for i in range(3):
        idx = random.randint(0, len(df) - 1)
        print(f"\nInput: {df.iloc[idx]['input_text']}")
        print(f"Output: {df.iloc[idx]['output_text']}")
    
    return dataset_dict

if __name__ == "__main__":
    create_recipe_dataset()

Creating dataset compatible with flax-community/t5-recipe-generation...


tokenizer_config.json:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Generating 10000 synthetic recipes...


Saving the dataset (0/1 shards):   0%|          | 0/8000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Dataset created with 10000 recipes
Train: 8000, Validation: 1000, Test: 1000

Example recipes:

Input: items: salmon, nutmeg, milk, flour, cumin, bread, pasta, chicken
Output: title: Slow-cooked Flour Sandwich <section> ingredients: salmon <sep> nutmeg <sep> milk <sep> flour <sep> cumin <sep> bread <sep> pasta <sep> chicken <section> directions: Prepare all the ingredients <sep> Heat oven to 385 degrees <sep> Mix salmon and nutmeg together <sep> Add milk, flour and cook for 8 minutes <sep> Stir in cumin, bread, pasta, chicken <sep> Season with salt and pepper to taste <sep> Serve hot

Input: items: flour, lamb, oregano, potatoes, pork, carrots, rosemary
Output: title: Slow-cooked Pork Stir-fry <section> ingredients: flour <sep> lamb <sep> oregano <sep> potatoes <sep> pork <sep> carrots <sep> rosemary <section> directions: Prepare all the ingredients <sep> Heat oven to 376 degrees <sep> Mix flour and lamb together <sep> Add oregano, potatoes and cook for 8 minutes <sep> Stir in pork, ca

In [4]:
from transformers import AutoTokenizer, T5TokenizerFast
from datasets import load_from_disk
import torch
from torch.utils.data import DataLoader

def preprocess_recipe_data(model_name="t5-base", max_input_length=256, max_output_length=512):
    """
    Preprocess recipe dataset for sequence-to-sequence training with T5.
    """
    print("Preprocessing recipe dataset...")
    
    # Load dataset
    dataset = load_from_disk("recipe_dataset")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Add special tokens for recipe formatting
    special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
    tokenizer.add_special_tokens(special_tokens)
    print(f"Added special tokens to the tokenizer")
    
    def preprocess_function(examples):
        # Tokenize inputs
        model_inputs = tokenizer(
            examples["input_text"],
            max_length=max_input_length,
            padding="max_length",
            truncation=True
        )
        
        # Tokenize outputs
        labels = tokenizer(
            examples["output_text"],
            max_length=max_output_length,
            padding="max_length",
            truncation=True
        )
        
        model_inputs["labels"] = labels["input_ids"]
        
        # Replace padding token id with -100 for loss calculation
        for i in range(len(model_inputs["labels"])):
            model_inputs["labels"][i] = [
                -100 if token == tokenizer.pad_token_id else token 
                for token in model_inputs["labels"][i]
            ]
        
        return model_inputs
    
    # Apply preprocessing to all splits
    processed_dataset = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset["train"].column_names
    )
    
    # Save processed dataset
    processed_dataset.save_to_disk("preprocessed_recipe_dataset")
    print(f"Preprocessed dataset saved to 'preprocessed_recipe_dataset'")
    
    # Create PyTorch dataloaders
    def collate_fn(batch):
        input_ids = torch.tensor([item["input_ids"] for item in batch])
        attention_mask = torch.tensor([item["attention_mask"] for item in batch])
        labels = torch.tensor([item["labels"] for item in batch])
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }
    
    train_dataloader = DataLoader(
        processed_dataset["train"], 
        batch_size=8, 
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    eval_dataloader = DataLoader(
        processed_dataset["validation"], 
        batch_size=8, 
        collate_fn=collate_fn
    )
    
    print(f"Created dataloaders with batch size 8")
    print(f"Training batches: {len(train_dataloader)}, Validation batches: {len(eval_dataloader)}")
    
    return processed_dataset, tokenizer, train_dataloader, eval_dataloader

if __name__ == "__main__":
    processed_dataset, tokenizer, _, _ = preprocess_recipe_data()
    print(f"Processed dataset created with {len(processed_dataset['train'])} training examples")
    print(f"Tokenizer vocabulary size: {len(tokenizer)}")
    
    # Sample preprocessing results
    sample = processed_dataset["train"][0]
    print("\nSample preprocessed example:")
    print(f"Input IDs (first 10): {sample['input_ids'][:10]}")
    print(f"Attention mask (first 10): {sample['attention_mask'][:10]}")
    print(f"Labels (first 10): {sample['labels'][:10]}")

Preprocessing recipe dataset...


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Added special tokens to the tokenizer


Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Preprocessed dataset saved to 'preprocessed_recipe_dataset'
Created dataloaders with batch size 8
Training batches: 1000, Validation batches: 125
Processed dataset created with 8000 training examples
Tokenizer vocabulary size: 32102

Sample preprocessed example:
Input IDs (first 10): [1173, 10, 5240, 9, 6, 18684, 6, 21659, 6, 9177]
Attention mask (first 10): [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Labels (first 10): [2233, 10, 7745, 17, 721, 26, 10792, 1836, 26335, 32101]


In [5]:
class RecipePostProcessor:
    def __init__(self, tokenizer):
        """
        Initialize post-processor with tokenizer for handling special tokens.
        Designed to work with the synthetic dataset format.
        """
        self.tokenizer = tokenizer
        self.special_tokens = tokenizer.all_special_tokens
        self.tokens_map = {
            "<sep>": "--",
            "<section>": "\n"
        }
    
    def postprocess_text(self, generated_texts):
        """
        Post-process generated recipe texts:
        1. Remove special tokens except our custom <sep> and <section>
        2. Replace mapped tokens with their human-readable versions
        3. Format into structured recipes
        """
        processed_recipes = []
        
        for text in generated_texts:
            # Remove special tokens except our custom ones
            for token in self.special_tokens:
                if token not in self.tokens_map:
                    text = text.replace(token, "")
            
            # Replace mapped tokens with readable versions
            for k, v in self.tokens_map.items():
                text = text.replace(k, v)
            
            # Format structured recipe
            formatted_recipe = self._format_recipe(text)
            processed_recipes.append(formatted_recipe)
        
        return processed_recipes
    
    def _format_recipe(self, text):
        """Format recipe text into structured dictionary"""
        recipe_dict = {"title": "", "ingredients": [], "directions": []}
        
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            
            if section.startswith("title:"):
                recipe_dict["title"] = section.replace("title:", "").strip().capitalize()
            
            elif section.startswith("ingredients:"):
                ingr_text = section.replace("ingredients:", "").strip()
                recipe_dict["ingredients"] = [
                    item.strip().capitalize() for item in ingr_text.split("--") if item.strip()
                ]
            
            elif section.startswith("directions:"):
                dir_text = section.replace("directions:", "").strip()
                recipe_dict["directions"] = [
                    item.strip().capitalize() for item in dir_text.split("--") if item.strip()
                ]
        
        return recipe_dict
    
    def format_for_display(self, recipe_dict):
        """Format recipe dictionary for display"""
        display_text = f"[TITLE]: {recipe_dict['title']}\n\n"
        
        display_text += "[INGREDIENTS]:\n"
        for i, ingredient in enumerate(recipe_dict['ingredients']):
            display_text += f"  - {i+1}: {ingredient}\n"
        
        display_text += "\n[DIRECTIONS]:\n"
        for i, step in enumerate(recipe_dict['directions']):
            display_text += f"  - {i+1}: {step}\n"
        
        return display_text
    
    def evaluate_recipe_quality(self, recipe_dict):
        """Evaluate recipe quality based on simple heuristics"""
        score = 0
        max_score = 100
        
        # Check title
        if recipe_dict["title"] and len(recipe_dict["title"]) > 3:
            score += 10
        
        # Check ingredients
        if recipe_dict["ingredients"]:
            score += min(len(recipe_dict["ingredients"]) * 5, 30)  # Up to 30 points for ingredients
            
            # Check for ingredient variety (rough estimate)
            unique_words = set()
            for ingredient in recipe_dict["ingredients"]:
                unique_words.update(ingredient.lower().split())
            score += min(len(unique_words) * 2, 20)  # Up to 20 points for variety
        
        # Check directions
        if recipe_dict["directions"]:
            score += min(len(recipe_dict["directions"]) * 5, 30)  # Up to 30 points for directions
            
            # Check for detailed instructions (rough estimate by length)
            total_length = sum(len(step) for step in recipe_dict["directions"])
            score += min(total_length // 50, 10)  # Up to 10 points for detail
        
        # Normalize to 100
        normalized_score = min(score, max_score)
        
        return normalized_score

if __name__ == "__main__":
    # Example usage
    from transformers import AutoTokenizer
    
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
    tokenizer.add_special_tokens(special_tokens)
    
    processor = RecipePostProcessor(tokenizer)
    
    sample_text = "title: roasted chicken pasta <section> ingredients: chicken <sep> pasta <sep> garlic <sep> olive oil <sep> salt <sep> pepper <section> directions: prepare all the ingredients <sep> cook pasta <sep> season chicken <sep> roast chicken <sep> combine with pasta <sep> serve hot"
    
    # Process a single text
    processed = processor.postprocess_text([sample_text])[0]
    display_text = processor.format_for_display(processed)
    
    print("Original text:")
    print(sample_text)
    print("\nProcessed recipe:")
    print(display_text)
    
    # Evaluate quality
    quality_score = processor.evaluate_recipe_quality(processed)
    print(f"\nRecipe quality score: {quality_score}/100")

Original text:
title: roasted chicken pasta <section> ingredients: chicken <sep> pasta <sep> garlic <sep> olive oil <sep> salt <sep> pepper <section> directions: prepare all the ingredients <sep> cook pasta <sep> season chicken <sep> roast chicken <sep> combine with pasta <sep> serve hot

Processed recipe:
[TITLE]: Roasted chicken pasta

[INGREDIENTS]:
  - 1: Chicken
  - 2: Pasta
  - 3: Garlic
  - 4: Olive oil
  - 5: Salt
  - 6: Pepper

[DIRECTIONS]:
  - 1: Prepare all the ingredients
  - 2: Cook pasta
  - 3: Season chicken
  - 4: Roast chicken
  - 5: Combine with pasta
  - 6: Serve hot


Recipe quality score: 85/100


In [6]:
class RecipePostProcessor:
    def __init__(self, tokenizer):
        """
        Initialize post-processor with tokenizer for handling special tokens.
        Includes allergen detection and additional recipe processing.
        """
        self.tokenizer = tokenizer
        self.special_tokens = tokenizer.all_special_tokens
        self.tokens_map = {
            "<sep>": "--",
            "<section>": "\n"
        }
        
        # Allergen combinations to avoid
        self.allergen_combinations = [
            # Format: (ingredient1, ingredient2, reason)
            ("fish", "dairy", "Fish and dairy combinations can cause digestive issues for many people"),
            ("shellfish", "dairy", "Shellfish and dairy combinations can trigger allergic reactions"),
            ("fish", "yogurt", "Fish and yogurt may cause digestive issues"),
            ("fish", "milk", "Fish and milk can cause adverse reactions in some individuals"),
            ("fish", "curd", "Fish and curd combinations may cause allergic reactions"),
            ("fish", "cheese", "Fish and cheese can trigger food sensitivities"),
            ("peanuts", "gluten", "Peanuts and gluten can cause severe reactions in some people"),
            ("shellfish", "mango", "Shellfish and mango can cause histamine reactions"),
            ("strawberry", "chocolate", "Strawberry and chocolate may trigger migraine in sensitive individuals")
        ]
    
    def postprocess_text(self, generated_texts):
        """
        Post-process generated recipe texts:
        1. Remove special tokens except our custom ones
        2. Replace mapped tokens with their human-readable versions
        3. Format into structured recipes
        4. Check for allergen combinations
        """
        processed_recipes = []
        
        for text in generated_texts:
            # Remove special tokens except our custom ones
            for token in self.special_tokens:
                if token not in self.tokens_map:
                    text = text.replace(token, "")
            
            # Replace mapped tokens with readable versions
            for k, v in self.tokens_map.items():
                text = text.replace(k, v)
            
            # Format structured recipe
            formatted_recipe = self._format_recipe(text)
            
            # Check for allergen combinations
            allergen_warnings = self.check_allergens(formatted_recipe)
            formatted_recipe["allergen_warnings"] = allergen_warnings
            
            processed_recipes.append(formatted_recipe)
        
        return processed_recipes
    
    def _format_recipe(self, text):
        """Format recipe text into structured dictionary"""
        recipe_dict = {"title": "", "ingredients": [], "directions": [], "allergen_warnings": []}
        
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            
            if section.startswith("title:"):
                recipe_dict["title"] = section.replace("title:", "").strip().capitalize()
            
            elif section.startswith("ingredients:"):
                ingr_text = section.replace("ingredients:", "").strip()
                recipe_dict["ingredients"] = [
                    item.strip().capitalize() for item in ingr_text.split("--") if item.strip()
                ]
            
            elif section.startswith("directions:"):
                dir_text = section.replace("directions:", "").strip()
                recipe_dict["directions"] = [
                    item.strip().capitalize() for item in dir_text.split("--") if item.strip()
                ]
        
        return recipe_dict
    
    def check_allergens(self, recipe_dict):
        """Check for potentially problematic ingredient combinations"""
        warnings = []
        ingredients_lower = [ing.lower() for ing in recipe_dict["ingredients"]]
        
        for ing1, ing2, reason in self.allergen_combinations:
            # Check if both ingredients are present
            has_ing1 = any(ing1 in ingredient for ingredient in ingredients_lower)
            has_ing2 = any(ing2 in ingredient for ingredient in ingredients_lower)
            
            if has_ing1 and has_ing2:
                warnings.append({
                    "ingredients": (ing1, ing2),
                    "reason": reason
                })
        
        return warnings
    
    def filter_allergenic_recipes(self, recipes):
        """Filter out recipes with allergen combinations"""
        safe_recipes = []
        filtered_count = 0
        
        for recipe in recipes:
            if not recipe.get("allergen_warnings", []):
                safe_recipes.append(recipe)
            else:
                filtered_count += 1
        
        print(f"Filtered {filtered_count} recipes with allergen combinations")
        return safe_recipes
    
    def format_for_display(self, recipe_dict):
        """Format recipe dictionary for display"""
        display_text = f"[TITLE]: {recipe_dict['title']}\n\n"
        
        display_text += "[INGREDIENTS]:\n"
        for i, ingredient in enumerate(recipe_dict['ingredients']):
            display_text += f"  - {i+1}: {ingredient}\n"
        
        display_text += "\n[DIRECTIONS]:\n"
        for i, step in enumerate(recipe_dict['directions']):
            display_text += f"  - {i+1}: {step}\n"
        
        # Add allergen warnings if present
        if recipe_dict.get("allergen_warnings", []):
            display_text += "\n[⚠️ ALLERGEN WARNINGS]:\n"
            for i, warning in enumerate(recipe_dict["allergen_warnings"]):
                display_text += f"  - Warning {i+1}: {warning['ingredients'][0]} and {warning['ingredients'][1]} - {warning['reason']}\n"
        
        return display_text
    
    def evaluate_recipe_quality(self, recipe_dict):
        """Evaluate recipe quality based on simple heuristics"""
        score = 0
        max_score = 100
        
        # Check title
        if recipe_dict["title"] and len(recipe_dict["title"]) > 3:
            score += 10
        
        # Check ingredients
        if recipe_dict["ingredients"]:
            score += min(len(recipe_dict["ingredients"]) * 5, 30)  # Up to 30 points for ingredients
            
            # Check for ingredient variety (rough estimate)
            unique_words = set()
            for ingredient in recipe_dict["ingredients"]:
                unique_words.update(ingredient.lower().split())
            score += min(len(unique_words) * 2, 20)  # Up to 20 points for variety
        
        # Check directions
        if recipe_dict["directions"]:
            score += min(len(recipe_dict["directions"]) * 5, 30)  # Up to 30 points for directions
            
            # Check for detailed instructions (rough estimate by length)
            total_length = sum(len(step) for step in recipe_dict["directions"])
            score += min(total_length // 50, 10)  # Up to 10 points for detail
        
        # Deduct points for allergen warnings
        allergen_warnings = recipe_dict.get("allergen_warnings", [])
        score -= len(allergen_warnings) * 15  # Deduct 15 points per allergen warning
        
        # Normalize to 100 (but don't go below 0)
        normalized_score = max(min(score, max_score), 0)
        
        return normalized_score

if __name__ == "__main__":
    # Example usage
    from transformers import AutoTokenizer
    
    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
    tokenizer.add_special_tokens(special_tokens)
    
    processor = RecipePostProcessor(tokenizer)
    
    # Example 1: Basic recipe
    sample_text1 = "title: roasted chicken pasta <section> ingredients: chicken <sep> pasta <sep> garlic <sep> olive oil <sep> salt <sep> pepper <section> directions: prepare all the ingredients <sep> cook pasta <sep> season chicken <sep> roast chicken <sep> combine with pasta <sep> serve hot"
    
    # Example 2: Recipe with potential allergens (fish + cheese)
    sample_text2 = "title: tuna melt sandwich <section> ingredients: tuna <sep> mayonnaise <sep> cheese <sep> bread <sep> tomato <sep> onion <section> directions: mix tuna with mayonnaise <sep> place on bread <sep> top with cheese <sep> toast until cheese melts <sep> add tomato and onion <sep> serve warm"
    
    # Example 3: Another recipe with different allergens (shellfish + dairy)
    sample_text3 = "title: creamy shrimp pasta <section> ingredients: shrimp <sep> pasta <sep> heavy cream <sep> parmesan cheese <sep> garlic <sep> butter <sep> parsley <section> directions: cook pasta <sep> sauté garlic in butter <sep> add shrimp and cook <sep> add heavy cream <sep> add cooked pasta <sep> sprinkle with parmesan <sep> garnish with parsley"
    
    # Process all sample texts
    print("Processing multiple recipes:\n")
    all_sample_texts = [sample_text1, sample_text2, sample_text3]
    processed_recipes = processor.postprocess_text(all_sample_texts)
    
    for i, recipe in enumerate(processed_recipes):
        print(f"\n--- RECIPE {i+1} ---")
        display_text = processor.format_for_display(recipe)
        print(display_text)
        
        # Evaluate quality
        quality_score = processor.evaluate_recipe_quality(recipe)
        print(f"Recipe quality score: {quality_score}/100")
    
    # Demonstrate filtering allergenic recipes
    print("\nFiltering allergenic recipes:")
    safe_recipes = processor.filter_allergenic_recipes(processed_recipes)
    print(f"Original recipes: {len(processed_recipes)}, Safe recipes: {len(safe_recipes)}")
    
    # Display the safe recipes
    print("\nSafe recipes after filtering:")
    for i, recipe in enumerate(safe_recipes):
        print(f"\n--- SAFE RECIPE {i+1} ---")
        print(processor.format_for_display(recipe))

Processing multiple recipes:


--- RECIPE 1 ---
[TITLE]: Roasted chicken pasta

[INGREDIENTS]:
  - 1: Chicken
  - 2: Pasta
  - 3: Garlic
  - 4: Olive oil
  - 5: Salt
  - 6: Pepper

[DIRECTIONS]:
  - 1: Prepare all the ingredients
  - 2: Cook pasta
  - 3: Season chicken
  - 4: Roast chicken
  - 5: Combine with pasta
  - 6: Serve hot

Recipe quality score: 85/100

--- RECIPE 2 ---
[TITLE]: Tuna melt sandwich

[INGREDIENTS]:
  - 1: Tuna
  - 2: Mayonnaise
  - 3: Cheese
  - 4: Bread
  - 5: Tomato
  - 6: Onion

[DIRECTIONS]:
  - 1: Mix tuna with mayonnaise
  - 2: Place on bread
  - 3: Top with cheese
  - 4: Toast until cheese melts
  - 5: Add tomato and onion
  - 6: Serve warm

Recipe quality score: 84/100

--- RECIPE 3 ---
[TITLE]: Creamy shrimp pasta

[INGREDIENTS]:
  - 1: Shrimp
  - 2: Pasta
  - 3: Heavy cream
  - 4: Parmesan cheese
  - 5: Garlic
  - 6: Butter
  - 7: Parsley

[DIRECTIONS]:
  - 1: Cook pasta
  - 2: Sauté garlic in butter
  - 3: Add shrimp and cook
  - 4: Add heavy cream
  

In [7]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time

# Allergen test class that uses our RecipePostProcessor
class AllergenDetectionTester:
    def __init__(self, model_path="./recipe_model"):
        """Initialize the tester with model and post-processor."""
        print(f"Loading model and tokenizer from {model_path}...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        except:
            print("Model not found, using t5-base for demonstration")
            self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
            self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
            
        # Add special tokens for recipe formatting if needed
        special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
        self.tokenizer.add_special_tokens(special_tokens)
        
        # Initialize the post-processor
        self.processor = RecipePostProcessor(self.tokenizer)
        
        # Optimize model for inference
        self.model = self.model.eval()
        if torch.cuda.is_available():
            print("Using GPU for inference")
            self.model = self.model.cuda()
        else:
            print("Using CPU for inference")
            
        # Default generation parameters
        self.generation_kwargs = {
            "max_length": 512,
            "min_length": 64,
            "no_repeat_ngram_size": 3,
            "num_beams": 4,
            "temperature": 0.8,
            "do_sample": True,
            "top_p": 0.95
        }
        
        print("Tester initialized")
        
    def generate_recipe(self, ingredients):
        """Generate a recipe from ingredients."""
        # Format input
        if isinstance(ingredients, list):
            ingredients = ", ".join(ingredients)
            
        input_text = f"items: {ingredients}"
        
        # For demonstration, we'll also create a mock recipe if model is t5-base
        if "t5-base" in self.tokenizer.name_or_path:
            print("Using mock recipe generation for demonstration")
            # Create a mock recipe that includes all ingredients
            ingredients_list = [ing.strip() for ing in ingredients.split(",")]
            mock_recipe = (
                f"title: {ingredients_list[0].capitalize()} Recipe <section> "
                f"ingredients: {' <sep> '.join(ingredients_list)} <section> "
                f"directions: mix ingredients <sep> cook <sep> serve"
            )
            return mock_recipe
            
        # Tokenize
        inputs = self.tokenizer(
            input_text, 
            max_length=256, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
        
        # Move to GPU if available
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # Generate
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                **self.generation_kwargs
            )
        
        # Decode
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=False)
        return generated_text
        
    def test_allergen_combinations(self):
        """Test various allergen combinations."""
        test_combinations = [
            "fish, cheese, pasta, garlic, olive oil",
            "shrimp, milk, pasta, garlic, butter",
            "tuna, yogurt, onions, celery, mayonnaise",
            "salmon, curd, lemon, dill, rice",
            "chicken, rice, vegetables, soy sauce",  # Non-allergenic control
            "shellfish, mango, coconut, lime, ginger",
            "peanuts, wheat flour, sugar, eggs, butter",
            "strawberry, chocolate, cream, sugar, vanilla"
        ]
        
        results = []
        
        print("\n======= TESTING ALLERGEN COMBINATIONS =======\n")
        
        for ingredients in test_combinations:
            print(f"Testing ingredients: {ingredients}")
            
            # Generate a recipe
            recipe_text = self.generate_recipe(ingredients)
            
            # Post-process the recipe
            processed_recipes = self.processor.postprocess_text([recipe_text])
            recipe = processed_recipes[0]
            
            # Calculate quality score
            quality_score = self.processor.evaluate_recipe_quality(recipe)
            
            # Format the recipe for display
            formatted_recipe = self.processor.format_for_display(recipe)
            
            results.append({
                "ingredients": ingredients,
                "recipe": recipe,
                "formatted_text": formatted_recipe,
                "quality_score": quality_score,
                "allergen_warnings": recipe.get("allergen_warnings", [])
            })
            
            # Print the result
            print(f"\n{formatted_recipe}")
            print(f"Quality Score: {quality_score}/100")
            print("-" * 50)
            
        # Summary of allergen detections
        print("\n======= ALLERGEN DETECTION SUMMARY =======\n")
        for result in results:
            warnings = result.get("allergen_warnings", [])
            print(f"Ingredients: {result['ingredients']}")
            print(f"Allergen Warnings: {len(warnings)}")
            for warning in warnings:
                print(f"  - {warning['ingredients'][0]} + {warning['ingredients'][1]}: {warning['reason']}")
            print()
            
        # Count safe vs. unsafe recipes
        safe_recipes = [r for r in results if not r.get("allergen_warnings")]
        print(f"Total Recipes: {len(results)}")
        print(f"Safe Recipes: {len(safe_recipes)}")
        print(f"Unsafe Recipes: {len(results) - len(safe_recipes)}")
        
        return results

# Recipe Post-processor definition
class RecipePostProcessor:
    def __init__(self, tokenizer):
        """
        Initialize post-processor with tokenizer for handling special tokens.
        Includes allergen detection and additional recipe processing.
        """
        self.tokenizer = tokenizer
        self.special_tokens = tokenizer.all_special_tokens
        self.tokens_map = {
            "<sep>": "--",
            "<section>": "\n"
        }
        
        # Allergen combinations to avoid
        self.allergen_combinations = [
            # Format: (ingredient1, ingredient2, reason)
            ("fish", "dairy", "Fish and dairy combinations can cause digestive issues for many people"),
            ("fish", "yogurt", "Fish and yogurt may cause digestive issues"),
            ("fish", "milk", "Fish and milk can cause adverse reactions in some individuals"),
            ("fish", "curd", "Fish and curd combinations may cause allergic reactions"),
            ("fish", "cheese", "Fish and cheese can trigger food sensitivities"),
            ("tuna", "cheese", "Tuna and cheese can trigger food sensitivities"),
            ("salmon", "cheese", "Salmon and cheese can trigger food sensitivities"),
            ("shellfish", "dairy", "Shellfish and dairy combinations can trigger allergic reactions"),
            ("shellfish", "milk", "Shellfish and milk can trigger allergic reactions"),
            ("shrimp", "milk", "Shrimp and milk can trigger allergic reactions"),
            ("peanuts", "gluten", "Peanuts and gluten can cause severe reactions in some people"),
            ("peanuts", "wheat", "Peanuts and wheat can cause severe reactions in some people"),
            ("shellfish", "mango", "Shellfish and mango can cause histamine reactions"),
            ("strawberry", "chocolate", "Strawberry and chocolate may trigger migraine in sensitive individuals")
        ]
    
    def postprocess_text(self, generated_texts):
        """
        Post-process generated recipe texts:
        1. Remove special tokens except our custom ones
        2. Replace mapped tokens with their human-readable versions
        3. Format into structured recipes
        4. Check for allergen combinations
        """
        processed_recipes = []
        
        for text in generated_texts:
            # Remove special tokens except our custom ones
            for token in self.special_tokens:
                if token not in self.tokens_map:
                    text = text.replace(token, "")
            
            # Replace mapped tokens with readable versions
            for k, v in self.tokens_map.items():
                text = text.replace(k, v)
            
            # Format structured recipe
            formatted_recipe = self._format_recipe(text)
            
            # Check for allergen combinations
            allergen_warnings = self.check_allergens(formatted_recipe)
            formatted_recipe["allergen_warnings"] = allergen_warnings
            
            processed_recipes.append(formatted_recipe)
        
        return processed_recipes
    
    def _format_recipe(self, text):
        """Format recipe text into structured dictionary"""
        recipe_dict = {"title": "", "ingredients": [], "directions": [], "allergen_warnings": []}
        
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            
            if section.startswith("title:"):
                recipe_dict["title"] = section.replace("title:", "").strip().capitalize()
            
            elif section.startswith("ingredients:"):
                ingr_text = section.replace("ingredients:", "").strip()
                recipe_dict["ingredients"] = [
                    item.strip().capitalize() for item in ingr_text.split("--") if item.strip()
                ]
            
            elif section.startswith("directions:"):
                dir_text = section.replace("directions:", "").strip()
                recipe_dict["directions"] = [
                    item.strip().capitalize() for item in dir_text.split("--") if item.strip()
                ]
        
        return recipe_dict
    
    def check_allergens(self, recipe_dict):
        """Check for potentially problematic ingredient combinations"""
        warnings = []
        ingredients_lower = [ing.lower() for ing in recipe_dict["ingredients"]]
        
        for ing1, ing2, reason in self.allergen_combinations:
            # Check if both ingredients are present
            has_ing1 = any(ing1 in ingredient for ingredient in ingredients_lower)
            has_ing2 = any(ing2 in ingredient for ingredient in ingredients_lower)
            
            if has_ing1 and has_ing2:
                warnings.append({
                    "ingredients": (ing1, ing2),
                    "reason": reason
                })
        
        return warnings
    
    def format_for_display(self, recipe_dict):
        """Format recipe dictionary for display"""
        display_text = f"[TITLE]: {recipe_dict['title']}\n\n"
        
        display_text += "[INGREDIENTS]:\n"
        for i, ingredient in enumerate(recipe_dict['ingredients']):
            display_text += f"  - {i+1}: {ingredient}\n"
        
        display_text += "\n[DIRECTIONS]:\n"
        for i, step in enumerate(recipe_dict['directions']):
            display_text += f"  - {i+1}: {step}\n"
        
        # Add allergen warnings if present
        if recipe_dict.get("allergen_warnings", []):
            display_text += "\n[⚠️ ALLERGEN WARNINGS]:\n"
            for i, warning in enumerate(recipe_dict["allergen_warnings"]):
                display_text += f"  - Warning {i+1}: {warning['ingredients'][0]} and {warning['ingredients'][1]} - {warning['reason']}\n"
        
        return display_text
    
    def evaluate_recipe_quality(self, recipe_dict):
        """Evaluate recipe quality based on simple heuristics"""
        score = 0
        max_score = 100
        
        # Check title
        if recipe_dict["title"] and len(recipe_dict["title"]) > 3:
            score += 10
        
        # Check ingredients
        if recipe_dict["ingredients"]:
            score += min(len(recipe_dict["ingredients"]) * 5, 30)  # Up to 30 points for ingredients
            
            # Check for ingredient variety (rough estimate)
            unique_words = set()
            for ingredient in recipe_dict["ingredients"]:
                unique_words.update(ingredient.lower().split())
            score += min(len(unique_words) * 2, 20)  # Up to 20 points for variety
        
        # Check directions
        if recipe_dict["directions"]:
            score += min(len(recipe_dict["directions"]) * 5, 30)  # Up to 30 points for directions
            
            # Check for detailed instructions (rough estimate by length)
            total_length = sum(len(step) for step in recipe_dict["directions"])
            score += min(total_length // 50, 10)  # Up to 10 points for detail
        
        # Deduct points for allergen warnings
        allergen_warnings = recipe_dict.get("allergen_warnings", [])
        score -= len(allergen_warnings) * 15  # Deduct 15 points per allergen warning
        
        # Normalize to 100 (but don't go below 0)
        normalized_score = max(min(score, max_score), 0)
        
        return normalized_score

# Run the test
if __name__ == "__main__":
    tester = AllergenDetectionTester()
    results = tester.test_allergen_combinations()

Loading model and tokenizer from ./recipe_model...
Model not found, using t5-base for demonstration


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Using GPU for inference
Tester initialized


Testing ingredients: fish, cheese, pasta, garlic, olive oil
Using mock recipe generation for demonstration

[TITLE]: Fish recipe

[INGREDIENTS]:
  - 1: Fish
  - 2: Cheese
  - 3: Pasta
  - 4: Garlic
  - 5: Olive oil

[DIRECTIONS]:
  - 1: Mix ingredients
  - 2: Cook
  - 3: Serve


Quality Score: 47/100
--------------------------------------------------
Testing ingredients: shrimp, milk, pasta, garlic, butter
Using mock recipe generation for demonstration

[TITLE]: Shrimp recipe

[INGREDIENTS]:
  - 1: Shrimp
  - 2: Milk
  - 3: Pasta
  - 4: Garlic
  - 5: Butter

[DIRECTIONS]:
  - 1: Mix ingredients
  - 2: Cook
  - 3: Serve


Quality Score: 45/100
--------------------------------------------------
Testing ingredients: tuna, yogurt, onions, celery, mayonnaise
Using mock recipe generation for demonstration

[TITLE]: Tuna recipe

[INGREDIENTS]:
  - 1: Tuna
  - 2: Yogurt
  - 3: Onions
  - 4: Celery
  - 5: Mayonnaise

[DIRECTIONS]:
  - 1: Mix ingredie

In [8]:
import os
import numpy as np
import torch
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)
from datasets import load_from_disk, DatasetDict # Import DatasetDict for type hint clarity
import time
import logging # Import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def fine_tune_recipe_model(
    base_model: str = "t5-base",
    output_dir: str = "./recipe_model_finetuned", # Default to safer output dir
    preprocessed_dataset_dir: str = "preprocessed_recipe_dataset",
    num_train_epochs: int = 1,
    learning_rate: float = 5e-5,
    weight_decay: float = 0.01,
    warmup_ratio: float = 0.1,
    fp16: bool = True,
    batch_size: int = 4
):
    """
    Fine-tune a T5 model for recipe generation using a preprocessed dataset.
    Includes fix for compute_metrics error.
    """
    logging.info(f"Starting fine-tuning process for {base_model}...")
    logging.info(f"--- Configuration ---")
    logging.info(f"Output Directory: {output_dir}")
    logging.info(f"Dataset Directory: {preprocessed_dataset_dir}")
    logging.info(f"Epochs: {num_train_epochs}")
    logging.info(f"Learning Rate: {learning_rate}")
    logging.info(f"Batch Size: {batch_size}")
    logging.info(f"Eval/Save Steps: 500") # Hardcoded based on previous request
    logging.info(f"FP16 Enabled: {fp16}")
    logging.info(f"---------------------")

    # --- 1. Load Preprocessed Dataset ---
    try:
        # Ensure the path exists
        if not os.path.isdir(preprocessed_dataset_dir):
             raise FileNotFoundError(f"Dataset directory not found: {preprocessed_dataset_dir}")

        processed_dataset = load_from_disk(preprocessed_dataset_dir)
        logging.info(f"Loaded preprocessed dataset: {processed_dataset}")

        # Validate dataset structure
        if not isinstance(processed_dataset, DatasetDict):
            raise TypeError(f"Expected loaded object to be a DatasetDict, but got {type(processed_dataset)}")
        if "train" not in processed_dataset:
             raise ValueError("Dataset missing 'train' split.")
        if "validation" not in processed_dataset:
             raise ValueError("Dataset missing 'validation' split.")
             # Test split is optional but recommended

        logging.info(f"Training examples: {len(processed_dataset['train'])}")
        logging.info(f"Validation examples: {len(processed_dataset['validation'])}")
        if "test" in processed_dataset:
            logging.info(f"Test examples: {len(processed_dataset['test'])}")

        # Example calculation for steps
        steps_per_epoch = len(processed_dataset['train']) // batch_size
        total_steps = steps_per_epoch * num_train_epochs
        logging.info(f"Estimated steps per epoch: {steps_per_epoch}")
        logging.info(f"Estimated total training steps: {total_steps}")

        # Verify necessary columns exist in a sample (important!)
        sample = processed_dataset["train"][0]
        required_columns = ['input_ids', 'attention_mask', 'labels']
        missing_columns = [col for col in required_columns if col not in sample]
        if missing_columns:
            raise ValueError(f"Dataset samples missing required columns: {missing_columns}. Found keys: {list(sample.keys())}")
        logging.info("Dataset structure validated successfully.")


    except Exception as e:
        logging.error(f"Error loading or validating preprocessed dataset from '{preprocessed_dataset_dir}': {e}", exc_info=True)
        return None, None

    # --- 2. Load Tokenizer and Model ---
    try:
        tokenizer = AutoTokenizer.from_pretrained(base_model)
        logging.info(f"Loaded tokenizer: {base_model}")

        # Add special tokens
        special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
        num_added_tokens = tokenizer.add_special_tokens(special_tokens)
        if num_added_tokens > 0:
            logging.info(f"Added {num_added_tokens} special tokens: {special_tokens['additional_special_tokens']}")

        model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
        logging.info(f"Loaded base model: {base_model}")

        # Resize embeddings for new tokens
        model.resize_token_embeddings(len(tokenizer))
        logging.info(f"Resized model token embeddings to size {len(tokenizer)}")

        # Check trainable parameters
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        if trainable_params == 0:
            logging.warning("Model has no trainable parameters!")
        else:
            logging.info(f"Model trainable parameters: {trainable_params:,}")

    except Exception as e:
        logging.error(f"Error loading model or tokenizer '{base_model}': {e}", exc_info=True)
        return None, None

    # --- 3. Training Arguments ---
    run_name = f"recipe-{base_model.split('/')[-1]}-{int(time.time())}" # More specific run name
    effective_eval_save_steps = 500 # Using the previously requested value

    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        # Strategies
        eval_strategy="steps",
        save_strategy="steps",
        logging_strategy="steps",
        # Steps
        eval_steps=effective_eval_save_steps,
        save_steps=effective_eval_save_steps,
        logging_steps=50, # Log loss more frequently
        # Hyperparameters
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2, # Can often use larger batch size for eval
        weight_decay=weight_decay,
        num_train_epochs=num_train_epochs,
        warmup_ratio=warmup_ratio,
        # Checkpointing and Best Model
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="loss", # Use validation loss to find the best model
        greater_is_better=False,
        # Performance
        fp16=fp16, # Use mixed precision if available/enabled
        optim="adamw_torch", # Recommended optimizer
        dataloader_num_workers=os.cpu_count() // 2 if os.cpu_count() else 2, # Adjust based on your system
        # Generation settings (for predict_with_generate)
        predict_with_generate=True,
        generation_max_length=512,
        generation_num_beams=4,
        # Reporting
        report_to=["tensorboard"],
        run_name=run_name,
    )

    # --- 4. Data Collator ---
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding="max_length", # Ensure all sequences in a batch have the same length
        max_length=512, # Max length for tokenization
        label_pad_token_id=tokenizer.pad_token_id # Ensure labels are padded correctly
    )

    # --- 5. Compute Metrics (CORRECTED) ---
    def compute_metrics(eval_preds):
        # The Trainer calculates and logs validation loss automatically.
        # Since metric_for_best_model='loss', this loss is used for
        # checkpointing the best model and for early stopping.
        # We don't need to calculate any additional metrics here,
        # so we return an empty dictionary.
        # If you wanted ROUGE/BLEU, you would decode preds/labels and compute them here.
        # predictions, labels = eval_preds
        # decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        # decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        # Add metric calculation here... e.g., rouge = compute_rouge(decoded_preds, decoded_labels)
        return {} # Return empty dict as loss is handled internally

    # --- 6. Initialize Trainer ---
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset["train"],
        eval_dataset=processed_dataset["validation"],
        data_collator=data_collator,
        tokenizer=tokenizer, # Pass tokenizer for saving correctly
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)] # Stop if loss doesn't improve significantly
    )

    # --- 7. Train ---
    logging.info("Starting fine-tuning...")
    train_result = None
    try:
        train_result = trainer.train()
        logging.info("Training completed successfully!")

        # Log training metrics
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)

    except Exception as e:
        logging.error(f"An error occurred during training: {e}", exc_info=True)
        # The debug print from before is less useful now with logging/validation
        return None, None # Return None if training failed

    # --- 8. Save Final Model ---
    # save_model saves the tokenizer too if passed to Trainer constructor
    try:
        logging.info(f"Saving the best model to {output_dir}...")
        trainer.save_model(output_dir)
        # tokenizer.save_pretrained(output_dir) # Trainer should handle this
        logging.info(f"Best model saved successfully.")
    except Exception as e:
        logging.error(f"Could not save the final model: {e}", exc_info=True)


    # --- 9. Evaluate on Test Set ---
    if "test" in processed_dataset:
        logging.info("Evaluating the best model on the test set...")
        try:
            test_results = trainer.evaluate(eval_dataset=processed_dataset["test"])
            logging.info(f"Test set evaluation results: {test_results}")
            trainer.log_metrics("test", test_results)
            trainer.save_metrics("test", test_results)
        except Exception as e:
            logging.error(f"Could not evaluate on the test set: {e}", exc_info=True)
    else:
        logging.warning("No 'test' split found in the dataset. Skipping final test evaluation.")

    logging.info("Fine-tuning script finished.")
    # Return the trained model and tokenizer in memory (may use significant RAM)
    # It's often better to load the saved model later using AutoModel...from_pretrained(output_dir)
    return model, tokenizer

# --- Main Execution Block ---
if __name__ == "__main__":
    print("============================================")
    print(" Recipe Model Fine-Tuning Script ")
    print("============================================")
    print("Ensure 'accelerate' library is installed for optimized training and FP16: pip install accelerate -U")
    print("Ensure your preprocessed dataset exists at './preprocessed_recipe_dataset/'")
    print("--------------------------------------------")

    # Determine FP16 availability
    use_fp16 = torch.cuda.is_available()
    if use_fp16:
        print("CUDA detected. FP16 training will be enabled.")
    else:
        print("CUDA not detected. FP16 training will be disabled (runs on CPU or MPS).")

    # Define dataset and output paths
    dataset_path = "preprocessed_recipe_dataset"
    model_output_path = "./recipe_model_finetuned" # Separate output dir

    # --- Call the Fine-Tuning Function ---
    model, tokenizer = fine_tune_recipe_model(
         base_model="t5-base",
         preprocessed_dataset_dir=dataset_path,
         output_dir=model_output_path,
         num_train_epochs=1,
         batch_size=4, # Keep batch size small if memory is limited
         fp16=use_fp16
     )

    # --- Report Outcome ---
    print("--------------------------------------------")
    if model and tokenizer:
        print(f"✅ Fine-tuning process finished successfully.")
        print(f"   Model and tokenizer returned (in memory).")
        print(f"   Best model saved to: {model_output_path}")
    else:
        print(f"❌ Fine-tuning process did not complete successfully.")
        print(f"   Please check the log messages above for errors.")
    print("============================================")

 Recipe Model Fine-Tuning Script 
Ensure 'accelerate' library is installed for optimized training and FP16: pip install accelerate -U
Ensure your preprocessed dataset exists at './preprocessed_recipe_dataset/'
--------------------------------------------
CUDA detected. FP16 training will be enabled.


  trainer = Seq2SeqTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss
500,0.495,0.410472
1000,0.364,0.290418
1500,0.3161,0.256346
2000,0.2966,0.248394


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


***** train metrics *****
  epoch                    =        1.0
  total_flos               =  4537089GF
  train_loss               =     0.8938
  train_runtime            = 1:01:52.17
  train_samples_per_second =      2.155
  train_steps_per_second   =      0.539


***** test metrics *****
  epoch                   =        1.0
  eval_loss               =     0.2514
  eval_runtime            = 0:09:25.73
  eval_samples_per_second =      1.768
  eval_steps_per_second   =      0.221
--------------------------------------------
✅ Fine-tuning process finished successfully.
   Model and tokenizer returned (in memory).
   Best model saved to: ./recipe_model_finetuned


In [9]:
# First, make sure your Google Drive is mounted
from kaggle_web_client import KaggleWebClient
from kaggle_datasets import KaggleDatasets
import os
import zipfile
import shutil

# Path to your fine-tuned model in Kaggle
model_path = "./recipe_model_finetuned"  # Update this to your model path

# Create a zip file of your model
zip_path = "./recipe_model_finetuned.zip"
print(f"Creating zip file at {zip_path}...")

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(model_path):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, start=model_path)
            print(f"Adding {arcname} to zip...")
            zipf.write(file_path, arcname=arcname)

print(f"Zip file created successfully at {zip_path}")

# Path in Google Drive where you want to save
drive_path = "/kaggle/drive/MyDrive/RecipeGenie"

# Create the directory if it doesn't exist
os.makedirs(drive_path, exist_ok=True)

# Copy the zip file to Google Drive
shutil.copy(zip_path, os.path.join(drive_path, "recipe_model_finetuned.zip"))

print(f"Model zip file has been saved to Google Drive at {drive_path}")

Creating zip file at ./recipe_model_finetuned.zip...
Adding generation_config.json to zip...
Adding training_args.bin to zip...
Adding train_results.json to zip...
Adding special_tokens_map.json to zip...
Adding config.json to zip...
Adding tokenizer.json to zip...
Adding spiece.model to zip...
Adding model.safetensors to zip...
Adding all_results.json to zip...
Adding test_results.json to zip...
Adding added_tokens.json to zip...
Adding tokenizer_config.json to zip...
Adding runs/Apr10_10-33-49_ae2120e9f195/events.out.tfevents.1744285510.ae2120e9f195.31.1 to zip...
Adding runs/Apr10_10-33-49_ae2120e9f195/events.out.tfevents.1744281230.ae2120e9f195.31.0 to zip...
Adding checkpoint-1500/generation_config.json to zip...
Adding checkpoint-1500/training_args.bin to zip...
Adding checkpoint-1500/scheduler.pt to zip...
Adding checkpoint-1500/special_tokens_map.json to zip...
Adding checkpoint-1500/config.json to zip...
Adding checkpoint-1500/tokenizer.json to zip...
Adding checkpoint-1500/sp

In [10]:
from IPython.display import FileLink, display

# Zip your model
import zipfile
import os

model_path = "./recipe_model_finetuned"
zip_path = "./recipe_model_finetuned.zip"

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(model_path):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, start=model_path)
            zipf.write(file_path, arcname=arcname)

print(f"Created zip file at {zip_path}")

# Generate download link
print("Click the link below to download the model zip file:")
display(FileLink(zip_path))

Created zip file at ./recipe_model_finetuned.zip
Click the link below to download the model zip file:


In [10]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os

class RecipePostProcessor:
    def __init__(self, tokenizer):
        """
        Initialize post-processor with tokenizer for handling special tokens.
        Includes allergen detection and additional recipe processing.
        """
        self.tokenizer = tokenizer
        self.special_tokens = tokenizer.all_special_tokens
        self.tokens_map = {
            "<sep>": "--",
            "<section>": "\n"
        }
        
        # Allergen combinations to avoid
        self.allergen_combinations = [
            # Format: (ingredient1, ingredient2, reason)
            ("fish", "dairy", "Fish and dairy combinations can cause digestive issues for many people"),
            ("fish", "yogurt", "Fish and yogurt may cause digestive issues"),
            ("fish", "milk", "Fish and milk can cause adverse reactions in some individuals"),
            ("fish", "curd", "Fish and curd combinations may cause allergic reactions"),
            ("fish", "cheese", "Fish and cheese can trigger food sensitivities"),
            ("tuna", "cheese", "Tuna and cheese can trigger food sensitivities"),
            ("salmon", "cheese", "Salmon and cheese can trigger food sensitivities"),
            ("shellfish", "dairy", "Shellfish and dairy combinations can trigger allergic reactions"),
            ("shellfish", "milk", "Shellfish and milk can trigger allergic reactions"),
            ("shrimp", "milk", "Shrimp and milk can trigger allergic reactions"),
            ("peanuts", "gluten", "Peanuts and gluten can cause severe reactions in some people"),
            ("peanuts", "wheat", "Peanuts and wheat can cause severe reactions in some people"),
            ("shellfish", "mango", "Shellfish and mango can cause histamine reactions"),
            ("strawberry", "chocolate", "Strawberry and chocolate may trigger migraine in sensitive individuals")
        ]
    
    def postprocess_text(self, text):
        """
        Post-process generated recipe text:
        1. Remove special tokens except our custom ones
        2. Replace mapped tokens with their human-readable versions
        3. Format into structured recipe
        4. Check for allergen combinations
        """
        # Remove special tokens except our custom ones
        for token in self.special_tokens:
            if token not in self.tokens_map:
                text = text.replace(token, "")
        
        # Replace mapped tokens with readable versions
        for k, v in self.tokens_map.items():
            text = text.replace(k, v)
        
        # Format structured recipe
        recipe_dict = self._format_recipe(text)
        
        # Check for allergen combinations
        allergen_warnings = self.check_allergens(recipe_dict)
        recipe_dict["allergen_warnings"] = allergen_warnings
        
        return recipe_dict
    
    def _format_recipe(self, text):
        """Format recipe text into structured dictionary"""
        recipe_dict = {"title": "", "ingredients": [], "directions": [], "allergen_warnings": []}
        
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            
            if section.startswith("title:"):
                recipe_dict["title"] = section.replace("title:", "").strip().capitalize()
            
            elif section.startswith("ingredients:"):
                ingr_text = section.replace("ingredients:", "").strip()
                recipe_dict["ingredients"] = [
                    item.strip().capitalize() for item in ingr_text.split("--") if item.strip()
                ]
            
            elif section.startswith("directions:"):
                dir_text = section.replace("directions:", "").strip()
                recipe_dict["directions"] = [
                    item.strip().capitalize() for item in dir_text.split("--") if item.strip()
                ]
        
        return recipe_dict
    
    def check_allergens(self, recipe_dict):
        """Check for potentially problematic ingredient combinations"""
        warnings = []
        ingredients_lower = [ing.lower() for ing in recipe_dict["ingredients"]]
        
        for ing1, ing2, reason in self.allergen_combinations:
            # Check if both ingredients are present
            has_ing1 = any(ing1 in ingredient for ingredient in ingredients_lower)
            has_ing2 = any(ing2 in ingredient for ingredient in ingredients_lower)
            
            if has_ing1 and has_ing2:
                warnings.append({
                    "ingredients": (ing1, ing2),
                    "reason": reason
                })
        
        return warnings
    
    def format_for_display(self, recipe_dict):
        """Format recipe dictionary for display"""
        display_text = f"[TITLE]: {recipe_dict['title']}\n\n"
        
        display_text += "[INGREDIENTS]:\n"
        for i, ingredient in enumerate(recipe_dict['ingredients']):
            display_text += f"  - {i+1}: {ingredient}\n"
        
        display_text += "\n[DIRECTIONS]:\n"
        for i, step in enumerate(recipe_dict['directions']):
            display_text += f"  - {i+1}: {step}\n"
        
        # Add allergen warnings if present
        if recipe_dict.get("allergen_warnings", []):
            display_text += "\n[⚠️ ALLERGEN WARNINGS]:\n"
            for i, warning in enumerate(recipe_dict["allergen_warnings"]):
                display_text += f"  - Warning {i+1}: {warning['ingredients'][0]} and {warning['ingredients'][1]} - {warning['reason']}\n"
        
        return display_text
    
    def evaluate_recipe_quality(self, recipe_dict):
        """Evaluate recipe quality based on simple heuristics"""
        score = 0
        max_score = 100
        
        # Check title
        if recipe_dict["title"] and len(recipe_dict["title"]) > 3:
            score += 10
        
        # Check ingredients
        if recipe_dict["ingredients"]:
            score += min(len(recipe_dict["ingredients"]) * 5, 30)  # Up to 30 points for ingredients
            
            # Check for ingredient variety (rough estimate)
            unique_words = set()
            for ingredient in recipe_dict["ingredients"]:
                unique_words.update(ingredient.lower().split())
            score += min(len(unique_words) * 2, 20)  # Up to 20 points for variety
        
        # Check directions
        if recipe_dict["directions"]:
            score += min(len(recipe_dict["directions"]) * 5, 30)  # Up to 30 points for directions
            
            # Check for detailed instructions (rough estimate by length)
            total_length = sum(len(step) for step in recipe_dict["directions"])
            score += min(total_length // 50, 10)  # Up to 10 points for detail
        
        # Deduct points for allergen warnings
        allergen_warnings = recipe_dict.get("allergen_warnings", [])
        score -= len(allergen_warnings) * 15  # Deduct 15 points per allergen warning
        
        # Normalize to 100 (but don't go below 0)
        normalized_score = max(min(score, max_score), 0)
        
        return normalized_score


class RecipeModelTester:
    def __init__(self, model_path="./recipe_model_finetuned"):
        """Initialize the model tester with the path to the fine-tuned model."""
        print(f"Loading model and tokenizer from {model_path}...")
        
        # Try to load the fine-tuned model, or fall back to the pre-trained model
        try:
            if os.path.exists(model_path):
                self.tokenizer = AutoTokenizer.from_pretrained(model_path)
                self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
                print(f"Successfully loaded fine-tuned model from {model_path}")
            else:
                # Fall back to original pre-trained model
                print(f"Fine-tuned model not found. Falling back to pre-trained model.")
                self.tokenizer = AutoTokenizer.from_pretrained("flax-community/t5-recipe-generation")
                self.model = AutoModelForSeq2SeqLM.from_pretrained("flax-community/t5-recipe-generation")
                print("Successfully loaded pre-trained 'flax-community/t5-recipe-generation' model")
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Falling back to t5-base model as a last resort")
            self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
            self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
            print("Loaded t5-base model")
        
        # Add special tokens if needed
        if "<sep>" not in self.tokenizer.get_vocab():
            special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
            self.tokenizer.add_special_tokens(special_tokens)
            self.model.resize_token_embeddings(len(self.tokenizer))
            print("Added special tokens to tokenizer")
        
        # Optimize model for inference
        self.model = self.model.eval()
        if torch.cuda.is_available():
            print("Using GPU for inference")
            self.model = self.model.cuda()
            try:
                self.model = self.model.half()  # Use FP16 for faster inference
                print("Using half precision for faster inference")
            except:
                print("Half precision not supported, using full precision instead")
        else:
            print("Using CPU for inference")
        
        # Default generation parameters
        self.generation_kwargs = {
            "max_length": 512,
            "min_length": 64,
            "no_repeat_ngram_size": 3,
            "num_beams": 4,
            "early_stopping": True,
            "length_penalty": 1.2,
            "do_sample": True,
            "temperature": 0.8,
            "top_k": 50,
            "top_p": 0.95
        }
        
        # Initialize post-processor
        self.postprocessor = RecipePostProcessor(self.tokenizer)
        
        print("Model initialized and ready for testing")

    def generate_recipe(self, ingredients, **generation_params):
        """Generate a recipe from a list of ingredients with optional parameters."""
        # Format input
        if isinstance(ingredients, list):
            ingredients = ", ".join(ingredients)
        
        input_text = f"items: {ingredients}"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text, 
            max_length=256, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )
        
        # Move to GPU if available
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
        
        # Update generation parameters if provided
        gen_kwargs = self.generation_kwargs.copy()
        gen_kwargs.update(generation_params)
        
        # Generate
        start_time = time.time()
        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                **gen_kwargs
            )
        generation_time = time.time() - start_time
        
        # Decode and post-process
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=False)
        recipe_dict = self.postprocessor.postprocess_text(generated_text)
        formatted_text = self.postprocessor.format_for_display(recipe_dict)
        
        # Evaluate quality
        quality_score = self.postprocessor.evaluate_recipe_quality(recipe_dict)
        
        return {
            "raw_text": generated_text,
            "recipe_dict": recipe_dict,
            "formatted_text": formatted_text,
            "generation_time": generation_time,
            "quality_score": quality_score,
            "allergen_warnings": recipe_dict.get("allergen_warnings", [])
        }
    
    def test_allergen_combinations(self):
        """Test various allergen combinations."""
        test_combinations = [
            "fish, cheese, pasta, garlic, olive oil",
            "shrimp, milk, pasta, garlic, butter",
            "tuna, yogurt, onions, celery, mayonnaise",
            "salmon, curd, lemon, dill, rice",
            "chicken, rice, vegetables, soy sauce",  # Non-allergenic control
            "shellfish, mango, coconut, lime, ginger",
            "peanuts, wheat flour, sugar, eggs, butter",
            "strawberry, chocolate, cream, sugar, vanilla"
        ]
        
        results = []
        
        print("\n======= TESTING ALLERGEN COMBINATIONS =======\n")
        
        for ingredients in test_combinations:
            print(f"Testing ingredients: {ingredients}")
            
            # Generate a recipe
            result = self.generate_recipe(ingredients)
            results.append(result)
            
            # Print the result
            print(f"\n{result['formatted_text']}")
            print(f"Quality Score: {result['quality_score']}/100")
            print("-" * 50)
            
        # Summary of allergen detections
        print("\n======= ALLERGEN DETECTION SUMMARY =======\n")
        for i, result in enumerate(results):
            warnings = result.get("allergen_warnings", [])
            print(f"Recipe {i+1} - Ingredients: {test_combinations[i]}")
            print(f"Allergen Warnings: {len(warnings)}")
            for warning in warnings:
                print(f"  - {warning['ingredients'][0]} + {warning['ingredients'][1]}: {warning['reason']}")
            print()
            
        # Count safe vs. unsafe recipes
        safe_recipes = [r for r in results if not r.get("allergen_warnings")]
        print(f"Total Recipes: {len(results)}")
        print(f"Safe Recipes: {len(safe_recipes)}")
        print(f"Unsafe Recipes: {len(results) - len(safe_recipes)}")
        
        return results
    
    def run_benchmark_test(self, test_samples, batch_size=4):
        """Run a benchmark test on multiple ingredient lists."""
        print(f"Running benchmark with {len(test_samples)} test samples...")
        results = []
        quality_scores = []
        generation_times = []
        allergen_counts = []
        
        # Process test samples
        for i in tqdm(range(0, len(test_samples), batch_size)):
            batch = test_samples[i:i+batch_size]
            batch_results = []
            
            for ingredients in batch:
                result = self.generate_recipe(ingredients)
                batch_results.append(result)
                
                quality_scores.append(result["quality_score"])
                generation_times.append(result["generation_time"])
                allergen_counts.append(len(result.get("allergen_warnings", [])))
            
            results.extend(batch_results)
        
        # Calculate statistics
        avg_quality = np.mean(quality_scores)
        avg_time = np.mean(generation_times)
        avg_allergens = np.mean(allergen_counts)
        
        print(f"Benchmark complete: {len(results)} recipes generated")
        print(f"Average generation time: {avg_time:.3f} seconds")
        print(f"Average quality score: {avg_quality:.1f}/100")
        print(f"Average allergen warnings per recipe: {avg_allergens:.2f}")
        
        # Plot quality distribution
        plt.figure(figsize=(10, 5))
        plt.hist(quality_scores, bins=10, alpha=0.7)
        plt.title('Recipe Quality Score Distribution')
        plt.xlabel('Quality Score')
        plt.ylabel('Count')
        plt.savefig('quality_distribution.png')
        print(f"Quality distribution plot saved to 'quality_distribution.png'")
        
        # Plot allergen warnings
        plt.figure(figsize=(10, 5))
        plt.hist(allergen_counts, bins=max(allergen_counts)+1, alpha=0.7)
        plt.title('Allergen Warnings Distribution')
        plt.xlabel('Number of Warnings')
        plt.ylabel('Count')
        plt.savefig('allergen_distribution.png')
        print(f"Allergen distribution plot saved to 'allergen_distribution.png'")
        
        return results, quality_scores, generation_times, allergen_counts

if __name__ == "__main__":
    # Test the model
    tester = RecipeModelTester()
    
    # Test with individual ingredients including potentially allergenic combinations
    test_samples = [
        "chicken, rice, garlic, onion, bell pepper",
        "beef, potatoes, carrots, onion, garlic",
        "fish, cheese, lemon, pasta, herbs",  # Potentially allergenic
        "shrimp, milk, garlic, butter, pasta",  # Potentially allergenic
        "salmon, curd, dill, rice, lemon",  # Potentially allergenic
        "tofu, vegetables, soy sauce, ginger, garlic",
        "chocolate, strawberry, cream, vanilla, sugar"  # Potentially allergenic
    ]
    
    # Generate a single recipe
    print("\nGenerating a sample recipe:")
    result = tester.generate_recipe(test_samples[0])
    print(result["formatted_text"])
    print(f"Generation time: {result['generation_time']:.3f} seconds")
    print(f"Quality score: {result['quality_score']}/100\n")
    
    # Run allergen tests
    print("\nTesting allergen detection:")
    tester.test_allergen_combinations()

Loading model and tokenizer from ./recipe_model_finetuned...
Successfully loaded fine-tuned model from ./recipe_model_finetuned
Using GPU for inference
Using half precision for faster inference
Model initialized and ready for testing

Generating a sample recipe:
[TITLE]: Grilled onion stir-fry

[INGREDIENTS]:
  - 1: Chicken
  - 2: Rice
  - 3: Garlic
  - 4: Onion
  - 5: Bell pepper

[DIRECTIONS]:
  - 1: Prepare all the ingredients
  - 2: Heat pan to 388 degrees
  - 3: Mix chicken and rice together
  - 4: Add garlic, onion and cook for 26 minutes
  - 5: Stir in bell pepper
  - 6: Season with salt and pepper to taste
  - 7: Serve hot

Generation time: 2.053 seconds
Quality score: 80/100


Testing allergen detection:


Testing ingredients: fish, cheese, pasta, garlic, olive oil

[TITLE]: Grilled olive oil pasta

[INGREDIENTS]:
  - 1: Fish
  - 2: Cheese
  - 3: Pasta
  - 4: Garlic
  - 5: Olive oil

[DIRECTIONS]:
  - 1: Prepare all the ingredients
  - 2: Heat pot to 337 degrees
  - 3: Mix fis

In [11]:
def test_allergen_detection(self):
    """Conduct explicit tests for allergen detection functionality"""
    print("\n======= EXPLICIT ALLERGEN DETECTION TESTS =======\n")
    
    # Test cases with expected outcomes
    test_cases = [
        {
            "ingredients": "fish, cheese, pasta, olive oil",
            "expect_warning": True,
            "expected_allergens": [("fish", "cheese")]
        },
        {
            "ingredients": "shrimp, milk, butter, pasta",
            "expect_warning": True,
            "expected_allergens": [("shrimp", "milk"), ("shellfish", "milk"), ("shellfish", "dairy")]
        },
        {
            "ingredients": "chicken, rice, vegetables, soy sauce",
            "expect_warning": False,
            "expected_allergens": []
        },
        {
            "ingredients": "salmon, curd, lemon, rice",
            "expect_warning": True,
            "expected_allergens": [("fish", "curd"), ("salmon", "curd")]
        },
        {
            "ingredients": "tuna, yogurt, celery, onion",
            "expect_warning": True,
            "expected_allergens": [("fish", "yogurt"), ("tuna", "yogurt")]
        },
        {
            "ingredients": "strawberry, chocolate, sugar, flour",
            "expect_warning": True,
            "expected_allergens": [("strawberry", "chocolate")]
        },
        {
            "ingredients": "peanuts, wheat flour, eggs, sugar",
            "expect_warning": True,
            "expected_allergens": [("peanuts", "wheat")]
        }
    ]
    
    # Results tracking
    test_results = []
    passed = 0
    failed = 0
    
    # Run each test case
    for i, test_case in enumerate(test_cases):
        print(f"\nTest {i+1}: {test_case['ingredients']}")
        print(f"Expected warnings: {'Yes' if test_case['expect_warning'] else 'No'}")
        
        # Generate recipe
        result = self.generate_recipe(test_case['ingredients'])
        warnings = result.get("allergen_warnings", [])
        
        # Extract detected allergen pairs
        detected_pairs = [(w['ingredients'][0], w['ingredients'][1]) for w in warnings]
        
        # Check if warnings match expectations
        has_warnings = len(warnings) > 0
        warning_status = "PASS" if has_warnings == test_case['expect_warning'] else "FAIL"
        
        # Check for specific allergen pairs - at least one should match
        allergen_match = False
        if test_case['expect_warning']:
            for expected_pair in test_case['expected_allergens']:
                if any(all(item in detected for item in expected_pair) or
                       all(detected[i] == expected_pair[i] for i in range(2))
                       for detected in detected_pairs):
                    allergen_match = True
                    break
        else:
            allergen_match = len(detected_pairs) == 0
            
        allergen_status = "PASS" if allergen_match else "FAIL"
        
        # Overall test result
        test_passed = warning_status == "PASS" and allergen_status == "PASS"
        if test_passed:
            passed += 1
        else:
            failed += 1
            
        test_results.append({
            "test_case": test_case,
            "warnings": warnings,
            "warning_status": warning_status,
            "allergen_status": allergen_status,
            "overall": "PASS" if test_passed else "FAIL"
        })
        
        # Print results for this test
        print(f"Warning detection: {warning_status}")
        print(f"Allergen match: {allergen_status}")
        print(f"Overall test result: {'✅ PASS' if test_passed else '❌ FAIL'}")
        
        # Print the detected allergens
        if warnings:
            print("Detected allergens:")
            for j, warning in enumerate(warnings):
                print(f"  - {j+1}: {warning['ingredients'][0]} + {warning['ingredients'][1]}")
                print(f"    Reason: {warning['reason']}")
        else:
            print("No allergens detected")
            
        # Show recipe excerpt
        print("\nRecipe excerpt:")
        title = result["recipe_dict"]["title"]
        ing_count = len(result["recipe_dict"]["ingredients"])
        print(f"[TITLE]: {title}")
        print(f"[INGREDIENTS COUNT]: {ing_count}")
        print("-" * 50)
    
    # Print summary
    print("\n======= ALLERGEN DETECTION TEST SUMMARY =======")
    print(f"Total tests: {len(test_cases)}")
    print(f"Passed: {passed} ({passed/len(test_cases)*100:.1f}%)")
    print(f"Failed: {failed} ({failed/len(test_cases)*100:.1f}%)")
    
    # Print failure details if any
    if failed > 0:
        print("\nFailed tests:")
        for i, result in enumerate(test_results):
            if result["overall"] == "FAIL":
                test_case = result["test_case"]
                print(f"- Test {i+1}: {test_case['ingredients']}")
                print(f"  Expected warnings: {'Yes' if test_case['expect_warning'] else 'No'}")
                print(f"  Detected warnings: {'Yes' if result['warnings'] else 'No'}")
                print(f"  Warning status: {result['warning_status']}")
                print(f"  Allergen status: {result['allergen_status']}")
    
    return test_results

In [12]:
import os
import numpy as np
from transformers import FlaxAutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from pathlib import Path

def convert_flax_to_pytorch():
    print("Loading Flax model...")
    MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
    flax_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
    
    print("Converting to PyTorch...")
    # Convert to PyTorch
    pt_model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_NAME_OR_PATH, 
        from_flax=True,
    )
    
    # Save the PyTorch model
    output_dir = Path("./optimized_recipe_model_pt")
    output_dir.mkdir(exist_ok=True)
    pt_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"PyTorch model saved to {output_dir}")
    
    return pt_model, tokenizer

def optimize_model(model):
    print("Optimizing model...")
    # Move to half precision to reduce memory usage and increase speed
    model = model.half()
    
    # Enable torch inference optimizations
    torch.set_grad_enabled(False)
    if torch.cuda.is_available():
        model = model.cuda()
    
    return model

def create_optimized_pipeline():
    # Load or create the model
    try:
        model_path = "./optimized_recipe_model_pt"
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    except:
        model, tokenizer = convert_flax_to_pytorch()
    
    # Optimize the model
    model = optimize_model(model)
    
    # Optimized generation parameters
    generation_kwargs = {
        "max_length": 512,
        "min_length": 64,
        "no_repeat_ngram_size": 3,
        "num_beams": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
        "do_sample": True,  # Enable sampling
        "temperature": 0.8,  # Now properly used with do_sample=True
        "top_k": 50,        # Add top_k sampling
        "top_p": 0.95       # Add nucleus sampling
    }
    
    # Define tokens_map for post-processing
    special_tokens = tokenizer.all_special_tokens
    tokens_map = {
        "<sep>": "--",
        "<section>": "\n"
    }
    
    def generate_recipes(ingredient_lists, batch_size=4):
        """
        Generate recipes from multiple ingredient lists efficiently
        """
        all_recipes = []
        
        # Process in batches for efficiency
        for i in range(0, len(ingredient_lists), batch_size):
            batch = ingredient_lists[i:i+batch_size]
            
            # Prepare inputs
            prefix = "items: "
            inputs = [prefix + inp for inp in batch]
            encoded_inputs = tokenizer(
                inputs,
                max_length=256,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            
            # Move inputs to same device as model
            if torch.cuda.is_available():
                encoded_inputs = {k: v.cuda() for k, v in encoded_inputs.items()}
            
            # Generate with optimized parameters
            with torch.no_grad():
                output_ids = model.generate(
                    **encoded_inputs,
                    **generation_kwargs
                )
            
            # Decode and post-process
            generated_recipes = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
            
            # Post-process recipes
            for text in generated_recipes:
                # Skip special tokens
                for token in special_tokens:
                    text = text.replace(token, "")
                
                # Replace mapped tokens
                for k, v in tokens_map.items():
                    text = text.replace(k, v)
                
                all_recipes.append(text)
        
        return all_recipes

    return generate_recipes

def optimize_recipe_model():
    """Main function to optimize the T5 recipe model"""
    print("Starting model optimization process...")
    
    # Create optimized pipeline
    generate_recipes = create_optimized_pipeline()
    
    # Test the optimized model
    test_ingredients = [
        "macaroni, butter, salt, bacon, milk, flour, pepper, cream corn",
        "provolone cheese, bacon, bread, ginger"
    ]
    
    print("\nTesting optimized model with sample ingredients...")
    generated_recipes = generate_recipes(test_ingredients)
    
    for i, text in enumerate(generated_recipes):
        print(f"\nRecipe {i+1} from ingredients: {test_ingredients[i]}")
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            if section.startswith("title:"):
                section = section.replace("title:", "")
                headline = "TITLE"
            elif section.startswith("ingredients:"):
                section = section.replace("ingredients:", "")
                headline = "INGREDIENTS"
            elif section.startswith("directions:"):
                section = section.replace("directions:", "")
                headline = "DIRECTIONS"

            if headline == "TITLE":
                print(f"[{headline}]: {section.strip().capitalize()}")
            else:
                section_info = [f"  - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
                print(f"[{headline}]:")
                print("\n".join(section_info))

        print("-" * 50)
    
    print("\nModel optimization complete!")
    print("The optimized model is ready for integration with the Flask/FastAPI backend.")
    
    return generate_recipes

if __name__ == "__main__":
    optimize_recipe_model()

Starting model optimization process...
Loading Flax model...


flax_model.msgpack:   0%|          | 0.00/892M [00:00<?, ?B/s]

Converting to PyTorch...


  pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
All Flax model weights were used when initializing T5ForConditionalGeneration.

Some weights of T5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'lm_head.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PyTorch model saved to optimized_recipe_model_pt
Optimizing model...

Testing optimized model with sample ingredients...

Recipe 1 from ingredients: macaroni, butter, salt, bacon, milk, flour, pepper, cream corn
[TITLE]: Macaroni and corn casserole
[INGREDIENTS]:
  - 1: 1 lb. box elbow or elbow pasta
  - 2: 1/4 c. butter
  - 3: 1 tsp. salt
  - 4: 6 slices bacon, cooked and crumbled
  - 5: 2 1/2 c milk
  - 6: 2 tbsp flour
  - 7: 1/8 t. pepper
  - 8: 1 can cream corn
[DIRECTIONS]:
  - 1: Cook pasta according to package directions.
  - 2: Drain.
  - 3: In a saucepan, melt butter.
  - 4: Stir in flour, salt and pepper.
  - 5: Gradually stir in milk.
  - 6: Cook and stir until thickened and bubbly.
  - 7: Add corn and bacon.
  - 8: Pour into a greased 2 quart casserole.
  - 9: Bake at 350 degrees for 30 minutes.
--------------------------------------------------

Recipe 2 from ingredients: provolone cheese, bacon, bread, ginger
[TITLE]: Grilled provolone and bacon sandwich
[INGREDIENTS]:
  

In [13]:
import os
import numpy as np
from transformers import FlaxAutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from pathlib import Path

def convert_flax_to_pytorch():
    print("Loading Flax model...")
    MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
    flax_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
    
    print("Converting to PyTorch...")
    # Convert to PyTorch
    pt_model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_NAME_OR_PATH, 
        from_flax=True,
    )
    
    # Save the PyTorch model
    output_dir = Path("./optimized_recipe_model_pt")
    output_dir.mkdir(exist_ok=True)
    pt_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"PyTorch model saved to {output_dir}")
    
    return pt_model, tokenizer

def optimize_model(model):
    print("Optimizing model...")
    # Move to half precision to reduce memory usage and increase speed
    model = model.half()
    
    # Enable torch inference optimizations
    torch.set_grad_enabled(False)
    if torch.cuda.is_available():
        model = model.cuda()
    
    return model

def create_optimized_pipeline():
    # Load or create the model
    try:
        model_path = "./optimized_recipe_model_pt"
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    except:
        model, tokenizer = convert_flax_to_pytorch()
    
    # Optimize the model
    model = optimize_model(model)
    
    # Optimized generation parameters
    generation_kwargs = {
        "max_length": 512,
        "min_length": 64,
        "no_repeat_ngram_size": 3,
        "num_beams": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
        "do_sample": True,  # Enable sampling
        "temperature": 0.8,  # Now properly used with do_sample=True
        "top_k": 50,        # Add top_k sampling
        "top_p": 0.95       # Add nucleus sampling
    }
    
    # Define tokens_map for post-processing
    special_tokens = tokenizer.all_special_tokens
    tokens_map = {
        "<sep>": "--",
        "<section>": "\n"
    }
    
    def generate_recipes(ingredient_lists, batch_size=4):
        """
        Generate recipes from multiple ingredient lists efficiently
        """
        all_recipes = []
        
        # Process in batches for efficiency
        for i in range(0, len(ingredient_lists), batch_size):
            batch = ingredient_lists[i:i+batch_size]
            
            # Prepare inputs
            prefix = "items: "
            inputs = [prefix + inp for inp in batch]
            encoded_inputs = tokenizer(
                inputs,
                max_length=256,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            
            # Move inputs to same device as model
            if torch.cuda.is_available():
                encoded_inputs = {k: v.cuda() for k, v in encoded_inputs.items()}
            
            # Generate with optimized parameters
            with torch.no_grad():
                output_ids = model.generate(
                    **encoded_inputs,
                    **generation_kwargs
                )
            
            # Decode and post-process
            generated_recipes = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
            
            # Post-process recipes
            for text in generated_recipes:
                # Skip special tokens
                for token in special_tokens:
                    text = text.replace(token, "")
                
                # Replace mapped tokens
                for k, v in tokens_map.items():
                    text = text.replace(k, v)
                
                all_recipes.append(text)
        
        return all_recipes

    return generate_recipes

def optimize_recipe_model():
    """Main function to optimize the T5 recipe model"""
    print("Starting model optimization process...")
    
    # Create optimized pipeline
    generate_recipes = create_optimized_pipeline()
    
    # Test the optimized model
    test_ingredients = [
        "paneer, butter, peas, fresh cream",
        "curd, fish"
    ]
    
    print("\nTesting optimized model with sample ingredients...")
    generated_recipes = generate_recipes(test_ingredients)
    
    for i, text in enumerate(generated_recipes):
        print(f"\nRecipe {i+1} from ingredients: {test_ingredients[i]}")
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            if section.startswith("title:"):
                section = section.replace("title:", "")
                headline = "TITLE"
            elif section.startswith("ingredients:"):
                section = section.replace("ingredients:", "")
                headline = "INGREDIENTS"
            elif section.startswith("directions:"):
                section = section.replace("directions:", "")
                headline = "DIRECTIONS"

            if headline == "TITLE":
                print(f"[{headline}]: {section.strip().capitalize()}")
            else:
                section_info = [f"  - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
                print(f"[{headline}]:")
                print("\n".join(section_info))

        print("-" * 50)
    
    print("\nModel optimization complete!")
    print("The optimized model is ready for integration with the Flask/FastAPI backend.")
    
    return generate_recipes

if __name__ == "__main__":
    optimize_recipe_model()

Starting model optimization process...
Optimizing model...

Testing optimized model with sample ingredients...

Recipe 1 from ingredients: paneer, butter, peas, fresh cream
[TITLE]: Paneer with peas and cream
[INGREDIENTS]:
  - 1: 1 lb. pkg. samosas or sourdough bread
  - 2: 8 oz. drained, cubed panereer
  - 3: 2 tbsp. butter
  - 4: 1 1/2 c. cooked, shelled, peeled and mashed boiled potatoes
  - 5: 1 pt. fresh cream
[DIRECTIONS]:
  - 1: Preheat oven to 350 .
  - 2: Spread bread cubes in a single layer on a baking sheet.
  - 3: Bake until lightly browned, about 10 minutes.
  - 4: Remove from oven and set aside.
--------------------------------------------------

Recipe 2 from ingredients: curd, fish
[TITLE]: Chinese style bean curd and fish
[INGREDIENTS]:
  - 1: 1 block beancurd
  - 2: 1 tbsp fish stock
[DIRECTIONS]:
  - 1: Cut the bean curde into bite sized pieces.
  - 2: Put the beancurde and fish stock in a frying pan and bring to a boil.
  - 3: When it starts to boil, turn the heat 

In [14]:
import os
import numpy as np
from transformers import FlaxAutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from pathlib import Path

def convert_flax_to_pytorch():
    print("Loading Flax model...")
    MODEL_NAME_OR_PATH = "flax-community/t5-recipe-generation"
    flax_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
    
    print("Converting to PyTorch...")
    # Convert to PyTorch
    pt_model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_NAME_OR_PATH, 
        from_flax=True,
    )
    
    # Save the PyTorch model
    output_dir = Path("./optimized_recipe_model_pt")
    output_dir.mkdir(exist_ok=True)
    pt_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"PyTorch model saved to {output_dir}")
    
    return pt_model, tokenizer

def optimize_model(model):
    print("Optimizing model...")
    # Move to half precision to reduce memory usage and increase speed
    model = model.half()
    
    # Enable torch inference optimizations
    torch.set_grad_enabled(False)
    if torch.cuda.is_available():
        model = model.cuda()
    
    return model

def create_optimized_pipeline():
    # Load or create the model
    try:
        model_path = "./optimized_recipe_model_pt"
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    except:
        model, tokenizer = convert_flax_to_pytorch()
    
    # Optimize the model
    model = optimize_model(model)
    
    # Optimized generation parameters
    generation_kwargs = {
        "max_length": 512,
        "min_length": 64,
        "no_repeat_ngram_size": 3,
        "num_beams": 4,
        "early_stopping": True,
        "length_penalty": 1.2,
        "do_sample": True,  # Enable sampling
        "temperature": 0.8,  # Now properly used with do_sample=True
        "top_k": 50,        # Add top_k sampling
        "top_p": 0.95       # Add nucleus sampling
    }
    
    # Define tokens_map for post-processing
    special_tokens = tokenizer.all_special_tokens
    tokens_map = {
        "<sep>": "--",
        "<section>": "\n"
    }
    
    def generate_recipes(ingredient_lists, batch_size=4):
        """
        Generate recipes from multiple ingredient lists efficiently
        """
        all_recipes = []
        
        # Process in batches for efficiency
        for i in range(0, len(ingredient_lists), batch_size):
            batch = ingredient_lists[i:i+batch_size]
            
            # Prepare inputs
            prefix = "items: "
            inputs = [prefix + inp for inp in batch]
            encoded_inputs = tokenizer(
                inputs,
                max_length=256,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            )
            
            # Move inputs to same device as model
            if torch.cuda.is_available():
                encoded_inputs = {k: v.cuda() for k, v in encoded_inputs.items()}
            
            # Generate with optimized parameters
            with torch.no_grad():
                output_ids = model.generate(
                    **encoded_inputs,
                    **generation_kwargs
                )
            
            # Decode and post-process
            generated_recipes = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
            
            # Post-process recipes
            for text in generated_recipes:
                # Skip special tokens
                for token in special_tokens:
                    text = text.replace(token, "")
                
                # Replace mapped tokens
                for k, v in tokens_map.items():
                    text = text.replace(k, v)
                
                all_recipes.append(text)
        
        return all_recipes

    return generate_recipes

def optimize_recipe_model():
    """Main function to optimize the T5 recipe model"""
    print("Starting model optimization process...")
    
    # Create optimized pipeline
    generate_recipes = create_optimized_pipeline()
    
    # Test the optimized model
    test_ingredients = [
        "flour, sugar, vanilla extract, butter, cream cheese, chocolate chips",
        "brinjal, channa, drumstick,tomato, onion"
    ]
    
    print("\nTesting optimized model with sample ingredients...")
    generated_recipes = generate_recipes(test_ingredients)
    
    for i, text in enumerate(generated_recipes):
        print(f"\nRecipe {i+1} from ingredients: {test_ingredients[i]}")
        sections = text.split("\n")
        for section in sections:
            section = section.strip()
            if section.startswith("title:"):
                section = section.replace("title:", "")
                headline = "TITLE"
            elif section.startswith("ingredients:"):
                section = section.replace("ingredients:", "")
                headline = "INGREDIENTS"
            elif section.startswith("directions:"):
                section = section.replace("directions:", "")
                headline = "DIRECTIONS"

            if headline == "TITLE":
                print(f"[{headline}]: {section.strip().capitalize()}")
            else:
                section_info = [f"  - {i+1}: {info.strip().capitalize()}" for i, info in enumerate(section.split("--"))]
                print(f"[{headline}]:")
                print("\n".join(section_info))

        print("-" * 50)
    
    print("\nModel optimization complete!")
    print("The optimized model is ready for integration with the Flask/FastAPI backend.")
    
    return generate_recipes

if __name__ == "__main__":
    optimize_recipe_model()

Starting model optimization process...
Optimizing model...

Testing optimized model with sample ingredients...

Recipe 1 from ingredients: flour, sugar, vanilla extract, butter, cream cheese, chocolate chips
[TITLE]: Chocolate chip cream cheese cookies
[INGREDIENTS]:
  - 1: 2 c. flour
  - 2: 1 c sugar
  - 3: 2 tsp. vanilla extract
  - 4: 2 sticks butter, softened
  - 5: 1 8 oz. pkg. cream cheese
  - 6: 1 6 ox. bag chocolate chips
[DIRECTIONS]:
  - 1: Mix flour, sugar, vanilla, butter and cream cheese.
  - 2: Stir in chocolate chips.
  - 3: Drop by teaspoonfuls onto ungreased cookie sheet.
  - 4: Bake at 350 for 10 to 12 minutes.
--------------------------------------------------

Recipe 2 from ingredients: brinjal, channa, drumstick,tomato, onion
[TITLE]: Brinjal, channa, and drumstick stew
[INGREDIENTS]:
  - 1: 1 medium size idaho or a smoky kashmiri brinjala
  - 2: 1 medium sized iranian kishma or chana dal
  - 3: 1/2 medium size drumstick
  - 4: 1 large ripe tomato
  - 5: 1 small on

> 06-04-2025

In [2]:
import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq,
    T5TokenizerFast # Can be useful but AutoTokenizer usually suffices
)
from datasets import Dataset, DatasetDict, load_from_disk
import os
import json
import random
import time

# --- Configuration ---
BASE_MODEL = "t5-base"
SYNTHETIC_DATA_DIR = "synthetic_recipe_data"
RAW_DATASET_DIR = "recipe_dataset"
PREPROCESSED_DATA_DIR = "preprocessed_recipe_dataset"
OUTPUT_MODEL_DIR = "./recipe_model_finetuned" # Fine-tuning output / Trained model location
LOGGING_DIR = "./recipe_logs" # Fine-tuning logs

# Fine-tuning Parameters
SAMPLE_SIZE = 2000 # Reduced for faster demo; increase for better results (e.g., 10000)
NUM_TRAIN_EPOCHS = 1 # Reduced for faster demo; increase for better results (e.g., 3)
BATCH_SIZE = 4 # Adjust based on GPU memory
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
FP16 = torch.cuda.is_available() # Enable FP16 if GPU is available

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# --- Stage 1: Data Generation ---

def create_recipe_dataset(sample_size=1000, seed=42):
    """
    Create a compatible dataset for T5 recipe generation.
    Uses sample data formatted to match the expected input/output structure.
    """
    print("\n--- Stage 1: Creating Synthetic Recipe Dataset ---")
    print(f"Generating {sample_size} synthetic recipes...")

    # Set seeds for reproducibility within function scope if needed again
    random.seed(seed)
    np.random.seed(seed)

    # Sample ingredients for synthetic data
    common_ingredients = [
        "chicken", "beef", "pork", "salmon", "tuna", "shrimp", "lamb", "cod", "haddock", "mackerel", # Fish/Meat
        "pasta", "rice", "potatoes", "bread", "flour", "oats", "quinoa", "couscous", "noodles", # Grains/Starch
        "onion", "garlic", "tomatoes", "bell peppers", "carrots", "broccoli", "cauliflower",
        "spinach", "lettuce", "mushrooms", "zucchini", "eggplant", "corn", "peas", "beans", # Vegetables
        "apple", "banana", "orange", "lemon", "lime", "berries", "mango", "pineapple", # Fruits
        "butter", "olive oil", "vegetable oil", "coconut oil", "sesame oil", # Fats/Oils
        "salt", "pepper", "oregano", "basil", "thyme", "rosemary", "cumin", "coriander", "paprika", "chili powder", # Spices
        "milk", "cream", "cheese", "cheddar", "parmesan", "mozzarella", "yogurt", "curd", "eggs", "mayonnaise", # Dairy/Binders
        "sugar", "brown sugar", "honey", "maple syrup", # Sweeteners
        "lentils", "chickpeas", "tofu", "tempeh", # Legumes/Plant-protein
        "chocolate", "vanilla extract", "cinnamon", "nutmeg", "soy sauce", "vinegar", "mustard" # Flavorings
    ]

    # Create directory for generated data
    os.makedirs(SYNTHETIC_DATA_DIR, exist_ok=True)

    # Generate synthetic recipes
    recipes = []
    for i in range(sample_size):
        num_ingredients = random.randint(4, 12)
        ingredients = random.sample(common_ingredients, num_ingredients)

        main_ingredient = random.choice(ingredients)
        cooking_methods = ["Roasted", "Grilled", "Baked", "Fried", "Steamed", "Sautéed", "Slow-cooked", "Spicy", "Creamy", "Simple", "Quick"]
        dish_types = ["Casserole", "Soup", "Stew", "Salad", "Pasta", "Curry", "Stir-fry", "Sandwich", "Bowl", "Tacos", "Pizza"]
        title = f"{random.choice(cooking_methods)} {main_ingredient.capitalize()} {random.choice(dish_types)}"

        directions = [
            f"Prepare all ingredients: Chop vegetables, measure spices.",
            f"Preheat your {random.choice(['oven', 'skillet', 'grill', 'pot'])} to {random.randint(150, 220)}°C ({random.randint(300, 425)}°F).",
            f"In a bowl, combine {ingredients[0]} and {ingredients[1]}.",
            f"Add {random.choice(['olive oil', 'butter'])} to a {random.choice(['pan', 'pot'])} over medium heat.",
            f"Sauté {random.choice(['onion', 'garlic'])} until fragrant, about {random.randint(2, 5)} minutes.",
            f"Add {', '.join(ingredients[2:min(5, len(ingredients))])} and cook for {random.randint(5, 15)} minutes, stirring occasionally.",
            f"Stir in the remaining ingredients ({', '.join(ingredients[min(5, len(ingredients)):])}) and {random.choice(['broth', 'water', 'sauce'])}.",
            f"Bring to a simmer, then reduce heat and cook for {random.randint(10, 30)} minutes until {main_ingredient} is cooked through.",
            f"Season with salt, pepper, and other desired spices to taste.",
            f"Garnish with {random.choice(['parsley', 'cilantro', 'cheese', 'nuts'])} and serve hot."
        ]
        random.shuffle(directions) # Make directions less predictable
        directions = directions[:random.randint(5, 8)] # Use a subset of directions

        input_ingredients = ", ".join(ingredients)

        output_text = (
            f"title: {title} <section> "
            f"ingredients: {' <sep> '.join(ingredients)} <section> "
            f"directions: {' <sep> '.join(directions)}"
        )

        recipes.append({
            "input_text": f"items: {input_ingredients}",
            "output_text": output_text
        })

    # Save raw synthetic data (optional, but good for reference)
    output_file = os.path.join(SYNTHETIC_DATA_DIR, "synthetic_recipes.json")
    with open(output_file, 'w') as f:
        json.dump(recipes, f, indent=2)
    print(f"Saved raw synthetic data to {output_file}")

    # Create dataset splits
    df = pd.DataFrame(recipes)
    df = df.sample(frac=1, random_state=seed).reset_index(drop=True)
    train_size = int(0.8 * len(df))
    val_size = int(0.1 * len(df))

    train_df = df[:train_size]
    val_df = df[train_size:train_size+val_size]
    test_df = df[train_size+val_size:]

    dataset_dict = DatasetDict({
        'train': Dataset.from_pandas(train_df),
        'validation': Dataset.from_pandas(val_df),
        'test': Dataset.from_pandas(test_df)
    })

    # Save dataset using the defined path constant
    dataset_dict.save_to_disk(RAW_DATASET_DIR)
    print(f"Dataset created with {len(df)} recipes and saved to '{RAW_DATASET_DIR}'")
    print(f"Split sizes -> Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")

    # Display a few examples
    print("\nExample recipes (first 3):")
    for i in range(min(3, len(df))):
        print(f"\nInput: {df.iloc[i]['input_text']}")
        print(f"Output: {df.iloc[i]['output_text']}")

    return dataset_dict

# --- Stage 2: Data Preprocessing ---

def preprocess_recipe_data(
    raw_dataset_dir=RAW_DATASET_DIR,
    preprocessed_dataset_dir=PREPROCESSED_DATA_DIR,
    model_name=BASE_MODEL,
    max_input_length=256,
    max_output_length=512
    ):
    """
    Preprocess recipe dataset for sequence-to-sequence training with T5.
    """
    print("\n--- Stage 2: Preprocessing Recipe Dataset ---")

    # Load dataset
    try:
        dataset = load_from_disk(raw_dataset_dir)
        print(f"Loaded raw dataset from '{raw_dataset_dir}'")
    except FileNotFoundError:
        print(f"Error: Raw dataset directory '{raw_dataset_dir}' not found.")
        print("Please ensure Stage 1 (create_recipe_dataset) ran successfully.")
        return None, None

    # Load tokenizer
    print(f"Loading tokenizer '{model_name}'...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    # Add special tokens for recipe formatting
    special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
    num_added_toks = tokenizer.add_special_tokens(special_tokens)
    if num_added_toks > 0:
        print(f"Added {num_added_toks} special tokens ('<sep>', '<section>') to the tokenizer.")
    else:
        print("Special tokens already present in the tokenizer.")

    def preprocess_function(examples):
        # Ensure examples contain the expected keys
        if "input_text" not in examples or "output_text" not in examples:
             print("Warning: Missing 'input_text' or 'output_text' in examples batch.")
             # Handle potential list of dicts vs dict of lists
             input_texts = examples.get("input_text", [])
             output_texts = examples.get("output_text", [])
        else:
             input_texts = examples["input_text"]
             output_texts = examples["output_text"]

        # Tokenize inputs
        model_inputs = tokenizer(
            input_texts,
            max_length=max_input_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt" # Request PyTorch tensors
        )

        # Tokenize outputs (labels)
        # Important: T5 uses input_ids as labels during training
        labels = tokenizer(
            text_target=output_texts, # Use text_target for T5 labels
            max_length=max_output_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt" # Request PyTorch tensors
        )

        # Replace padding token id in labels with -100 for loss calculation
        label_input_ids = labels["input_ids"]
        label_input_ids[label_input_ids == tokenizer.pad_token_id] = -100

        model_inputs["labels"] = label_input_ids

        # Detach tensors before returning from map function if needed (usually not necessary)
        # model_inputs = {k: v.squeeze().tolist() for k,v in model_inputs.items()}
        # Convert tensors back to lists for dataset saving if not using set_format
        model_inputs = {k: v.squeeze().tolist() for k, v in model_inputs.items()}

        return model_inputs

    # Apply preprocessing to all splits
    print("Applying preprocessing function...")
    processed_dataset = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset["train"].column_names # Remove original text columns
    )

    # Set format to PyTorch tensors for DataLoader
    processed_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    # Save processed dataset
    processed_dataset.save_to_disk(preprocessed_dataset_dir)
    print(f"Preprocessed dataset saved to '{preprocessed_dataset_dir}'")

    # Optional: Check a sample
    print("\nSample preprocessed example (first 10 tokens):")
    sample = processed_dataset["train"][0]
    print(f"Input IDs: {sample['input_ids'][:10].tolist()}") # Convert tensor to list for printing
    print(f"Attention mask: {sample['attention_mask'][:10].tolist()}")
    print(f"Labels: {sample['labels'][:10].tolist()}")


    return processed_dataset, tokenizer

# --- Post-Processor Definition --- (Single definition)

class RecipePostProcessor:
    def __init__(self, tokenizer):
        """
        Initialize post-processor with tokenizer for handling special tokens.
        Includes allergen detection and additional recipe processing.
        """
        self.tokenizer = tokenizer
        # Get all special tokens EXCEPT pad, bos, eos if needed for cleaning,
        # but focus on removing the generation control tokens and keeping our custom ones.
        all_special_ids = tokenizer.all_special_ids
        self.special_tokens_to_remove = [t for t_id in all_special_ids for t in tokenizer.convert_ids_to_tokens(t_id, skip_special_tokens=False) if t not in ['<sep>', '<section>', tokenizer.eos_token, tokenizer.bos_token, tokenizer.pad_token]] # keep structure tokens + EOS/PAD
        # Explicitly add common control tokens if not in special_ids list
        self.special_tokens_to_remove.extend(['<s>', '</s>', '<pad>'])
        self.special_tokens_to_remove = list(set(self.special_tokens_to_remove)) # Unique

        self.tokens_map = {
            "<sep>": "--",      # Use for separating items within a section
            "<section>": "\n",    # Use for separating sections (title, ingredients, directions)
            tokenizer.eos_token: "" # Remove End Of Sequence token
        }

        # Enhanced Allergen combinations to avoid - Be mindful of keywords
        self.allergen_combinations = [
            ("fish", "dairy", "Potential digestive issues or sensitivities when combining fish and dairy."),
            ("fish", "yogurt", "Potential digestive issues or sensitivities."),
            ("fish", "milk", "Potential digestive issues or sensitivities."),
            ("fish", "curd", "Potential digestive issues or sensitivities."),
            ("fish", "cheese", "Potential digestive issues or sensitivities (especially certain types)."),
            ("tuna", "cheese", "Often considered an allergenic or sensitive combination."),
            ("salmon", "cheese", "Often considered an allergenic or sensitive combination."),
            ("shellfish", "dairy", "Common allergen combination trigger."),
            ("shellfish", "milk", "Common allergen combination trigger."),
            ("shrimp", "milk", "Common allergen combination trigger."),
            ("shellfish", "cheese", "Common allergen combination trigger."),
            ("peanuts", "gluten", "Co-occurrence risk for individuals with multiple allergies (e.g., celiac). Needs context."),
            ("peanuts", "wheat", "Co-occurrence risk. Needs context."),
            ("shellfish", "mango", "Potential histamine reaction trigger in sensitive individuals."),
            ("strawberry", "chocolate", "Potential migraine trigger in sensitive individuals.")
            # Add more specific terms if needed: e.g., ("haddock", "cheddar") etc.
        ]
        # Keywords for broad categories to improve matching
        self.allergen_keywords = {
             "fish": ["fish", "tuna", "salmon", "cod", "haddock", "mackerel", "sardine"],
             "shellfish": ["shellfish", "shrimp", "prawn", "crab", "lobster", "clam", "oyster", "mussel", "scallop"],
             "dairy": ["dairy", "milk", "cream", "cheese", "cheddar", "parmesan", "mozzarella", "yogurt", "curd", "butter"], # Note: butter has low lactose
             "gluten": ["gluten", "wheat", "barley", "rye", "flour"], # Be careful with just "flour"
             "wheat": ["wheat", "flour"],
             # Add others like nuts, soy, eggs if desired
        }


    def postprocess_text(self, generated_texts):
        """
        Post-process generated recipe texts:
        1. Decode and clean special tokens.
        2. Replace mapped tokens (<sep>, <section>) with human-readable versions.
        3. Format into structured recipes.
        4. Check for allergen combinations.
        """
        processed_recipes = []

        for text in generated_texts:
            # Decode might handle some special tokens, but we clean residual ones
            # Remove other special tokens like <pad>, potentially <s>, etc.
            for token in self.special_tokens_to_remove:
                 if token in text: # Check before replacing
                    text = text.replace(token, "")

            # Replace mapped tokens with readable versions AFTER cleaning others
            for k, v in self.tokens_map.items():
                 if k in text: # Check before replacing
                    text = text.replace(k, v)

            text = text.strip() # Remove leading/trailing whitespace

            # Format structured recipe
            formatted_recipe = self._format_recipe(text)

            # Check for allergen combinations
            allergen_warnings = self.check_allergens(formatted_recipe)
            formatted_recipe["allergen_warnings"] = allergen_warnings

            processed_recipes.append(formatted_recipe)

        return processed_recipes

    def _format_recipe(self, text):
        """Format recipe text into structured dictionary"""
        recipe_dict = {"title": "Untitled Recipe", "ingredients": [], "directions": [], "allergen_warnings": []}
        current_section = None

        # Split by the section separator we introduced (\n)
        lines = text.split("\n")

        for line in lines:
            line = line.strip()
            if not line: continue # Skip empty lines

            if line.lower().startswith("title:"):
                recipe_dict["title"] = line[len("title:"):].strip().capitalize()
                current_section = "title"
            elif line.lower().startswith("ingredients:"):
                # Handle ingredients possibly spread across the same line after "ingredients:" marker
                items_text = line[len("ingredients:"):].strip()
                if items_text:
                     recipe_dict["ingredients"].extend([
                         item.strip().capitalize() for item in items_text.split("--") if item.strip()
                     ])
                current_section = "ingredients"
            elif line.lower().startswith("directions:"):
                # Handle directions possibly spread across the same line
                items_text = line[len("directions:"):].strip()
                if items_text:
                    recipe_dict["directions"].extend([
                        item.strip().capitalize() for item in items_text.split("--") if item.strip()
                    ])
                current_section = "directions"
            elif current_section == "ingredients" and line: # If already in ingredients section, treat line as ingredients
                 recipe_dict["ingredients"].extend([
                    item.strip().capitalize() for item in line.split("--") if item.strip()
                 ])
            elif current_section == "directions" and line: # If already in directions section, treat line as directions
                 recipe_dict["directions"].extend([
                    item.strip().capitalize() for item in line.split("--") if item.strip()
                 ])
            # else: Handle lines that don't fit expected structure (optional: log warning)

        # Clean up potentially empty items from splitting issues
        recipe_dict["ingredients"] = [ing for ing in recipe_dict["ingredients"] if ing]
        recipe_dict["directions"] = [step for step in recipe_dict["directions"] if step]

        # Capitalize first letter of steps
        recipe_dict["directions"] = [step[0].upper() + step[1:] if step else "" for step in recipe_dict["directions"]]


        return recipe_dict

    def check_allergens(self, recipe_dict):
        """Check for potentially problematic ingredient combinations using keywords."""
        warnings = []
        ingredients_lower = [ing.lower() for ing in recipe_dict["ingredients"]]
        present_keywords = set()

        # Identify all keywords present in the ingredients list
        for ing_lower in ingredients_lower:
            for category, keywords in self.allergen_keywords.items():
                for keyword in keywords:
                    if keyword in ing_lower:
                        present_keywords.add(category)
                        break # Go to next category once a match is found for this ingredient

        # Check defined combinations based on categories found
        checked_pairs = set()
        for cat1, cat2, reason in self.allergen_combinations:
             # Use the broader categories defined in allergen_keywords keys
             cat1_present = cat1 in present_keywords
             cat2_present = cat2 in present_keywords

             # Avoid duplicate warnings (e.g., fish-dairy vs dairy-fish)
             pair = tuple(sorted((cat1, cat2)))
             if cat1_present and cat2_present and pair not in checked_pairs:
                 warnings.append({
                     "categories": (cat1, cat2),
                     "reason": reason
                 })
                 checked_pairs.add(pair)

        return warnings

    def format_for_display(self, recipe_dict):
        """Format recipe dictionary for readable display"""
        if not isinstance(recipe_dict, dict):
            return "Invalid recipe format (not a dictionary)"

        display_text = f"[TITLE]: {recipe_dict.get('title', 'N/A')}\n\n"

        display_text += "[INGREDIENTS]:\n"
        ingredients = recipe_dict.get('ingredients', [])
        if ingredients:
            for i, ingredient in enumerate(ingredients):
                display_text += f"  - {ingredient}\n" # No numbering needed, just list
        else:
            display_text += "  (No ingredients listed)\n"

        display_text += "\n[DIRECTIONS]:\n"
        directions = recipe_dict.get('directions', [])
        if directions:
            for i, step in enumerate(directions):
                display_text += f"  {i+1}. {step}\n" # Number steps
        else:
             display_text += "  (No directions listed)\n"

        # Add allergen warnings if present
        allergen_warnings = recipe_dict.get("allergen_warnings", [])
        if allergen_warnings:
            display_text += "\n[⚠️ POTENTIAL ALLERGEN/SENSITIVITY WARNINGS]:\n"
            for i, warning in enumerate(allergen_warnings):
                 cat1, cat2 = warning['categories']
                 display_text += f"  - Combination ({cat1.capitalize()} + {cat2.capitalize()}): {warning['reason']}\n"

        return display_text

    def evaluate_recipe_quality(self, recipe_dict):
        """Evaluate recipe quality based on simple heuristics (0-100)"""
        if not isinstance(recipe_dict, dict): return 0
        score = 0
        max_score = 100

        # Check title presence and basic length
        if recipe_dict.get("title", "") and len(recipe_dict["title"]) > 3 and recipe_dict["title"] != "Untitled Recipe":
            score += 15 # More points for a seemingly valid title
        else:
             score += 5 # Minimal points if title is missing/default

        # Check ingredients presence and quantity
        ingredients = recipe_dict.get("ingredients", [])
        if ingredients:
            score += min(len(ingredients) * 4, 30)  # Points per ingredient, capped

            # Check ingredient variety (crude check)
            unique_words = set()
            for ingredient in ingredients:
                unique_words.update(ingredient.lower().split())
            score += min(len(unique_words) * 1, 15) # Points for unique words, capped
        else:
            score -= 10 # Penalize missing ingredients

        # Check directions presence and quantity/detail
        directions = recipe_dict.get("directions", [])
        if directions:
            score += min(len(directions) * 5, 30)  # Points per step, capped

            # Check step length as proxy for detail
            total_length = sum(len(step) for step in directions)
            score += min(total_length // 30, 10) # Points for total length, capped
        else:
             score -= 10 # Penalize missing directions

        # Deduct significant points for allergen warnings
        allergen_warnings = recipe_dict.get("allergen_warnings", [])
        score -= len(allergen_warnings) * 20  # Heavy penalty per warning

        # Normalize score to be between 0 and 100
        normalized_score = max(0, min(score, max_score))

        return int(normalized_score) # Return integer score


# --- Stage 3: Fine-Tuning ---

def fine_tune_recipe_model(
    base_model=BASE_MODEL,
    output_dir=OUTPUT_MODEL_DIR,
    preprocessed_dataset_dir=PREPROCESSED_DATA_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    fp16=FP16,
    batch_size=BATCH_SIZE,
    logging_dir=LOGGING_DIR
):
    """
    Fine-tune a T5 model for recipe generation.
    """
    print(f"\n--- Stage 3: Fine-tuning {base_model} ---")
    print(f"Configuration: Epochs={num_train_epochs}, BatchSize={batch_size}, LR={learning_rate}, FP16={fp16}")
    print(f"Output directory: {output_dir}")
    print(f"Logs directory: {logging_dir}")
    print(f"Ensure 'accelerate' is installed if FP16 is True.")

    # Load preprocessed dataset
    try:
        processed_dataset = load_from_disk(preprocessed_dataset_dir)
        print(f"Loaded preprocessed dataset from '{preprocessed_dataset_dir}'")
        # Ensure format is PyTorch tensors (might be redundant if set in preprocessing)
        processed_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
        print(f"Dataset contains {len(processed_dataset['train'])} training examples.")
        # Verify structure
        print(f"Train dataset columns: {processed_dataset['train'].column_names}")
        if not all(col in processed_dataset['train'].column_names for col in ['input_ids', 'attention_mask', 'labels']):
             print("Error: Dataset missing required columns ('input_ids', 'attention_mask', 'labels'). Check preprocessing.")
             return None, None
    except FileNotFoundError:
        print(f"Error: Preprocessed dataset directory '{preprocessed_dataset_dir}' not found.")
        print("Please ensure Stage 2 (preprocess_recipe_data) ran successfully.")
        return None, None
    except Exception as e:
        print(f"Error loading or processing dataset: {e}")
        # Attempt to print details of the first element for debugging
        try:
            print("First element structure:", processed_dataset["train"][0])
        except Exception as debug_e:
            print(f"Could not examine first element: {debug_e}")
        return None, None


    # Load tokenizer (ensure it matches the one used for preprocessing)
    tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
    special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
    num_added_tokens = tokenizer.add_special_tokens(special_tokens)
    if num_added_tokens > 0:
        print(f"Re-added {num_added_tokens} special tokens to tokenizer for safety.")

    # Load model
    print(f"Loading base model '{base_model}'...")
    model = AutoModelForSeq2SeqLM.from_pretrained(base_model)
    # Resize embeddings if new tokens were added
    model.resize_token_embeddings(len(tokenizer))
    print(f"Model tokenizer vocabulary size: {len(tokenizer)}")

    # Verify model has trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if trainable_params == 0:
         print("Warning: Model loaded has no trainable parameters!")
    else:
         print(f"Model has {trainable_params:,} trainable parameters.")


    # Set up training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        # Evaluation and Saving Strategy
        evaluation_strategy="epoch", # Evaluate at the end of each epoch
        save_strategy="epoch",       # Save at the end of each epoch
        # Learning Rate and Optimization
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        warmup_ratio=warmup_ratio,
        # Batch Size
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2, # Can often use larger batch size for eval
        # Training Duration
        num_train_epochs=num_train_epochs,
        # Technical Details
        fp16=fp16 and torch.cuda.is_available(), # Use FP16 only if flag is True AND cuda is available
        gradient_accumulation_steps=2, # Accumulate gradients to simulate larger batch size if needed
        # Checkpointing and Logging
        save_total_limit=2, # Keep only the best and the latest checkpoints
        load_best_model_at_end=True,
        metric_for_best_model="loss", # Use validation loss to determine the best model
        greater_is_better=False,    # Lower loss is better
        logging_dir=logging_dir,
        logging_strategy="steps",
        logging_steps=50, # Log training loss every 50 steps
        report_to=["tensorboard"], # Report metrics to TensorBoard
        # Generation arguments (used if predict_with_generate=True)
        predict_with_generate=True,
        generation_max_length=512,
        generation_num_beams=4,
    )

    # Create data collator
    # This pads batches dynamically - requires tokenizer to have pad_token set.
    if tokenizer.pad_token is None:
        print("Warning: Tokenizer does not have a pad token. Adding EOS as pad token.")
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id # Ensure model config knows pad token id

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100, # Make sure padding in labels is ignored for loss
        pad_to_multiple_of=8 if fp16 else None # Pad to multiple of 8 for efficiency with FP16
    )

    # Dummy compute_metrics - focusing on loss for now
    # Replace with ROUGE, BLEU etc. later if needed
    def compute_metrics(eval_preds):
        # loss is calculated by trainer automatically
        # Can add generation metrics later
        return {"placeholder_metric": 0.0}

    # Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=processed_dataset["train"],
        eval_dataset=processed_dataset["validation"],
        tokenizer=tokenizer, # Pass tokenizer for padding and generation
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)] # Stop if eval loss doesn't improve significantly
    )

    # --- Start Training ---
    print("\nStarting fine-tuning training...")
    try:
        train_result = trainer.train()
        print("Training completed successfully!")

        # Save final model and tokenizer
        trainer.save_model() # Saves the best model according to load_best_model_at_end=True
        tokenizer.save_pretrained(output_dir)
        print(f"Best model and tokenizer saved to '{output_dir}'")

        # Log metrics
        metrics = train_result.metrics
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    except Exception as e:
        print(f"\n!!! Training Error !!!")
        print(f"Error type: {type(e)}")
        print(f"Error details: {e}")
        # Simple debugging - check shapes of first batch
        try:
            print("\nAttempting to inspect first training batch:")
            first_batch = next(iter(trainer.get_train_dataloader()))
            print(f"Keys in batch: {first_batch.keys()}")
            for key, tensor in first_batch.items():
                print(f"  {key}: shape={tensor.shape}, dtype={tensor.dtype}")
        except Exception as debug_e:
            print(f"Could not inspect batch: {debug_e}")
        print("Training failed. Please check error message and dataset integrity.")
        return None, None # Indicate failure

    # --- Evaluate on Test Set ---
    print("\nEvaluating final model on the test set...")
    try:
        test_results = trainer.evaluate(eval_dataset=processed_dataset["test"])
        print("\nTest Set Evaluation Results:")
        trainer.log_metrics("test", test_results)
        trainer.save_metrics("test", test_results)
        print(json.dumps(test_results, indent=2))
    except Exception as e:
        print(f"Error during final evaluation on test set: {e}")

    print("\n--- Fine-tuning Stage Completed ---")
    # Return the loaded best model and tokenizer
    return model, tokenizer

# --- Stage 4: Testing & Generation (Optional - Example) ---

class AllergenDetectionTester:
    def __init__(self, model_path=OUTPUT_MODEL_DIR, base_model_fallback=BASE_MODEL):
        """Initialize the tester with model and post-processor."""
        print(f"\n--- Stage 4: Initializing AllergenDetectionTester ---")
        self.model = None
        self.tokenizer = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.use_mock_generation = False

        # Try loading fine-tuned model first
        if model_path and os.path.isdir(model_path):
             try:
                 print(f"Attempting to load fine-tuned model and tokenizer from '{model_path}'...")
                 self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                 self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
                 print("Successfully loaded fine-tuned model.")
                 self.model.to(self.device)
             except Exception as e:
                 print(f"Warning: Failed to load model from '{model_path}'. Error: {e}")
                 self.model = None
                 self.tokenizer = None
        else:
            print(f"Fine-tuned model path '{model_path}' not found or invalid.")

        # Fallback to base model or mock generation
        if self.model is None:
             print(f"Falling back to base model '{base_model_fallback}' for generation OR mock generation.")
             try:
                 self.tokenizer = AutoTokenizer.from_pretrained(base_model_fallback, use_fast=True)
                 # Optionally load base model for actual generation - can be slow on CPU
                 # self.model = AutoModelForSeq2SeqLM.from_pretrained(base_model_fallback)
                 # self.model.to(self.device)
                 # print("Loaded base model for generation.")
                 print("Will use MOCK generation as base model loading is commented out / not requested.")
                 self.use_mock_generation = True # Set to use mock if base model isn't loaded
             except Exception as e:
                 print(f"FATAL: Could not load base tokenizer '{base_model_fallback}'. Error: {e}")
                 raise # Can't proceed without a tokenizer

        # Ensure special tokens are added (important for post-processor)
        special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
        self.tokenizer.add_special_tokens(special_tokens)
        if self.model and not self.use_mock_generation: # Resize embeddings if using loaded model
              self.model.resize_token_embeddings(len(self.tokenizer))

        # Initialize the post-processor WITH the loaded/fallback tokenizer
        self.processor = RecipePostProcessor(self.tokenizer)

        if self.model and not self.use_mock_generation:
            self.model.eval() # Set model to evaluation mode

        print(f"Using device: {self.device}")
        print("Tester initialized.")

    def generate_recipe(self, ingredients):
        """Generate a recipe from ingredients using loaded model or mock."""
        if isinstance(ingredients, list):
            ingredients_str = ", ".join(ingredients)
        else:
            ingredients_str = ingredients # Assume it's already a string

        input_text = f"items: {ingredients_str}"

        # Use mock generation if specified or if no model loaded
        if self.use_mock_generation or not self.model:
            print(f"(Mock Generation for: {ingredients_str})")
            ingredients_list = [ing.strip() for ing in ingredients_str.split(",")]
            # Make mock recipe a bit more realistic & include potential allergens from input
            mock_title = f"Mock {ingredients_list[0].capitalize()} Dish"
            mock_ingredients_str = " <sep> ".join(ingredients_list)
            mock_directions = f"prepare the ingredients ({', '.join(ingredients_list[:2])}) <sep> mix everything together <sep> cook using {random.choice(['pan', 'oven', 'pot'])} <sep> season well <sep> serve"
            mock_recipe_text = (
                f"title: {mock_title} <section> "
                f"ingredients: {mock_ingredients_str} <section> "
                f"directions: {mock_directions}"
            )
            return mock_recipe_text

        # --- Actual Model Generation ---
        print(f"(Generating recipe with {self.model.name_or_path} for: {ingredients_str})")
        inputs = self.tokenizer(
            input_text,
            max_length=256,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        # Generation parameters
        generation_kwargs = {
            "max_length": 512,
            "min_length": 50,
            "num_beams": 4,
            "early_stopping": True,
            # For more creative/varied output:
            # "do_sample": True,
            # "temperature": 0.7,
            # "top_k": 50,
            # "top_p": 0.95,
            # "no_repeat_ngram_size": 3
        }

        with torch.no_grad():
            output_ids = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                **generation_kwargs
            )

        # Decode - skip special tokens handled by postprocessor, clean others
        # Set clean_up_tokenization_spaces=True for potentially better spacing.
        generated_text = self.tokenizer.decode(
            output_ids[0],
            skip_special_tokens=False, # Let postprocessor handle <sep>/<section>
            clean_up_tokenization_spaces=True
            )

        return generated_text


    def test_allergen_combinations(self, use_actual_generation=True):
        """Test specific ingredient combinations for allergen detection."""
        test_combinations = [
            # Potential Allergen Triggers
            "fish, parmesan cheese, pasta, garlic, olive oil", # Fish + Cheese (Dairy)
            "shrimp, milk, noodles, ginger, soy sauce",         # Shellfish + Milk (Dairy)
            "tuna, yogurt, bread, cucumber, dill",           # Fish + Yogurt (Dairy)
            "salmon, curd, rice, lemon, spinach",             # Fish + Curd (Dairy)
            "lobster, cream, butter, garlic, bread",           # Shellfish + Cream (Dairy)
            "crab cakes (crab, breadcrumbs), mango salsa (mango, lime)", # Shellfish + Mango
            "peanut butter cookies (peanuts, flour, egg)",      # Peanuts + Flour (Gluten/Wheat)
            "strawberry chocolate tart (strawberry, chocolate, cream)",# Strawberry + Chocolate

            # Likely Non-Allergenic Controls (according to our list)
            "chicken breast, rice, broccoli, soy sauce, sesame oil",
            "beef steak, potatoes, carrots, onion gravy",
            "lentil soup (lentils, carrots, celery, onion, broth)",
            "tofu stir-fry (tofu, bell peppers, onion, soy sauce)"
        ]

        results = []
        self.use_mock_generation = not use_actual_generation # Override based on arg

        print("\n======= TESTING ALLERGEN COMBINATIONS =======\n")
        if not use_actual_generation and not self.model:
             print("NOTE: Using MOCK generation because 'use_actual_generation' is False or no model was loaded.")
        elif not self.model:
            print("NOTE: Using MOCK generation because no model could be loaded.")

        for ingredients in test_combinations:
            print(f"\n--- Testing ingredients: {ingredients} ---")

            start_time = time.time()
            # Generate a recipe (mock or real)
            raw_recipe_text = self.generate_recipe(ingredients)
            generation_time = time.time() - start_time
            #print(f"Raw output: {raw_recipe_text[:150]}...") # Debug raw output

            # Post-process the recipe
            start_time = time.time()
            processed_recipes = self.processor.postprocess_text([raw_recipe_text])
            recipe_dict = processed_recipes[0] # Get the first (only) processed recipe
            processing_time = time.time() - start_time

            # Calculate quality score
            quality_score = self.processor.evaluate_recipe_quality(recipe_dict)

            # Format the recipe for display
            formatted_recipe = self.processor.format_for_display(recipe_dict)

            results.append({
                "input_ingredients": ingredients,
                "generated_text_raw": raw_recipe_text,
                "processed_recipe": recipe_dict,
                "formatted_text": formatted_recipe,
                "quality_score": quality_score,
                "allergen_warnings": recipe_dict.get("allergen_warnings", []),
                "generation_time_sec": round(generation_time, 2),
                "processing_time_sec": round(processing_time, 2)
            })

            # Print the result for this combination
            print(formatted_recipe)
            print(f"Quality Score: {quality_score}/100")
            print(f"(Generation: {generation_time:.2f}s, Processing: {processing_time:.2f}s)")
            print("-" * 60)


        # --- Summary of Allergen Detections ---
        print("\n======= ALLERGEN DETECTION SUMMARY =======\n")
        detected_count = 0
        for i, result in enumerate(results):
            warnings = result.get("allergen_warnings", [])
            print(f"{i+1}. Ingredients: {result['input_ingredients']}")
            print(f"   Detected Warnings: {len(warnings)}")
            if warnings:
                detected_count += 1
                for warning in warnings:
                    print(f"     - {warning['categories'][0].capitalize()} + {warning['categories'][1].capitalize()}: {warning['reason']}")
            print() # Newline for readability

        safe_recipes_count = len(results) - detected_count
        print(f"Summary: Total Tested = {len(results)}, Recipes with Warnings = {detected_count}, Recipes without Warnings = {safe_recipes_count}")

        return results


# --- Main Execution Block ---

if __name__ == "__main__":
    print("Starting Recipe Generation Pipeline...")
    print(f"Using Base Model: {BASE_MODEL}")
    print(f"Output Model Dir: {OUTPUT_MODEL_DIR}")

    # --- Step 1: Create Data (if needed) ---
    # Check if raw data exists, otherwise create it
    if not os.path.exists(RAW_DATASET_DIR):
        print(f"Raw dataset not found at '{RAW_DATASET_DIR}'. Running Stage 1...")
        create_recipe_dataset(sample_size=SAMPLE_SIZE, seed=SEED)
    else:
        print(f"Raw dataset found at '{RAW_DATASET_DIR}'. Skipping Stage 1.")

    # --- Step 2: Preprocess Data (if needed) ---
    # Check if preprocessed data exists, otherwise create it
    if not os.path.exists(PREPROCESSED_DATA_DIR):
         print(f"Preprocessed dataset not found at '{PREPROCESSED_DATA_DIR}'. Running Stage 2...")
         processed_data, pp_tokenizer = preprocess_recipe_data()
         if processed_data is None:
             print("Preprocessing failed. Exiting.")
             exit()
    else:
        print(f"Preprocessed dataset found at '{PREPROCESSED_DATA_DIR}'. Skipping Stage 2.")
        # We still need a tokenizer for later stages, load it from base
        print(f"Loading tokenizer '{BASE_MODEL}' for subsequent stages...")
        pp_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
        special_tokens = {"additional_special_tokens": ["<sep>", "<section>"]}
        pp_tokenizer.add_special_tokens(special_tokens)


    # --- Step 3: Fine-tune Model ---
    # Set train=True to run fine-tuning, False to skip and proceed to testing
    train_model = True
    fine_tuned_model = None
    fine_tuned_tokenizer = None

    if train_model:
        print("\nProceeding to Stage 3: Fine-tuning...")
        # Pass directories explicitly
        fine_tuned_model, fine_tuned_tokenizer = fine_tune_recipe_model(
            base_model=BASE_MODEL,
            output_dir=OUTPUT_MODEL_DIR,
            preprocessed_dataset_dir=PREPROCESSED_DATA_DIR,
            num_train_epochs=NUM_TRAIN_EPOCHS,
            learning_rate=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY,
            warmup_ratio=WARMUP_RATIO,
            fp16=FP16,
            batch_size=BATCH_SIZE,
            logging_dir=LOGGING_DIR
        )

        if fine_tuned_model is None:
            print("Fine-tuning failed or was skipped due to error. Exiting or proceeding without fine-tuned model...")
            # Decide whether to exit or allow testing with base/mock
            # exit() # Uncomment to stop if training fails
        else:
             print("Fine-tuning seems successful.")
    else:
        print("\nSkipping Stage 3: Fine-tuning based on 'train_model' flag.")
        # Try to load if model exists from previous run
        if os.path.exists(OUTPUT_MODEL_DIR):
            print(f"Attempting to load previously fine-tuned model from {OUTPUT_MODEL_DIR} for testing...")
            try:
                fine_tuned_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_MODEL_DIR, use_fast=True)
                # No need to load the full model here unless tester needs it directly passed
                print("Loaded tokenizer for testing.")
            except Exception as e:
                print(f"Could not load tokenizer from {OUTPUT_MODEL_DIR}: {e}. Using base tokenizer.")
                fine_tuned_tokenizer = pp_tokenizer # Fallback to preprocessor tokenizer
        else:
             print(f"No previously fine-tuned model found at {OUTPUT_MODEL_DIR}.")
             fine_tuned_tokenizer = pp_tokenizer # Fallback to preprocessor tokenizer


    # --- Step 4: Test Generation and Allergen Detection ---
    print("\nProceeding to Stage 4: Testing Allergen Detection...")
    # Initialize tester - it will try to load from OUTPUT_MODEL_DIR
    tester = AllergenDetectionTester(model_path=OUTPUT_MODEL_DIR, base_model_fallback=BASE_MODEL)

    # Run the tests - specify whether to force mock generation or attempt actual generation
    # Set use_actual_generation=True if you trained a model AND have enough compute/time
    # Set use_actual_generation=False to always use MOCK generation for speed/debugging post-processing
    use_actual_generation_flag = (fine_tuned_model is not None or os.path.exists(OUTPUT_MODEL_DIR)) and torch.cuda.is_available() # Example logic: Use real generation if trained/exists and GPU available
    if not use_actual_generation_flag:
         print("\nNOTE: Will use MOCK recipe generation for allergen tests (either model not trained/found, or no GPU).")

    test_results = tester.test_allergen_combinations(use_actual_generation=use_actual_generation_flag)

    # Optional: Save test results
    results_file = os.path.join(OUTPUT_MODEL_DIR, "allergen_test_results.json")
    try:
        # Need to handle potential non-serializable items (like tensors if any slip through)
        def default_serializer(obj):
             if isinstance(obj, torch.Tensor):
                 return obj.tolist() # Convert tensors to lists
             # Add other types if needed
             return f"Object of type {type(obj).__name__} is not JSON serializable"

        with open(results_file, 'w') as f:
            json.dump(test_results, f, indent=2, default=default_serializer)
        print(f"\nSaved allergen test results to {results_file}")
    except Exception as e:
        print(f"\nError saving test results to JSON: {e}")


    print("\nRecipe Generation Pipeline Finished.")

Starting Recipe Generation Pipeline...
Using Base Model: t5-base
Output Model Dir: ./recipe_model_finetuned
Raw dataset not found at 'recipe_dataset'. Running Stage 1...

--- Stage 1: Creating Synthetic Recipe Dataset ---
Generating 2000 synthetic recipes...
Saved raw synthetic data to synthetic_recipe_data/synthetic_recipes.json


Saving the dataset (0/1 shards):   0%|          | 0/1600 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]

Dataset created with 2000 recipes and saved to 'recipe_dataset'
Split sizes -> Train: 1600, Validation: 200, Test: 200

Example recipes (first 3):

Input: items: curd, eggs, lentils, tempeh
Output: title: Simple Lentils Stew <section> ingredients: curd <sep> eggs <sep> lentils <sep> tempeh <section> directions: Bring to a simmer, then reduce heat and cook for 26 minutes until lentils is cooked through. <sep> Preheat your oven to 219°C (341°F). <sep> In a bowl, combine curd and eggs. <sep> Add olive oil to a pot over medium heat. <sep> Add lentils, tempeh and cook for 7 minutes, stirring occasionally. <sep> Garnish with nuts and serve hot. <sep> Sauté onion until fragrant, about 3 minutes.

Input: items: beef, cream, shrimp, lettuce, orange, chocolate, oats, lamb
Output: title: Steamed Lamb Casserole <section> ingredients: beef <sep> cream <sep> shrimp <sep> lettuce <sep> orange <sep> chocolate <sep> oats <sep> lamb <section> directions: Bring to a simmer, then reduce heat and cook for 

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Added 2 special tokens ('<sep>', '<section>') to the tokenizer.
Applying preprocessing function...


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1600 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/200 [00:00<?, ? examples/s]

Preprocessed dataset saved to 'preprocessed_recipe_dataset'

Sample preprocessed example (first 10 tokens):
Input IDs: [1173, 10, 5495, 26, 6, 5875, 6, 24026, 7, 6]
Attention mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Labels: [2233, 10, 9415, 301, 295, 1558, 3557, 210, 32101, 3018]

Proceeding to Stage 3: Fine-tuning...

--- Stage 3: Fine-tuning t5-base ---
Configuration: Epochs=1, BatchSize=4, LR=5e-05, FP16=True
Output directory: ./recipe_model_finetuned
Logs directory: ./recipe_logs
Ensure 'accelerate' is installed if FP16 is True.
Loaded preprocessed dataset from 'preprocessed_recipe_dataset'
Dataset contains 1600 training examples.
Train dataset columns: ['input_ids', 'attention_mask', 'labels']
Re-added 2 special tokens to tokenizer for safety.
Loading base model 't5-base'...


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Model tokenizer vocabulary size: 32102
Model has 222,883,584 trainable parameters.


  trainer = Seq2SeqTrainer(



Starting fine-tuning training...


  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Placeholder Metric
1,1.2169,0.961998,0.0


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Training completed successfully!
Best model and tokenizer saved to './recipe_model_finetuned'
***** train metrics *****
  epoch                    =        1.0
  total_flos               =   453708GF
  train_loss               =      3.191
  train_runtime            = 0:08:15.46
  train_samples_per_second =      3.229
  train_steps_per_second   =      0.404

Evaluating final model on the test set...



Test Set Evaluation Results:
***** test metrics *****
  epoch                   =        1.0
  eval_loss               =     0.9741
  eval_placeholder_metric =        0.0
  eval_runtime            = 0:04:33.44
  eval_samples_per_second =      0.731
  eval_steps_per_second   =      0.091
{
  "eval_loss": 0.9741456508636475,
  "eval_placeholder_metric": 0.0,
  "eval_runtime": 273.4498,
  "eval_samples_per_second": 0.731,
  "eval_steps_per_second": 0.091,
  "epoch": 1.0
}

--- Fine-tuning Stage Completed ---
Fine-tuning seems successful.

Proceeding to Stage 4: Testing Allergen Detection...

--- Stage 4: Initializing AllergenDetectionTester ---
Attempting to load fine-tuned model and tokenizer from './recipe_model_finetuned'...
Successfully loaded fine-tuned model.
Using device: cuda
Tester initialized.



--- Testing ingredients: fish, parmesan cheese, pasta, garlic, olive oil ---
(Generating recipe with ./recipe_model_finetuned for: fish, parmesan cheese, pasta, garlic, olive oil)
[TIT