In [42]:
from repairing_genomic_gaps import cae_200, cae_500, cae_1000
from repairing_genomic_gaps.utils import get_model_history_path, get_model_weights_path
from repairing_genomic_gaps import build_multivariate_dataset_cae, build_synthetic_dataset_cae
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Dict, List, Callable
from pprint import pprint
from tensorflow.keras import Model
from deflate_dict import deflate, inflate
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, balanced_accuracy_score
from holdouts_generator.utils.metrics import binary_classifications_metrics
from multiprocessing import Pool, cpu_count, current_process
from tensorflow.keras.utils import Sequence

In [43]:
window_size = 1000

In [44]:
train, test = build_synthetic_dataset_cae(window_size, batch_size=512)

HBox(children=(IntProgress(value=0, description='Loading chromosomes for genome hg19', layout=Layout(flex='2')…

HBox(children=(IntProgress(value=0, description='Rendering gaps in hg19', layout=Layout(flex='2'), max=22, sty…

HBox(children=(IntProgress(value=0, description='Tessellating windows', layout=Layout(flex='2'), max=357, styl…

HBox(children=(IntProgress(value=0, description='Rendering sequences in hg19', layout=Layout(flex='2'), max=28…

HBox(children=(IntProgress(value=0, description='Converting nucleotides to numeric classes', layout=Layout(fle…

HBox(children=(IntProgress(value=0, description='Rendering gaps in hg19', layout=Layout(flex='2'), max=2, styl…

HBox(children=(IntProgress(value=0, description='Tessellating windows', layout=Layout(flex='2'), max=27, style…

HBox(children=(IntProgress(value=0, description='Rendering sequences in hg19', layout=Layout(flex='2'), max=16…

HBox(children=(IntProgress(value=0, description='Converting nucleotides to numeric classes', layout=Layout(fle…

In [45]:
def get_central_nucleotides(predictions:np.ndarray)->np.ndarray:
    return predictions[:, predictions.shape[1]//2].reshape(-1, 1, 4)

In [46]:
def categorical_report(y_true:np.ndarray, y_pred:np.ndarray, true_class, pred_class):
    return {
        "roc_auc_score": roc_auc_score(y_true, y_pred),
        "average_precision_score": average_precision_score(y_true, y_pred),
        "accuracy_score": accuracy_score(true_class, pred_class),
        "balanced_accuracy_score": balanced_accuracy_score(true_class, pred_class)
    }

In [47]:
def categorical_nucleotides_report(y_true:np.ndarray, y_pred:np.ndarray)->Dict:
    true_class = np.argmax(y_true, axis=-1)
    pred_class = np.argmax(y_pred, axis=-1)
    
    nucleotides = [
        "adenine",
        "cytosine",
        "thymine",
        "guanine"
    ]
        
    return {
        "all_nucleotides":categorical_report(
            y_true.flatten(), y_pred.flatten(),
            true_class.flatten(), pred_class.flatten()
        ),
        **{
            nucleotide:binary_classifications_metrics(y_true[:, :, i].flatten(), y_pred[:, :, i].flatten())
            for i, nucleotide in enumerate(nucleotides)
        }
    }

In [48]:
def cae_report(y_true:np.ndarray, y_pred:np.ndarray)->Dict:
    return {
        "reconstruction": categorical_nucleotides_report(y_true, y_pred),
        "gap_filling": categorical_nucleotides_report(
            get_central_nucleotides(y_true),
            get_central_nucleotides(y_pred)
        )
    }

In [49]:
def flat_report(report:List[Dict], model:Model):
    return [
        {
            "model":model.name,
            "task":task,
            "target":target,
            **target_results
        }
        for report_set in report
        for task, results in report_set.items()
        for target, target_results in results.items()
    ]

In [None]:
model = cae_1000(verbose=False)
model.load_weights(get_model_weights_path(model))

batches = [
    test[batch]
    for batch in range(test.steps_per_epoch)
]
 
report = [
    cae_report(y, model.predict(X))
    for X, y in tqdm(batches)
]

HBox(children=(IntProgress(value=0, max=295), HTML(value='')))

In [None]:
pd.DataFrame(flat_report(report, model)).groupby(["model", "task", "target"]).mean().to_csv("{}.csv".format(model.name))