In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

In [None]:
ehr_data_path = "/Users/jk1/stroke_datasets/stroke_unit_dataset/per_value/Extraction_20220815"
stroke_registry_data_path = "/Users/jk1/Library/CloudStorage/OneDrive-unige.ch/stroke_research/geneva_stroke_unit_dataset/data/stroke_registry/post_hoc_modified/stroke_registry_post_hoc_modified.xlsx"
patient_selection_path = "/Users/jk1/temp/opsum_end/gsu_extraction_09052025_204357/high_frequency_data_patient_selection_with_details.csv"

In [None]:
from preprocessing.geneva_stroke_unit_preprocessing.variable_assembly.relative_timestamps import \
    transform_to_relative_timestamps
from preprocessing.geneva_stroke_unit_preprocessing.variable_assembly.variable_database_assembly import \
    assemble_variable_database

desired_time_range = 72
feature_df = assemble_variable_database(ehr_data_path, stroke_registry_data_path, patient_selection_path,
                                            imaging_data_path='',
                                            restrict_to_patients_with_imaging_data_available=False,
                                            log_dir='', verbose=False)
restricted_feature_df = transform_to_relative_timestamps(feature_df, drop_old_columns=False,
                                                             restrict_to_time_range=True, desired_time_range=desired_time_range,
                                                             enforce_min_time_range=True, min_time_range=12,
                                                             log_dir='')

In [None]:
restricted_feature_df.head()

In [None]:
nihss_df = restricted_feature_df[(restricted_feature_df['sample_label'] == 'NIHSS') & (restricted_feature_df['source'] == 'EHR')]

In [None]:
def detect_end_events(temp, require_min_repeats=False, min_delta=4):
    temp['sample_date'] = pd.to_datetime(temp['sample_date'], format='%d.%m.%Y %H:%M')
    temp['value'] = temp['value'].astype(float)
    temp.sort_values('sample_date', inplace=True)

    if require_min_repeats:
        # for a given patient, compute minimum NIHSS confirmed by at least 2 consecutive measurements
        temp['same_as_previous'] = (temp['value'].shift(1) == temp['value']).astype(int)
        temp['score_with_min_1_repeat'] = temp['value']
        temp.loc[temp['same_as_previous'] == 0, 'score_with_min_1_repeat'] = np.nan
        # for every row, compute min in rows with same_as_previous up to this row
        temp['min_nihss'] = temp['score_with_min_1_repeat'].expanding().min()
        drop_cols = ['same_as_previous', 'score_with_min_1_repeat', 'min_nihss', 'delta_to_min']
    else:
        temp['min_nihss'] = temp['value'].expanding().min()
        drop_cols = ['min_nihss', 'delta_to_min']

    temp['delta_to_min'] = temp['value'] - temp['min_nihss']
    temp['end'] = temp['delta_to_min'] >= min_delta

    # only retain first end event
    temp['n_end'] = temp['end'].cumsum()
    temp.loc[temp['n_end'] > 1, 'end'] = False

    drop_cols.append('n_end')
    temp.drop(drop_cols, axis=1, inplace=True)
    return temp

In [None]:
def detect_end_events(temp, require_min_repeats=False, min_delta=4, keep_multiple_events=True):
    """
    Detect neurological deterioration events based on NIHSS increases.
    
    Parameters:
    - temp: DataFrame with NIHSS measurements
    - require_min_repeats: If True, only consider NIHSS values confirmed by at least 2 consecutive measurements
    - min_delta: Minimum increase in NIHSS score to be considered deterioration
    - keep_multiple_events: If True, detect multiple deterioration events by resetting baseline after each event
    
    Returns:
    - DataFrame with detected events
    """
    temp = temp.copy()
    temp['sample_date'] = pd.to_datetime(temp['sample_date'], format='%d.%m.%Y %H:%M')
    temp['value'] = temp['value'].astype(float)
    temp.sort_values('sample_date', inplace=True)
    
    # Initialize columns for tracking
    temp['min_nihss'] = np.nan
    temp['delta_to_min'] = np.nan
    temp['end'] = False
    
    if not keep_multiple_events:
        # Original behavior - just use expanding min and mark first event only
        if require_min_repeats:
            # For a given patient, compute minimum NIHSS confirmed by at least 2 consecutive measurements
            temp['same_as_previous'] = (temp['value'].shift(1) == temp['value']).astype(int)
            temp['score_with_min_1_repeat'] = temp['value']
            temp.loc[temp['same_as_previous'] == 0, 'score_with_min_1_repeat'] = np.nan
            temp['min_nihss'] = temp['score_with_min_1_repeat'].expanding().min()
        else:
            temp['min_nihss'] = temp['value'].expanding().min()
            
        temp['delta_to_min'] = temp['value'] - temp['min_nihss']
        temp['end'] = temp['delta_to_min'] >= min_delta
        
        # Only retain first end event
        temp['n_end'] = temp['end'].cumsum()
        temp.loc[temp['n_end'] > 1, 'end'] = False
        drop_cols = ['n_end']
        
    else:
        # New behavior - reset minimum after each event
        current_min = np.inf
        last_event_idx = -1
        
        if require_min_repeats:
            # Mark scores that are repeated at least once
            temp['same_as_previous'] = (temp['value'].shift(1) == temp['value']).astype(int)
            temp['valid_score'] = (temp['same_as_previous'] == 1) | (temp['same_as_previous'].shift(-1) == 1)
        else:
            temp['valid_score'] = True
            
        # Process rows sequentially to detect multiple events
        for i, row in temp.iterrows():
            if not pd.isna(row['value']):
                if require_min_repeats and not row['valid_score']:
                    # Skip this measurement as it's not confirmed
                    temp.at[i, 'min_nihss'] = current_min
                    continue
                    
                # Update minimum if this is a new minimum since last event
                if row['value'] < current_min:
                    current_min = row['value']
                    
                # Calculate delta and check for event
                temp.at[i, 'min_nihss'] = current_min
                temp.at[i, 'delta_to_min'] = row['value'] - current_min
                
                if temp.at[i, 'delta_to_min'] >= min_delta:
                    # This is a deterioration event
                    temp.at[i, 'end'] = True
                    # Reset the minimum to this new value for subsequent measurements
                    current_min = row['value']
        
        if require_min_repeats:
            drop_cols = ['same_as_previous', 'valid_score']
        else:
            drop_cols = []
    
    # Drop temporary columns
    if 'drop_cols' in locals() and len(drop_cols) > 0:
        temp.drop(drop_cols, axis=1, inplace=True)
        
    return temp

