In [3]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
from fastai.basics import *

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep

## Constants

### Data

In [5]:
ALPHABET       = constants.ALPHABET
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = 300
STRIDE         = 300
DIMENSIONS_OUT = 70

BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = f'Range0-5000-FixLabelLen{DIMENSIONS_OUT}'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

### Train/Predict

In [6]:
LR = 1e-3  # default learning rate
BS = 128  # batch size
EPOCHS = 1000
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [7]:
DIMENSIONS_PREDICTION_OUT = 100
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
LOSS_FUNC = metric.CtcLoss(DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE).loss()
METRICS = [metric.ErrorRate(metric.ctc_error(ALPHABET, 5))]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

import bonito_basic as model_file
model, MODEL_NAME = model_file.model(DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)

MODEL_DIR = f'weights/{MODEL_NAME}'
SPECIFIC_MODEL_WEIGHTS = None #'bonito-5000' #Set to specific name of model ('None' uses the newest)

## Load data

In [8]:
# Read data from feather
data = f.read_data_from_feather_file(FEATHER_FOLDER)

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [9]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS)

In [10]:
m.load_model_weights(learner, SPECIFIC_MODEL_WEIGHTS)

Model weights loaded tmp


## Train

In [11]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

epoch,train_loss,valid_loss,ctc_error,time


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


KeyboardInterrupt: 

In [None]:
# Default to LR if lr_find() has not been run
try: lr = learner.recorder.min_grad_lr
except: lr = LR
lr

In [None]:
learner.fit_one_cycle(EPOCHS, max_lr=lr, callbacks=[SAVE_CALLBACK(learner)])

In [None]:
learner.recorder.plot_losses()

## Predict

In [None]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, training_data=False, stride=1)
validate_signal_count = 5

In [None]:
for index in range(validate_signal_count): 
    # Get read object (signal and reference)
    read_object = sc[index]
    # Predict signals
    x = read_object.x_for_prediction(DEVICE)
    assembled, (accuracy, alignment) = m.predict(learner, x, ALPHABET, WINDOW_SIZE, STRIDE, read_object.reference, beam_size=15, beam_threshold=0.1)
    
    print(accuracy)