In [None]:
import os,re,json
import polars as pl
import plotly.express as px

In [None]:
# really quickly aggregate the model summaries for plotting and such
models = os.listdir('./models/prod')

In [None]:
metrics = [e for e in models if re.match(r'.+\.json', e)]

out = []
for j in metrics:
    with open(f'./models/prod/{j}', 'r') as f:
        out.append(json.load(f))

In [None]:
tmp = {
     'file':[],
     'model':[],
     'target':[],
     'kmer':[],
     'contig':[],
     'fold':[],
    }

for e in metrics:
    _ = e.split('__')[0].split('-')
    tmp['file'].append(e)
    tmp['model'].append(_[0])
    tmp['target'].append(_[1])
    tmp['kmer'].append(_[2])
    tmp['contig'].append(_[3])
    tmp['fold'].append(_[4])


metrics = pl.concat([
    pl.DataFrame(tmp),
    pl.concat([pl.DataFrame(e) for e in out])
    ], 
    how='horizontal')

In [None]:
metrics.write_parquet('./models/prod/all__metrics.parquet')

In [None]:
predictions = [e for e in models if re.match(r'.+fold\d+__.+\.parquet', e)]

p = predictions[0]

In [None]:
def _parse_filename(fn):
    _ = fn.split('__')[0].split('-')
    tmp = {'file': [fn],
    'model': [_[0]],
    'target': [_[1]],
    'kmer': [_[2]],
    'contig': [_[3]],
    'fold': [_[4]]}
    return(tmp)



In [None]:
# parse file name as df, join with data save as list and concat
_ = [(pl.DataFrame(_parse_filename(fn=p)).with_columns(join_on_this = pl.lit(True))
 .join(pl.read_parquet(f"./models/prod/{p}").with_columns(join_on_this = pl.lit(True)), 
       on='join_on_this').drop(pl.col('join_on_this'))
       ) for p in predictions]

In [None]:
predictions = pl.concat(_)

In [None]:
predictions.write_parquet('./models/prod/all__predictions.parquet')

In [None]:
predictions.shape

In [None]:
_ = (
    predictions
    .filter(pl.col('TestSet') == True)
    .drop('file', 'TestSet')
    .group_by([pl.col(e) for e in ['model', 'target', 'kmer', 'contig', 'Header', 'Label']])
    .agg(
        pl.col('Yhat').sum(), 
        pl.col('ProbPos').mean(), 
        )
    )

In [None]:
_ = _.pivot('target', index = ['model', 'kmer', 'contig', 'Header', 'Label'], values= ['Yhat', 'ProbPos'])
_


In [None]:
_.write_parquet('./models/prod/all_agg_pivot__predictions.parquet')

In [None]:
# what are the best models for each species?

# use metrics to get the best models to examine

_ = (metrics
.select(  ['model', 'target', 'kmer', 'contig', 'fold', 'accuracy_tst'])
.group_by(['model', 'target', 'kmer', 'contig'])
.agg(pl.col('accuracy_tst').mean())
)

best_models = (_
               .group_by(['model', 'target'])
               .agg(pl.col('accuracy_tst').max())
               ).join(_, how='inner', on=['model',  'target', 'accuracy_tst'])

best_models = best_models.select(['model',  'target', 'kmer', 'contig', 'accuracy_tst'])
best_models

In [None]:
best_models.write_csv('./models/prod/all_best_models_summary.csv')

In [None]:
# get all the predictions from only the best models 
_ = (
    best_models
    .drop('accuracy_tst')
    .join(
        (predictions
         .filter(pl.col('TestSet') == True)
         .drop('file', 'TestSet')
        ), how= 'left', on = ['model',  'target', 'kmer', 'contig'])
)
_

In [None]:
# we only need the target to identify the parameters so we can select very few cols
_ = _.select('model', 'target', 'Header', 'Label', 'Yhat', 'ProbPos').group_by(['model', 'target', 'Header', 'Label']).agg(
    pl.col('Yhat').mean(), 
    pl.col('ProbPos').mean(), 
)

_ 
# ).pivot('target', index = ['model', 'target', 'Header', 'Label'], values= ['Yhat', 'ProbPos'])


In [None]:
# based on the shape fo the dfs there are ever so slightly more than 13 observations being collapsed. This could be from slight imballences in cvs.
# _.select('Yhat').unique()

442994/33890

In [None]:
# per target model == label

# example for one model 
(
    _
    .filter(pl.col('Yhat') > 0.5)
    .filter(pl.col('target') == 'Vitis_vinifera')
    .group_by('Label')
    .count()
)

# .group_by('Label', )

In [None]:
best_models_pospred = (
    _
    .filter(pl.col('Yhat') > 0.5)
    .group_by('target', 'Label')
    .len()
    .sort('target', 'Label')
)
best_models_pospred

In [None]:
best_models_pospred.write_csv('./models/prod/all_best_models_pospred_summary.csv')

In [None]:
# What are the ones we're suprised by?
# these are the ones that should be followed up.
best_models_falsepos = (
    _
    .filter(pl.col('Yhat') > 0.5)
    .filter(pl.col('target') != pl.col('Label'))
    .sort('target', 'Label')
)

best_models_falsepos.write_csv('./models/prod/all_best_models_falsepos.csv')

In [None]:
from   ax.service.ax_client import AxClient

In [None]:
# really quickly aggregate the model summaries for plotting and such
axs = os.listdir('./models/tune')


In [None]:
def _parse_filename(fn):
    _ = fn.split('__')[0].split('-')
    tmp = {'file': [fn],
    'model': [_[0]],
    'target': [_[1]],
    'kmer': [_[2]],
    'contig': [_[3]],
    'fold': [_[4]]}
    return(tmp)



In [None]:
_parse_filename(fn = j)

In [None]:
j = axs[0]

_ = []
for j in axs:

    ax_client = (AxClient.load_from_json_file(filepath = f"./models/tune/{j}"))

    hyps = pl.DataFrame(ax_client.get_trials_data_frame())

    hyps = pl.DataFrame(_parse_filename(fn=j)).with_columns(join_on_this = pl.lit(True)).join(
                                hyps.with_columns(join_on_this = pl.lit(True)),
                                on='join_on_this').drop(pl.col('join_on_this'))
    _.append(hyps)

In [None]:
_ = pl.concat(_)


In [None]:
_.write_parquet('./models/prod/all__hyperparameters.parquet')