In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.basics import *
import json

import jkbc.model as m
import jkbc.model.factory as factory
import jkbc.utils.constants as constants
import jkbc.utils.torch_files as f
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils.fasta as fasta

import jkbc.utils.kd as kd

## Constants

### Data

In [5]:
BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
MIN_LABEL_LEN  = int(config['minl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

### Train/Predict

In [6]:
LR = 1e-3  # default learning rate
BS = 2**3  # batch size
EPOCHS = 400
DEVICE = torch.device("cuda:0") #torch.device("cpu")

### Model

In [9]:
DIMENSIONS_PREDICTION_OUT = WINDOW_SIZE//3+1
DROP_LAST = False # SET TO TRUE IF IT FAILS ON LAST BATCH
model, MODEL_NAME = factory.bonito(DEVICE, BASE_DIR)
MODEL_NAME = f'{MODEL_NAME}'
MODEL_DIR = f'{MODEL_NAME}/weights'
model_weights = 'bestmodel_8'
predicter = m.get_predicter(model, MODEL_DIR, DEVICE)

In [10]:
if not model_weights:
    model_weights = m.get_newest_model(MODEL_DIR)

predicter.load(model_weights);

## Save teacher output

In [None]:
# Read data from feather
data = f.load_training_data(FEATHER_FOLDER) 
kd.generate_and_save_y_teacher(FEATHER_FOLDER, MODEL_NAME, m.signal_to_input_tensor(data.x, DEVICE), predicter.model, bs=BS)

## Get accuracies

In [13]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, labels_per_window=(MIN_LABEL_LEN,DIMENSIONS_OUT), 
                           stride=STRIDE, window_size=(WINDOW_SIZE-1, WINDOW_SIZE), blank_id=constants.BLANK_ID)

In [14]:
# predict range
predict_objects = m.predict_range(predicter.model, sc, ALPHABET_STR, WINDOW_SIZE, DEVICE, indexes=np.random.randint(0, len(sc), 50))

  2%|▏         | 1/50 [00:00<00:41,  1.17it/s]

Error: Skipping index 33119 (probably because it was empty)


 44%|████▍     | 22/50 [00:15<00:12,  2.23it/s]

Error: Skipping index 420565 (probably because it was empty)


 96%|█████████▌| 48/50 [00:38<00:01,  1.09it/s]

Error: Skipping index 65986 (probably because it was empty)


100%|██████████| 50/50 [00:39<00:00,  1.28it/s]


In [15]:
# convert outputs to dictionaries
references = []
predictions = []
accuracies = []
for po in predict_objects:
    [accuracies.append(a) for a in m.get_accuracies(po.references, po.predictions, ALPHABET_VAL)]
    ref, pred = fasta.map_decoded(po, ALPHABET_VAL, False)
    references.append(ref)
    predictions.append(pred)

# Sanity tjek
np.mean(accuracies)

12it [00:00, 1771.49it/s]
100%|██████████| 12/12 [00:00<00:00, 5426.01it/s]
11it [00:00, 1419.44it/s]
100%|██████████| 11/11 [00:00<00:00, 3819.00it/s]
12it [00:00, 1719.15it/s]
100%|██████████| 12/12 [00:00<00:00, 4136.05it/s]
8it [00:00, 1445.25it/s]
100%|██████████| 8/8 [00:00<00:00, 3484.73it/s]
12it [00:00, 1893.09it/s]
100%|██████████| 12/12 [00:00<00:00, 4151.41it/s]
9it [00:00, 1670.15it/s]
100%|██████████| 9/9 [00:00<00:00, 3647.93it/s]
1it [00:00, 625.64it/s]
100%|██████████| 1/1 [00:00<00:00, 2392.64it/s]
13it [00:00, 1795.80it/s]
100%|██████████| 13/13 [00:00<00:00, 5419.00it/s]
13it [00:00, 2032.58it/s]
100%|██████████| 13/13 [00:00<00:00, 4257.51it/s]
14it [00:00, 2061.73it/s]
100%|██████████| 14/14 [00:00<00:00, 5235.87it/s]
12it [00:00, 1814.54it/s]
100%|██████████| 12/12 [00:00<00:00, 3821.11it/s]
14it [00:00, 2005.95it/s]
100%|██████████| 14/14 [00:00<00:00, 5437.56it/s]
1it [00:00, 571.35it/s]
100%|██████████| 1/1 [00:00<00:00, 1277.58it/s]
7it [00:00, 1673.04it/s]
1

0.8892272517231213

In [16]:
# make dicts ready to be saved
ref_dict = fasta.merge(references)
pred_dict = fasta.merge(predictions)

In [19]:
# save dicts
fasta.save_dicts(pred_dict, ref_dict, f'{MODEL_NAME}/predictions/{DATA_SET}')