In [17]:
import pandas as pd
import os

def load_datasets(
        data_dir: str='../data/interim'
        ) -> dict[str, list[pd.DataFrame]]:
    '''
    Loads all the data from the provided (default interm) folder and returns a dictionary with patient ids as keys and lists of dataframes as values.
    ---
    Parameters:
        data_dir: str
            The directory where the data is stored. Finds and loads all the csv files in the directory.
    ---
    Returns:
        dataframes_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
            Eg:
                {
                    'patientId1': [df1, df2, df3],
                    'patientId2': [df4, df5],
                    ...
                }
    '''
    csv_files = [f for f in os.listdir(data_dir) if f.endswith('.csv')]
    dataframes_by_patient = {}
    for file in csv_files:

        # skip meal annotation file for now
        if file == "meal_annotation_plus_2hr_meal.csv":
            continue

        df = pd.read_csv(os.path.join(data_dir, file))
        patient_id = file.split('_')[1]
        if patient_id not in dataframes_by_patient.keys():
            dataframes_by_patient[patient_id] = []

        dataframes_by_patient[patient_id].append(df)

    return dataframes_by_patient


def create_segmentation_labels(datasets_by_patient: dict[str, list[pd.DataFrame]]) -> dict[str, list[pd.DataFrame]]:
    '''
    Creates segmentation labels for the datasets.
    Specifically, makes a new column 'segmentation_label' that is 1 for meal periods and 0 for non-meal periods.
    ---
    Implementation details:
        - Identifies meal periods based on ANNOUNCE_MEAL messages and food_g values.
            - The start of the meal period is defined as the time of the ANNOUNCE_MEAL message.
            - The end of the meal period is defined as the first NaN that occurs in food_g after the ANNOUNCE_MEAL message.
            - If there is no NaN (in the food_g column) after the ANNOUNCE_MEAL message, the end of the meal period is the last row of the dataframe.
    ---
    Parameters:
        datasets_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
    ---
    Returns:
        datasets_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
    '''
    for patient_id, datasets in datasets_by_patient.items():
        for dataset in datasets:
            # Initialize segmentation label column
            dataset['segmentation_label'] = 0  # 0 for non-meal periods
            
            # Find ANNOUNCE_MEAL indices
            meal_indices = dataset.index[dataset['msg_type'] == 'ANNOUNCE_MEAL']
            
            for meal_start in meal_indices:
                # Find meal end (first NaN after non-zero values in food_g)
                meal_subset = dataset.loc[meal_start:, 'food_g']
                non_zero_mask = meal_subset != 0
                first_non_zero = meal_subset.index[non_zero_mask].min()
                meal_end = meal_subset.loc[first_non_zero:].isna().idxmax()
                
                if meal_end == first_non_zero:  # If no NaN found after non-zero values
                    meal_end = dataset.index[-1]
                else:
                    meal_end = meal_end - 1  # Exclude the first NaN
                
                # Assign meal label (1) for the meal period
                dataset.loc[meal_start:meal_end, 'segmentation_label'] = 1
    
    return datasets_by_patient
