# This notebook uses all the script to train a CAE for DNase

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import json
import numpy as np
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
testing = False
    
# The base directory is one level up
base = '..'
settings_filepath = '../settings-grch38-chip-12kb.json'
search_filepath = '../cnn-optimize-2.json'

with open(settings_filepath, "r") as f:
    settings = json.load(f)

with open(search_filepath, "r") as f:
    search = json.load(f)

## Create training jobs

In [None]:
from jobs import jobs

jobs(
    os.path.relpath(search_filepath, base),
    os.path.relpath(settings_filepath, base),
    dataset="cnn-search",
    name="cnn-optimize-2",
    cluster="seas",
    epochs=25,
    batch_size=256,
    base=base,
    clear=True,
    verbose=False
)

## Test training run of the most complex CAE

In [None]:
from train import train_on_single_dataset

model_name = 'cf-128-256-512-1024--ck-3-7-11-15--du---do-0-0-0-0--e-10--rl-0--o-adadelta--lr-1.0--lrd-0.001--l-smse-10--m---bn-0-0-0-0--bni-0'

with open('../models/{}.json'.format(model_name), 'r') as f:
    definition = json.load(f)

train_on_single_dataset(
    settings,
    'cnn-search',
    definition,
    epochs=2,
    batch_size=256,
    peak_weight=2,
    signal_weighting='logn',
    signal_weighting_zero_point_percentage=0.02,
    base=base,
    clear=True,
)

## Create evaluation jobs

In [None]:
from evaluate import create_jobs

create_jobs(
    'cnn-optimize-2',
    name='cnn-optimize-2',
    dataset='cnn-search',
    cluster='seas',
    base=base,
    clear=True,
    incl_dtw=False,
)

The data will be downloaded to `../data`.

In a terminal run: `sbatch evaluate-cnn-optimize-2.slurm`

## Compare

In [None]:
from compare import compare
    
performance = compare(
    'definitions-cnn-optimize-2.json',
    dataset_name='cnn-search',
    base=base,
    clear=False,
    verbose=False,
    silent=False,
    remove_common_prefix_from_df=True,
)

In [None]:
import qgrid

qgw = qgrid.show_grid(performance)
qgw

In [None]:
from IPython.core.display import Image, display

for model in qgw.get_selected_df().iterrows():
    print(model[0])
    display(Image(os.path.join(base, 'models', 'cf-{}---predictions-{}.png'.format(model[0], 'cnn-search'))))