In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *
import json
from tqdm import tqdm

import jkbc.model as m
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.general as g
import jkbc.utils.metrics as metric
import jkbc.utils.preprocessing as prep

In [3]:
# Initialise random libs and setup cudnn
random_seed = 25 # MAGIC!!
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Constants

### Data

In [32]:
BASE_DIR = Path("../../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-10000-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

KNOWLEGDE_DISTILLATION = False
TEACHER_OUTPUT = 'bonito-pretrained-Valid[3.625368118286133]-CTC[90.3227304562121]' # Set to name of y_teacher output
if KNOWLEGDE_DISTILLATION and not TEACHER_OUTPUT:
    print('WARNING! Must provide name of teacher output when doing knowledge distillation')

### Train/Predict

In [5]:
LR = 1e-3  # default learning rate
BS = 2**6  # batch size
EPOCHS = 400
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [6]:
import bonito_basic as model_file
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3+1
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
model, MODEL_NAME = model_file.model(DEVICE, WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT)
MODEL_NAME = f'{MODEL_NAME}-windowsize={WINDOW_SIZE}'
MODEL_DIR = f'weights/{MODEL_NAME}'
SPECIFIC_MODEL_WEIGHTS = None #'bonito-pretrained-Valid[1.545677900314331]-CTC[91.66666666666667]' #Set to specific name of model ('None' uses the newest)

### Loss, metrics and callback

In [7]:
_ctc_loss = metric.CtcLoss(WINDOW_SIZE, DIMENSIONS_PREDICTION_OUT, BS, ALPHABET_SIZE)
_kd_loss = metric.KdLoss(alpha=.3, temperature=5, label_loss=_ctc_loss)
LOSS_FUNC = _kd_loss.loss() if KNOWLEGDE_DISTILLATION else _ctc_loss.loss()
#LOSS_FUNC = nn.CTCLoss()
METRICS = []#[metric.ctc_accuracy(ALPHABET, 5)]
SAVE_CALLBACK = partial(metric.SaveModelCallback, every='epoch', monitor='valid_loss')

## Load data

In [8]:
# Read data from feather
data = f.load_training_data(FEATHER_FOLDER) 

# Convert to databunch
train_dl, valid_dl = prep.convert_to_dataloaders(data, split=.8, batch_size=BS, drop_last=DROP_LAST)
databunch = DataBunch(train_dl, valid_dl, device=DEVICE)

## Model

In [9]:
optimizer = partial(torch.optim.AdamW, amsgrad=True, lr=LR)

In [10]:
learner = Learner(databunch, model, loss_func=LOSS_FUNC, model_dir=MODEL_DIR, metrics=METRICS, opt_func=optimizer)

In [11]:
#m.load_model_weights(learner, SPECIFIC_MODEL_WEIGHTS)

## Train

In [12]:
#learner.lr_find()
#learner.recorder.plot(suggestion=True)

In [13]:
# Default to LR if lr_find() has not been run
try: lr = learner.recorder.min_grad_lr
except: lr = LR
lr = LR

In [None]:
learner.fit(EPOCHS, lr=lr, callbacks=[SAVE_CALLBACK(learner)])

In [None]:
learner.recorder.plot_losses()

## Predict

In [25]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, labels_per_window=(200,400), stride=500, window_size=(WINDOW_SIZE-1, WINDOW_SIZE), blank_id=constants.BLANK_ID)
read_object = sc[5]

In [35]:
# Predict signals
x = m.signal_to_input_tensor(read_object.x, DEVICE)
assembled, decoded = m.predict_and_assemble(learner.model, x, ALPHABET_STR, WINDOW_SIZE, 1, beam_size=500, beam_threshold=0.1)
print(f'Total accuracy: {accuracy}')

Total accuracy: 0.5642201834862385


In [50]:
accuracies = []
alignment = ""
for index in tqdm(range(len(decoded))):
    accuracy, alignment = m.get_accuracy(read_object.y[index], decoded[index], ALPHABET_VAL)
    accuracies.append(accuracy)
print(f'Average: {sum(accuracies)/len(accuracies)}')
print('Last window:')
print('\n'.join([alignment[0], alignment[1], alignment[2]]))

100%|██████████| 111/111 [00:00<00:00, 2201.99it/s]

Average: 0.5259627993830994
Last window:
ATCGTGGTGGGAGACAGTGTCAGGCG-GGCAGTTTGACTG-GGGCG-GTCGCCTCCTAAAAG-GTAACGGAGGCGCTCAAAGGTTCCCTCAGAATGGTTG-GAAATCATTCATAGAGTGT-AAAG-GCAT-AAGGGAGCTTGACTGCGAGACCTACA--AGTCGAGCAGGGTCGAAAGACGGACT-TAGTG-ATCCGGTG-GTTCCGCATG-GA-AGGG
||.|||.|  ||||.||||  |.|.| |..|||.||.||| |.|.| || | |||..|.|.| |||  .|||.||.| |.||.|....|..| |||  || ||.| .|  | .||.|||| |.|| |..| .||.||| |.||....||.|..||.|  ||| |||.| |.|||  .||..||.| |..|| ||.|.||| |.|  |.||| || ||.|
ATAGTGAT--GAGATAGTG--ATGTGTGTGAGTGTGTCTGTGTGTGCGT-G-CTCTGAGATGAGTA--TGAGACGAT-AGAGATGATGTGTG-ATG--TGAGAGA-GA--C-GAGTGTGTGATAGTGTGTGTAGAGAG-TAGATATAGATATATAGATGAGT-GAGTA-GATCG--TGAGTGAGTGTGATGTATGCTGTGTGAT--GTATGAGATAGTG
wassaa?





In [40]:
accuracy, alignment = m.get_accuracy(read_object.reference, assembled, ALPHABET_VAL)
alignment[0][:150], alignment[1][:150], alignment[2][:150]

('GGCGCAGGGCGGGACGGGGCGCGG-GGGGAG-GAGG-GGGGGCCGGCGCGGGCAGGGCACCACAGGCAGCGGCGAGGCAGCGCGCGGGCGGGGAACAAACCGCG-CCGGGGGCAGGCGGAGCGGACAGCGCGGACCAAGAACGGGACACG',
 '||.|.||.|.|.||.||||.|.|| |||||| |||| |||||  ||.|.|||..|||      ||...|.|| ||||.|| |.|.|||.||||...|....|.| ..|||||..||.||.|.|||.||.|.|    |.||  ||||.|..',
 'GGGGGAGAGAGAGAGGGGGGGGGGAGGGGAGTGAGGTGGGGG--GGTGGGGGTGGGG------AGTGTGTGG-GAGGGAG-GGGGGGGAGGGGGGTAGTGTGTGTGAGGGGGGGGGGGGGGGGGAGAGTGAG----AGGA--GGGAGAGT')