In [None]:
import os, json, argparse, pickle
import numpy as np
import pandas as pd

import pyarrow as pa
import pyarrow.parquet as pq

import polars as pl 

import sklearn
from   sklearn import metrics
from   sklearn.metrics import RocCurveDisplay
from   sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
from   sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

from   sklearn.neighbors import KNeighborsClassifier

import imblearn
from   imblearn.ensemble import BalancedBaggingClassifier

from   ax.service.ax_client import AxClient

import matplotlib.pyplot as plt

In [None]:
def mk_np_from_pq(data):
    # confirm no nulls
    assert not data.null_count().pipe(sum).item() > 0
    ys = data.select(pl.col('Label')).to_numpy().reshape(-1)

    xs = data.drop(
        pl.col('Header'),
        pl.col('Label'), 
        pl.col('BasePair'), 
        pl.col('SeqLength'),
        pl.col('Contig')
        ).to_numpy()
    
    # convert logits to probability
    xs = np.exp(xs)/(1+np.exp(xs))
    return((ys, xs))


In [None]:
def mk_y_onehot(ys, target_label):
    y_onehot = np.zeros_like(ys)
    y_onehot[ys == target_label] = 1
    y_onehot = y_onehot.astype(int)
    return y_onehot

In [None]:
def append_metric(metrics_path, metric_dict):
    # load and append to stats json if it exists
    if os.path.exists(metrics_path):
        with open(metrics_path, 'r') as f:
            stats = json.load(f)
    else:
        stats = {}

    stats |= metric_dict

    with open(metrics_path, 'w') as f:
        f.write(json.dumps(stats, indent=4, sort_keys=True))

In [None]:
target_label = 'Glycine_max'
# Data options
kmer = 1 
BasePair = '3000'
cv_fold = 0 
cv_mode = 'tuning'

In [None]:
# Data prep ----
## Prepare filters for the train/test (or train/validation) splits
fold_df = pl.from_arrow(pq.read_table('./data/cv_folds.parquet'))

match cv_mode:
    case 'tuning':
        mask_trn = ((pl.col('Fold') != f"fold_{cv_fold}_val") & 
                    (pl.col('Fold') != f"fold_{cv_fold}_tst"))
        
        mask_tst = ((pl.col('Fold') == f"fold_{cv_fold}_val"))

    case 'training':
        mask_trn = ((pl.col('Fold') != f"fold_{cv_fold}_tst"))
        
        mask_tst = ((pl.col('Fold') == f"fold_{cv_fold}_tst"))

## Load in data for target kmer
# Basically this is a convoluted way to access parquet data as a np array
# converting directly fails because we can't take a column as a pa.array and convert that. 
# get basepair lengths up to target bp
bp = ['500', '1000', '3000', '5000', '10000', 'genome']
bp = bp[0:(1 + [i for i in range(len(bp)) if bp[i] == BasePair][0])] # add one because we want to include the end of the slice

df = pl.from_arrow(pq.read_table(f'./data/kmer{kmer}.parquet', 
                                 filters = [[('BasePair', 'in', bp)]])
                                 ) 

# use inner join so that nans are not introduced for cases where a record is not present at all bp
data_trn = fold_df.filter(mask_trn).drop(pl.col('Fold')).join(df, on=['Header', 'Label'], how="inner")
data_tst = fold_df.filter(mask_tst).drop(pl.col('Fold')).join(df, on=['Header', 'Label'], how="inner") 

ys_trn, xs_trn = mk_np_from_pq(data = data_trn)
ys_tst, xs_tst = mk_np_from_pq(data = data_tst)

ys_trn = mk_y_onehot(ys = ys_trn, target_label = target_label)
ys_tst = mk_y_onehot(ys = ys_tst, target_label = target_label)

In [None]:
# Placeholder for exploring new models.