In [1]:
from fastai.basics import *
import json
import wandb

import jkbc.constants as constants
import jkbc.files.torch_files as f
import jkbc.utils.preprocessing as prep
import jkbc.utils.postprocessing as pop
import jkbc.utils as utils
from tqdm import tqdm 


BASE_DIR = Path("../..")
PATH_DATA = 'data/feather-files'
PROJECT = 'jk-basecalling-v2' 
TEAM="jkbc"
DEVICE = torch.device('cuda')

In [2]:
def predict(model, data_loader, alphabet: str, beam_size = 25, beam_threshold=0.1):
    # Predict signals
    predictions = []
    labels = []
    model.eval()
    for input, (target, _, _) in iter(data_loader):
        pred = model(input.to('cuda')).detach().cpu()
        decoded = pop.decode(pred, alphabet, beam_size=beam_size, threshold=beam_threshold)
        
        predictions.append(decoded)
        labels.append(target)
    
    return predictions, labels

In [3]:
def map_decoded(x, y, alphabet_values):
    references = {}
    predictions= {}
    for batch in range(len(x)):
        for index in range(len(x[batch])):
            key = f'#{batch}#{index}#'
            references[key] = pop.convert_idx_to_base_sequence(y[batch][index], alphabet_values)
            predictions[key] = x[batch][index]
    return references, predictions

In [4]:
def save(model_id, data_name, labels, predictions):
    import jkbc.files.fasta as fasta
    reference_dict, prediction_dict = map_decoded(predictions, labels, list(constants.ALPHABET.values()))
    
    '''
    # make dicts ready to be saved
    ref_merged = fasta.merge(reference_dict)
    pred_dict = fasta.merge(prediction_dict)
    '''
    # save fasta
    fasta.save_dicts(prediction_dict, reference_dict, f'predictions/{model_id}-{data_name}')

In [5]:
def get_config(run_path, root):
    config_path = wandb.restore('config.yaml', run_path=run_path, replace=True, root=root)
    with open(config_path.name, 'r') as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)
    return config


def get_window_size(data_set):
    with open(f'{data_set}/config.json', 'r') as fp:
        data_config = json.load(fp)
        window_size    = int(data_config['maxw']) #maxw = max windowsize
        '''    
        dimensions_out = int(data_config['maxl']) # maxl = max label length
        min_label_len  = int(data_config['minl']) # minl = min label length
        stride         = int(data_config['s'])
        '''
    return window_size

In [6]:
def get_model(config, window_size, device, run_path, root):
    import jkbc.model as m
    import jkbc.model.factory as factory
    import jkbc.utils.bonito.tune as bonito

    # Model
    model_params = utils.get_nested_dict(config, 'model_params')['value']
    model_config = bonito.get_bonito_config(model_params, double_kernel_sizes=False)

    model, _ = factory.bonito(window_size, device, model_config)
    predicter = m.get_predicter(model, device, '')

    weights = wandb.restore('bestmodel.pth', run_path=run_path, replace=True, root=root)
    # fastai requires the name without .pth
    model_weights = '.'.join(weights.name.split('.')[:-1])
    predicter.load(model_weights)
    
    return predicter.model

In [7]:
ids = [
    '2j9fzbx4', #swept-durian-82 # bonito
    '5916pnqr', #sleek-serenity-251 - stupid
    
    '2eiadj4y', #eternal-deluge-448
    '1ywu3vo9', #breezy-cosmos-408
    '2d84exku', #scarlet-sound-417
    'j6f2sn3v', #vibrant-puddle-433
    '1c2vr2my'  #playful-oath-434
]

ids = ['117mrxzu']
data_set = [('all', BASE_DIR/PATH_DATA/'all-other'), ('Bacillus', BASE_DIR/PATH_DATA/'bacillus')]
batch_size=64
alphabet = ''.join(constants.ALPHABET.values())

errors = []
for name, path in data_set:
    window_size = get_window_size(path)
    data = f.load_training_data(path) 
    test_dl, _ = prep.convert_to_dataloaders(data, split=1, batch_size=batch_size, drop_last=True)
    for id in tqdm(ids):
        #try:
        run_path = f"{TEAM}/{PROJECT}/{id}"
        root=f'wandb/{id}'

        config = get_config(run_path, root)
        model = get_model(config, window_size, DEVICE, run_path, root)

        predictions, labels = predict(model, test_dl, alphabet, 
                                      beam_size = 25, beam_threshold=0.1)

        save(id, name, labels, predictions)
        #except:
        #    errors.append((id, name))
print(errors)

100%|██████████| 1/1 [23:22<00:00, 1402.59s/it]
100%|██████████| 1/1 [22:39<00:00, 1359.95s/it]

[]





In [8]:
!nvidia-smi

Thu Jun  4 12:02:42 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.116.00   Driver Version: 418.116.00   CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM3...  On   | 00000000:B7:00.0 Off |                    0 |
| N/A   34C    P0    66W / 350W |  29439MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|    0  