    This notebook contains the code for the Shanghai T2DM data analysis, as described in Section 4 and 5 of M5 (WP3) "EXplainable ML Model development" by HK3Lab:

    - Observing and explaining time-dependent relationships between clinical variables and daily CGM time series (Section 4)
    - LLM-driven food categorization for CGM event analysis (Section 5)

    Running the analyses assumes you have already installed all requiredments in `./requirements.txt` and downloaded the data in the `./data` folder (see README.md for instructions).

### OBSERVING AND EXPLAINING TIME-DEPENDENT RELATIONSHIPS BETWEEN CLINICAL VARIABLES AND DAILY CGM TIME SERIES (Section 4)

In [None]:
from scripts.data_preprocess.cgm_data_class import ChineseCGMData
from typing import Optional
import polars as pl


In [None]:
local_base_path = "./data"

Load the data by instantiating the ChineseCGMData class. 
When calling this class, we can expect the dtype of some columns to be difficult to infer. Python will fall back to strings.

In [None]:
chinese_data = ChineseCGMData(local_base_path)

In [None]:
chinese_data

In [None]:
# Subjects are string IDs, e.g.,
random_id = chinese_data.get_random_subject_id()
random_id

In [None]:
# Each subject has these columns
chinese_data.df_metadata.columns

In [None]:
# Subjects are stored in a dataframe
chinese_data.df_metadata

In [None]:
demographics_columns = ["Patient Number","Gender (Female=1, Male=2)","Age (years)","Height (m)","Weight (kg)","BMI (kg/m2)"]
patients_demographics = chinese_data.df_metadata.select(pl.col(demographics_columns)).unique()
patients_demographics


In [None]:
# Imports for visualization and statistical testing
from matplotlib.gridspec import GridSpec
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

The function below serves to
- visualize the distribution of numeric metadata variables from the Shanghai T2DM dataset using histograms with kernel density estimates;
- compute and display a correlation heatmap to highlight relationships between variables.

In [None]:
def visualize_numeric_distributions(chinese_data=ChineseCGMData) -> None:
    """
    Create a subplot grid of histograms with kernel density estimates 
    for all numeric columns in the metadata.
    """
    
    # Get numeric columns
    numeric_cols = [
        "Fasting Plasma Glucose (mg/dl)",
        "2-hour Postprandial Plasma Glucose (mg/dl)",
        "Fasting C-peptide (nmol/L)",
        "2-hour Postprandial C-peptide (nmol/L)",
        "Fasting Insulin (pmol/L)",
        "2-hour Postprandial insulin (pmol/L)",
        "HbA1c (mmol/mol)",
        "Glycated Albumin (%)",
        "Total Cholesterol (mmol/L)",
        "Triglyceride (mmol/L)",
        "High-Density Lipoprotein Cholesterol (mmol/L)",
        "Low-Density Lipoprotein Cholesterol (mmol/L)",
        "Creatinine (umol/L)",
        "Estimated Glomerular Filtration Rate (ml/min/1.73m2)",
        "Uric Acid (mmol/L)",
        "Blood Urea Nitrogen (mmol/L)",
        "Age (years)",
        "Height (m)",
        "Weight (kg)",
        "BMI (kg/m2)",
    ]
    
    # Filter existing columns
    numeric_cols = [col for col in numeric_cols if col in chinese_data.df_metadata.columns]
    
    # Calculate grid dimensions
    n_cols = 3
    n_rows = (len(numeric_cols) + n_cols - 1) // n_cols
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows))
    fig.suptitle('Distribution of Numeric Variables in Metadata', y=1.02, fontsize=16)
    
    # Flatten axes for easier iteration
    axes_flat = axes.flatten()
    
    for idx, (ax, col) in enumerate(zip(axes_flat, numeric_cols)):
        # Get data without nulls
        data = chinese_data.df_metadata[col].drop_nulls().to_numpy()
        
        if len(data) > 0:
            # Calculate basic statistics
            mean_val = np.mean(data)
            median_val = np.median(data)
            std_val = np.std(data)
            
            # Create histogram
            counts, bins, _ = ax.hist(data, bins='auto', density=True, 
                                    alpha=0.6, color='skyblue')
            
            # Kernel density estimation
            kde_x = np.linspace(min(data), max(data), 200)
            kde = stats.gaussian_kde(data)
            ax.plot(kde_x, kde(kde_x), 'r-', lw=2, label='KDE')
            
            # Add vertical lines for mean and median
            ax.axvline(mean_val, color='red', linestyle='--', alpha=0.5, 
                      label=f'Mean: {mean_val:.1f}')
            ax.axvline(median_val, color='green', linestyle='--', alpha=0.5,
                      label=f'Median: {median_val:.1f}')
            
            # Calculate missing value percentage
            missing_pct = (chinese_data.df_metadata[col].null_count() / 
                         len(chinese_data.df_metadata)) * 100
            
            # Add title with stats
            ax.set_title(f'{col}\n'
                        f'(n={len(data)}, {missing_pct:.1f}% missing)\n'
                        f'std: {std_val:.1f}', 
                        fontsize=10)
            
            # Add legend
            ax.legend(fontsize='small')
            
            # Rotate x-axis labels if needed
            ax.tick_params(axis='x', rotation=45)
        else:
            ax.text(0.5, 0.5, f'No valid data for {col}', 
                   ha='center', va='center')
            ax.set_xticks([])
            ax.set_yticks([])
    
    # Remove empty subplots if any
    for idx in range(len(numeric_cols), len(axes_flat)):
        fig.delaxes(axes_flat[idx])
    
    # Adjust layout
    plt.tight_layout()
    
    # Print summary statistics
    print("\nSummary Statistics:")
    stats_df = (chinese_data.df_metadata
                .select(numeric_cols)
                .describe()
                .with_columns(pl.col("*").cast(pl.Float64, strict=False)))
    
    print(stats_df)
    

    
    # Calculate correlations using correct Polars syntax
    corr_values = np.zeros((len(numeric_cols), len(numeric_cols)))
    for i, col1 in enumerate(numeric_cols):
        for j, col2 in enumerate(numeric_cols):
            # Calculate correlation using pl.corr correctly
            corr = (
                chinese_data.df_metadata
                .select(pl.corr(col1, col2))
                .item()
            )
            corr_values[i, j] = round(float(corr), 2) if corr is not None else np.nan
    
    # Create correlation heatmap
    plt.figure(figsize=(12, 10))
    mask = np.triu(np.ones_like(corr_values), k=1)  # Mask upper triangle
    sns.heatmap(corr_values,
                annot=True,
                cmap='RdBu_r',
                center=0,
                fmt='.2f',
                xticklabels=numeric_cols,
                yticklabels=numeric_cols,
                mask=mask)  # Apply mask to show only lower triangle
    plt.title('Correlation Matrix of Numeric Variables')
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

