    This notebook contains the code to perform LLM-driven categorization of food entries on the Shangai T2DM data.

    Running the notebook assumes you have already 
    - installed all requiredments in `./requirements.txt` 
    - downloaded the data in the `./data` folder 
    - set up your LLM API key in `.env` file.
    
    See README.md for instructions.

In [None]:
local_base_path = "./data"
output_file_name = "processed_food_diary_entries.parquet"

To use LLMs, we first need to check the keys are properly set in the .env file. For this analysis, we use an OpenAI LLM (gpt-4o-mini).

In [None]:
from dotenv import load_dotenv
import os

# Check if the .env file is present. If not, please create the file.
load_dotenv()

In [None]:
if os.getenv("OPENAI_API_KEY") is None:
    raise ValueError("OPENAI_API_KEY not found") # Note: no errors must be raised!

We will also need a cache folder to store the output files of the LLM.

In [None]:
out_folder = "./LLM_outputs/"
cache_folder = f"{out_folder}cache_folder/"
os.makedirs(cache_folder, exist_ok=True)

Let's define structured Pydantic models to represent, validate, and analyze raw food diary entries, including meal classification, portion assessment, and data quality flags. These classes support both LLM input formatting and downstream consistency checks for nutrition-related data processing.

In [None]:
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from typing import List, Dict, Optional
from enum import Enum

class MealType(str, Enum):
    """
    Classification of the meal timing during the day.
    """
    BREAKFAST = "Breakfast"
    LUNCH = "Lunch"
    DINNER = "Dinner"

class PortionSize(str, Enum):
    """
    Classification of portion sizes relative to standard recommendations.
    """
    LOW = "low"  # Below recommended range
    AVERAGE = "average"  # Within recommended range
    HIGH = "high"  # Above recommended range but plausible
    TYPING_ERROR = "typing_error"  # Likely data entry error

class CaloricDensity(str, Enum):
    """
    Overall caloric assessment of the meal.
    """
    LOW = "low"  # Insufficient calories for meal type
    AVERAGE = "average"  # Appropriate caloric content
    HIGH = "high"  # Higher than recommended but plausible
    UNREALISTIC = "unrealistic"  # Implausible caloric content

class DataQualityIssue(str, Enum):
    """
    Common data quality issues identified in the entries.
    """
    MISSING_QUANTITY = "missing_quantity"  # Amounts not specified
    MISSING_UNIT = "missing_unit"  # Units not specified
    FORMATTING_ERROR = "formatting_error"  # Text formatting issues
    AMBIGUOUS_DESCRIPTION = "ambiguous_description"  # Unclear descriptions
    INCOMPLETE_ENTRY = "incomplete_entry"  # Missing major components
    UNREALISTIC_VALUE = "unrealistic_value"  # Implausible amounts
    TYPING_ERROR = "typing_error"  # Likely typing/data entry errors
    VALID = "valid"  # No issues detected

class UnprocessedFoodDiaryEntry(BaseModel):
    """
    Raw food diary entry containing the original description and patient context.
    This is used as input for LLM processing.
    """
    event_description: str = Field(
        ...,
        description="Raw text description of the meal, typically containing food items and their quantities in various formats",
        examples=["Rice 185 g\nTomato and egg soup 200 g\nBraised eggplant 50 g"]
    )
    event_type: MealType = Field(
        ...,
        description="The type of meal or eating event"
    )
    age: int = Field(
        ...,
        ge=0,
        le=120,
        description="Patient's age in years"
    )
    height: float = Field(
        ...,
        ge=0.5,
        le=2.5,
        description="Patient's height in meters"
    )
    weight: float = Field(
        ...,
        ge=20.0,
        le=300.0,
        description="Patient's weight in kilograms"
    )
    bmi: float = Field(
        ...,
        ge=10.0,
        le=100.0,
        description="Patient's Body Mass Index (weight/height²)"
    )
    is_female: bool = Field(
        ...,
        description="Patient's biological sex (True for female, False for male)"
    )

    def __str__(self) -> str:
        """
        Create a markdown formatted string representation suitable for LLM input.
        """
        lines = [
            f"## Food Diary Entry - {self.event_type}",
            "",
            "### Patient Information",
            f"- Age: {self.age} years",
            f"- Gender: {'Female' if self.is_female else 'Male'}",
            f"- BMI: {self.bmi:.1f} kg/m²",
            "",
            "### Raw Food Description",
            f"{self.event_description}"
        ]
        return "\n".join(lines)

class FoodItemWithAnalysis(BaseModel):
    """
    Combined food item and portion analysis.
    """
    name: str = Field(
        ...,
        description="The name of the food item, including cooking method when known (e.g., 'Steamed rice', 'Fried chicken')",
        min_length=1
    )
    amount: float = Field(
        ...,
        description="The numerical quantity of the food item",
        gt=0
    )
    unit: str = Field(
        ...,
        description="The unit of measurement (e.g., 'g' for grams, 'ml' for milliliters, 'piece')",
        min_length=1
    )
    portion_size: PortionSize = Field(
        ...,
        description="Assessment of this item's portion size compared to standard recommendations"
    )
    standard_range: str = Field(
        ...,
        description="Reference range for this food type (e.g., '150-300g for rice')"
    )

    def __str__(self) -> str:
        """String representation with portion analysis"""
        return f"{self.name} ({self.amount} {self.unit}, {self.portion_size})"

