In [88]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [89]:
from fastai.basics import *
import json

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep

## Constants

### Data

In [122]:
BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = False
TEACHER_OUTPUT = 'bonito-pretrained-Valid[3.625368118286133]-CTC[90.3227304562121]' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')



### Train/Predict

In [123]:
LR = 1e-3  # default learning rate
BS = 2^6  # batch size
EPOCHS = 1000
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [124]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
model, MODEL_NAME = model_file.model(DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)
MODEL_NAME = f'{MODEL_NAME}-windowsize={WINDOW_SIZE}'
MODEL_DIR = f'weights/{MODEL_NAME}'
SPECIFIC_MODEL_WEIGHTS = None #'bonito-pretrained-Valid[1.545677900314331]-CTC[91.66666666666667]' #Set to specific name of model ('None' uses the newest)

### Loss, metrics and callback

In [136]:
_ctc_loss = metric.CtcLoss(WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)
_kd_loss = metric.KdLoss(alpha=.3, temperature=5, label_loss=_ctc_loss)
LOSS_FUNC = _kd_loss.loss() if KNOWLEGDE_DISTILLATION else _ctc_loss.loss()
METRICS = [metric.ErrorRate(metric.ctc_accuracy(ALPHABET, 5))]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

## Load data

In [137]:
# Read data from feather
data = f.load_training_data(FEATHER_FOLDER) 

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [138]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS)

In [133]:
m.load_model_weights(learner, SPECIFIC_MODEL_WEIGHTS)

Model weights loaded tmp


## Train

In [134]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

epoch,train_loss,valid_loss,ctc_accuracy,time


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.


KeyboardInterrupt: 

In [119]:
# Default to LR if lr_find() has not been run
try: lr = learner.recorder.min_grad_lr
except: lr = LR
lr

1.5848931924611132e-06

In [139]:
learner.fit_one_cycle(1, max_lr=lr, callbacks=[SAVE_CALLBACK(learner)])

epoch,train_loss,valid_loss,ctc_accuracy,time


2
2
2


IndexError: list index out of range

In [None]:
learner.recorder.plot_losses()

## Predict

In [None]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, training_data=True, stride=100, window_size=WINDOW_SIZE)
validate_signal_count = 1

In [None]:
#for index in range(validate_signal_count): 
# Get read object (signal and reference)
print('start') 
read_object = sc[]
print('read_')
# Predict signals
x = read_object.x[0]
x = m.signal_to_input_tensor(x, DEVICE)
print(x.shape)
a = learner.model(x)

    #assembled, (accuracy, alignment) = m.predict(learner, x, ALPHABET, WINDOW_SIZE, STRIDE, read_object.reference, beam_size=1, beam_threshold=0.1)
    
print(a)

In [None]:
alignment