## adding KS4 data
1. Export spike times:
    https://codeocean.allenneuraldynamics.org/capsule/2210357/tree
    - attach new ks4 sorted asset
    - use `dump_timing_info.py` with `npc_sessions` and put .json file in `/code` 
    - run capsule to create individual session parquet files in /results, and a consolidate file
      (without spike times) at `s3://aind-scratch-data/dynamic-routing/unit-rasters-ks4/units.parquet`
    - create asset with parquet files in `/ks4` folder
2. https://codeocean.allenneuraldynamics.org/capsule/6421158/tree
    - attach new asset with ks4 session unit parquet files
    - run capsule to create R^2 values for spike count vs time
    - download `corr_values_ks4.parquet` to `\\allen\programs\mindscope\workgroups\dynamicrouting\ben`
3. run `run_unit_drift_lda.py` 
    - uses corr values and consolidated units table (for `presence_ratio`)
4. restart this notebook 

In [1]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt

from app_charts import lda_all_df, get_roc_curve_df, lda_all_ks4_df
import db_utils
from db_utils import get_df

---
Annotated units

In [None]:
df = (
    get_df()
    .drop_nulls('drift_rating')
    .with_columns(
        drift_rating=pl.when(pl.col("drift_rating") == db_utils.UnitDriftRating.YES)
        .then(pl.lit('drift'))
        .when(pl.col("drift_rating") == db_utils.UnitDriftRating.NO)
        .then(pl.lit('no drift'))
        .when(pl.col("drift_rating") == db_utils.UnitDriftRating.UNSURE)
        .then(pl.lit('unsure'))
    ) 
    .get_column('drift_rating')
    .value_counts(normalize=True)
    .sort('proportion')
)

display(df)
base = (
    alt.Chart(df)
    .encode(
        theta=alt.Theta("proportion:Q").stack(True),
        color=alt.Color("drift_rating").scale(scheme="blues", reverse=True).legend(None),
        tooltip=["drift_rating", "proportion"],
    )
    .properties(title=f"{len(get_df().drop_nulls('drift_rating'))} annotated units (KS2.5)")
)
base.mark_arc(outerRadius=120) + base.mark_text(radius=160, size=16).encode(text="drift_rating:N")

drift_rating,proportion
str,f64
"""drift""",0.15475
"""unsure""",0.35144
"""no drift""",0.49381


In [101]:
df = (
    get_df()
    .drop_nulls('drift_rating')
    .join(lda_all_df, on='unit_id', how='left')
    .group_by('session_id')
    .agg(
        pl.col('lda').median().alias('lda_median'),
    )
    .sort('lda_median', descending=True)
    .drop_nulls('lda_median')
)
display(df)
display(df['lda_median'].describe(percentiles=(0.1, 0.9)))

session_id,lda_median
str,f64
"""715710_2024-07-17""",-0.204278
"""741148_2024-10-18""",-0.220592
"""699847_2024-04-15""",-0.244285
"""649943_2023-02-14""",-0.29585
"""741148_2024-10-16""",-0.31744
…,…
"""708016_2024-04-29""",-0.533534
"""681532_2023-10-17""",-0.53596
"""666986_2023-08-15""",-0.536073
"""737403_2024-09-26""",-0.53659


statistic,value
str,f64
"""count""",138.0
"""null_count""",0.0
"""mean""",-0.447178
"""std""",0.069481
"""min""",-0.54674
"""10%""",-0.523033
"""90%""",-0.352652
"""max""",-0.204278


---
All units (KS2.5)

In [25]:
df = (
    lda_all_df
    .drop_nulls('drift_rating')
    .with_columns(
            session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
    )
    .group_by('session_id')
    .agg(
        pl.col('lda').median().alias('lda_median'),
    )
    .sort('lda_median', descending=True)
)
display(df)
display(df['lda_median'].describe(percentiles=(0.1, 0.9)))

session_id,lda_median
str,f64
"""741148_2024-10-18""",-0.193454
"""715710_2024-07-17""",-0.206137
"""699847_2024-04-15""",-0.291405
"""644867_2023-02-20""",-0.298015
"""713655_2024-08-08""",-0.310214
…,…
"""681532_2023-10-17""",-0.540436
"""666986_2023-08-15""",-0.542345
"""702136_2024-03-05""",-0.543165
"""667252_2023-09-25""",-0.551671


statistic,value
str,f64
"""count""",138.0
"""null_count""",0.0
"""mean""",-0.450507
"""std""",0.070944
"""min""",-0.554267
"""10%""",-0.52845
"""90%""",-0.349397
"""max""",-0.193454


In [2]:

from altair import XOffset


def fpr_to_lda_threshold(fpr_threshold: float) -> float:
    return (
        get_roc_curve_df(metric='lda')
        .select(
            pl.col('value').sort_by((pl.col('fpr') - fpr_threshold).abs())
        )
        .get_column('value')
        .first()
    )
    
def get_above_lda_threshold_df(df: pl.DataFrame, lda_threshold: float, drop_empty: bool = True) -> pl.DataFrame:
    return (
        df
        .with_columns(
            session_id=pl.col('unit_id').str.split('_').list.slice(0, 2).list.join('_'),
            subject_id=pl.col('unit_id').str.split('_').list.get(0),
        )
        .group_by('session_id')
        .agg(
            pl.col('subject_id').first(),
            above_threshold=(pl.col('lda').gt(lda_threshold)).arg_true().count().truediv(pl.col('lda').count()),
            n_units_with_lda_score=pl.col('lda').count(),
        )
        .filter(
            (pl.col('n_units_with_lda_score') > 0) if drop_empty else pl.lit(True)
        )
        .with_columns(
            subject_total_above_threshold=pl.col('above_threshold').sum().over('subject_id'),
        )
    )
    
