In [215]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K
from keras.utils import get_custom_objects

warnings.filterwarnings("ignore")

size_title = 18
size_label = 14
n_pred = 2


def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data

def create_model(model_path):
    
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    while True:
        try:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
        except Exception as exception:
            break
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary, compatibile_tools


def verify_model(model, tool_sequence, labels, dictionary, reverse_dictionary, compatible_tools, class_weights, topk=20, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    last_tool_name = reverse_dictionary[str(tl_seq[-1])]
    last_compatible_tools = compatible_tools[last_tool_name]
    sample = np.zeros(max_seq_len)
    for idx, tool_id in enumerate(tl_seq):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))

    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tool_sequence.split(",")]
    
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    
    prediction = np.reshape(prediction, (prediction.shape[1],))
    
    prediction_pos = np.argsort(prediction, axis=-1)

    # get topk prediction
    topk_prediction_pos = prediction_pos[-topk:]
    
    topk_prediction_val = [int(prediction[pos] * 100) for pos in topk_prediction_pos]
    
    topk_prediction_val = [(val * 100) / np.max(topk_prediction_val) for val in topk_prediction_val]
    
    # read tool names using reverse dictionary
    pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos if tool_pos > 0]
    actual_next_tool_ids = list(set(pred_tool_ids).intersection(set(last_compatible_tools.split(","))))

    pred_tool_ids_sorted = dict()
    for (tool_pos, tool_pred_val) in zip(topk_prediction_pos, topk_prediction_val):
        try:
            tool_name = reverse_dictionary[str(tool_pos)]
            if tool_name not in last_tool_name and tool_name in actual_next_tool_ids: #tool_name in actual_next_tool_ids and 
                pred_tool_ids_sorted[tool_name] = tool_pred_val
        except:
            continue
    pred_tool_ids_sorted = dict(sorted(pred_tool_ids_sorted.items(), key=lambda kv: kv[1], reverse=True))
    
    cls_wt = dict()
    usg_wt = dict()
    inv_wt = dict()
    ids_tools = dict()
    keys = list(pred_tool_ids_sorted.keys())
    for k in keys:
        try:
            cls_wt[k] = np.round(class_weights[str(data_dict[k])], 2)
            usg_wt[k] = np.round(usage_weights[k], 2)
            inv_wt[k] = np.round(inverted_weights[str(data_dict[k])], 2)
        except:
            continue
    print("Predicted tools: \n")
    print(pred_tool_ids_sorted)
    print()
    print("Class weights: \n")
    cls_wt = dict(sorted(cls_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(cls_wt)
    print()
    print("Usage weights: \n")
    usg_wt = dict(sorted(usg_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(usg_wt)
    print()
    total_usage_wt = np.mean(list(usg_wt.values()))
    print("Mean usage wt: %0.4f" % (total_usage_wt))
    print()
    print("Inverted weights: \n")
    inv_wt = dict(sorted(inv_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(inv_wt)
    for key in pred_tool_ids_sorted:
        ids_tools[key] = dictionary[key]
    print()
    print("Tool ids")
    print(ids_tools)
    print("======================================")
    return cls_wt, usg_wt, inv_wt, pred_tool_ids_sorted

base_path = "data/models/"

model_path = base_path + "model_rnn_custom_loss.hdf5"

trained_model = h5py.File(model_path, 'r')
model_config = json.loads(trained_model.get('model_config').value)
class_weights = json.loads(trained_model.get('class_weights').value)
    
loaded_model = model_from_json(model_config)
dictionary = json.loads(trained_model.get('data_dictionary').value)
compatibile_tools = json.loads(trained_model.get('compatible_tools').value)
best_params = json.loads(trained_model.get('best_parameters').value)

model, dictionary, reverse_dictionary, compatibile_tools = create_model(model_path)

print(reverse_dictionary)

{'1': 'qual_stats_boxplot', '2': 'mimodd_convert', '3': 'fastq_join', '4': 'scanpy_find_markers', '5': 'glimmer_knowlegde-based', '6': 'flexbar', '7': 'mtbls520_09_species_venn', '8': 'rseqc_geneBody_coverage', '9': 'rsem_calculate_expression', '10': 'mycrobiota-qc-report', '11': 'tmhmm2', '12': 'mothur_count_seqs', '13': 'bedtools_unionbedgraph', '14': 'ConsensusID', '15': 'vt_normalize', '16': 'infernal_cmbuild', '17': 'bcftools_mpileup', '18': 'mothur_classify_rf', '19': 'bg_statistical_hypothesis_testing', '20': 'deeptools_bamCoverage', '21': 'align_back_trans', '22': 'MultiplexResolver', '23': 'transpose', '24': 'fasta_merge_files_and_filter_unique_sequences', '25': 'mz_to_sqlite', '26': 'CONVERTER_fasta_to_tabular', '27': 'ConsensusMapNormalizer', '28': 'velvetg_jgi', '29': 'seurat_read10x', '30': 'mycrobiota-krona-mothur', '31': 'mothur_classify_seqs', '32': 'prokaryotic_ncbi_submission', '33': 'ip_scale_image', '34': 'sklear_numeric_clustering', '35': 'gemini_pathways', '36': '

In [219]:
topk = 30
tool_seq = "823,833,1222,1025"
verify_model(model, tool_seq, "", dictionary, reverse_dictionary, compatibile_tools, class_weights, topk)

Predicted tools: 

{'unipept': 100.0, 'addValue': 100.0, 'collection_column_join': 100.0, 'Summary_Statistics1': 100.0, 'Grep1': 100.0, 'htseq_count': 100.0, 'ggplot2_point': 100.0, 'tp_replace_in_line': 100.0, 'datamash_ops': 100.0, 'Cut1': 100.0, 'join_files_on_column_fuzzy': 100.0, 'wc_gnu': 100.0, 'Paste1': 100.0, 'cardinal_classification': 100.0, 'cshl_awk_tool': 100.0, 'tp_easyjoin_tool': 100.0, 'mass_spectrometry_imaging_filtering': 100.0, 'tp_head_tool': 100.0, 'cat1': 100.0, 'cor2': 100.0, 'tab2fasta': 100.0, 'eggnog_mapper': 100.0, 'tp_sort_header_tool': 100.0, 'Show beginning1': 100.0, 'Grouping1': 100.0, 'trim_galore': 100.0, 'get_sequences': 100.0, 'Filter1': 100.0, 'tp_replace_in_column': 100.0, 'edger': 100.0}

Class weights: 

{}

Usage weights: 

{}

Mean usage wt: nan

Inverted weights: 

{}

Tool ids
{'unipept': 666, 'addValue': 44, 'collection_column_join': 831, 'Summary_Statistics1': 708, 'Grep1': 641, 'htseq_count': 717, 'ggplot2_point': 662, 'tp_replace_in_line':

({},
 {},
 {},
 {'unipept': 100.0,
  'addValue': 100.0,
  'collection_column_join': 100.0,
  'Summary_Statistics1': 100.0,
  'Grep1': 100.0,
  'htseq_count': 100.0,
  'ggplot2_point': 100.0,
  'tp_replace_in_line': 100.0,
  'datamash_ops': 100.0,
  'Cut1': 100.0,
  'join_files_on_column_fuzzy': 100.0,
  'wc_gnu': 100.0,
  'Paste1': 100.0,
  'cardinal_classification': 100.0,
  'cshl_awk_tool': 100.0,
  'tp_easyjoin_tool': 100.0,
  'mass_spectrometry_imaging_filtering': 100.0,
  'tp_head_tool': 100.0,
  'cat1': 100.0,
  'cor2': 100.0,
  'tab2fasta': 100.0,
  'eggnog_mapper': 100.0,
  'tp_sort_header_tool': 100.0,
  'Show beginning1': 100.0,
  'Grouping1': 100.0,
  'trim_galore': 100.0,
  'get_sequences': 100.0,
  'Filter1': 100.0,
  'tp_replace_in_column': 100.0,
  'edger': 100.0})

In [220]:
class_weights["666"]

2.927592911295817