In [4]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
from fastai.basics import *
import json
from tqdm import tqdm

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep

In [None]:
# Initialise random libs and setup cudnn
random_seed = 25 # MAGIC!!
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Constants

### Data

In [6]:
BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-10000-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = False
TEACHER_OUTPUT = 'bonito-pretrained-Valid[3.625368118286133]-CTC[90.3227304562121]' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

### Train/Predict

In [7]:
LR = 1e-3  # default learning rate
BS = 2^11  # batch size
EPOCHS = 1000
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [8]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3+1
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
model, MODEL_NAME = model_file.model(DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)
MODEL_NAME = f'{MODEL_NAME}-windowsize={WINDOW_SIZE}'
MODEL_DIR = f'weights/{MODEL_NAME}'
SPECIFIC_MODEL_WEIGHTS = None #'bonito-pretrained-Valid[1.545677900314331]-CTC[91.66666666666667]' #Set to specific name of model ('None' uses the newest)

### Loss, metrics and callback

In [10]:
_ctc_loss = metric.CtcLoss(WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)
_kd_loss = metric.KdLoss(alpha=.3, temperature=5, label_loss=_ctc_loss)
LOSS_FUNC = _kd_loss.loss() if KNOWLEGDE_DISTILLATION else _ctc_loss.loss()
#LOSS_FUNC = nn.CTCLoss()
METRICS = []#[metric.ctc_accuracy(ALPHABET, 5)]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

## Load data

In [11]:
# Read data from feather
data = f.load_training_data(FEATHER_FOLDER) 

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=1, batch_size=BS, drop_last=DROP_LAST, windows=10000)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [None]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS, opt_func=AdamW(model.parameters(), amsgrad=True, lr=LR))

In [None]:
m.load_model_weights(learner, SPECIFIC_MODEL_WEIGHTS)

## Train

In [None]:
learner.lr_find()
learner.recorder.plot(suggestion=True)

In [None]:
# Default to LR if lr_find() has not been run
try: lr = learner.recorder.min_grad_lr
except: lr = LR
lr = LR

In [None]:
learner.fit(10, lr=lr, callbacks=[SAVE_CALLBACK(learner)])

In [None]:
learner.recorder.plot_losses()

## Predict

In [None]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, training_data=False, stride=1, window_size=(WINDOW_SIZE-1, WINDOW_SIZE), blank_id=constants.BLANK_ID)
validate_signal_count = 1

In [None]:
#for index in range(validate_signal_count): 
# Get read object (signal and reference)
read_object = sc[0]
# Predict signals
x = read_object.x
x_size= len(x)
outputs = []
x = torch.tensor(read_object.x, dtype=torch.float32)
for from_ in tqdm(range(0, x_size, 100)):
    to = min(x_size-1, from_+100)
    x_ = m.signal_to_input_tensor(x[from_:to], DEVICE)
    outputs += learner.model(x_).detach().cpu()

In [None]:
assembled, (accuracy, alignment) = m.predict(learner, outputs, ALPHABET, WINDOW_SIZE, 1, read_object.reference, beam_size=500, beam_threshold=0.1)

In [None]:
alignment