In [None]:
visualize_numeric_distributions(chinese_data)

In [None]:
chinese_data.df_metadata["2-hour Postprandial Plasma Glucose (mg/dl)"]

In [None]:
daily_df = chinese_data.get_cgm_data_at_resolution(resolution="3h")
daily_df


In [None]:
def calculate_significance_metrics(hourly_stats: list[dict]) -> tuple[int, float]:
    """
    Calculate number of significant hours and their ratio.
    
    Returns:
        tuple: (number of significant hours, ratio of significant hours)
    """
    significant_hours = sum(1 for stat in hourly_stats if stat['p_value'] < 0.05)
    ratio = significant_hours / 24
    return significant_hours, ratio


def visualize_cgm_patterns_by_group(
    chinese_data: ChineseCGMData,
    group1_filter: pl.Expr,
    group2_filter: Optional[pl.Expr] = None,
    show_individual: bool = False,
    group_names: tuple[str, str] = ("Group 1", "Group 2"),
    display: bool = True,
    title: str = 'Daily CGM Patterns by Group with Hourly Statistical Comparison'
) -> tuple[plt.Figure, list[plt.Axes]]:
    """
    Visualize daily CGM patterns comparing two groups defined by metadata filters.
    Includes hourly statistical comparisons between groups.
    
    Args:
        chinese_data: ChineseCGMData object containing the data
        group1_filter: Polars expression to filter metadata for group 1
        group2_filter: Polars expression to filter metadata for group 2 (if None, uses NOT group1_filter)
        show_individual: If True, shows individual subject curves
        group_names: Names for the two groups in the legend
        display: Whether to display the plot immediately
        title: Title for the plot
        
    Returns:
        tuple: (figure, list of axes)
    """
    
    # Get daily data
    cgm_resolution_df = chinese_data.get_cgm_data_at_resolution("1d")
    
    # Filter patients based on metadata
    group1_patients = chinese_data.df_metadata.filter(group1_filter)["Patient Number"].to_list()
    if group2_filter is None:
        group2_patients = chinese_data.df_metadata.filter(~group1_filter)["Patient Number"].to_list()
    else:
        group2_patients = chinese_data.df_metadata.filter(group2_filter)["Patient Number"].to_list()
    
    def process_patient_data(patient_df):
        """Process individual patient data into hourly bins"""
        daily_data = {hour: [] for hour in range(24)}
        
        for row in patient_df.iter_rows(named=True):
            for time, value in zip(row["cgm_time_stamp"], row["CGM"]):
                hour = time.hour
                daily_data[hour].append(value)
        
        # Calculate mean for each hour
        hourly_means = []
        for hour in range(24):
            if daily_data[hour]:
                hourly_means.append(np.mean(daily_data[hour]))
            else:
                hourly_means.append(np.nan)
                
        return hourly_means
    
    # Process each group
    group1_curves = {}
    group2_curves = {}
    
    for patient in group1_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group1_curves[patient] = process_patient_data(patient_df)
    
    for patient in group2_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group2_curves[patient] = process_patient_data(patient_df)
    
    # Convert to arrays for easier statistical analysis
    hours = np.arange(24)
    group1_array = np.array(list(group1_curves.values()))
    group2_array = np.array(list(group2_curves.values()))
    
    # Calculate hourly statistics and t-tests
    hourly_stats = []
    for hour in range(24):
        g1_hour = group1_array[:, hour]
        g2_hour = group2_array[:, hour]
        
        # Remove NaN values for t-test
        g1_clean = g1_hour[~np.isnan(g1_hour)]
        g2_clean = g2_hour[~np.isnan(g2_hour)]
        
        if len(g1_clean) > 0 and len(g2_clean) > 0:
            t_stat, p_val = stats.ttest_ind(g1_clean, g2_clean)
        else:
            t_stat, p_val = np.nan, np.nan
            
        hourly_stats.append({
            'hour': hour,
            'g1_mean': np.nanmean(g1_hour),
            'g1_std': np.nanstd(g1_hour),
            'g2_mean': np.nanmean(g2_hour),
            'g2_std': np.nanstd(g2_hour),
            'p_value': p_val
        })
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 12), height_ratios=[3, 1])
    
    # Plot CGM patterns
    if show_individual:
        # Plot individual curves for both groups
        for curves, color in [(group1_curves, 'blue'), (group2_curves, 'red')]:
            for patient_data in curves.values():
                ax1.plot(hours, patient_data, alpha=0.3, color=color, linewidth=1)
        
        # Add representative lines for legend
        ax1.plot([], [], color='blue', alpha=0.8, label=f'{group_names[0]} (n={len(group1_curves)})')
        ax1.plot([], [], color='red', alpha=0.8, label=f'{group_names[1]} (n={len(group2_curves)})')
    
    # Always show means
    g1_mean = np.nanmean(group1_array, axis=0)
    g1_sem = np.nanstd(group1_array, axis=0) / np.sqrt(len(group1_curves))
    g2_mean = np.nanmean(group2_array, axis=0)
    g2_sem = np.nanstd(group2_array, axis=0) / np.sqrt(len(group2_curves))
    
    ax1.plot(hours, g1_mean, color='blue', linewidth=2, 
             label=f'{group_names[0]} Mean' if show_individual else f'{group_names[0]} (n={len(group1_curves)})')
    ax1.fill_between(hours, g1_mean - 1.96*g1_sem, g1_mean + 1.96*g1_sem, color='blue', alpha=0.2)
    
    ax1.plot(hours, g2_mean, color='red', linewidth=2,
             label=f'{group_names[1]} Mean' if show_individual else f'{group_names[1]} (n={len(group2_curves)})')
    ax1.fill_between(hours, g2_mean - 1.96*g2_sem, g2_mean + 1.96*g2_sem, color='red', alpha=0.2)
    
    # Plot p-values (1 - p_value for better visualization)
    p_values = [stat['p_value'] for stat in hourly_stats]
    significance = [1 - p for p in p_values]  # Transform to show significance
    ax2.plot(hours, significance, 'k-', linewidth=2)
    ax2.axhline(y=0.95, color='r', linestyle='--', alpha=0.5, label='p=0.05')
    ax2.fill_between(hours, 0, significance, alpha=0.2)
    
    # Set p-value plot styling
    ax2.set_xlabel('Hour of Day')
    ax2.set_ylabel('Statistical Significance (1-p)')
    ax2.set_ylim(0.8, 1)
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # Customize plots
    ax1.set_xlabel('Hour of Day')
    ax1.set_ylabel('CGM Value (mg/dL)')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Add meal time indicators
    meal_times = [(6, 'Breakfast'), (12, 'Lunch'), (18, 'Dinner')]
    for hour, meal in meal_times:
        ax1.axvline(x=hour, color='gray', linestyle='--', alpha=0.3)
        # Place text at bottom of plot instead of top
        ax1.text(hour, ax1.get_ylim()[0], meal, 
                rotation=90, va='top', ha='right')
    
    ax2.set_xlabel('Hour of Day')
    ax2.set_ylabel('-log10(p-value)')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # Calculate significance metrics
    significant_hours, ratio = calculate_significance_metrics(hourly_stats)
    
    # Update title to include significance ratio
    plt.suptitle(f"{title}\nSignificant hours: {significant_hours}/24 ({ratio:.2%})")
    plt.tight_layout()
    
    if display:
        plt.show()
        # Print key statistics
        print("\nGroup Statistics:")
        print(f"{group_names[0]}: {len(group1_curves)} subjects")
        print(f"{group_names[1]}: {len(group2_curves)} subjects")
        
        print("\nSignificant Differences (p < 0.05):")
        for stat in hourly_stats:
            if stat['p_value'] < 0.05:
                print(f"Hour {stat['hour']:02d}:00 - "
                      f"Group 1: {stat['g1_mean']:.1f}±{stat['g1_std']:.1f}, "
                      f"Group 2: {stat['g2_mean']:.1f}±{stat['g2_std']:.1f}, "
                      f"p={stat['p_value']:.4f}")
    
    
    
    
    return fig, [ax1, ax2], significant_hours



