In [11]:
from repairing_genomic_gaps import cae_200, cae_500, cae_1000, cnn_200, cnn_500, cnn_1000
from repairing_genomic_gaps.utils import get_model_history_path, get_model_weights_path
from repairing_genomic_gaps import build_multivariate_dataset_cae, build_synthetic_dataset_cae
from repairing_genomic_gaps import build_multivariate_dataset_cnn, build_synthetic_dataset_cnn
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Dict, List, Callable
from tensorflow.keras.utils import Sequence
from tensorflow.keras import Model
from repairing_genomic_gaps.reports.report_utils import cae_report, cnn_report, flat_report
import warnings

warnings.simplefilter("ignore")

In [12]:
models = {
    "cae":{
        200: cae_200,
        500: cae_500,
        1000: cae_1000,
    }, 
    "cnn":{
        200: cnn_200,
        500: cnn_500,
        1000: cnn_1000
    }
}

datasets = {
    "cae": [
        build_multivariate_dataset_cae,
        build_synthetic_dataset_cae
    ],
    "cnn": [
        build_multivariate_dataset_cnn,
        build_synthetic_dataset_cnn
    ]
}

report_types = {
    "cae": cae_report,
    "cnn": cnn_report
}

dataset_kwargs = dict(
    batch_size = 512,
    training_chromosomes = ["chrM"],
    testing_chromosomes = ["chrM"]
)

In [13]:
def build_report(model:Model, report:Callable, sequence:Sequence):
    for batch in tqdm(range(sequence.steps_per_epoch), leave=False):
        X, y = sequence[batch]
        yield report(y, model.predict(X))

In [14]:
reports = []
for model_type in tqdm(models, desc="Model types", leave=False):
    report = report_types[model_type]
    for dataset in tqdm(datasets[model_type], desc="Datasets", leave=False):
        for window_size, build_model in tqdm(models[model_type].items(), desc="Models", leave=False):
            _, test = dataset(window_size, **dataset_kwargs)
            model = build_model(verbose=False)
            model.load_weights(get_model_weights_path(model))
            reports += flat_report(
                build_report(model, report, test),
                model,
                dataset,
                "test"
            )

HBox(children=(IntProgress(value=0, description='Model types', max=2, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Datasets', max=2, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cae/e866a710d53326625e2f2ee289583ffc1e60c4dac704cde2d2b2a2f9d69b7e31.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cae/c06cbdb570a75555b6e18b8dcf23b038d39bd91bf9785c0d479b605f52e5d1c6.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cae/efe458eb1065830b5c52aef6fb7e36ceb438ca29fdd2c552e174ab5c21df9895.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cae/af894691b758516d947ef52e1264b1caa1e4365ba98dbcfafc2dd54c24474d7b.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cae/862384637e6397a92ef716087398790e941db953fd25f1b6ee28b53baac8f8bb.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cae/cfafc2a48848e21a9f6f0cc12d36125979c6376e9fa25f3d8235be8f36d82076.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Datasets', max=2, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cnn/5d38c7bd7a815adeac969f412b5e68d7cfbc45540d733ea2dbc8a4a84fdfc650.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cnn/0db3a628354e80334b3376175729a7ef492a59d6db101a2d68703a0224aa7300.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_multivariate_synthetic_dataset_build_multivariate_dataset_cnn/0eb383e182d95b9c6970c33b4c7ab8e2326154ad476a818cb515780a96a2cdfa.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cnn/6cff14cf17366b2476cd717e212c06f2fe3cec2dc06a386d8412202dc514c9ab.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cnn/b2d8442ebf8f162864396152530f364ac20c2a679a8ff4802891e36045537749.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

Loading cache at ./cache/build_synthetic_dataset_build_synthetic_dataset_cnn/d4d77c918b1b9b1a42db24940fa0c76b1040b9384dcef56507274a003bbc5c32.pkl


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

In [16]:
pd.DataFrame(reports).groupby(["model", "dataset", "task", "target"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,roc_auc_score,average_precision_score,accuracy_score,balanced_accuracy_score,f1_score,precision_score,recall_score,fall_out,true_negatives,true_positives,false_negatives,false_positives
model,dataset,task,target,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
cae_1000,build_multivariate_dataset_cae_cached,gap_filling,adenine,0.897436,0.722222,0.812500,0.500000,0.0,0.0,0.0,0.0,13.0,0.0,3.0,0.0
cae_1000,build_multivariate_dataset_cae_cached,gap_filling,all_nucleotides,0.638021,0.382085,0.437500,0.458333,,,,,,,,
cae_1000,build_multivariate_dataset_cae_cached,gap_filling,cytosine,0.666667,0.553571,0.750000,0.500000,0.0,0.0,0.0,0.0,12.0,0.0,4.0,0.0
cae_1000,build_multivariate_dataset_cae_cached,gap_filling,guanine,0.435897,0.455556,0.812500,0.500000,0.0,0.0,0.0,0.0,13.0,0.0,3.0,0.0
cae_1000,build_multivariate_dataset_cae_cached,gap_filling,thymine,0.383333,0.343359,0.625000,0.500000,0.0,0.0,0.0,0.0,10.0,0.0,6.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
cnn_500,build_synthetic_dataset_cnn_cached,gap_filling,adenine,0.515000,0.256955,0.757576,0.500000,0.0,0.0,0.0,0.0,25.0,0.0,8.0,0.0
cnn_500,build_synthetic_dataset_cnn_cached,gap_filling,all_nucleotides,0.590144,0.296846,0.393939,0.322115,,,,,,,,
cnn_500,build_synthetic_dataset_cnn_cached,gap_filling,cytosine,0.573077,0.477833,0.606061,0.500000,0.0,0.0,0.0,0.0,20.0,0.0,13.0,0.0
cnn_500,build_synthetic_dataset_cnn_cached,gap_filling,guanine,0.715517,0.223214,0.878788,0.500000,0.0,0.0,0.0,0.0,29.0,0.0,4.0,0.0