class ProcessedFoodDiaryEntry(UnprocessedFoodDiaryEntry):
    """
    Processed food diary entry with detailed analysis fields.
    """
    items: List[FoodItemWithAnalysis] = Field(
        ...,
        description="List of food items with portion analysis",
        min_items=1
    )
    
    meal_caloric_density: CaloricDensity = Field(
        ...,
        description="Overall assessment of the meal's caloric content"
    )
    
    data_quality_issues: List[DataQualityIssue] = Field(
        ...,
        description="List of identified data quality issues in the entry",
        min_items=1
    )
    
    is_meal_complete: bool = Field(
        ...,
        description="Whether the meal appears to be a complete record (vs. partial logging)"
    )
    
    has_realistic_portions: bool = Field(
        ...,
        description="Whether all portions are within normal ranges"
    )
    
    needs_nutrition_review: bool = Field(
        ...,
        description="Flag for entries that might need dietary intervention"
    )

    should_reject_sample: bool = Field(
        ...,
        description="""Final decision on sample rejection. Only reject if:
        - Clearly impossible amounts (e.g., 1500g rice)
        - Missing critical information (quantities/units)
        - Completely ambiguous descriptions and incomplete
        - Formatting errors making interpretation impossible and incomplete
        Do NOT reject for:
        - Small but plausible portions
        - Single item snacks
        - Diet-appropriate portions
        - Cultural portion variations"""
    )
    
    extra_info: str = Field(
        ...,
        description="Additional context, explanations of issues, or analyst notes"
    )

    @field_validator("age", "height", "weight", "bmi", "is_female", "event_type", mode='after')
    @classmethod
    def validate_unchanged_fields(cls, value: any, info: ValidationInfo) -> any:
        """Validates that fields from the original entry remain unchanged."""
        if info.context is None:
            return value
            
        original_entry: Optional[Dict] = info.context.get('original_entry')
        if original_entry is None:
            return value
            
        field_name = info.field_name
        original_value = original_entry.get(field_name)
        
        if original_value is not None and value != original_value:
            raise ValueError(
                f"Field '{field_name}' has been modified. "
                f"Original value: {original_value}, "
                f"New value: {value}. "
                "Patient information fields must remain unchanged during processing."
            )
        
        return value

    @model_validator(mode='after')
    def validate_meal(self) -> 'ProcessedFoodDiaryEntry':
        """Validate overall meal consistency"""
        # Set has_realistic_portions based on portion sizes
        unrealistic_portions = any(
            item.portion_size == PortionSize.TYPING_ERROR 
            for item in self.items
        )
        if unrealistic_portions != (not self.has_realistic_portions):
            self.has_realistic_portions = not unrealistic_portions

        # Determine if sample should be rejected
        should_reject = any([
            # Data entry errors
            unrealistic_portions,
            # Missing critical information
            DataQualityIssue.MISSING_QUANTITY in self.data_quality_issues,
            DataQualityIssue.MISSING_UNIT in self.data_quality_issues,
            # Formatting issues
            DataQualityIssue.FORMATTING_ERROR in self.data_quality_issues 
            and not self.is_meal_complete,  # Only reject if incomplete
            # Ambiguous entries
            DataQualityIssue.AMBIGUOUS_DESCRIPTION in self.data_quality_issues 
            and not self.is_meal_complete  # Only reject if incomplete
        ])

        if should_reject != self.should_reject_sample:
            self.should_reject_sample = should_reject

        # Ensure UNREALISTIC_VALUE in issues if unrealistic portions
        if unrealistic_portions and DataQualityIssue.UNREALISTIC_VALUE not in self.data_quality_issues:
            self.data_quality_issues.append(DataQualityIssue.UNREALISTIC_VALUE)

        return self

    def __str__(self) -> str:
        """
        Create a markdown formatted string representation suitable for review.
        """
        lines = [
            f"## Food Diary Entry - {self.event_type}",
            "",
            "### Patient Information",
            f"- Age: {self.age} years",
            f"- Gender: {'Female' if self.is_female else 'Male'}",
            f"- BMI: {self.bmi:.1f} kg/m²",
            "",
            "### Original Description",
            self.event_description,
            "",
            "### Structured Food Items"
        ]
        
        # Add each food item with its analysis
        for item in self.items:
            lines.append(f"- {str(item)}")
            lines.append(f"  * Standard Range: {item.standard_range}")
            
        # Add assessment section
        lines.extend([
            "",
            "### Meal Assessment",
            f"- Caloric Density: {self.meal_caloric_density}",
            f"- Data Quality Issues: {', '.join(str(issue) for issue in self.data_quality_issues)}",
            f"- Complete Meal: {'Yes' if self.is_meal_complete else 'No'}",
            f"- Realistic Portions: {'Yes' if self.has_realistic_portions else 'No'}",
            f"- Needs Nutrition Review: {'Yes' if self.needs_nutrition_review else 'No'}",
            f"- Should Reject Sample: {'Yes' if self.should_reject_sample else 'No'}",
            "",
            "### Additional Information",
            self.extra_info
        ])
        
        return "\n".join(lines)

Below are the system and user prompts for the LLM to process the food diary entries.

