In [None]:
import polars as pl

pl.read_csv('model_results.csv').sort('fpr')

tpr,fpr,tnr,fnr,model,cross_validation
f64,f64,f64,f64,str,bool
0.62096,0.034884,0.965116,0.37904,"""LinearDiscriminantAnalysis""",True
0.715965,0.050297,0.949703,0.284035,"""GradientBoostingClassifier""",True
0.693438,0.060303,0.939697,0.306562,"""RandomForestClassifier""",True
0.718903,0.073553,0.926447,0.281097,"""QuadraticDiscriminantAnalysis""",True


In [None]:
import pathlib
import sys
sys.path.append(str(pathlib.Path.cwd().absolute().parent))

import polars as pl
import altair as alt
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
import tqdm

from scripts.run_unit_drift_lda import load_annotation_and_metrics_df, get_x, get_x_y, z_score, get_annotations_df

df = load_annotation_and_metrics_df()
METRICS_TO_KEEP = {'presence_ratio', 'vis_response_r2', 'aud_response_r2'} #, 'vis_baseline_r2', 'aud_baseline_r2'} # , 'ancova_t_time', 'ancova_coef_time'
COLUMNS_TO_DROP = set(df.columns) - METRICS_TO_KEEP - {'unit_id', 'drift_rating'}

leave_one_out = True

all_results = []
for model_class in (RandomForestClassifier, GradientBoostingClassifier, QuadraticDiscriminantAnalysis, LinearDiscriminantAnalysis):

    model = model_class()
    annotated = get_annotations_df().drop(COLUMNS_TO_DROP).pipe(z_score).drop_nans()
    unit_id_to_value = {}
    if leave_one_out:
        for unit_id in tqdm.tqdm(annotated['unit_id'], total=len(annotated), unit='units', ncols=100, desc=f"Cross-validating {model_class.__name__}"):
            train = annotated.filter(pl.col('unit_id') != unit_id)
            test = annotated.filter(pl.col('unit_id') == unit_id)
            model.fit(*get_x_y(train))
            unit_id_to_value[unit_id] = model.predict(get_x(test)).item()
    else:
        model.fit(*get_x_y(annotated))
        for unit_id in tqdm.tqdm(annotated['unit_id'], total=len(annotated), unit='units', ncols=100, desc=f"Predicting {model_class.__name__}"):
            test = annotated.filter(pl.col('unit_id') == unit_id)
            unit_id_to_value[unit_id] = model.predict(get_x(test)).item()
                
    annotated.insert_column(-1,  pl.Series('prediction', unit_id_to_value.values()))
    results = dict(
            tpr = annotated.filter(pl.col('drift_rating') == 1, pl.col('prediction') == 1).height / annotated.filter(pl.col('drift_rating') == 1).height,
            fpr = annotated.filter(pl.col('drift_rating') == 0, pl.col('prediction') == 1).height / annotated.filter(pl.col('drift_rating') == 0).height,
            tnr = annotated.filter(pl.col('drift_rating') == 0, pl.col('prediction') == 0).height / annotated.filter(pl.col('drift_rating') == 0).height,
            fnr = annotated.filter(pl.col('drift_rating') == 1, pl.col('prediction') == 0).height / annotated.filter(pl.col('drift_rating') == 1).height,
        )
    print(pl.DataFrame(results))
    all_results.append(results | {'model': model_class.__name__, 'cross_validation': leave_one_out})
pl.DataFrame(all_results).write_csv('model_results.csv')

Cross-validating RandomForestClassifier:   0%|               | 15/4719 [00:11<1:00:16,  1.30units/s]


KeyboardInterrupt: 