In [None]:
def calculate_cgm_statistics(
    chinese_data: ChineseCGMData,
    group1_filter: pl.Expr,
    group2_filter: Optional[pl.Expr] = None,
) -> tuple[dict[str, list[dict]], int]:
    """
    Calculate hourly statistics between two groups.
    Returns hourly statistics and number of significant hours.
    """
    # Get daily data
    cgm_resolution_df = chinese_data.get_cgm_data_at_resolution("1d")
    
    # Filter patients based on metadata
    group1_patients = chinese_data.df_metadata.filter(group1_filter)["Patient Number"].to_list()
    if group2_filter is None:
        group2_patients = chinese_data.df_metadata.filter(~group1_filter)["Patient Number"].to_list()
    else:
        group2_patients = chinese_data.df_metadata.filter(group2_filter)["Patient Number"].to_list()
    
    def process_patient_data(patient_df):
        daily_data = {hour: [] for hour in range(24)}
        for row in patient_df.iter_rows(named=True):
            for time, value in zip(row["cgm_time_stamp"], row["CGM"]):
                hour = time.hour
                daily_data[hour].append(value)
                
        hourly_means = []
        for hour in range(24):
            if daily_data[hour]:
                hourly_means.append(np.mean(daily_data[hour]))
            else:
                hourly_means.append(np.nan)
        return hourly_means
    
    # Process each group
    group1_curves = {}
    group2_curves = {}
    
    for patient in group1_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group1_curves[patient] = process_patient_data(patient_df)
    
    for patient in group2_patients:
        patient_df = cgm_resolution_df.filter(pl.col("Patient Number") == patient)
        group2_curves[patient] = process_patient_data(patient_df)
    
    # Convert to arrays for statistics
    group1_array = np.array(list(group1_curves.values()))
    group2_array = np.array(list(group2_curves.values()))
    
    # Calculate hourly statistics and t-tests
    hourly_stats = []
    for hour in range(24):
        g1_hour = group1_array[:, hour]
        g2_hour = group2_array[:, hour]
        
        # Remove NaN values for t-test
        g1_clean = g1_hour[~np.isnan(g1_hour)]
        g2_clean = g2_hour[~np.isnan(g2_hour)]
        
        if len(g1_clean) > 0 and len(g2_clean) > 0:
            t_stat, p_val = stats.ttest_ind(g1_clean, g2_clean)
        else:
            t_stat, p_val = np.nan, np.nan
            
        hourly_stats.append({
            'hour': hour,
            'g1_mean': np.nanmean(g1_hour),
            'g1_std': np.nanstd(g1_hour),
            'g2_mean': np.nanmean(g2_hour),
            'g2_std': np.nanstd(g2_hour),
            'p_value': p_val
        })
    
    significant_hours = sum(1 for stat in hourly_stats if stat['p_value'] < 0.05)
    
    return {
        'hourly_stats': hourly_stats,
        'group1_curves': group1_curves,
        'group2_curves': group2_curves,
        'group_sizes': (len(group1_curves), len(group2_curves))
    }, significant_hours