In [None]:
def create_system_prompt(unprocessed_schema: dict, processed_schema: dict) -> str:
    """
    Creates the system prompt with proper schema references and mixed items handling.
    """
    return f'''You are a specialized AI trained to process food diary entries by extracting structured food items from text descriptions and performing detailed nutritional analysis. Your task is to transform unprocessed food diary entries into processed entries with validated food items.

INPUT SCHEMA:
{str(unprocessed_schema)}

OUTPUT SCHEMA:
{str(processed_schema)}

Your primary responsibilities are:

1. EXTRACTION AND STRUCTURING:
   - Extract food items with their quantities
   - Keep mixed items together (do not split amounts)
   - Include cooking methods when known
   - Maintain complete item descriptions

   Mixed Items Handling:
   - Keep mixed ingredients as single items with their total weight
   - Use complete, descriptive names
   Example: "vegetables and meat stir-fry 150g" becomes:
   {{
     "name": "Stir-fried meat with vegetables",
     "amount": 150,
     "unit": "g"
   }}
   NOT:
   [
     {{"name": "Meat", "amount": 100, "unit": "g"}},
     {{"name": "Vegetables", "amount": 50, "unit": "g"}}  // Don't split like this
   ]

2. PORTION SIZE ANALYSIS:
   Each food item must be classified as:
   - "low": Below standard range but plausible (e.g., elderly, diet control)
   - "average": Within standard range
   - "high": Above standard range but plausible
   - "typing_error": Clearly impossible amounts

   Standard Ranges (for complete dishes):
   * Mixed Rice/Noodle Dishes: 200-400g
   * Mixed Meat/Vegetable Dishes: 150-300g
   * Soups/Porridges: 200-400ml
   * Single Component Items:
     - Plain Rice/Pasta/Noodles: 150-300g
     - Plain Meat/Fish: 100-200g
     - Plain Vegetables: 100-300g
     - Fruits: 100-200g
   
   Adjust ranges for:
   - Elderly patients: 50-70% of standard range
   - Diet control: 50-70% of standard range
   - High BMI patients: lower end of ranges
   - Low BMI patients: higher end of ranges

3. DATA QUALITY ASSESSMENT:
   Use appropriate data_quality_issues:
   - MISSING_QUANTITY: No amount specified
   - MISSING_UNIT: No unit specified
   - FORMATTING_ERROR: Text formatting issues (e.g., missing spaces)
   - AMBIGUOUS_DESCRIPTION: Unclear descriptions
   - INCOMPLETE_ENTRY: Missing major components
   - UNREALISTIC_VALUE: Clearly impossible amounts
   - TYPING_ERROR: Data entry errors (e.g., missing decimals)
   - VALID: No issues detected

4. REJECTION CRITERIA:
   Set should_reject_sample = true ONLY for:
   - Clearly impossible amounts (e.g., 1500g rice)
   - Missing critical information AND incomplete meal
   - Completely ambiguous descriptions AND incomplete meal
   - Formatting errors making interpretation impossible

   Do NOT reject for:
   - Small portions (especially for elderly)
   - Single item entries that are complete
   - High but plausible portions
   - Mixed items with shared weights

Here are examples showing proper analysis:

EXAMPLE 1 - Mixed Items (Valid Entry):
Input:
{{
    "event_description": "Pork and vegetable stir-fry 250g, Rice 150g",
    "event_type": "Lunch",
    "age": 45,
    "height": 1.75,
    "weight": 70.0,
    "bmi": 22.9,
    "is_female": false
}}

Output:
{{
    "event_description": "Pork and vegetable stir-fry 250g, Rice 150g",
    "event_type": "Lunch",
    "age": 45,
    "height": 1.75,
    "weight": 70.0,
    "bmi": 22.9,
    "is_female": false,
    "items": [
        {{
            "name": "Stir-fried pork with vegetables",
            "amount": 250,
            "unit": "g",
            "portion_size": "average",
            "standard_range": "150-300g for mixed meat and vegetable dishes"
        }},
        {{
            "name": "Steamed white rice",
            "amount": 150,
            "unit": "g",
            "portion_size": "average",
            "standard_range": "150-300g for rice"
        }}
    ],
    "meal_caloric_density": "average",
    "data_quality_issues": ["valid"],
    "is_meal_complete": true,
    "has_realistic_portions": true,
    "needs_nutrition_review": false,
    "should_reject_sample": false,
    "extra_info": "Well-balanced meal with appropriate portions. Mixed dish kept as single item with total weight."
}}

EXAMPLE 2 - Elderly Patient (Valid with Low Portions):
Input:
{{
    "event_description": "Mixed vegetable soup with meat 150ml, Small rice 75g",
    "event_type": "Dinner",
    "age": 82,
    "height": 1.55,
    "weight": 52.0,
    "bmi": 21.6,
    "is_female": true
}}

Output:
{{
    "event_description": "Mixed vegetable soup with meat 150ml, Small rice 75g",
    "event_type": "Dinner",
    "age": 82,
    "height": 1.55,
    "weight": 52.0,
    "bmi": 21.6,
    "is_female": true,
    "items": [
        {{
            "name": "Mixed vegetable and meat soup",
            "amount": 150,
            "unit": "ml",
            "portion_size": "low",
            "standard_range": "200-400ml for soups, 100-200ml appropriate for elderly"
        }},
        {{
            "name": "Steamed white rice",
            "amount": 75,
            "unit": "g",
            "portion_size": "low",
            "standard_range": "150-300g for rice, 50-150g appropriate for elderly"
        }}
    ],
    "meal_caloric_density": "low",
    "data_quality_issues": ["valid"],
    "is_meal_complete": true,
    "has_realistic_portions": true,
    "needs_nutrition_review": false,
    "should_reject_sample": false,
    "extra_info": "Portions appropriately sized for elderly patient. Mixed soup kept as single item. While portions are below standard ranges, they are suitable for demographics."
}}

EXAMPLE 3 - Data Entry Error:
Input:
{{
    "event_description": "Rice1500g, Mixed vegetables and pork 50g",
    "event_type": "Breakfast",
    "age": 35,
    "height": 1.65,
    "weight": 65.0,
    "bmi": 23.9,
    "is_female": true
}}

Output:
{{
    "event_description": "Rice1500g, Mixed vegetables and pork 50g",
    "event_type": "Breakfast",
    "age": 35,
    "height": 1.65,
    "weight": 65.0,
    "bmi": 23.9,
    "is_female": true,
    "items": [
        {{
            "name": "Rice",
            "amount": 1500,
            "unit": "g",
            "portion_size": "typing_error",
            "standard_range": "150-300g for rice"
        }},
        {{
            "name": "Mixed vegetables and pork stir-fry",
            "amount": 50,
            "unit": "g",
            "portion_size": "low",
            "standard_range": "150-300g for mixed dishes"
        }}
    ],
    "meal_caloric_density": "unrealistic",
    "data_quality_issues": ["formatting_error", "typing_error", "unrealistic_value"],
    "is_meal_complete": false,
    "has_realistic_portions": false,
    "needs_nutrition_review": true,
    "should_reject_sample": true,
    "extra_info": "Clear data entry error in rice portion (1500g, likely missing decimal). Mixed dish kept as single item but portion is low. Formatting error in rice entry (no space)."
}}

EXAMPLE 4 - Complex Mixed Items:
Input:
{{
    "event_description": "Stir-fried rice with shrimp, pork and vegetables 300g\\nCabbage and mushroom soup 200ml",
    "event_type": "Lunch",
    "age": 55,
    "height": 1.70,
    "weight": 68.0,
    "bmi": 23.5,
    "is_female": false
}}

Output:
{{
    "event_description": "Stir-fried rice with shrimp, pork and vegetables 300g\\nCabbage and mushroom soup 200ml",
    "event_type": "Lunch",
    "age": 55,
    "height": 1.70,
    "weight": 68.0,
    "bmi": 23.5,
    "is_female": false,
    "items": [
        {{
            "name": "Mixed stir-fried rice with shrimp, pork and vegetables",
            "amount": 300,
            "unit": "g",
            "portion_size": "average",
            "standard_range": "200-400g for mixed rice dishes"
        }},
        {{
            "name": "Cabbage and mushroom soup",
            "amount": 200,
            "unit": "ml",
            "portion_size": "average",
            "standard_range": "200-400ml for soups"
        }}
    ],
    "meal_caloric_density": "average",
    "data_quality_issues": ["valid"],
    "is_meal_complete": true,
    "has_realistic_portions": true,
    "needs_nutrition_review": false,
    "should_reject_sample": false,
    "extra_info": "Well-balanced meal with appropriate portions. Complex mixed dishes kept as single items with clear descriptions. Portions suitable for patient demographics."
}}

Remember:
1. Always keep mixed items together with their total weight
2. Consider patient demographics for portion assessment
3. Use descriptive names that include all components
4. Document any assumptions in extra_info
5. Only reject samples that are truly unusable

Your goal is to create accurate, structured data while properly handling mixed items and maintaining consistency in portion analysis.
'''

def create_user_prompt(unprocessed_entry: UnprocessedFoodDiaryEntry) -> str:
    """
    Creates a formatted user prompt from an unprocessed food diary entry.
    """
    return f'''Please analyze this food diary entry and convert it to a structured format with detailed nutritional analysis. Consider the patient's context:
- Age: {unprocessed_entry.age} years
- BMI: {unprocessed_entry.bmi:.1f} kg/m²
- Gender: {'Female' if unprocessed_entry.is_female else 'Male'}

Here is the entry in both human-readable and JSON format:

HUMAN READABLE FORMAT:
{str(unprocessed_entry)}

JSON FORMAT:
{unprocessed_entry.model_dump_json(indent=2)}

For each food item provide:
1. Standardized name (including cooking method when known)
2. Amount and unit
3. Portion size assessment ("low", "average", "high", or "typing_error")
4. Standard portion range reference

Analyze the overall meal for:
1. Completeness and balance
2. Caloric density
3. Data quality issues
4. Need for nutritional review
5. Potential rejection criteria

Consider the patient's age and BMI when assessing portions. Format your response as a valid JSON object matching the output schema provided. Include only the JSON object, no explanatory text.'''

system_prompt = create_system_prompt(unprocessed_schema=UnprocessedFoodDiaryEntry.model_json_schema(),processed_schema=ProcessedFoodDiaryEntry.model_json_schema())

