# Kaplan-Meier Survival Curve Widget

In [1]:
# Load libraries
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np
import os

In [2]:
# Load data
def load_survival_data():
    data_path = 'data/accession-items-history_1970to2025.csv'
    if not os.path.exists(data_path) and os.path.exists('../data/accession-items-history_1970to2025.csv'):
        data_path = '../data/accession-items-history_1970to2025.csv'
    df = pd.read_csv(data_path, low_memory=False)
    

    df = df[df['ItemType'] == 'Planting'].copy()

    df['ItemStatusDate'] = pd.to_datetime(df['ItemStatusDate'], errors='coerce')
    df['ItemStatusDateTo'] = pd.to_datetime(df['ItemStatusDateTo'], errors='coerce')
    df.loc[df['ItemStatusDateTo'] == pd.Timestamp('9999-12-31'), 'ItemStatusDateTo'] = pd.NaT

    df.sort_values(by=['ItemAccNoFull', 'ItemStatusDate'], inplace=True)
    df['NonNursery'] = ~df['ItemLocationCode'].astype(str).str.startswith('8')
    first_non_nursery = df[df['NonNursery']].groupby('ItemAccNoFull').head(1)
    valid_ids = first_non_nursery['ItemAccNoFull'].unique()
    df = df[df['ItemAccNoFull'].isin(valid_ids)].copy()

    initial = df.groupby('ItemAccNoFull')['ItemStatusDate'].min().reset_index(name='StartDate')
    final = df.groupby('ItemAccNoFull').agg({
        'ItemStatusDateTo': 'max',
        'ItemStatus': 'last',
        'AccNoFull': 'last',
        'ProvenanceCode': 'last',
        'ItemLocationCode': 'last'
    }).reset_index()
    final.rename(columns={'ItemStatusDateTo': 'EndDate', 'ItemStatus': 'FinalStatus'}, inplace=True)

    survival_df = pd.merge(initial, final, on='ItemAccNoFull')
    survival_df['Duration'] = (survival_df['EndDate'] - survival_df['StartDate']).dt.days / 365.25
    terminal = ['Dead', 'Dead - Natural cause', 'Removed', 'Removed/Discarded', 'Not found', 'Stolen']
    survival_df['Event'] = survival_df['FinalStatus'].isin(terminal).astype(int)
    survival_df = survival_df[survival_df['Duration'] > 0].dropna(subset=['Duration'])

    def map_prov(code):
        if code in ['W', 'Z']:
            return 'W+Z'
        elif code == 'G':
            return 'G'
        else:
            return 'Unknown'
    survival_df['ProvenanceGroup'] = survival_df['ProvenanceCode'].apply(map_prov)

    lf = pd.read_csv("../data/Accession_1970to2025.csv", usecols=['AccNoFull', 'LifeForm'])
    survival_df = pd.merge(survival_df, lf, on='AccNoFull', how='left')

    def map_lf(x):
        if x in ['Tree', 'Shrub', 'Climber_Liana_Vine']:
            return 'Woody'
        elif x in ['Herbaceous Perennial', 'Annual']:
            return 'Herbaceous'
        return 'Unknown'
    survival_df['LifeFormGroup'] = survival_df['LifeForm'].apply(map_lf)

    bed_map = pd.read_csv("../data/unique_beds.csv").rename(columns={"UniqueBed": "ItemLocationCode"})
    survival_df = pd.merge(survival_df, bed_map[['ItemLocationCode', 'Component']], on='ItemLocationCode', how='left')

    return survival_df

df_all = load_survival_data()

In [3]:
#Widget

# Component definitions
primary_components = ['Alpine', 'Asian', 'North America', 'Main Garden']
secondary_components = ['Food', 'Winter', 'Contemporary', 'Physic', 'Front Entrance']
all_components = primary_components + secondary_components

