In [None]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt
import panel as pn

from app_charts import lda_all_df, filtered_df, unfiltered_df, get_roc_curve_df
# from db_utils import get_units_df

: 

---
Annotated sessions

In [2]:
lda_all_df.columns

['unit_id',
 'drift_rating',
 'vis_response_r2',
 'aud_response_r2',
 'presence_ratio',
 'lda']

In [None]:

def get_above_lda_threshold_df(df: pl.DataFrame, lda_threshold: float) -> pl.DataFrame:
    return (
        df
        .with_columns(
            session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
            subject_id=pl.col('unit_id').str.split('_').list.get(0),
        )
        .group_by('session_id')
        .agg(
            pl.col('subject_id').first(),
            above_threshold=(pl.col('lda').gt(lda_threshold)).sum().truediv(pl.col('lda').count()),
            n_units_with_lda_score=pl.col('lda').count(),
        )
        .with_columns(
            subject_total_above_threshold=pl.col('above_threshold').sum().over('subject_id'),
        )
    )
    
def get_sessions_above_threshold_chart(df: pl.DataFrame, fpr_threshold = 0.1) -> alt.Chart:

    lda_threshold: float = (
        get_roc_curve_df(metric='lda')
        .select(
            pl.col('value').sort_by((pl.col('fpr') - fpr_threshold).abs())
        )
        .get_column('value').first()
    )

    return (
        alt.Chart(df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)).mark_bar().encode(
            color=alt.Color('above_threshold:Q'),#.scale(domain=[0, 1]),
            x=alt.X('subject_id:N', sort='-y'),
            y=alt.Y('above_threshold:Q'),
            tooltip=['session_id', 'above_threshold'],
        ).properties(
            title=[f'Fraction of units above {lda_threshold=:.2f}', '(some sessions missing LDA metric for all units)'],
            width=1600,
            height=200,
        )
    )

fpr_threshold_slider = pn.widgets.FloatSlider(name='FPR threshold', start=0, end=1, value=0.1, step=0.01)
pn.extension()

@pn.depends(fpr_threshold_slider.param.value)
def sessions_above_threshold_panel(fpr_threshold):
    return pn.pane.Vega(get_sessions_above_threshold_chart(lda_all_df, fpr_threshold))

pn.Column(
    fpr_threshold_slider,
    sessions_above_threshold_panel,
    # width=1600,
    # height=200,
).servable()


TypeError: unhashable type: 'Series'

Column
    [0] FloatSlider(name='FPR threshold', step=0.01, value=0.1)
    [1] ParamFunction(function, _pane=Vega, defer_load=False)

In [None]:
df.sort('n_units_with_lda_score')

session_id,subject_id,above_threshold,n_units_with_lda_score
str,str,f64,u32
"""668759_2023-07-13""","""668759""",,0
"""620263_2022-07-26""","""620263""",,0
"""646318_2023-01-17""","""646318""",,0
"""636397_2022-09-27""","""636397""",,0
"""662983_2023-05-15""","""662983""",,0
…,…,…,…
"""626791_2022-08-17""","""626791""",0.4,20
"""733780_2024-09-04""","""733780""",0.2,20
"""681532_2023-10-16""","""681532""",0.25,20
"""666986_2023-08-17""","""666986""",0.318182,22


In [None]:
lda_all_df