In [1]:
from typing import Dict
import os
import sys
import pandas as pd
import numpy as np
from epigenomic_dataset import load_all_tasks
from tqdm.auto import tqdm
from ucsc_genomes_downloader import Genome
from keras_bed_sequence import BedSequence
from keras_mixed_sequence import MixedSequence, VectorSequence
from sklearn.model_selection import StratifiedShuffleSplit
from meta_models.meta_models import CNN1DMetaModel, MetaModel
from extra_keras_metrics import get_standard_binary_metrics
import ray
from ray import tune
from ray.tune.suggest.bayesopt import BayesOptSearch
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.keras import TuneReportCallback
from tensorflow.keras.callbacks import EarlyStopping
from multiprocessing import cpu_count
import silence_tensorflow.auto

In [2]:
def get_cnn_training_sequence(y:pd.DataFrame) -> MixedSequence:
    """Return training sequence for CNN.
    
    Parameters
    --------------------
    
    """
    return MixedSequence(
        BedSequence(
            genome,
            bed=y.reset_index()[y.index.names],
            batch_size=batch_size,
            seed=seed
        ),
        VectorSequence(
            y.values,
            batch_size=batch_size,
            seed=seed,
        )
    )

In [3]:
def build_cnn_meta_model(window_size:int)->CNN1DMetaModel:
    return CNN1DMetaModel(
        blocks=2,
        input_shape=(window_size, 4),
        meta_layer_kwargs=dict(
            batch_normalization=False,
            max_kernel_size=10,
            min_kernel_size=2
        )
    )

In [4]:
def train_cnn(
    train:MixedSequence,
    validation:MixedSequence,
    meta_model:MetaModel,
    max_epochs:int,
    patience:int,
    min_delta:float,
    **kwargs:Dict
):
    model = meta_model.build(**kwargs)
    model.compile(
        optimizer='nadam',
        loss="binary_crossentropy",
        metrics=get_standard_binary_metrics()
    )
    model.fit(
        train,
        validation_data=validation,
        epochs=max_epochs,
        verbose=False,
        callbacks=[
            TuneReportCallback(metrics=[
                "{}{}".format(sub, metric)
                for metric in model.metrics_names
                for sub in ("", "val_")
            ], ),
            EarlyStopping(
                monitor="loss",
                min_delta=min_delta,
                patience=patience
            )
        ]
    )

In [5]:
def cnn_ray_loss_wrapper(config: Dict, **data):
    import silence_tensorflow.auto
    train_cnn(**config, **data)

In [6]:
# Sequence data
genome = Genome("hg38")
window_size=256
# Holdouts stuff
seed=42
holdouts=1
inner_holdouts=1
# BO
num_samples=200
random_search_steps=50
# Training
max_epochs=1000
# Early stopping
batch_size = 256
patience=5
min_delta=0.0001

HBox(children=(HTML(value='Loading chromosomes for genome hg38'), FloatProgress(value=0.0, layout=Layout(flex=…

In [7]:
# Build the meta-model
meta_model = build_cnn_meta_model(window_size)

In [None]:
# Starting up Ray
ray.init(ignore_reinit_error=True)

# Main loop
for (X, y), task in tqdm(load_all_tasks(window_size=window_size), desc="Tasks", total=5):
    for outer, (train_idx, test_idx) in tqdm(
        enumerate(StratifiedShuffleSplit(
            n_splits=holdouts,
            train_size=0.8,
            random_state=seed
        ).split(X, y)),
        desc="Outer holdouts",
        total=holdouts
   ):
        train_x, test_x = X.iloc[train_idx], X.iloc[test_idx]
        train_y, test_y = y.iloc[train_idx], y.iloc[test_idx]
        for inner, (inner_train_idx, valid_idx) in tqdm(
            enumerate(StratifiedShuffleSplit(
                n_splits=inner_holdouts,
                train_size=0.8,
                random_state=seed
            ).split(train_x, train_y)),
            desc="Inner holdouts",
            total=inner_holdouts
        ):
            inner_train_x, valid_x = train_x.iloc[inner_train_idx], train_x.iloc[valid_idx]
            inner_train_y, valid_y = train_y.iloc[inner_train_idx], train_y.iloc[valid_idx]
            inner_train_sequence = get_training_sequence(inner_train_y)
            valid_sequence = get_training_sequence(valid_y)
            bayesopt = BayesOptSearch(
                meta_model.space(),
                metric="val_loss",
                mode="min",
                random_search_steps=random_search_steps,
                random_state=seed,
            )
            asha_scheduler = ASHAScheduler(
                time_attr='training_iteration',
                metric='val_loss',
                mode='min',
                max_t=100,
                grace_period=10,
                reduction_factor=3,
                brackets=1
            )
            validation = tune.run(
                tune.with_parameters(
                    cnn_ray_loss_wrapper,
                    train=inner_train_sequence,
                    validation=valid_sequence,
                    meta_model=meta_model
                ),
                name="{}-{}-{}".format(
                    task,
                    outer,
                    inner
                ),
                search_alg=bayesopt,
                scheduler=asha_scheduler,
                resources_per_trial={
                    "cpu": cpu_count()//4,
                    "gpu": 1
                },
                num_samples=num_samples,
                fail_fast=True,
                verbose=1,
                config={
                    "window_size": window_size,
                    "max_epochs": max_epochs,
                    "patience": patience,
                    "min_delta": min_delta
                }
            )
            break
        break
    break

In [None]:
validation.dataframe()