In [1]:
from repairing_genomic_gaps import cae_200, cae_500, cae_1000, cnn_200, cnn_500, cnn_1000
from repairing_genomic_gaps.utils import get_model_history_path, get_model_weights_path
from repairing_genomic_gaps import build_multivariate_dataset_cae, build_synthetic_dataset_cae
from repairing_genomic_gaps import build_multivariate_dataset_cnn, build_synthetic_dataset_cnn
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Dict, List, Callable
from tensorflow.keras.utils import Sequence
from tensorflow.keras import Model
from repairing_genomic_gaps.reports.report_utils import cae_report, cnn_report, flat_report
import warnings

warnings.simplefilter("ignore")

In [2]:
def build_biological_dataset_cae(*args, **kwargs):
    pass

def build_biological_dataset_cnn(*args, **kwargs):
    pass

In [3]:
models = {
    "cae":{
        200: cae_200,
        500: cae_500,
        1000: cae_1000,
    }, 
    "cnn":{
        200: cnn_200,
        500: cnn_500,
        1000: cnn_1000
    }
}

datasets = {
    "cae": [
        build_multivariate_dataset_cae,
        build_synthetic_dataset_cae,
        build_biological_dataset_cae
    ],
    "cnn": [
        build_multivariate_dataset_cnn,
        build_synthetic_dataset_cnn,
        build_biological_dataset_cnn
    ]
}

report_types = {
    "cae": cae_report,
    "cnn": cnn_report
}

dataset_kwargs = dict(
    batch_size = 512,
    training_chromosomes = ["chrM"],
    testing_chromosomes = ["chrM"]
)

In [4]:
def build_report(model:Model, report:Callable, sequence:Sequence):
    for batch in tqdm(range(sequence.steps_per_epoch), desc="Batches", leave=False):
        X, y = sequence[batch]
        yield report(y, model.predict(X))

In [8]:
reports = []
for model_type in tqdm(models, desc="Model types", leave=False):
    report = report_types[model_type]
    for window_size, build_model in tqdm(models[model_type].items(), desc="Models", leave=False):
        training, validation, biological = datasets[model_type]
        train, test = training(window_size, **dataset_kwargs)
        _, valid = validation(window_size, **dataset_kwargs)
        bio = biological(window_size)
        model = build_model(verbose=False)
        model.load_weights(get_model_weights_path(model))
        reports += flat_report(
            build_report(model, report, test),
            model,
            training,
            "test"
        )
        reports += flat_report(
            build_report(model, report, train),
            model,
            training,
            "train"
        )
        reports += flat_report(
            build_report(model, report, valid),
            model,
            validation,
            "synthetic validation"
        )
        reports += flat_report(
            ###################################
            # TODO: UPDATE THE DATASET AS SOON
            # AS IT BECOMES AVAILABLE!!!
            ###################################
            build_report(model, report, valid),
            model,
            biological,
            "biological validation"
        )

HBox(children=(IntProgress(value=0, description='Model types', max=2, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Models', max=3, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Batches', max=1, style=ProgressStyle(description_width='initi…

In [None]:
pd.DataFrame(reports).groupby(["model", "dataset", "task", "target", "run_type"]).mean()