This is an example usage showing both prompt creation and entry processing

In [None]:
# Sample test cases
test_cases = [
    # Case 1: Elderly patient with small portions
    {
        'event_description': 'Rice 185 g\nTomato and egg soup 200 g\nBraised eggplant 50 g\nBraised meat 25 g\nVegetable salad 90 g\nHairtail 30 g',
        'event_type': 'Lunch',
        'age': 77,
        'height': 1.5,
        'weight': 56.0,
        'bmi': 24.89,
        'is_female': True
    },
    # Case 2: Data entry error example
    {
        'event_description': 'Rice1500g\nVegetables 50g',
        'event_type': 'Breakfast',
        'age': 35,
        'height': 1.65,
        'weight': 65.0,
        'bmi': 23.9,
        'is_female': True
    },
    # Case 3: Mixed item example
    {
        'event_description': 'Vegetable and pork stir-fry 150g',
        'event_type': 'Dinner',
        'age': 45,
        'height': 1.70,
        'weight': 68.0,
        'bmi': 23.5,
        'is_female': False
    }
]

# Process each test case
for i, sample_data in enumerate(test_cases, 1):
    print(f"\n{'='*80}\nTEST CASE {i}\n{'='*80}")
    
    try:
        # 1. Create unprocessed entry
        unprocessed_entry = UnprocessedFoodDiaryEntry(**sample_data)
        print("\nUNPROCESSED ENTRY:")
        print(unprocessed_entry)
        
        # 2. Create prompts (what you'd send to LLM)
        unprocessed_schema = UnprocessedFoodDiaryEntry.model_json_schema()
        processed_schema = ProcessedFoodDiaryEntry.model_json_schema()
        
        system_prompt = create_system_prompt(unprocessed_schema, processed_schema)
        user_prompt = create_user_prompt(unprocessed_entry)
        
        print("\nPROMPTS CREATED SUCCESSFULLY")
        
        # 3. Example processed data (simulating LLM output)
        if i == 1:  # Elderly patient case
            processed_data = {
                **sample_data,
                'items': [
                    {
                        "name": "Steamed white rice",
                        "amount": 185,
                        "unit": "g",
                        "portion_size": "average",
                        "standard_range": "150-300g for rice, reduced range appropriate for elderly"
                    },
                    {
                        "name": "Tomato and egg soup",
                        "amount": 200,
                        "unit": "g",
                        "portion_size": "average",
                        "standard_range": "200-400ml for soup, 100-200ml appropriate for elderly"
                    },
                    {
                        "name": "Braised eggplant",
                        "amount": 50,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-300g for vegetables, 50-100g appropriate for elderly"
                    },
                    {
                        "name": "Braised meat",
                        "amount": 25,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-200g for meat, 25-100g appropriate for elderly"
                    },
                    {
                        "name": "Fresh vegetable salad",
                        "amount": 90,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-300g for vegetables, 50-100g appropriate for elderly"
                    },
                    {
                        "name": "Steamed hairtail fish",
                        "amount": 30,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-200g for fish, 25-100g appropriate for elderly"
                    }
                ],
                'meal_caloric_density': "low",
                'data_quality_issues': ["valid"],
                'is_meal_complete': True,
                'has_realistic_portions': True,
                'needs_nutrition_review': True,
                'should_reject_sample': False,
                'extra_info': (
                    "Multiple small portions appropriate for elderly patient (age 77). "
                    "While portions are below standard ranges, they are within acceptable ranges "
                    "for elderly individuals with potentially reduced appetite. "
                    "Meal is complete with good variety (carbs, proteins, vegetables) but "
                    "nutritional review recommended to ensure caloric needs are being met. "
                    "Consider patient's BMI (24.89) which is healthy, suggesting these portion "
                    "sizes may be appropriate for maintaining weight."
                )
            }
        elif i == 2:  # Data entry error case
            processed_data = {
                **sample_data,
                'items': [
                    {
                        "name": "Rice",
                        "amount": 1500,
                        "unit": "g",
                        "portion_size": "typing_error",
                        "standard_range": "150-300g for rice"
                    },
                    {
                        "name": "Mixed vegetables",
                        "amount": 50,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-300g for vegetables"
                    }
                ],
                'meal_caloric_density': "unrealistic",
                'data_quality_issues': ["formatting_error", "unrealistic_value"],
                'is_meal_complete': False,
                'has_realistic_portions': False,
                'needs_nutrition_review': True,
                'should_reject_sample': True,
                'extra_info': "Clear data entry error: 1500g of rice is ~7-8 meals worth. Also formatting error (no space between 'Rice' and '1500g')."
            }
        else:  # Mixed item case
            processed_data = {
                **sample_data,
                'items': [
                    {
                        "name": "Stir-fried mixed vegetables",
                        "amount": 100,
                        "unit": "g",
                        "portion_size": "average",
                        "standard_range": "100-300g for vegetables"
                    },
                    {
                        "name": "Stir-fried pork",
                        "amount": 50,
                        "unit": "g",
                        "portion_size": "low",
                        "standard_range": "100-200g for meat"
                    }
                ],
                'meal_caloric_density': "low",
                'data_quality_issues': ["ambiguous_description"],
                'is_meal_complete': True,
                'has_realistic_portions': True,
                'needs_nutrition_review': False,
                'should_reject_sample': False,
                'extra_info': "Split 150g total between vegetables (100g) and meat (50g) based on typical proportions. While individual portions are low, the combination forms a reasonable meal."
            }

        # 4. Create and validate processed entry
        processed_entry = ProcessedFoodDiaryEntry(**processed_data)
        print("\nPROCESSED ENTRY:")
        print(processed_entry)
        
        # 5. Validate with context
        processed_with_context = ProcessedFoodDiaryEntry.model_validate(
            processed_data,
            context=unprocessed_entry.model_dump()
        )
        
        print("\nValidation with context successful!")
        
    except Exception as e:
        print(f"\nError processing test case {i}:")
        print(f"{'='*40}")
        print(f"Error type: {type(e).__name__}")
        print(f"Error message: {str(e)}")
        print(f"{'='*40}")

In [None]:
print(system_prompt)
print(create_user_prompt(unprocessed_entry))


Now let's apply this approach to the Shanghai T2DM data

In [None]:
from scripts.data_preprocess.cgm_data_class import ChineseCGMData
import polars as pl
import os

# Load the data by instantiating the ChineseCGMData class. 
# When calling this class, we can expect the dtype of some columns to be difficult to infer. Python will fall back to strings.
chinese_data = ChineseCGMData(local_base_path)

In [None]:
chinese_data

In [None]:
data_df = chinese_data.get_all_events_cgm_data(before_offset="-1h",after_offset="3h")
data_df

In [None]:
events_df = chinese_data.get_all_events_cgm_data(before_offset="-12h",after_offset="12h")
events_df

In [None]:
demographics_columns = ["Patient Number","Gender (Female=1, Male=2)","Age (years)","Height (m)","Weight (kg)","BMI (kg/m2)"]
patients_demographics = chinese_data.df_metadata.select(pl.col(demographics_columns)).unique()
patients_demographics

