In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from datetime import datetime
import os
from IPython.display import clear_output
import itertools

from fastai.basics import *
import json
from tqdm import tqdm

import jkbc.model as m
import jkbc.model.factory as factory
import jkbc.model.schedulers as schedulers
import jkbc.constants as constants
import jkbc.files.torch_files as f
import jkbc.model.metrics as metric
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.files.fasta as fasta

## Constants

### Data

In [None]:
BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = True
TEACHER_OUTPUT = 'bonito-csv' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

In [None]:
METRICS = [metric.ctc_accuracy(ALPHABET, 5)]

### Train/Predict

In [None]:
LR = 1e-3  # default learning rate
BS = 2**6  # batch size
EPOCHS = 1
DEVICE = DEVICE = m.get_available_gpu() #torch.device("cpu")
DEVICE

### Model

In [None]:
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH

## Load data

In [None]:
# Read data from feather
if KNOWLEGDE_DISTILLATION:
    data, teacher = f.load_training_data_with_teacher(FEATHER_FOLDER, TEACHER_OUTPUT)
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, teacher=teacher, drop_last=DROP_LAST)
else:
    data = f.load_training_data(FEATHER_FOLDER) 
    train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)

# Convert to databunch
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [None]:
_ctc_loss = lambda dim_pred_out : metric.CtcLoss(WINDOW_SIZE, dim_pred_out, BS, ALPHABET_SIZE)
loss_funcs = {}
for t in [1,2,4,8,16,32]:
    for a in np.arange(0,1.1,.1):
        loss_funcs[f't={t},a={a}'] = lambda dim_pred_out : metric.KdLoss(alpha=a, temperature=t, label_loss=_ctc_loss(dim_pred_out)).loss()

In [None]:
optimizers = {'AdamW': partial(torch.optim.AdamW, amsgrad=True, lr=LR)}

In [None]:
moms = [0.8, 0.9]; cycles = [3, 5, 7, 9]
schedulers = schedulers.get_named_schedulers(EPOCHS, LR, moms, cycles)

In [None]:
## Model_name, loss_function, optimizer, ctc_accuracy
models = [partial(factory.bonito, WINDOW_SIZE, DEVICE, BASE_DIR)]

folder_id = f'hyper/{datetime.now()}'
if not os.path.exists(folder_id):
    os.makedirs(folder_id)
with open(f'{folder_id}/hyper-output.csv', 'w') as f:
    f.write('Model_name, loss_function, optimizer, scheduler, ctc_accuracy, run_id')
    hyperparameters = list(itertools.product(models, loss_funcs.items(), optimizers.items(), schedulers.items()))
    for model, (l_key, loss), (o_key, optim), (s_key, scheduler) in tqdm(hyperparameters, 'Hyperparameter progress'): 
        # Get model
        m, (MODEL_NAME, dim_pred_out) = model()
        run_id = f'{MODEL_NAME}-{l_key}-{o_key}-{s_key}'
        # Set loss with model output dimension 
        loss_ = loss(dim_pred_out)
        # Create learner
        learner = Learner(databunch, m, loss_func=loss_, metrics=METRICS, opt_func=optim).to_fp16()
        # Set scheduler
        scheduler(learner)
        # Create callbacks
        # Stop if no improvement after patience epochs
        early_stop_callback = metrics.EarlyStoppingCallback(learner, monitor='ctc_accuracy' mode='max', patience=10) 
        # Save output for each epoch
        csv_log_callback = metric.CSVLogger(learner, filename=f'{folder_id}/{run_id}', append=True)
        callbacks=[early_stop_callback, csv_log_callback]
        # FIT
        learner.fit(EPOCHS, lr=LR, callbacks=callbacks)
        # Save results
        score = learner.validate()[1]
        f.write(f'{MODEL_NAME}, {l_key}, {o_key}, {s_key}, {score}, {run_id}\n')
        clear_output(wait=True)
        print('Previous score:', score)