def get_sessions_above_threshold_chart(df: pl.DataFrame, fpr_threshold = 0.1) -> alt.Chart:
    lda_threshold = fpr_to_lda_threshold(fpr_threshold)
    return (
        alt.Chart(df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)).mark_bar().encode(
            color=alt.Color('above_threshold:Q'),#.scale(domain=[0, 1]),
            x=alt.X('subject_id:N', sort='x'),
            y=alt.Y('above_threshold:Q'),
            tooltip=['session_id', 'above_threshold'],
        ).properties(
            title=[f'Fraction of units above {lda_threshold=:.2f} (FPR={fpr_threshold})', '(some sessions missing LDA metric for all units)'],
            width=1600,
            height=200,
        )
        .interactive()
    )
    
def get_combined_above_threshold_chart(annotated_units_df: pl.DataFrame, all_units_df: pl.DataFrame, fpr_threshold = 0.1) -> alt.Chart:
    lda_threshold = fpr_to_lda_threshold(fpr_threshold)
    source = pl.concat([
        (
            all_units_df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)
            .with_columns(
                units_group=pl.lit('predicted_ks2.5'),
            )
        ),
        (
            annotated_units_df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)
            .with_columns(
                units_group=pl.lit('annotated_ks2.5'),
            )
        ),
        (
            lda_all_ks4_df.pipe(get_above_lda_threshold_df, lda_threshold=lda_threshold)
            .with_columns(
                units_group=pl.lit('predicted_ks4'),
            )
        ),
    ])
    order = ['annotated_ks2.5', 'predicted_ks2.5', 'predicted_ks4']
    assert all(name in source['units_group'].unique() for name in order)
    return (
        alt.Chart(source).mark_bar().encode(
            color=alt.Color('above_threshold:Q'),#.scale(domain=[0, 1]),
            x=alt.X('subject_id:N', sort='x'),
            y=alt.Y('above_threshold:Q'),
            order=alt.Order('session_id:N'),
            xOffset=alt.XOffset('units_group:N', sort=order),
            tooltip=['units_group', 'session_id', 'n_units_with_lda_score', 'above_threshold'],
        ).properties(
            title=[f'Fraction of units above {lda_threshold=:.2f} (FPR={fpr_threshold}) (some sessions missing LDA metric for units)', '', '{} units    |    {} (all units)    |    {} (all units)'.format(*order)],
            width=1800,
            height=300,
        )
        .interactive()
    )

fpr_threshold = 0.2
get_combined_above_threshold_chart(get_df().drop_nulls('drift_rating').join(lda_all_df, on='unit_id', how='left'), lda_all_df, fpr_threshold=fpr_threshold)

In [4]:
lda_all_ks4_df

unit_id,vis_response_r2,aud_response_r2,presence_ratio,lda
str,f64,f64,f64,f64
"""742903_2024-10-22_ks4_A-477_ks…",0.051758,0.033912,0.8984375,-0.491873
"""742903_2024-10-22_ks4_E-980_ks…",0.05456,0.006882,1.0,-0.57164
"""742903_2024-10-22_ks4_E-701_ks…",0.048463,0.048755,1.0,-0.544778
"""742903_2024-10-22_ks4_E-955_ks…",0.142847,0.07737,1.0,-0.469389
"""742903_2024-10-22_ks4_C-136_ks…",0.106584,0.008023,1.0,-0.540705
…,…,…,…,…
"""742903_2024-10-22_ks4_E-397_ks…",0.082113,0.056958,1.0,-0.519351
"""742903_2024-10-22_ks4_C-84_ks4""",0.111905,0.050335,0.9921875,-0.502165
"""742903_2024-10-22_ks4_B-653_ks…",0.58447,0.343926,1.0,-0.02036
"""742903_2024-10-22_ks4_A-524_ks…",0.026501,0.028518,1.0,-0.572175


In [5]:
pl.read_parquet("//allen/programs/mindscope/workgroups/dynamicrouting/ben/corr_values_ks4.parquet")

unit_id,vis_baseline_r2,aud_baseline_r2,vis_response_r2,aud_response_r2
str,f64,f64,f64,f64
"""742903_2024-10-22_ks4_A-477_ks…",0.020005,0.005845,0.051758,0.033912
"""742903_2024-10-21_ks4_B-196_ks…",0.023002,0.020432,0.000049,0.002397
"""742903_2024-10-22_ks4_E-980_ks…",0.000512,0.00021,0.05456,0.006882
"""742903_2024-10-22_ks4_E-701_ks…",0.003312,0.045995,0.048463,0.048755
"""742903_2024-10-22_ks4_E-955_ks…",0.019101,0.014411,0.142847,0.07737
…,…,…,…,…
"""742903_2024-10-22_ks4_E-397_ks…",0.028598,0.018105,0.082113,0.056958
"""742903_2024-10-22_ks4_C-84_ks4""",0.009422,0.006674,0.111905,0.050335
"""742903_2024-10-22_ks4_B-653_ks…",0.554296,0.195985,0.58447,0.343926
"""742903_2024-10-22_ks4_A-524_ks…",0.047194,0.007466,0.026501,0.028518


In [3]:
lda_all_ks4_df.pipe(get_above_lda_threshold_df, lda_threshold=fpr_to_lda_threshold(fpr_threshold))

session_id,subject_id,above_threshold,n_units_with_lda_score,subject_total_above_threshold
str,str,f64,u32,f64
"""742903_2024-10-22""","""742903""",0.327775,4018,0.327775
