In [7]:
from dash import Dash, dcc, html, Input, Output
import plotly.express as px
import pandas as pd
import numpy as np

app = Dash(__name__)

# Load data
df = pd.read_csv("sps.csv")

# Filter diagnoses
df = df[df["diagnose_big_class"].isin(["H5 Parenchymal ICH", "I1 Ischemic/TIA", "S1 Systemic", "H6 ICH Unspecified"])]

# Prepare drug categories
df['drug_cat_list'] = df['drug_cat_list'].fillna('').astype(str)
all_codes = df['drug_cat_list'].str.split(', ').explode().unique()
unique_categories = sorted([x for x in all_codes if x])

# Create jitter
df['diagnose_code'] = pd.Categorical(df['diagnose_big_class']).codes
df['jitter'] = df['diagnose_code'] + np.random.uniform(-0.2, 0.2, len(df))

app.layout = html.Div([
    dcc.Graph(id='scatter-plot'),

    html.Div([
    html.Label("Drug", style={'fontWeight': 'bold', 'fontSize': '16px', 'marginBottom': '10px'})]),
    
    html.Div([
        dcc.Checklist(
            id='category-selector',
            options=[{'label': cat, 'value': cat} for cat in unique_categories],
            value=[],  # enter "unique_categories" to start with all of the boxes checked
            inline=True,
            style={'padding-right': '50px'},
            #title="Drug"
        )
    ], style={
        'gap': '8px 16px',
        'padding': '20px',
        'maxWidth': '100%'
    })
])

@app.callback(
    Output('scatter-plot', 'figure'),
    Input('category-selector', 'value')
)
def update_plot(selected_categories):
    if not selected_categories:
        
        # place an empty plot if no drug categories are selected
        return px.scatter(title='Select at least one drug category to see data')
    
    # calculate the number of matches for each row to determine the circle size
    df['match_count'] = df['drug_cat_list'].apply(
        lambda s: sum(c in selected_categories for c in s.split(', ')))
    
    # Filter to only rows with matches
    filtered_df = df[df['match_count'] > 0]
    
    # Create scatter plot with size based on match count
    fig = px.scatter(
        filtered_df,
        x='jitter',
        y='survival_days',
        color='diagnose_big_class',
        size='match_count',  # Size based on number of matches
        size_max=20,  # Maximum point size
        title='Survival After Discharge by Diagnosis Group with Drug Filter',
        custom_data=['diagnose_big_class', 'drug_cat_list', 'match_count']
    )

    fig.update_yaxes(title='Survival Time (Days)')
    fig.update_layout(showlegend=False)
    
    # Configure x-axis
    fig.update_xaxes(
        tickvals=sorted(df['diagnose_code'].unique()),
        ticktext=sorted(df['diagnose_big_class'].unique()),
        title='Diagnosis Group'
    )
    
    # Enhanced hover information
    fig.update_traces(
        hovertemplate="<br>".join([
            "Diagnosis: %{customdata[0]}",
            "Drug Categories: %{customdata[1]}",
            "Matching Categories: %{customdata[2]}",
            "Survival Days: %{y}"
        ])
    )
    
    return fig

app.run(debug=True)

In [15]:
from sklearn.neighbors import NearestNeighbors

sps = df[["survival_days", "diagnose_big_class"]]
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(sps)
distances, indices = nbrs.kneighbors(sps)
indices

ValueError: could not convert string to float: 'H5 Parenchymal ICH'