In [None]:
labelled_df = events_df.filter(pl.col("event_description")!="data not available",pl.col("event_description").is_not_null())
labelled_df = labelled_df.select(pl.struct([pl.col("event_description"),pl.col("Patient Number"),pl.col("event_type"),pl.col("event_id")]).alias("event_struct")).unique().unnest("event_struct")
food_df = labelled_df.join(patients_demographics,on="Patient Number",how="left")
simpler_names = {"Gender (Female=1, Male=2)":"gender","Age (years)":"age","Height (m)":"height","Weight (kg)":"weight","BMI (kg/m2)":"bmi"}
food_df = food_df.rename(simpler_names)
#change gender to bool with 1=female, 0=male and rename column to is_female
food_df = food_df.with_columns(pl.when(pl.col("gender")==1).then(True).otherwise(False).alias("is_female"))
print(food_df.columns)
food_dicts = food_df.select(pl.struct(pl.all().exclude("Patient Number","gender","event_id")).alias("event_struct"))["event_struct"].to_list()
events_id = food_df["event_id"].to_list()

events_dictionary = dict(zip(events_id,food_dicts))
print(events_dictionary)
print(len(events_dictionary),len(food_dicts))

First, we reformat the data with the UnprocessedFoodDiaryEntry class.

In [None]:
entries = []
errors = []
for food_dict in food_dicts:
    try:
        entry = UnprocessedFoodDiaryEntry(**food_dict)
        entries.append(entry)
    except ValueError as e:
        errors.append(e)
        print(f"Validation error: {e}")
        print(f"Original event_description: {food_dict['event_description']}")

print(f"Total number of successful entries: {len(entries)} over {len(food_dicts)} with {len(errors)} errors")


In [None]:
validated_events_dictionary: Dict[str,UnprocessedFoodDiaryEntry] = {}
for event_id, dict_entry in events_dictionary.items():
    validated_events_dictionary[event_id] = UnprocessedFoodDiaryEntry.model_validate(dict_entry)


Next, run the model on the data.

In [None]:
# Loading functions from the LLM inference module
from scripts.inference.message_models import LLMPromptContext, LLMConfig, StructuredTool, LLMOutput
from scripts.inference.parallel_inference import ParallelAIUtilities, RequestLimits

oai_request_limits = RequestLimits(max_requests_per_minute=5000, max_tokens_per_minute=20000000)
ai_utils = ParallelAIUtilities(oai_request_limits=oai_request_limits,cache_folder=cache_folder)

oai_config = LLMConfig(client="openai",model="gpt-4o-mini",response_format="tool",max_tokens=2000)
structured_tool = StructuredTool(json_schema=ProcessedFoodDiaryEntry.model_json_schema(),schema_name="process_food_diary_entry",schema_description="Process a food diary entry into a structured format")


In [None]:
# Concatenating prompts to data for parallel processing
prompts = [LLMPromptContext(id=event_id,system_string=system_prompt,new_message=create_user_prompt(entry),llm_config=oai_config) for event_id,entry in validated_events_dictionary.items()]
print("Here's an example input for the LLM:")
for message in prompts[0].messages:
    print(message)

In [None]:
# Let's try with just one prompt
single_prompt = prompts[0]
result = await ai_utils.run_parallel_ai_completion([single_prompt])
print(result)

In [None]:
overwrite = False # This avoids reprocessing the same events if the notebook is run multiple times


if not os.path.exists(f"{out_folder}{output_file_name}") and not overwrite:
    completion_results = await ai_utils.run_parallel_ai_completion(prompts)
else:
    completion_results = None


In [None]:
# Display results
completion_results

In [None]:
if completion_results is not None:
    processed_dict : Dict[str,ProcessedFoodDiaryEntry] = {}
    invalid_food_items  : Dict[str,ProcessedFoodDiaryEntry] = {}
    for result in completion_results:
        if result.json_object is not None:
            try:
                unprocessed_input = validated_events_dictionary[result.source_id]
                processed_dict[result.source_id] = ProcessedFoodDiaryEntry.model_validate(result.json_object.object,context=unprocessed_input.model_dump())
                if  processed_dict[result.source_id].should_reject_sample:
                    invalid_food_items[result.source_id] = processed_dict[result.source_id]
            except ValueError as e:
                errors.append(e)
                print(f"Validation error: {e}")
                print(f"Original event_description: {validated_events_dictionary[result.source_id].event_description}")
                print(f"object_pre_validation: {result.json_object.object}")
        else:
            errors.append(f"No JSON object found for event {result.source_id}")
    for id,item in invalid_food_items.items():
        print(f"Event ID: {id}")
        print(f"Event Description: {item.event_description}")
        print(f"Extra Info: {item.extra_info}")
        print("\n")
    print(f"Total number of successful entries: {len(processed_dict)} over {len(validated_events_dictionary)} with {len(errors)} errors and {len(invalid_food_items)} invalid food items")



In [None]:
# Saving the results in a parquet file
if completion_results is not None:
    model_dump_dict : Dict[str,Dict] = {}
    for id,item in processed_dict.items():
        model_dump_dict[id] = item.model_dump()

    input_dict = {"event_id":list(model_dump_dict.keys()), "event_analysis_struct":list(model_dump_dict.values())}

    output_df = pl.DataFrame(input_dict).unnest("event_analysis_struct")
    output_df.write_parquet(f"{out_folder}{output_file_name}")


Analyze the LLM output, as it is done in `main_analyses.ipynb`.

In [None]:
# Load the saved file
loaded_df = pl.read_parquet(f"{out_folder}{output_file_name}")
loaded_df["should_reject_sample"].value_counts()

In [None]:
calories_df = loaded_df.filter(pl.col("should_reject_sample") == False).select(pl.col("event_id"),pl.col("meal_caloric_density"))
events_with_calories = events_df.join(calories_df,on="event_id",how="left").filter(pl.col("meal_caloric_density").is_in(["low","average"]))
print(events_with_calories.shape)
print(events_with_calories["meal_caloric_density"].value_counts())
events_with_calories


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import polars as pl
from typing import Optional
from matplotlib.gridspec import GridSpec

def calculate_hourly_p_values(group1_data, group2_data, standard_times):
    """
    Calculate hourly p-values for two groups of CGM data aligned to event time.
    
    Args:
        group1_data: List of (times, values) tuples for group 1
        group2_data: List of (times, values) tuples for group 2
        standard_times: Array of standardized time points
    
    Returns:
        Array of p-values for each hour relative to the event
    """
    hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
    p_values = []
    
    for hour in hours:
        start_time = hour * 60
        end_time = (hour + 1) * 60
        
        group1_values = []
        group2_values = []
        
        for times, values in group1_data:
            times = np.array(times)
            values = np.array(values)
            mask = (times >= start_time) & (times < end_time)
            group1_values.extend(values[mask])
        
        for times, values in group2_data:
            times = np.array(times)
            values = np.array(values)
            mask = (times >= start_time) & (times < end_time)
            group2_values.extend(values[mask])
        
        if len(group1_values) > 0 and len(group2_values) > 0:
            _, p_value = stats.ttest_ind(group1_values, group2_values)
            p_values.append(p_value)
        else:
            p_values.append(np.nan)
    
    return np.array(p_values)