In [None]:
def visualize_cgm_patterns_by_medians(
    chinese_data: ChineseCGMData,
    show_individual: bool = False,
    excluded_vars: list[str] = None
) -> tuple[list[str], dict[str, int]]:
    """
    Create subplots of CGM patterns for each numeric variable, splitting subjects 
    by the median value of each variable. Plots are sorted by statistical significance.
    
    Args:
        chinese_data: ChineseCGMData object containing the data
        show_individual: If True, shows individual subject curves
        excluded_vars: List of numeric variables to exclude from analysis
    
    Returns:
        tuple: (list of variables ordered by significance, dictionary of significance counts)
    """
    
    # Define numeric columns
    numeric_cols = [
        "Fasting Plasma Glucose (mg/dl)",
        "2-hour Postprandial Plasma Glucose (mg/dl)",
        "Fasting C-peptide (nmol/L)",
        "2-hour Postprandial C-peptide (nmol/L)",
        "Fasting Insulin (pmol/L)",
        "2-hour Postprandial insulin (pmol/L)",
        "HbA1c (mmol/mol)",
        "Glycated Albumin (%)",
        "Total Cholesterol (mmol/L)",
        "Triglyceride (mmol/L)",
        "High-Density Lipoprotein Cholesterol (mmol/L)",
        "Low-Density Lipoprotein Cholesterol (mmol/L)",
        "Creatinine (umol/L)",
        "Estimated Glomerular Filtration Rate (ml/min/1.73m2)",
        "Uric Acid (mmol/L)",
        "Blood Urea Nitrogen (mmol/L)",
        "Age (years)",
        "Height (m)",
        "Weight (kg)",
        "BMI (kg/m2)",
    ]
    
    # Remove excluded variables
    if excluded_vars:
        numeric_cols = [col for col in numeric_cols if col not in excluded_vars]
    
    # Filter to only existing columns with non-null values
    valid_cols = []
    for col in numeric_cols:
        if col in chinese_data.df_metadata.columns:
            if chinese_data.df_metadata[col].null_count() < len(chinese_data.df_metadata):
                valid_cols.append(col)
    
    # Calculate statistics for all variables
    var_significance = {}
    var_stats = {}
    for var in valid_cols:
        median_val = chinese_data.df_metadata[var].median()
        stats, significant_hours = calculate_cgm_statistics(
            chinese_data,
            group1_filter=pl.col(var) > median_val
        )
        var_significance[var] = significant_hours
        var_stats[var] = stats
    
    # Sort variables by significance
    sorted_vars = sorted(var_significance.items(), key=lambda x: x[1], reverse=True)
    ordered_vars = [v[0] for v in sorted_vars]
    
    # Create figure
    n_vars = len(ordered_vars)
    fig = plt.figure(figsize=(20, 6*n_vars))
    gs = GridSpec(n_vars * 2, 1, figure=fig, height_ratios=[3, 1] * n_vars)
    
    # Plot variables in order of significance
    for idx, var in enumerate(ordered_vars):
        median_val = chinese_data.df_metadata[var].median()
        stats = var_stats[var]
        significant_hours = var_significance[var]
        hours = np.arange(24)
        
        # Setup subplots
        ax1 = fig.add_subplot(gs[2*idx])
        ax2 = fig.add_subplot(gs[2*idx + 1])
        
        # Plot time series
        if show_individual:
            for curves, color in [(stats['group1_curves'], 'blue'), 
                                (stats['group2_curves'], 'red')]:
                for patient_data in curves.values():
                    ax1.plot(hours, patient_data, alpha=0.3, 
                            color=color, linewidth=1)
        
        # Calculate and plot means
        for i, (curves, color, group_name) in enumerate([
            (stats['group1_curves'], 'blue', f"High {var}"),
            (stats['group2_curves'], 'red', f"Low {var}")
        ]):
            values = np.array(list(curves.values()))
            mean = np.nanmean(values, axis=0)
            sem = np.nanstd(values, axis=0) / np.sqrt(len(curves))
            
            ax1.plot(hours, mean, color=color, linewidth=2,
                    label=f"{group_name} (n={len(curves)})")
            ax1.fill_between(hours, mean - 1.96*sem, mean + 1.96*sem,
                           color=color, alpha=0.2)
        
        # Plot significance
        p_values = [stat['p_value'] for stat in stats['hourly_stats']]
        significance = [1 - p for p in p_values]
        ax2.plot(hours, significance, 'k-', linewidth=2)
        ax2.axhline(y=0.95, color='r', linestyle='--', alpha=0.5, label='p=0.05')
        ax2.fill_between(hours, 0.8, significance, alpha=0.2)
        
        # Style time series plot
        ax1.set_xlim(0, 23)
        ax1.set_ylim(100, 210)
        ax1.set_xlabel('Hour of Day', labelpad=15)  # Added labelpad for more spacing
        ax1.set_ylabel('CGM Value (mg/dL)')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        
        # Add meal indicators
        meal_times = [(6, 'Breakfast'), (12, 'Lunch'), (18, 'Dinner')]
        for hour, meal in meal_times:
            ax1.axvline(x=hour, color='gray', linestyle='--', alpha=0.3)
            ax1.text(hour, ax1.get_ylim()[0], meal,
                    rotation=90, va='top', ha='right')
        
        # Style p-value plot
        ax2.set_xlim(0, 23)
        ax2.set_ylim(0.8, 1)
        ax2.set_xlabel('Hour of Day')
        ax2.set_ylabel('Statistical Significance (1-p)')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        
        # Set title with significance info
        ratio = significant_hours / 24
        title = f"{var} (median split: {median_val:.2f})\nSignificant hours: {significant_hours}/24 ({ratio:.1%})"
        ax1.set_title(title)
    
    plt.tight_layout()
    plt.show()
    
    # Print sorted summary
    print("\nVariables ordered by significance:")
    for var in ordered_vars:
        significant_hours = var_significance[var]
        ratio = significant_hours / 24
        print(f"{var}: {significant_hours}/24 hours significant ({ratio:.1%})")
    
    return ordered_vars, var_significance

