In [1]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt

from app_charts import lda_all_df, get_roc_curve_df, lda_all_ks4_df
import db_utils
from db_utils import get_df, KS4_UNITS_DF_PATH

from scripts.run_unit_drift_lda import load_annotation_and_metrics_df, get_x, get_x_y, z_score, get_annotations_df

In [None]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
import tqdm

from scripts.run_unit_drift_lda import load_annotation_and_metrics_df, get_x, get_x_y, z_score, get_annotations_df

df = load_annotation_and_metrics_df()
METRICS_TO_KEEP = {'presence_ratio', 'vis_response_r2', 'aud_response_r2'} #, 'vis_baseline_r2', 'aud_baseline_r2'} # , 'ancova_t_time', 'ancova_coef_time'
COLUMNS_TO_DROP = set(df.columns) - METRICS_TO_KEEP - {'unit_id', 'drift_rating'}

leave_one_out = True

model_to_results = {}
for model_class in (RandomForestClassifier, GradientBoostingClassifier, QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis):

    model = model_class()
    annotated = get_annotations_df().drop(COLUMNS_TO_DROP).pipe(z_score).drop_nans()
    unit_id_to_value = {}
    if leave_one_out:
        for unit_id in tqdm.tqdm(annotated['unit_id'], total=len(annotated), unit='units', ncols=100, desc=f"Cross-validating {model_class.__name__}"):
            train = annotated.filter(pl.col('unit_id') != unit_id)
            test = annotated.filter(pl.col('unit_id') == unit_id)
            model.fit(*get_x_y(train))
            unit_id_to_value[unit_id] = model.predict(get_x(test)).item()
    else:
        model.fit(*get_x_y(annotated))
        for unit_id in tqdm.tqdm(annotated['unit_id'], total=len(annotated), unit='units', ncols=100, desc=f"Predicting {model_class.__name__}"):
            test = annotated.filter(pl.col('unit_id') == unit_id)
            unit_id_to_value[unit_id] = model.predict(get_x(test)).item()
                
    annotated.insert_column(-1,  pl.Series('prediction', unit_id_to_value.values()))
    results = pl.DataFrame(
        dict(
            tpr = annotated.filter(pl.col('drift_rating') == 1, pl.col('prediction') == 1).height / annotated.filter(pl.col('drift_rating') == 1).height,
            fpr = annotated.filter(pl.col('drift_rating') == 0, pl.col('prediction') == 1).height / annotated.filter(pl.col('drift_rating') == 0).height,
            tnr = annotated.filter(pl.col('drift_rating') == 0, pl.col('prediction') == 0).height / annotated.filter(pl.col('drift_rating') == 0).height,
            fnr = annotated.filter(pl.col('drift_rating') == 1, pl.col('prediction') == 0).height / annotated.filter(pl.col('drift_rating') == 1).height,
        )
    )
    display(results)
    model_to_results[model_class] = results

INFO:run_unit_drift_lda:Loading data from parquet files
INFO:botocore.credentials:Found credentials in environment variables.
INFO:botocore.credentials:Found credentials in environment variables.
INFO:run_unit_drift_lda:Loaded data from parquet files in 0.68 seconds


shape: (9, 9)
┌───────────┬───────────┬───────────┬───────────┬───┬───────────┬───────────┬───────────┬──────────┐
│ statistic ┆ unit_id   ┆ drift_rat ┆ session_i ┆ … ┆ aud_basel ┆ vis_respo ┆ aud_respo ┆ presence │
│ ---       ┆ ---       ┆ ing       ┆ d         ┆   ┆ ine_r2    ┆ nse_r2    ┆ nse_r2    ┆ _ratio   │
│ str       ┆ str       ┆ ---       ┆ ---       ┆   ┆ ---       ┆ ---       ┆ ---       ┆ ---      │
│           ┆           ┆ f64       ┆ str       ┆   ┆ f64       ┆ f64       ┆ f64       ┆ f64      │
╞═══════════╪═══════════╪═══════════╪═══════════╪═══╪═══════════╪═══════════╪═══════════╪══════════╡
│ count     ┆ 411218    ┆ 7905.0    ┆ 411218    ┆ … ┆ 397359.0  ┆ 368483.0  ┆ 397359.0  ┆ 411218.0 │
│ null_coun ┆ 0         ┆ 403313.0  ┆ 0         ┆ … ┆ 13859.0   ┆ 42735.0   ┆ 13859.0   ┆ 0.0      │
│ t         ┆           ┆           ┆           ┆   ┆           ┆           ┆           ┆          │
│ mean      ┆ null      ┆ 1.914611  ┆ null      ┆ … ┆ NaN       ┆ NaN       ┆

Predicting RandomForestClassifier: 100%|████████████████████| 4719/4719 [00:44<00:00, 104.99units/s]


tpr,fpr,tnr,fnr
f64,f64,f64,f64
1.0,0.0,1.0,0.0


Predicting GradientBoostingClassifier: 100%|████████████████| 4719/4719 [00:06<00:00, 786.24units/s]


tpr,fpr,tnr,fnr
f64,f64,f64,f64
0.75808,0.038399,0.961601,0.24192


Predicting QuadraticDiscriminantAnalysis: 100%|████████████| 4719/4719 [00:04<00:00, 1155.77units/s]


tpr,fpr,tnr,fnr
f64,f64,f64,f64
0.718903,0.072742,0.927258,0.281097


Predicting LinearDiscriminantAnalysis: 100%|███████████████| 4719/4719 [00:04<00:00, 1161.46units/s]


tpr,fpr,tnr,fnr
f64,f64,f64,f64
0.622919,0.034884,0.965116,0.377081