In [None]:
def plot_aligned_cgm_events_by_group(
    chinese_data: ChineseCGMData,
    events_df: pl.DataFrame,
    group1_filter: pl.Expr,
    group2_filter: Optional[pl.Expr] = None,
    group_names: tuple[str, str] = ("Group 1", "Group 2"),
    show_individual: bool = False,
    split_meals: bool = False,
    split_by_calories: bool = True,
    alpha_individual: float = 0.1,
    figsize: tuple = (15, 10),
    display: bool = True,
    variable_name: str = "Variable"
) -> tuple[plt.Figure, list[plt.Axes]]:
    """
    Plot aligned CGM curves for events, split by clinical groups and caloric density.
    Includes p-value plots for both between-group and within-group comparisons.
    """
    # Get patient groups
    group1_patients = chinese_data.df_metadata.filter(group1_filter)["Patient Number"].to_list()
    if group2_filter is None:
        group2_patients = chinese_data.df_metadata.filter(~group1_filter)["Patient Number"].to_list()
    else:
        group2_patients = chinese_data.df_metadata.filter(group2_filter)["Patient Number"].to_list()
    
    # Adjust figure size for split view
    if split_meals and split_by_calories:
        figsize = (15, 25)  # Larger figure for all comparisons
    
    if split_meals:
        fig = plt.figure(figsize=figsize)
        if split_by_calories:
            fig = plt.figure(figsize=(15, 25))  # Increased height for additional plots
            
            # 10 rows: main title, legend space, between-group title, 3 CGM plots, 3 between-group p-value plots, 
            # within-group title, 3 within-group p-value plots
            gs = GridSpec(10, 3, figure=fig, 
                        height_ratios=[0.2, 0.1, 0.2, 3, 3, 1, 1, 0.2, 1, 1])
            
            # Create title and section headers
            title_ax = fig.add_subplot(gs[0, :])
            between_groups_title_ax = fig.add_subplot(gs[2, :])  # Moved up before its plots
            within_groups_title_ax = fig.add_subplot(gs[7, :])   # Moved up before its plots
            title_ax.axis('off')
            between_groups_title_ax.axis('off')
            within_groups_title_ax.axis('off')
            
            # Create main plots
            cgm_axes_low = [fig.add_subplot(gs[3, i]) for i in range(3)]
            cgm_axes_avg = [fig.add_subplot(gs[4, i]) for i in range(3)]
            
            # Between-group p-value plots
            p_axes_low = [fig.add_subplot(gs[5, i]) for i in range(3)]
            p_axes_avg = [fig.add_subplot(gs[6, i]) for i in range(3)]
            
            # Within-group p-value plots
            p_axes_within_group1 = [fig.add_subplot(gs[8, i]) for i in range(3)]
            p_axes_within_group2 = [fig.add_subplot(gs[9, i]) for i in range(3)]
            
            cgm_axes = {'low': cgm_axes_low, 'average': cgm_axes_avg}
            p_axes = {'low': p_axes_low, 'average': p_axes_avg}
            p_axes_within = {group_names[0]: p_axes_within_group1, 
                            group_names[1]: p_axes_within_group2}

   
            within_groups_title_ax.text(0.5, 0.5, 
                "Within-Group Comparisons (Low vs Average calories within each group and meal type)", 
                ha='center', va='center', fontsize=10)
        else:
            gs = GridSpec(3, 3, figure=fig, height_ratios=[0.2, 3, 1])
            title_ax = fig.add_subplot(gs[0, :])
            title_ax.axis('off')
            cgm_axes = [fig.add_subplot(gs[1, i]) for i in range(3)]
            p_axes = [fig.add_subplot(gs[2, i]) for i in range(3)]
            
        meal_categories = ['Breakfast', 'Lunch', 'Dinner']
    else:
        fig = plt.figure(figsize=figsize)
        gs = GridSpec(3, 1, figure=fig, height_ratios=[0.2, 3, 1])
        title_ax = fig.add_subplot(gs[0])
        title_ax.axis('off')
        ax_cgm = fig.add_subplot(gs[1])
        ax_p = fig.add_subplot(gs[2])
        cgm_axes = [ax_cgm]
        p_axes = [ax_p]
        meal_categories = [None]

    def wrap_title(text, width=40):
        """Wrap title text to multiple lines if too long"""
        import textwrap
        return '\n'.join(textwrap.wrap(text, width=width))

    # Create title with cutoff value
    cutoff_value = chinese_data.df_metadata[variable_name].median()
    n_group1 = len(group1_patients)
    n_group2 = len(group2_patients)
    title_text = wrap_title(
        f'CGM Patterns by {variable_name}: High {variable_name} (>{cutoff_value:.1f}, n={n_group1}) vs '
        f'Low {variable_name} (≤{cutoff_value:.1f}, n={n_group2})'
    )
    
    title_ax.text(0.5, 0.5, title_text, ha='center', va='center', fontsize=12)


    # Prepare data structures
    caloric_densities = ['low', 'average'] if split_by_calories else [None]
    groups_data = {
        cal_dens: {
            group_name: {meal: [] for meal in meal_categories}
            for group_name in group_names
        }
        for cal_dens in caloric_densities
    }
    
    groups_statistics = {
        cal_dens: {
            group: {meal: {'means': [], 'medians': [], 'sds': []} 
                   for meal in meal_categories}
            for group in group_names
        }
        for cal_dens in caloric_densities
    }
        # Track global ranges
    global_time_min = float('inf')
    global_time_max = float('-inf')
    
    # First pass: determine global time range
    for row in events_df.iter_rows(named=True):
        cgm_values = row['CGM']
        is_before = row['is_before_food_event']
        onset_idx = np.where(np.array(is_before) == False)[0][0]
        timestamps = [(i - onset_idx) * 15 for i in range(len(cgm_values))]
        
        global_time_min = min(global_time_min, min(timestamps))
        global_time_max = max(global_time_max, max(timestamps))
    
    # Create standard time points
    standard_times = np.arange(global_time_min, global_time_max + 15, 15)
    
    # Second pass: sort events into groups
    for row in events_df.iter_rows(named=True):
        patient = row['Patient Number']
        current_category = row['event_type'] if split_meals else None
        caloric_density = row.get('meal_caloric_density', None)
        
        if split_by_calories and caloric_density not in ['low', 'average']:
            continue
            
        if patient in group1_patients:
            group = group_names[0]
        elif patient in group2_patients:
            group = group_names[1]
        else:
            continue
        
        cgm_values = row['CGM']
        is_before = row['is_before_food_event']
        onset_idx = np.where(np.array(is_before) == False)[0][0]
        timestamps = [(i - onset_idx) * 15 for i in range(len(cgm_values))]
        
        cal_dens = caloric_density if split_by_calories else None
        groups_data[cal_dens][group][current_category].append((timestamps, cgm_values))
    
    # Calculate statistics and determine global CGM range
    global_cgm_min = float('inf')
    global_cgm_max = float('-inf')
    
    for cal_dens in caloric_densities:
        for group in group_names:
            for meal_category in meal_categories:
                all_series = groups_data[cal_dens][group][meal_category]
                if not all_series:
                    continue
                
                values_matrix = np.full((len(all_series), len(standard_times)), np.nan)
                
                for idx, (times, values) in enumerate(all_series):
                    interp_values = np.interp(standard_times, times, values)
                    values_matrix[idx] = interp_values
                
                mean_curve = np.nanmean(values_matrix, axis=0)
                median_curve = np.nanmedian(values_matrix, axis=0)
                std_curve = np.nanstd(values_matrix, axis=0)
                
                groups_statistics[cal_dens][group][meal_category]['means'] = mean_curve
                groups_statistics[cal_dens][group][meal_category]['medians'] = median_curve
                groups_statistics[cal_dens][group][meal_category]['sds'] = std_curve
                
                if show_individual:
                    global_cgm_min = min(global_cgm_min, np.nanmin(values_matrix))
                    global_cgm_max = max(global_cgm_max, np.nanmax(values_matrix))
                else:
                    global_cgm_min = min(global_cgm_min, np.nanmin(mean_curve - std_curve))
                    global_cgm_max = max(global_cgm_max, np.nanmax(mean_curve + std_curve))
    
    # Add padding to CGM range
    cgm_range = global_cgm_max - global_cgm_min
    global_cgm_min -= cgm_range * 0.05
    global_cgm_max += cgm_range * 0.05
    
    # Colors for groups
    colors = {group_names[0]: 'blue', group_names[1]: 'red'}
    
    # Plot each caloric density category
    for cal_dens in caloric_densities:
        current_cgm_axes = cgm_axes[cal_dens] if split_by_calories and split_meals else cgm_axes
        current_p_axes = p_axes[cal_dens] if split_by_calories and split_meals else p_axes
        
        for ax_idx, (ax_cgm, ax_p, meal_category) in enumerate(zip(current_cgm_axes, current_p_axes, meal_categories)):
            group1_data = groups_data[cal_dens][group_names[0]][meal_category]
            group2_data = groups_data[cal_dens][group_names[1]][meal_category]
            
            # Get event counts for title
            group1_count = len(group1_data)
            group2_count = len(group2_data)
            total_events = group1_count + group2_count
            
            for group in group_names:
                all_series = groups_data[cal_dens][group][meal_category]
                if not all_series:
                    continue
                
                color = colors[group]
                
                if show_individual:
                    for times, values in all_series:
                        interp_values = np.interp(standard_times, times, values)
                        ax_cgm.plot(standard_times, interp_values, '-', 
                                  color=color, alpha=alpha_individual)
                
                stats = groups_statistics[cal_dens][group][meal_category]
                mean_curve = stats['means']
                std_curve = stats['sds']
                
                ax_cgm.plot(standard_times, mean_curve, '-', 
                           color=color, linewidth=2, 
                           label=f'{group} (n={len(all_series)})')
                ax_cgm.fill_between(standard_times, 
                                  mean_curve - std_curve, 
                                  mean_curve + std_curve,
                                  color=color, alpha=0.2)
            
            # Plot between-group p-values
            if group1_data and group2_data:
                p_values = calculate_hourly_p_values(group1_data, group2_data, standard_times)
                hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
                
                significance = 1 - p_values
                
                ax_p.plot(hours * 60, significance, 'k-', linewidth=2, label='Significance')
                ax_p.axhline(y=0.95, color='r', linestyle='--', alpha=0.5, label='p=0.05')
                ax_p.fill_between(hours * 60, 0, significance, alpha=0.2)
                
                ax_p.set_ylim(0.8, 1)
                ax_p.grid(True, alpha=0.3)
            
            # Common plot elements
            ax_cgm.axvline(x=0, color='k', linestyle='--', label='Event Onset')
            ax_cgm.set_ylim(global_cgm_min, global_cgm_max)
            ax_cgm.set_xlim(standard_times[0], standard_times[-1])
            
            for ax in [ax_cgm, ax_p]:
                ax.set_xlabel('Minutes from Event')
                xticks = np.arange(standard_times[0], standard_times[-1] + 1, 120)
                ax.set_xticks(xticks)
                ax.set_xticklabels([str(int(x)) for x in xticks])
            
            if ax_idx == 0:
                ax_cgm.set_ylabel('CGM Value')
                ax_p.set_ylabel('Statistical Significance (1-p)')
            
            # Create title with event counts
            if meal_category:
                cal_dens_text = f" ({cal_dens.capitalize()} Calories)" if cal_dens else ""
                title = wrap_title(
                    f'{meal_category} Events{cal_dens_text} High: {group1_count}, Low: {group2_count}'
                )
            else:
                title = wrap_title(
                    f'All Events High: {group1_count}, Low: {group2_count}'
                )
            
            ax_cgm.set_title(title)
            ax_cgm.grid(True, alpha=0.3)
            ax_p.grid(True, alpha=0.3)

    # Add within-group p-value plots if using split view
    if split_by_calories and split_meals:
        for group in group_names:
            for ax_idx, (ax_p, meal_category) in enumerate(zip(p_axes_within[group], meal_categories)):
                low_cal_data = groups_data['low'][group][meal_category]
                avg_cal_data = groups_data['average'][group][meal_category]
                
                if low_cal_data and avg_cal_data:
                    p_values = calculate_hourly_p_values(low_cal_data, avg_cal_data, standard_times)
                    hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
                    
                    significance = 1 - p_values
                    
                    ax_p.plot(hours * 60, significance, 'k-', linewidth=2)
                    ax_p.axhline(y=0.95, color='r', linestyle='--', alpha=0.5)
                    ax_p.fill_between(hours * 60, 0, significance, alpha=0.2)
                    
                    ax_p.set_ylim(0.8, 1)
                    ax_p.grid(True, alpha=0.3)
                    
                    if ax_idx == 0:
                        ax_p.set_ylabel(f'{group}\nSignificance (1-p)')
                    ax_p.set_xlabel('Minutes from Event')
                    
                    # Add title with sample sizes
                    n_low = len(low_cal_data)
                    n_avg = len(avg_cal_data)
                    ax_p.set_title(f'{meal_category}\n(Low: {n_low}, Avg: {n_avg})')
                else:
                    print(f"Skipping within-group p-value calculation for {group}, {meal_category}")

    # Create a single legend
    handles, labels = (cgm_axes['low'] if split_by_calories and split_meals else cgm_axes)[0].get_legend_handles_labels()
    p_handles, p_labels = (p_axes['low'] if split_by_calories and split_meals else p_axes)[0].get_legend_handles_labels()
    
    unique_labels = []
    unique_handles = []
    for handle, label in zip(handles + p_handles, labels + p_labels):
        if label not in unique_labels:
            unique_labels.append(label)
            unique_handles.append(handle)
    
    # Move legend between title and plots
    if split_by_calories and split_meals:
        legend_y = 0.88
    else:
        legend_y = 0.85
    
    fig.legend(unique_handles, unique_labels, 
              loc='upper center', bbox_to_anchor=(0.5, legend_y), ncol=5)

    # Adjust layout
    plt.tight_layout()
    if split_by_calories and split_meals:
        plt.subplots_adjust(top=0.92, bottom=0.05, hspace=0.6)
    else:
        plt.subplots_adjust(top=0.92)

    if display:
        plt.show()
    return fig