We can now observe temporal significance patterns when comparing high versus low
clinical variable groups within the T2D population. 

In [None]:
ordered_vars, var_significance = visualize_cgm_patterns_by_medians(chinese_data)


Let's now see how metadata variables relate to each other by plotting a correlation heatmap, where variables are sorted by their CGM pattern significance. The diagonal will display the number of hours each variable shows statistically significant CGM differences

In [None]:
def visualize_sorted_correlations(
    chinese_data: ChineseCGMData, 
    ordered_vars: list[str],  
    significance_dict: dict[str, int]  
) -> None:
    """
    Create a correlation heatmap with variables sorted by their statistical significance.
    Diagonal shows number of significant hours instead of self-correlation.
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    import polars as pl
    
    n_vars = len(ordered_vars)
    
    # Create two matrices: one for correlations, one for significant hours
    corr_values = np.zeros((n_vars, n_vars))
    sig_hours = np.zeros((n_vars, n_vars))
    mask_sig = np.zeros((n_vars, n_vars), dtype=bool)
    
    # Fill matrices
    for i, col1 in enumerate(ordered_vars):
        for j, col2 in enumerate(ordered_vars):
            if i == j:
                sig_hours[i, j] = significance_dict[col1]
                mask_sig[i, j] = True
                corr_values[i, j] = 0  # This will be masked anyway
            else:
                corr = chinese_data.df_metadata.select(pl.corr(col1, col2)).item()
                corr_values[i, j] = round(float(corr), 2) if corr is not None else np.nan
    
    # Create figure with extra space for labels
    plt.figure(figsize=(15, 12))
    
    # Create simplified labels for x-axis (just variable names)
    x_labels = [var.split('(')[0].strip() for var in ordered_vars]
    # Keep full labels for y-axis
    y_labels = [f"{var} -" for var in ordered_vars]
    
    # Plot correlation heatmap
    ax = sns.heatmap(corr_values,
                     mask=mask_sig,  # Mask the diagonal
                     annot=True,
                     cmap='RdBu_r',
                     center=0,
                     vmin=-1,
                     vmax=1,
                     fmt='.2f',
                     xticklabels=x_labels,
                     yticklabels=y_labels,
                     cbar_kws={"shrink": .8, "label": "Correlation"})
    
    # Plot significant hours on diagonal
    # Use a different colormap and normalization for significance
    sig_norm = plt.Normalize(0, 24)
    sig_cmap = plt.cm.Reds
    
    for i in range(n_vars):
        color = sig_cmap(sig_norm(sig_hours[i, i]))
        ax.add_patch(plt.Rectangle((i, i), 1, 1, fill=True, color=color))
        plt.text(i + 0.5, i + 0.5, f'{int(sig_hours[i, i])}',
                ha='center', va='center', color='white' if sig_hours[i, i] > 12 else 'black')
    
    plt.title('Correlation Matrix (Variables Sorted by CGM Pattern Significance)\nDiagonal shows significant hours', pad=20)
    plt.xticks(rotation=90)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

In [None]:
visualize_sorted_correlations(chinese_data, ordered_vars, var_significance)

The function below will show how food events occur throughout the day by plotting their time distribution using both a density curve and raw counts. The visualization will highlight peak meal times and overall event frequency patterns.

In [None]:
def plot_event_time_distribution(events_df: pl.DataFrame, 
                               bin_width_minutes: int = 30,
                               kde_bandwidth: float = 0.5,
                               figsize: tuple = (12, 6)) -> None:
    """
    Plot histogram and KDE of event times across the day with count on secondary axis.
    
    Args:
        events_df: DataFrame from get_all_events_cgm_data
        bin_width_minutes: Width of histogram bins in minutes
        kde_bandwidth: Bandwidth parameter for kernel density estimation
        figsize: Figure size as (width, height)
    """
    # Extract hours as float (e.g., 14:30 -> 14.5)
    event_times = []
    for dt in events_df['event_date']:
        hours = dt.hour
        minutes = dt.minute
        event_times.append(hours + minutes/60)
    
    total_events = len(event_times)
    
    # Create figure with two y-axes
    fig, ax1 = plt.subplots(figsize=figsize)
    ax2 = ax1.twinx()
    
    # Calculate number of bins
    bins = int(24 * 60 / bin_width_minutes)
    
    # Plot histogram with density on left axis
    counts, bins, patches = ax1.hist(event_times, bins=bins, density=True, 
                                   alpha=0.6, color='skyblue', 
                                   label='Density')
    
    # Calculate and plot raw counts on right axis
    counts_raw, _, patches_raw = ax2.hist(event_times, bins=bins, 
                                        alpha=0.0,  # Make invisible
                                        label='Count')
    
    # Calculate KDE
    kde = stats.gaussian_kde(event_times, bw_method=kde_bandwidth)
    x_range = np.linspace(0, 24, 200)
    ax1.plot(x_range, kde(x_range), 'r-', lw=2, 
            label='Kernel Density Estimate')
    
    # Customize plots
    ax1.set_xlabel('Time of Day (hours)')
    ax1.set_ylabel('Density', color='b')
    ax2.set_ylabel('Count', color='g')
    
    plt.title(f'Distribution of Food Event Times (Total Events: {total_events:,})')
    ax1.grid(True, alpha=0.3)
    
    # Set x-axis ticks to show hours
    ax1.set_xticks(np.arange(0, 25, 2))
    
    # Add hour labels
    hours_labels = [f'{int(h):02d}:00' for h in np.arange(0, 25, 2)]
    ax1.set_xticklabels(hours_labels)
    
    # Color the tick labels
    ax1.tick_params(axis='y', labelcolor='b')
    ax2.tick_params(axis='y', labelcolor='g')
    
    # Add legends for both axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + ['Count'], 
              loc='upper right', bbox_to_anchor=(0.98, 0.98))
    
    plt.tight_layout()
    plt.show()

In [None]:
random_id = chinese_data.get_random_subject_id()

In [None]:

data_df = chinese_data.get_all_events_cgm_data(before_offset="-1h",after_offset="3h")
data_df

In [None]:
plot_event_time_distribution(data_df)
single_subject_events_cgm_data = chinese_data.get_single_subject_events_cgm_data(random_id,before_offset="-1h",after_offset="3h")
plot_event_time_distribution(single_subject_events_cgm_data)

In [None]:
events_df = chinese_data.get_all_events_cgm_data(before_offset="-12h",after_offset="12h")
events_df

### LLM-DRIVEN FOOD CATEGORIZATION FOR CGM EVENT ANALYSIS (Section 5)

Note: here we're reading the results of the LLM food categorization from a parquet file, stored in `./LLM_outpts/processed_food_driary_entries.parquet`. To run the LLM food categorization, use the notebook `LLM_pipeline.ipynb`.

In [None]:
# Load the files saved with the LLM categorization
loaded_df = pl.read_parquet(f"{local_base_path}/processed_food_diary_entries.parquet")
loaded_df["should_reject_sample"].value_counts()

In [None]:
calories_df = loaded_df.filter(pl.col("should_reject_sample") == False).select(pl.col("event_id"),pl.col("meal_caloric_density"))
events_with_calories = events_df.join(calories_df,on="event_id",how="left").filter(pl.col("meal_caloric_density").is_in(["low","average"]))
print(events_with_calories.shape)
print(events_with_calories["meal_caloric_density"].value_counts())
events_with_calories

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import polars as pl
from typing import Optional
from matplotlib.gridspec import GridSpec

def calculate_hourly_p_values(group1_data, group2_data, standard_times):
    """
    Calculate hourly p-values for two groups of CGM data aligned to event time.
    
    Args:
        group1_data: List of (times, values) tuples for group 1
        group2_data: List of (times, values) tuples for group 2
        standard_times: Array of standardized time points
    
    Returns:
        Array of p-values for each hour relative to the event
    """
    hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
    p_values = []
    
    for hour in hours:
        start_time = hour * 60
        end_time = (hour + 1) * 60
        
        group1_values = []
        group2_values = []
        
        for times, values in group1_data:
            times = np.array(times)
            values = np.array(values)
            mask = (times >= start_time) & (times < end_time)
            group1_values.extend(values[mask])
        
        for times, values in group2_data:
            times = np.array(times)
            values = np.array(values)
            mask = (times >= start_time) & (times < end_time)
            group2_values.extend(values[mask])
        
        if len(group1_values) > 0 and len(group2_values) > 0:
            _, p_value = stats.ttest_ind(group1_values, group2_values)
            p_values.append(p_value)
        else:
            p_values.append(np.nan)
    
    return np.array(p_values)


We can now visualize CGM pattern comparisons between two patient groups split by clinical variables.

In [None]:
def plot_aligned_cgm_events_by_group(
    chinese_data: ChineseCGMData,
    events_df: pl.DataFrame,
    group1_filter: pl.Expr,
    group2_filter: Optional[pl.Expr] = None,
    group_names: tuple[str, str] = ("Group 1", "Group 2"),
    show_individual: bool = False,
    split_meals: bool = False,
    split_by_calories: bool = True,
    alpha_individual: float = 0.1,
    figsize: tuple = (15, 10),
    display: bool = True,
    variable_name: str = "Variable"
) -> tuple[plt.Figure, list[plt.Axes]]:
    """
    Plot aligned CGM curves for events, split by clinical groups and caloric density.
    Includes p-value plots for both between-group and within-group comparisons.
    """
    # Get patient groups
    group1_patients = chinese_data.df_metadata.filter(group1_filter)["Patient Number"].to_list()
    if group2_filter is None:
        group2_patients = chinese_data.df_metadata.filter(~group1_filter)["Patient Number"].to_list()
    else:
        group2_patients = chinese_data.df_metadata.filter(group2_filter)["Patient Number"].to_list()
    
    # Adjust figure size for split view
    if split_meals and split_by_calories:
        figsize = (15, 25)  # Larger figure for all comparisons
    
    if split_meals:
        fig = plt.figure(figsize=figsize)
        if split_by_calories:
            fig = plt.figure(figsize=(15, 25))  # Increased height for additional plots
            
            # 10 rows: main title, legend space, between-group title, 3 CGM plots, 3 between-group p-value plots, 
            # within-group title, 3 within-group p-value plots
            gs = GridSpec(10, 3, figure=fig, 
                        height_ratios=[0.2, 0.1, 0.2, 3, 3, 1, 1, 0.2, 1, 1])
            
            # Create title and section headers
            title_ax = fig.add_subplot(gs[0, :])
            between_groups_title_ax = fig.add_subplot(gs[2, :])  # Moved up before its plots
            within_groups_title_ax = fig.add_subplot(gs[7, :])   # Moved up before its plots
            title_ax.axis('off')
            between_groups_title_ax.axis('off')
            within_groups_title_ax.axis('off')
            
            # Create main plots
            cgm_axes_low = [fig.add_subplot(gs[3, i]) for i in range(3)]
            cgm_axes_avg = [fig.add_subplot(gs[4, i]) for i in range(3)]
            
            # Between-group p-value plots
            p_axes_low = [fig.add_subplot(gs[5, i]) for i in range(3)]
            p_axes_avg = [fig.add_subplot(gs[6, i]) for i in range(3)]
            
            # Within-group p-value plots
            p_axes_within_group1 = [fig.add_subplot(gs[8, i]) for i in range(3)]
            p_axes_within_group2 = [fig.add_subplot(gs[9, i]) for i in range(3)]
            
            cgm_axes = {'low': cgm_axes_low, 'average': cgm_axes_avg}
            p_axes = {'low': p_axes_low, 'average': p_axes_avg}
            p_axes_within = {group_names[0]: p_axes_within_group1, 
                            group_names[1]: p_axes_within_group2}

   
            within_groups_title_ax.text(0.5, 0.5, 
                "Within-Group Comparisons (Low vs Average calories within each group and meal type)", 
                ha='center', va='center', fontsize=10)
        else:
            gs = GridSpec(3, 3, figure=fig, height_ratios=[0.2, 3, 1])
            title_ax = fig.add_subplot(gs[0, :])
            title_ax.axis('off')
            cgm_axes = [fig.add_subplot(gs[1, i]) for i in range(3)]
            p_axes = [fig.add_subplot(gs[2, i]) for i in range(3)]
            
        meal_categories = ['Breakfast', 'Lunch', 'Dinner']
    else:
        fig = plt.figure(figsize=figsize)
        gs = GridSpec(3, 1, figure=fig, height_ratios=[0.2, 3, 1])
        title_ax = fig.add_subplot(gs[0])
        title_ax.axis('off')
        ax_cgm = fig.add_subplot(gs[1])
        ax_p = fig.add_subplot(gs[2])
        cgm_axes = [ax_cgm]
        p_axes = [ax_p]
        meal_categories = [None]

    def wrap_title(text, width=40):
        """Wrap title text to multiple lines if too long"""
        import textwrap
        return '\n'.join(textwrap.wrap(text, width=width))

    # Create title with cutoff value
    cutoff_value = chinese_data.df_metadata[variable_name].median()
    n_group1 = len(group1_patients)
    n_group2 = len(group2_patients)
    title_text = wrap_title(
        f'CGM Patterns by {variable_name}: High {variable_name} (>{cutoff_value:.1f}, n={n_group1}) vs '
        f'Low {variable_name} (≤{cutoff_value:.1f}, n={n_group2})'
    )
    
    title_ax.text(0.5, 0.5, title_text, ha='center', va='center', fontsize=12)


    # Prepare data structures
    caloric_densities = ['low', 'average'] if split_by_calories else [None]
    groups_data = {
        cal_dens: {
            group_name: {meal: [] for meal in meal_categories}
            for group_name in group_names
        }
        for cal_dens in caloric_densities
    }
    
    groups_statistics = {
        cal_dens: {
            group: {meal: {'means': [], 'medians': [], 'sds': []} 
                   for meal in meal_categories}
            for group in group_names
        }
        for cal_dens in caloric_densities
    }
        # Track global ranges
    global_time_min = float('inf')
    global_time_max = float('-inf')
    
    # First pass: determine global time range
    for row in events_df.iter_rows(named=True):
        cgm_values = row['CGM']
        is_before = row['is_before_food_event']
        onset_idx = np.where(np.array(is_before) == False)[0][0]
        timestamps = [(i - onset_idx) * 15 for i in range(len(cgm_values))]
        
        global_time_min = min(global_time_min, min(timestamps))
        global_time_max = max(global_time_max, max(timestamps))
    
    # Create standard time points
    standard_times = np.arange(global_time_min, global_time_max + 15, 15)
    
    # Second pass: sort events into groups
    for row in events_df.iter_rows(named=True):
        patient = row['Patient Number']
        current_category = row['event_type'] if split_meals else None
        caloric_density = row.get('meal_caloric_density', None)
        
        if split_by_calories and caloric_density not in ['low', 'average']:
            continue
            
        if patient in group1_patients:
            group = group_names[0]
        elif patient in group2_patients:
            group = group_names[1]
        else:
            continue
        
        cgm_values = row['CGM']
        is_before = row['is_before_food_event']
        onset_idx = np.where(np.array(is_before) == False)[0][0]
        timestamps = [(i - onset_idx) * 15 for i in range(len(cgm_values))]
        
        cal_dens = caloric_density if split_by_calories else None
        groups_data[cal_dens][group][current_category].append((timestamps, cgm_values))
    
    # Calculate statistics and determine global CGM range
    global_cgm_min = float('inf')
    global_cgm_max = float('-inf')
    
    for cal_dens in caloric_densities:
        for group in group_names:
            for meal_category in meal_categories:
                all_series = groups_data[cal_dens][group][meal_category]
                if not all_series:
                    continue
                
                values_matrix = np.full((len(all_series), len(standard_times)), np.nan)
                
                for idx, (times, values) in enumerate(all_series):
                    interp_values = np.interp(standard_times, times, values)
                    values_matrix[idx] = interp_values
                
                mean_curve = np.nanmean(values_matrix, axis=0)
                median_curve = np.nanmedian(values_matrix, axis=0)
                std_curve = np.nanstd(values_matrix, axis=0)
                
                groups_statistics[cal_dens][group][meal_category]['means'] = mean_curve
                groups_statistics[cal_dens][group][meal_category]['medians'] = median_curve
                groups_statistics[cal_dens][group][meal_category]['sds'] = std_curve
                
                if show_individual:
                    global_cgm_min = min(global_cgm_min, np.nanmin(values_matrix))
                    global_cgm_max = max(global_cgm_max, np.nanmax(values_matrix))
                else:
                    global_cgm_min = min(global_cgm_min, np.nanmin(mean_curve - std_curve))
                    global_cgm_max = max(global_cgm_max, np.nanmax(mean_curve + std_curve))
    
    # Add padding to CGM range
    cgm_range = global_cgm_max - global_cgm_min
    global_cgm_min -= cgm_range * 0.05
    global_cgm_max += cgm_range * 0.05
    
    # Colors for groups
    colors = {group_names[0]: 'blue', group_names[1]: 'red'}
    
    # Plot each caloric density category
    for cal_dens in caloric_densities:
        current_cgm_axes = cgm_axes[cal_dens] if split_by_calories and split_meals else cgm_axes
        current_p_axes = p_axes[cal_dens] if split_by_calories and split_meals else p_axes
        
        for ax_idx, (ax_cgm, ax_p, meal_category) in enumerate(zip(current_cgm_axes, current_p_axes, meal_categories)):
            group1_data = groups_data[cal_dens][group_names[0]][meal_category]
            group2_data = groups_data[cal_dens][group_names[1]][meal_category]
            
            # Get event counts for title
            group1_count = len(group1_data)
            group2_count = len(group2_data)
            total_events = group1_count + group2_count
            
            for group in group_names:
                all_series = groups_data[cal_dens][group][meal_category]
                if not all_series:
                    continue
                
                color = colors[group]
                
                if show_individual:
                    for times, values in all_series:
                        interp_values = np.interp(standard_times, times, values)
                        ax_cgm.plot(standard_times, interp_values, '-', 
                                  color=color, alpha=alpha_individual)
                
                stats = groups_statistics[cal_dens][group][meal_category]
                mean_curve = stats['means']
                std_curve = stats['sds']
                
                ax_cgm.plot(standard_times, mean_curve, '-', 
                           color=color, linewidth=2, 
                           label=f'{group} (n={len(all_series)})')
                ax_cgm.fill_between(standard_times, 
                                  mean_curve - std_curve, 
                                  mean_curve + std_curve,
                                  color=color, alpha=0.2)
            
            # Plot between-group p-values
            if group1_data and group2_data:
                p_values = calculate_hourly_p_values(group1_data, group2_data, standard_times)
                hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
                
                significance = 1 - p_values
                
                ax_p.plot(hours * 60, significance, 'k-', linewidth=2, label='Significance')
                ax_p.axhline(y=0.95, color='r', linestyle='--', alpha=0.5, label='p=0.05')
                ax_p.fill_between(hours * 60, 0, significance, alpha=0.2)
                
                ax_p.set_ylim(0.8, 1)
                ax_p.grid(True, alpha=0.3)
            
            # Common plot elements
            ax_cgm.axvline(x=0, color='k', linestyle='--', label='Event Onset')
            ax_cgm.set_ylim(global_cgm_min, global_cgm_max)
            ax_cgm.set_xlim(standard_times[0], standard_times[-1])
            
            for ax in [ax_cgm, ax_p]:
                ax.set_xlabel('Minutes from Event')
                xticks = np.arange(standard_times[0], standard_times[-1] + 1, 120)
                ax.set_xticks(xticks)
                ax.set_xticklabels([str(int(x)) for x in xticks])
            
            if ax_idx == 0:
                ax_cgm.set_ylabel('CGM Value')
                ax_p.set_ylabel('Statistical Significance (1-p)')
            
            # Create title with event counts
            if meal_category:
                cal_dens_text = f" ({cal_dens.capitalize()} Calories)" if cal_dens else ""
                title = wrap_title(
                    f'{meal_category} Events{cal_dens_text} High: {group1_count}, Low: {group2_count}'
                )
            else:
                title = wrap_title(
                    f'All Events High: {group1_count}, Low: {group2_count}'
                )
            
            ax_cgm.set_title(title)
            ax_cgm.grid(True, alpha=0.3)
            ax_p.grid(True, alpha=0.3)

    # Add within-group p-value plots if using split view
    if split_by_calories and split_meals:
        for group in group_names:
            for ax_idx, (ax_p, meal_category) in enumerate(zip(p_axes_within[group], meal_categories)):
                low_cal_data = groups_data['low'][group][meal_category]
                avg_cal_data = groups_data['average'][group][meal_category]
                
                if low_cal_data and avg_cal_data:
                    p_values = calculate_hourly_p_values(low_cal_data, avg_cal_data, standard_times)
                    hours = np.arange(int(standard_times[0] / 60), int(standard_times[-1] / 60) + 1)
                    
                    significance = 1 - p_values
                    
                    ax_p.plot(hours * 60, significance, 'k-', linewidth=2)
                    ax_p.axhline(y=0.95, color='r', linestyle='--', alpha=0.5)
                    ax_p.fill_between(hours * 60, 0, significance, alpha=0.2)
                    
                    ax_p.set_ylim(0.8, 1)
                    ax_p.grid(True, alpha=0.3)
                    
                    if ax_idx == 0:
                        ax_p.set_ylabel(f'{group}\nSignificance (1-p)')
                    ax_p.set_xlabel('Minutes from Event')
                    
                    # Add title with sample sizes
                    n_low = len(low_cal_data)
                    n_avg = len(avg_cal_data)
                    ax_p.set_title(f'{meal_category}\n(Low: {n_low}, Avg: {n_avg})')
                else:
                    print(f"Skipping within-group p-value calculation for {group}, {meal_category}")

    # Create a single legend
    handles, labels = (cgm_axes['low'] if split_by_calories and split_meals else cgm_axes)[0].get_legend_handles_labels()
    p_handles, p_labels = (p_axes['low'] if split_by_calories and split_meals else p_axes)[0].get_legend_handles_labels()
    
    unique_labels = []
    unique_handles = []
    for handle, label in zip(handles + p_handles, labels + p_labels):
        if label not in unique_labels:
            unique_labels.append(label)
            unique_handles.append(handle)
    
    # Move legend between title and plots
    if split_by_calories and split_meals:
        legend_y = 0.88
    else:
        legend_y = 0.85
    
    fig.legend(unique_handles, unique_labels, 
              loc='upper center', bbox_to_anchor=(0.5, legend_y), ncol=5)

    # Adjust layout
    plt.tight_layout()
    if split_by_calories and split_meals:
        plt.subplots_adjust(top=0.92, bottom=0.05, hspace=0.6)
    else:
        plt.subplots_adjust(top=0.92)

    if display:
        plt.show()
    return fig

In [None]:
# Plot with groups based on some clinical variable
fig = plot_aligned_cgm_events_by_group(
    chinese_data,
    events_with_calories,
    group1_filter=pl.col("HbA1c (mmol/mol)") > chinese_data.df_metadata["HbA1c (mmol/mol)"].median(),
    group_names=("High HbA1c", "Low HbA1c"),
    split_meals=True,
    variable_name="HbA1c (mmol/mol)",
    
)

In [None]:
for column in ordered_vars:
    plot_aligned_cgm_events_by_group(
    chinese_data,
    events_with_calories,
    group1_filter=pl.col(column) > chinese_data.df_metadata[column].median(),
    group_names=(f"High {column}", f"Low {column}"),
    split_meals=True,
    variable_name=column,
    split_by_calories=True
)