In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.basics import *
import json

import jkbc.model as m
import jkbc.model.factory as factory
import jkbc.constants as constants
import jkbc.files.fasta as fasta
import jkbc.files.torch_files as f
import jkbc.utils.kd as kd
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop

## Constants

### Data

In [None]:
BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
DATA_SET = 'Range0-50-FixLabelLen400-winsize4096'
FEATHER_FOLDER = BASE_DIR/PATH_DATA/DATA_SET

with open(FEATHER_FOLDER/'config.json', 'r') as fp:
    config = json.load(fp)

ALPHABET       = constants.ALPHABET
ALPHABET_VAL   = list(ALPHABET.values())
ALPHABET_STR   = ''.join(ALPHABET_VAL)
ALPHABET_SIZE  = len(ALPHABET.keys())
WINDOW_SIZE    = int(config['maxw']) #maxw = max windowsize
DIMENSIONS_OUT = int(config['maxl']) # maxl = max label length
MIN_LABEL_LEN  = int(config['minl']) # maxl = max label length
STRIDE         = WINDOW_SIZE

### Train/Predict

In [4]:
BS = 2**3  # batch size
DEVICE = m.get_available_gpu() #torch.device("cpu")
DEVICE

device(type='cuda', index=4)

### Model

In [5]:
model, (MODEL_NAME, _) = factory.bonito(WINDOW_SIZE, DEVICE, BASE_DIR)
MODEL_NAME = f'{MODEL_NAME}'
MODEL_DIR = f'{MODEL_NAME}/weights'
model_weights = 'bestmodel_8'
predicter = m.get_predicter(model, MODEL_DIR, DEVICE).to_fp16()

In [6]:
if not model_weights:
    model_weights = m.get_newest_model(MODEL_DIR)

predicter.load(model_weights);
model_weights

'bestmodel_8'

## Save teacher output

In [None]:
# Read data from feather
data = f.load_training_data(FEATHER_FOLDER) 
kd.generate_and_save_y_teacher(FEATHER_FOLDER, MODEL_NAME, m.signal_to_input_tensor(data.x, DEVICE), predicter.model, bs=BS)

## Get accuracies

In [9]:
sc = prep.SignalCollection(BASE_DIR/constants.MAPPED_READS, BASE_DIR/constants.BACTERIA_DICT_PATH, labels_per_window=(MIN_LABEL_LEN,DIMENSIONS_OUT), 
                           stride=STRIDE, window_size=(WINDOW_SIZE-1, WINDOW_SIZE), blank_id=constants.BLANK_ID)

In [14]:
# predict range
predict_objects = m.predict_range(predicter.model, sc, ALPHABET_STR, WINDOW_SIZE, DEVICE, indexes=np.random.randint(0, len(sc), 1000))


  0%|          | 0/1000 [00:00<?, ?it/s][A
  0%|          | 1/1000 [00:01<16:58,  1.02s/it][A
  0%|          | 2/1000 [00:01<16:27,  1.01it/s][A
  0%|          | 3/1000 [00:02<12:23,  1.34it/s][A
  0%|          | 4/1000 [00:03<13:16,  1.25it/s][A
  0%|          | 5/1000 [00:03<13:42,  1.21it/s][A
  1%|          | 6/1000 [00:04<13:28,  1.23it/s][A
  1%|          | 7/1000 [00:05<11:18,  1.46it/s][A
  1%|          | 8/1000 [00:05<12:14,  1.35it/s][A
  1%|          | 9/1000 [00:06<13:32,  1.22it/s][A
  1%|          | 10/1000 [00:07<10:32,  1.57it/s][A
  1%|          | 11/1000 [00:08<12:33,  1.31it/s][A
  1%|          | 12/1000 [00:09<13:41,  1.20it/s][A
  1%|▏         | 13/1000 [00:10<14:31,  1.13it/s][A
  1%|▏         | 14/1000 [00:11<16:14,  1.01it/s][A
  2%|▏         | 15/1000 [00:12<14:13,  1.15it/s][A
  2%|▏         | 16/1000 [00:13<17:48,  1.09s/it][A
  2%|▏         | 17/1000 [00:15<19:08,  1.17s/it][A
  2%|▏         | 18/1000 [00:15<15:26,  1.06it/s][A
  2%|▏    

Error: Skipping index 350713. No windows in signal



  6%|▌         | 61/1000 [00:50<08:24,  1.86it/s][A
  6%|▌         | 62/1000 [00:51<09:58,  1.57it/s][A
  6%|▋         | 63/1000 [00:52<11:02,  1.41it/s][A
  6%|▋         | 64/1000 [00:53<10:57,  1.42it/s][A

Error: Skipping index 407922. No windows in signal



  7%|▋         | 66/1000 [00:54<10:10,  1.53it/s][A
  7%|▋         | 67/1000 [00:54<09:02,  1.72it/s][A

Error: Skipping index 152898. No windows in signal



  7%|▋         | 69/1000 [00:55<08:10,  1.90it/s][A
  7%|▋         | 70/1000 [00:55<06:58,  2.22it/s][A
  7%|▋         | 71/1000 [00:56<09:10,  1.69it/s][A
  7%|▋         | 72/1000 [00:56<07:02,  2.19it/s][A
  7%|▋         | 73/1000 [00:57<06:41,  2.31it/s][A
  7%|▋         | 74/1000 [00:58<08:10,  1.89it/s][A
  8%|▊         | 75/1000 [00:59<10:48,  1.43it/s][A
  8%|▊         | 76/1000 [00:59<08:33,  1.80it/s][A

KeyboardInterrupt: 

In [None]:
# convert outputs to dictionaries
references = []
predictions = []
accuracies = []
for po in predict_objects:
    [accuracies.append(a) for a in m.get_accuracies(po.references, po.predictions, ALPHABET_VAL)]
    ref, pred = fasta.map_decoded(po, ALPHABET_VAL, False)
    references.append(ref)
    predictions.append(pred)

# Sanity tjek
np.mean(accuracies)

In [None]:
# make dicts ready to be saved
ref_dict = fasta.merge(references)
pred_dict = fasta.merge(predictions)

In [None]:
# save dicts
fasta.save_dicts(pred_dict, ref_dict, f'{MODEL_NAME}/predictions/{DATA_SET}')