In [None]:
# Get events data

# Plot with groups based on some clinical variable
fig = plot_aligned_cgm_events_by_group(
    chinese_data,
    events_with_calories,
    group1_filter=pl.col("HbA1c (mmol/mol)") > chinese_data.df_metadata["HbA1c (mmol/mol)"].median(),
    group_names=("High HbA1c", "Low HbA1c"),
    split_meals=True,
    variable_name="HbA1c (mmol/mol)",
    
)

In [None]:
def calculate_cgm_statistics(
    chinese_data: ChineseCGMData,
    group1_filter: pl.Expr,
    group2_filter: Optional[pl.Expr] = None,
) -> tuple[dict[str, list[dict]], int]:
    """
    Calculate hourly statistics between two groups.
    Returns hourly statistics and number of significant hours.
    """
    # Get daily data
    cgm_resolution_df = chinese_data.get_cgm_data_at_resolution("1d")
    
    # Filter patients based on metadata
    group1_patients = chinese_data.df_metadata.filter(group1_filter)["Patient Number"].to_list()
    if group2_filter is None:
        group2_patients = chinese_data.df_metadata.filter(~group1_filter)["Patient Number"].to_list()
    else:
        group2_patients = chinese_data.df_metadata.filter(group2_filter)["Patient Number"].to_list()
    
    def process_patient_data(patient_df):
        daily_data = {hour: [] for hour in range(24)}
        for row in patient_df.iter_rows(named=True):
            for time, value in zip(row["cgm_time_stamp"], row["CGM"]):
                hour = time.hour
                daily_data[hour].append(value)
                
        hourly_means = []
        for hour in range(24):
            if daily_data[hour]:
                hourly_means.append(np.mean(daily_data[hour]))
            else:
                hourly_means.append(np.nan)
        return hourly_means
    
    # Process each group
    group1_curves = {}
    group2_curves = {}
    
    for patient in group1_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group1_curves[patient] = process_patient_data(patient_df)
    
    for patient in group2_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group2_curves[patient] = process_patient_data(patient_df)
    
    # Convert to arrays for statistics
    group1_array = np.array(list(group1_curves.values()))
    group2_array = np.array(list(group2_curves.values()))
    
    # Calculate hourly statistics and t-tests
    hourly_stats = []
    for hour in range(24):
        g1_hour = group1_array[:, hour]
        g2_hour = group2_array[:, hour]
        
        # Remove NaN values for t-test
        g1_clean = g1_hour[~np.isnan(g1_hour)]
        g2_clean = g2_hour[~np.isnan(g2_hour)]
        
        if len(g1_clean) > 0 and len(g2_clean) > 0:
            t_stat, p_val = stats.ttest_ind(g1_clean, g2_clean)
        else:
            t_stat, p_val = np.nan, np.nan
            
        hourly_stats.append({
            'hour': hour,
            'g1_mean': np.nanmean(g1_hour),
            'g1_std': np.nanstd(g1_hour),
            'g2_mean': np.nanmean(g2_hour),
            'g2_std': np.nanstd(g2_hour),
            'p_value': p_val
        })
    
    significant_hours = sum(1 for stat in hourly_stats if stat['p_value'] < 0.05)
    
    return {
        'hourly_stats': hourly_stats,
        'group1_curves': group1_curves,
        'group2_curves': group2_curves,
        'group_sizes': (len(group1_curves), len(group2_curves))
    }, significant_hours

