In [12]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [56]:
from fastai.basics import *

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep

## Constants

### Data

In [57]:
ALPHABET       = constants.ALPHABET
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = 300
STRIDE         = 300
DIMENSIONS_OUT = 70

KNOWLEGDE_DISTILLATION = True
TEACHER_OUTPUT = 'bonito-pretrained-Valid[3.625368118286133]-CTC[90.3227304562121]' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = f'Range0-50-FixLabelLen{DIMENSIONS_OUT}'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

### Train/Predict

In [58]:
LR = 1e-3  # default learning rate
BS = 128  # batch size
EPOCHS = 1000
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [59]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = 100
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
model, MODEL_NAME = model_file.model(DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)

MODEL_DIR = f'weights/{MODEL_NAME}'
SPECIFIC_MODEL_WEIGHTS = None #Set to specific name of model ('None' uses the newest)

### Loss, metrics and callback

In [60]:
_ctc_loss = metric.CtcLoss(DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)
_kd_loss = metric.KdLoss(alpha=.3, temperature=5, label_loss=_ctc_loss)
LOSS_FUNC = _kd_loss.loss() if KNOWLEGDE_DISTILLATION else _ctc_loss.loss()
METRICS = [metric.ErrorRate(metric.ctc_error(ALPHABET, 5))]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

## Load data

In [63]:
# Read data from feather
data = f.read_kd_data_from_feather_file(FEATHER_FOLDER, TEACHER_OUTPUT, DIMENSIONS_PREDICTION_OUT, ALPHABET_SIZE)
#data = f.read_data_from_feather_file(FEATHER_FOLDER) 

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

TypeError: 'int' object is not callable

## Model

In [50]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS)

In [51]:
m.load_model_weights(learner, SPECIFIC_MODEL_WEIGHTS)

Model weights loaded tmp


## Train

In [52]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

epoch,train_loss,valid_loss,ctc_error,time


tensor([[31],
        [32],
        [16],
        [18],
        [26],
        [21],
        [26],
        [30],
        [25],
        [21],
        [31],
        [23],
        [34],
        [33],
        [28],
        [23],
        [24],
        [37],
        [26],
        [22],
        [33],
        [26],
        [28],
        [25],
        [25],
        [30],
        [28],
        [29],
        [24],
        [26],
        [25],
        [31],
        [25],
        [38],
        [30],
        [27],
        [23],
        [11],
        [20],
        [22],
        [29],
        [28],
        [24],
        [21],
        [24],
        [23],
        [19],
        [30],
        [29],
        [22],
        [26],
        [19],
        [21],
        [35],
        [27],
        [22],
        [25],
        [23],
        [31],
        [28],
        [27],
        [26],
        [31],
        [23],
        [29],
        [23],
        [26],
        [23],
        [25],
        [32],
        [27],
      

TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'

In [None]:
# Default to LR if lr_find() has not been run
try: lr = learner.recorder.min_grad_lr
except: lr = LR
lr

In [None]:
learner.fit_one_cycle(EPOCHS, max_lr=lr, callbacks=[SAVE_CALLBACK(learner)])

In [None]:
learner.recorder.plot_losses()

## Predict

In [None]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, training_data=False, stride=1)
validate_signal_count = 5

In [None]:
for index in range(validate_signal_count): 
    # Get read object (signal and reference)
    read_object = sc[index]
    # Predict signals
    x = read_object.x_for_prediction(DEVICE)
    assembled, (accuracy, alignment) = m.predict(learner, x, ALPHABET, WINDOW_SIZE, STRIDE, read_object.reference, beam_size=15, beam_threshold=0.1)
    
    print(accuracy)