# Widget build
component_checkboxes = {comp: widgets.Checkbox(value=False, description=comp) for comp in all_components}
select_all = widgets.Checkbox(value=False, description='(All)', indent=False)
segment_lifeform = widgets.Checkbox(value=False, description='Segment by LifeForm')
segment_provenance = widgets.Checkbox(value=False, description='Segment by Provenance')
include_unknowns = widgets.Checkbox(value=False, description='Include Unknowns')
output = widgets.Output()
stats_output = widgets.Output()

# Operational logic 
def update_all_checkbox(change):
    # Temporarily disconnect redraws
    for cb in component_checkboxes.values():
        cb.unobserve(draw_km_curve, names='value')

    for cb in component_checkboxes.values():
        cb.value = select_all.value

    # Reconnect and manually trigger once
    for cb in component_checkboxes.values():
        cb.observe(draw_km_curve, names='value')

    draw_km_curve()

#Kaplan-Meier statistics
def compute_km_stats(df):
    stats = []
    for label, group in df.groupby('Label'):
        kmf = KaplanMeierFitter()
        kmf.fit(group['Duration'], group['Event'])
        censored_rate = 100 * (1 - group['Event'].mean())
        stats.append({
            'Group': label,
            'Count': len(group),
            'Median Survival (yrs)': round(kmf.median_survival_time_ or 0, 2),
            'Censoring Rate (%)': round(censored_rate, 1)
        })
    return pd.DataFrame(stats)

def draw_km_curve(change=None):
    with output:
        clear_output(wait=True)
    with stats_output:
        clear_output(wait=True)

    selected_components = [name for name, cb in component_checkboxes.items() if cb.value]
    all_checked = select_all.value

    if not selected_components and not all_checked:
        with output:
            print("Please select at least one component.")
        return

    seg_keys = []
    if segment_lifeform.value:
        seg_keys.append("LifeFormGroup")
    if segment_provenance.value:
        seg_keys.append("ProvenanceGroup")

    if all_checked:
        subset = df_all[df_all["Component"].isin(all_components)].copy()
        if seg_keys:
            label_keys = seg_keys
            subset['Label'] = subset[seg_keys].fillna("Unknown").agg(" - ".join, axis=1)
        else:
            subset['Label'] = '(All Components)'
    else:
        subset = df_all[df_all["Component"].isin(selected_components)].copy()
        label_keys = ["Component"] + seg_keys
        subset['Label'] = subset[label_keys].fillna("Unknown").agg(" - ".join, axis=1)

    # Drop 'Unknown' if checkbox is unchecked
    if not include_unknowns.value:
        subset = subset[~subset['Label'].str.contains("Unknown")]

    if subset.empty:
        with output:
            print("No data to display.")
        return

    plt.figure(figsize=(10, 6))
    for label, group in subset.groupby('Label'):
        kmf = KaplanMeierFitter()
        kmf.fit(group['Duration'], group['Event'], label=label)
        kmf.plot_survival_function()

    title_seg = " + ".join(seg_keys) if seg_keys else "(All Components)" if all_checked else "Component"
    plt.title(f"Kaplan–Meier Survival Curves by {title_seg}")
    plt.xlabel("Years")
    plt.ylabel("Survival Probability")
    plt.xlim(0, 50)
    plt.grid(True)
    plt.legend(title='Group')
    plt.tight_layout()
    with output:
        plt.show()

    # Show stats table
    stats_df = compute_km_stats(subset)
    with stats_output:
        display(stats_df)

# Link interactions
select_all.observe(update_all_checkbox, names='value')
for cb in component_checkboxes.values():
    cb.observe(draw_km_curve, names='value')
segment_lifeform.observe(draw_km_curve, names='value')
segment_provenance.observe(draw_km_curve, names='value')
include_unknowns.observe(draw_km_curve, names='value')

# Display
component_box = widgets.VBox([select_all] + list(component_checkboxes.values()))
controls = widgets.HBox([component_box, widgets.VBox([segment_lifeform, segment_provenance, include_unknowns])])

display(controls, output, stats_output)
draw_km_curve()

HBox(children=(VBox(children=(Checkbox(value=False, description='(All)', indent=False), Checkbox(value=False, …

Output()

Output()