In [6]:
import pandas as pd
import os
import numpy as np
def load_datasets(
        data_dir: str='../data/interim'
        ) -> dict[str, list[pd.DataFrame]]:
    '''
    Loads all the data from the provided (default interm) folder and returns a dictionary with patient ids as keys and lists of dataframes as values.
    ---
    Parameters:
        data_dir: str
            The directory where the data is stored. Finds and loads all the csv files in the directory.
    ---
    Returns:
        dataframes_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
            Eg:
                {
                    'patientId1': [df1, df2, df3],
                    'patientId2': [df4, df5],
                    ...
                }
    '''
    csv_files = [f for f in os.listdir(data_dir) if f.endswith('.csv')]
    dataframes_by_patient = {}
    for file in csv_files:

        # skip meal annotation file for now
        if file == "meal_annotation_plus_2hr_meal.csv":
            continue

        df = pd.read_csv(os.path.join(data_dir, file))

        patient_id = file.split('_')[1]
        if patient_id not in dataframes_by_patient.keys():
            dataframes_by_patient[patient_id] = []

        dataframes_by_patient[patient_id].append(df)

    return dataframes_by_patient


def create_segmentation_labels(datasets_by_patient: dict[str, list[pd.DataFrame]]) -> dict[str, list[pd.DataFrame]]:
    '''
    Creates segmentation labels for the datasets.
    Specifically, makes a new column 'segmentation_label' that assigns unique integers for each meal and non-meal period.
    ---
    Implementation details:
        - Identifies meal periods based on ANNOUNCE_MEAL messages and food_g values.
            - The start of the meal period is defined as the time of the ANNOUNCE_MEAL message.
            - The end of the meal period is defined as the first NaN that occurs in food_g after the ANNOUNCE_MEAL message.
            - If there is no NaN (in the food_g column) after the ANNOUNCE_MEAL message, the end of the meal period is the last row of the dataframe.
        - Assigns a unique integer label for each meal and non-meal period.
    ---
    Parameters:
        datasets_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
    ---
    Returns:
        datasets_by_patient: dict[str, list[pd.DataFrame]]
            A dictionary with patient ids as keys and lists of dataframes as values.
    '''
    for patient_id, datasets in datasets_by_patient.items():
        for dataset in datasets:
            # Initialize segmentation label column
            dataset['segmentation_label'] = 0  # Start with 0 for the first non-meal period
            
            # Find ANNOUNCE_MEAL indices
            meal_indices = dataset.index[dataset['msg_type'] == 'ANNOUNCE_MEAL']
            
            current_label = 1  # Start labeling from 1
            
            for i, meal_start in enumerate(meal_indices):
                # Find meal end (first NaN after non-zero values in food_g)
                meal_subset = dataset.loc[meal_start:, 'food_g']
                non_zero_mask = meal_subset != 0
                first_non_zero = meal_subset.index[non_zero_mask].min()
                meal_end = meal_subset.loc[first_non_zero:].isna().idxmax()
                
                if meal_end == first_non_zero:  # If no NaN found after non-zero values
                    meal_end = dataset.index[-1]
                else:
                    meal_end = meal_end - 1  # Exclude the first NaN
                
                # Assign unique label for the meal period
                dataset.loc[meal_start:meal_end, 'segmentation_label'] = current_label
                current_label += 1
                
                # Assign unique label for the non-meal period (if it's not the last meal)
                if i < len(meal_indices) - 1:
                    next_meal_start = meal_indices[i+1]
                    dataset.loc[meal_end+1:next_meal_start-1, 'segmentation_label'] = current_label
                    current_label += 1
            
            # Assign unique label for the last non-meal period (if exists)
            if meal_end < dataset.index[-1]:
                dataset.loc[meal_end+1:, 'segmentation_label'] = current_label
    
    return datasets_by_patient

def preprocess_datasets(
        datasets_by_patient: dict[str, list[pd.DataFrame]],
        drop_columns: list[str]=['day_start_shift', 'food_g_keep', 'affects_fob', 'affects_iob']
        ) -> dict[str, list[pd.DataFrame]]:
    '''
    Preprocesses the datasets by performing the following operations:
    1. Fill NaN values with 0
    2. Delete msg_type column
    3. Drop rows with invalid dates
    4. Change affects_fob and affects_iob to 1 and 0
    5. Set the index to be its date
    '''
    processed_datasets = {}
    
    for patient_id, datasets in datasets_by_patient.items():
        processed_datasets[patient_id] = []
        
        for i, dataset in enumerate(datasets):
            df = dataset.copy()
            
            # Log initial NaN state
            # initial_nans = df.isna().sum()
            # if initial_nans.sum() > 0:
            #     print(f"\nPatient {patient_id}, Dataset {i} - Initial NaN counts:")
            #     print(initial_nans[initial_nans > 0])
            
            # Replace infinite values with NaN
            df.replace([np.inf, -np.inf], np.nan, inplace=True)
            
            # Fill NaN values with forward fill first, then backward fill, then 0
            df = df.fillna(method='ffill').fillna(method='bfill').fillna(0)
            
            # Delete msg_type column if it exists
            if 'msg_type' in df.columns:
                df.drop('msg_type', axis=1, inplace=True)
            df.drop(columns=drop_columns, inplace=True)
            # Handle date column
            df['date'] = pd.to_datetime(df['date'], format='%Y-%m-%d %H:%M:%S%z', errors='coerce', utc=True)
            df.dropna(subset=['date'], inplace=True)
            df.sort_values('date', inplace=True)
            df.set_index('date', inplace=True)
            
            # Change affects_fob and affects_iob to 1 and 0
            if 'affects_fob' in df.columns:
                df['affects_fob'] = df['affects_fob'].map({'true': 1, 'false': 0}).fillna(0)
            if 'affects_iob' in df.columns:
                df['affects_iob'] = df['affects_iob'].map({'true': 1, 'false': 0}).fillna(0)
            
            # Ensure the index is a DatetimeIndex
            if not isinstance(df.index, pd.DatetimeIndex):
                df.index = pd.to_datetime(df.index)
            
            # Final NaN check
            final_nans = df.isna().sum()
            if final_nans.sum() > 0:
                print(f"\nWarning: Patient {patient_id}, Dataset {i} still has NaN values after preprocessing:")
                # print(final_nans[final_nans > 0])
                
            processed_datasets[patient_id].append(df)
    
    return processed_datasets
