In [6]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt
# import panel as pn

from app_charts import lda_all_df, get_roc_curve_df
from db_utils import get_df

---
Annotated units

In [26]:
df = (
    get_df()
    .drop_nulls('drift_rating')
    .join(lda_all_df, on='unit_id', how='left')
    .group_by('session_id')
    .agg(
        pl.col('lda').median().alias('lda_median'),
    )
    .sort('lda_median', descending=True)
)
display(df)
display(df['lda_median'].describe(percentiles=(0.1, 0.9)))

session_id,lda_median
str,f64
"""649944_2023-02-28""",
"""620263_2022-07-26""",
"""620263_2022-07-27""",
"""644547_2022-12-06""",
"""628801_2022-09-19""",
…,…
"""708016_2024-04-29""",-0.533534
"""681532_2023-10-17""",-0.53596
"""666986_2023-08-15""",-0.536073
"""737403_2024-09-26""",-0.53659


statistic,value
str,f64
"""count""",138.0
"""null_count""",21.0
"""mean""",-0.447178
"""std""",0.069481
"""min""",-0.54674
"""10%""",-0.523033
"""90%""",-0.352652
"""max""",-0.204278


---
All units

In [25]:
df = (
    lda_all_df
    .drop_nulls('drift_rating')
    .with_columns(
            session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
    )
    .group_by('session_id')
    .agg(
        pl.col('lda').median().alias('lda_median'),
    )
    .sort('lda_median', descending=True)
)
display(df)
display(df['lda_median'].describe(percentiles=(0.1, 0.9)))

session_id,lda_median
str,f64
"""741148_2024-10-18""",-0.193454
"""715710_2024-07-17""",-0.206137
"""699847_2024-04-15""",-0.291405
"""644867_2023-02-20""",-0.298015
"""713655_2024-08-08""",-0.310214
…,…
"""681532_2023-10-17""",-0.540436
"""666986_2023-08-15""",-0.542345
"""702136_2024-03-05""",-0.543165
"""667252_2023-09-25""",-0.551671


statistic,value
str,f64
"""count""",138.0
"""null_count""",0.0
"""mean""",-0.450507
"""std""",0.070944
"""min""",-0.554267
"""10%""",-0.52845
"""90%""",-0.349397
"""max""",-0.193454


In [38]:

def fpr_to_lda_threshold(fpr_threshold: float) -> float:
    return (
        get_roc_curve_df(metric='lda')
        .select(
            pl.col('value').sort_by((pl.col('fpr') - fpr_threshold).abs())
        )
        .get_column('value')
        .first()
    )
    
def get_above_lda_threshold_df(df: pl.DataFrame, lda_threshold: float) -> pl.DataFrame:
    return (
        df
        .with_columns(
            session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
            subject_id=pl.col('unit_id').str.split('_').list.get(0),
        )
        .group_by('session_id')
        .agg(
            pl.col('subject_id').first(),
            above_threshold=(pl.col('lda').gt(lda_threshold)).arg_true().count().truediv(pl.col('lda').count()),
            n_units_with_lda_score=pl.col('lda').count(),
        )
        .with_columns(
            subject_total_above_threshold=pl.col('above_threshold').sum().over('subject_id'),
        )
    )

    
def get_sessions_above_threshold_chart(df: pl.DataFrame, fpr_threshold = 0.1) -> alt.Chart:
    lda_threshold = fpr_to_lda_threshold(fpr_threshold)
    return (
        alt.Chart(df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)).mark_bar().encode(
            color=alt.Color('above_threshold:Q'),#.scale(domain=[0, 1]),
            x=alt.X('subject_id:N', sort='x'),
            y=alt.Y('above_threshold:Q'),
            tooltip=['session_id', 'above_threshold'],
        ).properties(
            title=[f'Fraction of units above {lda_threshold=:.2f} (FPR={fpr_threshold})', '(some sessions missing LDA metric for all units)'],
            width=1600,
            height=200,
        )
        .interactive()
    )
    
def get_combined_above_threshold_chart(annotated_units_df: pl.DataFrame, all_units_df: pl.DataFrame, fpr_threshold = 0.1) -> alt.Chart:
    lda_threshold = fpr_to_lda_threshold(fpr_threshold)
    source = pl.concat([
        (
            all_units_df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)
            .with_columns(
                units_group=pl.lit('all'),
            )
        ),
        (
            annotated_units_df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)
            .with_columns(
                units_group=pl.lit('annotated'),
            )
        ),
    ])
    return (
        alt.Chart(source).mark_bar().encode(
            color=alt.Color('above_threshold:Q'),#.scale(domain=[0, 1]),
            x=alt.X('subject_id:N', sort='x'),
            y=alt.Y('above_threshold:Q'),
            xOffset='units_group:N',
            tooltip=['units_group', 'session_id', 'above_threshold'],
        ).properties(
            title=[f'Fraction of units above {lda_threshold=:.2f} (FPR={fpr_threshold}) (some sessions missing LDA metric for units)', 'left: all units | right: annotated units'],
            width=1800,
            height=300,
        )
        .interactive()
    )

fpr_threshold = 0.2
get_combined_above_threshold_chart(get_df().drop_nulls('drift_rating').join(lda_all_df, on='unit_id', how='left'), lda_all_df, fpr_threshold=fpr_threshold)