In [None]:
def get_significant_vars(
    chinese_data: ChineseCGMData,
    excluded_vars: list[str] = None
) -> tuple[list[str], dict[str, int]]:
    """
    Take the variable names (strings) in order of statistical significance
    
    Args:
        chinese_data: ChineseCGMData object containing the data
        excluded_vars: List of numeric variables to exclude from analysis
    
    Returns:
        list of variables ordered by significance
    """
    
    # Define numeric columns
    numeric_cols = [
        "Fasting Plasma Glucose (mg/dl)",
        "2-hour Postprandial Plasma Glucose (mg/dl)",
        "Fasting C-peptide (nmol/L)",
        "2-hour Postprandial C-peptide (nmol/L)",
        "Fasting Insulin (pmol/L)",
        "2-hour Postprandial insulin (pmol/L)",
        "HbA1c (mmol/mol)",
        "Glycated Albumin (%)",
        "Total Cholesterol (mmol/L)",
        "Triglyceride (mmol/L)",
        "High-Density Lipoprotein Cholesterol (mmol/L)",
        "Low-Density Lipoprotein Cholesterol (mmol/L)",
        "Creatinine (umol/L)",
        "Estimated Glomerular Filtration Rate (ml/min/1.73m2)",
        "Uric Acid (mmol/L)",
        "Blood Urea Nitrogen (mmol/L)",
        "Age (years)",
        "Height (m)",
        "Weight (kg)",
        "BMI (kg/m2)",
    ]
    
    # Remove excluded variables
    if excluded_vars:
        numeric_cols = [col for col in numeric_cols if col not in excluded_vars]
    
    # Filter to only existing columns with non-null values
    valid_cols = []
    for col in numeric_cols:
        if col in chinese_data.df_metadata.columns:
            if chinese_data.df_metadata[col].null_count() < len(chinese_data.df_metadata):
                valid_cols.append(col)
    
    # Calculate statistics for all variables
    var_significance = {}
    var_stats = {}
    for var in valid_cols:
        median_val = chinese_data.df_metadata[var].median()
        stats, significant_hours = calculate_cgm_statistics(
            chinese_data,
            group1_filter=pl.col(var) > median_val
        )
        var_significance[var] = significant_hours
        var_stats[var] = stats
    
    # Sort variables by significance
    sorted_vars = sorted(var_significance.items(), key=lambda x: x[1], reverse=True)
    ordered_vars = [v[0] for v in sorted_vars]

    return ordered_vars


In [None]:

ordered_vars = get_significant_vars(chinese_data)
for column in ordered_vars:
    plot_aligned_cgm_events_by_group(
    chinese_data,
    events_with_calories,
    group1_filter=pl.col(column) > chinese_data.df_metadata[column].median(),
    group_names=(f"High {column}", f"Low {column}"),
    split_meals=True,
    variable_name=column,
    split_by_calories=True
)