In [None]:
nihss_df.shape

In [None]:
temp = nihss_df[nihss_df['case_admission_id'] == '1011794_0030']
temp = detect_end_events(temp, require_min_repeats=False, min_delta=4, keep_multiple_events=True)
# temp['sample_date'] = pd.to_datetime(temp['sample_date'], format='%d.%m.%Y %H:%M')
# temp['value'] = temp['value'].astype(float)
# temp.sort_values('sample_date', inplace=True)


# temp['min_nihss'] = temp['value'].expanding().min()
# drop_cols = ['min_nihss', 'delta_to_min']
# temp['delta_to_min'] = temp['value'] - temp['min_nihss']
# temp['end'] = temp['delta_to_min'] >= min_delta

# # only retain first end event
# temp['n_end'] = temp['end'].cumsum()

In [None]:
min_delta = 4
keep_multiple_events = True

In [None]:
end_strict_df = nihss_df.groupby('case_admission_id').apply(detect_end_events, require_min_repeats=True, min_delta=min_delta, keep_multiple_events=keep_multiple_events)
end_strict_df.reset_index(drop=True, inplace=True)

end_strict_df[end_strict_df.end].case_admission_id.nunique()

In [None]:
require_min_repeats = False

non_strict_end_nihss_df = nihss_df.groupby('case_admission_id').apply(detect_end_events, require_min_repeats=require_min_repeats, min_delta=min_delta, keep_multiple_events=keep_multiple_events)
# undo groupby
non_strict_end_nihss_df.reset_index(drop=True, inplace=True)

In [None]:
non_strict_end_nihss_df[non_strict_end_nihss_df.end].case_admission_id.nunique()

In [None]:
non_strict_end_nihss_df.end.sum()

In [None]:
# number of of end events without duplicates
non_strict_end_nihss_df[non_strict_end_nihss_df.end].sample_date.nunique()


In [None]:
# non_strict_end_nihss_df = non_strict_end_nihss_df[non_strict_end_nihss_df.end]
non_strict_end_nihss_df['relative_sample_date_hourly_cat'] = np.floor(non_strict_end_nihss_df['relative_sample_date'])

In [None]:
non_strict_end_nihss_df[non_strict_end_nihss_df.end]

In [None]:
# histogram of relative sample date hourly cat
sns.histplot(data=non_strict_end_nihss_df, x='relative_sample_date_hourly_cat', hue='end', multiple='stack', bins=desired_time_range)
plt.show()

In [None]:
def plot_end(nihss_df, cid):
    temp = nihss_df[nihss_df['case_admission_id'] == cid]
    temp['value'] = temp['value'].astype(float)
    ax = sns.scatterplot(x='relative_sample_date', y='value', data=temp, hue='value', legend=False)
    # tilt x label
    plt.xticks(rotation=45)
    
    # plot red bar on end events
    end_events = temp[temp['end']]
    for i in range(end_events.shape[0]):
        plt.axvline(end_events.iloc[i]['relative_sample_date'], color='red', linestyle='--', alpha=0.5)
    
    ax.set_xlim(0, 73)
    
    fig = ax.get_figure()
    fig.suptitle(f'{cid}')
    return fig
    

In [None]:
nihss_df = nihss_df 

In [None]:
# random cid of all case_admission_id
# cid = nihss_df['case_admission_id'].sample(1).values[0]
cid = '10338_5096'
fig = plot_end(non_strict_end_nihss_df, cid)
plt.show()

In [None]:
# create pdf with all plots
from matplotlib.backends.backend_pdf import PdfPages
pdf_path = '/Users/jk1/temp/nihss_non_strict_multi_end_event_plots.pdf'
with PdfPages(pdf_path) as pdf:
    for cid in non_strict_end_nihss_df[non_strict_end_nihss_df.end]['case_admission_id'].unique():
        fig = plot_end(non_strict_end_nihss_df, cid)
        pdf.savefig(fig)
        plt.close(fig)

In [None]:
nihss_df[nihss_df['case_admission_id'] == '400976_8257']