In [12]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K
from keras.utils import get_custom_objects

warnings.filterwarnings("ignore")

size_title = 18
size_label = 14
n_pred = 2


def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data

def create_model(model_path):
    
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    while True:
        try:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
        except Exception as exception:
            break
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary, compatibile_tools


def verify_model(model, tool_sequence, labels, dictionary, reverse_dictionary, compatible_tools, class_weights, topk=20, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    last_tool_name = reverse_dictionary[str(tl_seq[-1])]
    last_compatible_tools = compatible_tools[last_tool_name]
    sample = np.zeros(max_seq_len)
    for idx, tool_id in enumerate(tl_seq):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))

    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tool_sequence.split(",")]
    
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    
    prediction = np.reshape(prediction, (prediction.shape[1],))
    
    #prediction = prediction * weight_val
    
    prediction = prediction / float(np.max(prediction))
    
    prediction_pos = np.argsort(prediction, axis=-1)

    # get topk prediction
    topk_prediction_pos = prediction_pos[-topk:]
    topk_prediction_val = [int(prediction[pos] * 100) for pos in topk_prediction_pos]
    
    # read tool names using reverse dictionary
    pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos if tool_pos > 0]
    actual_next_tool_ids = list(set(pred_tool_ids).intersection(set(last_compatible_tools.split(","))))

    pred_tool_ids_sorted = dict()
    for (tool_pos, tool_pred_val) in zip(topk_prediction_pos, topk_prediction_val):
        try:
            tool_name = reverse_dictionary[str(tool_pos)]
            if tool_name not in last_tool_name and tool_name in actual_next_tool_ids:
                pred_tool_ids_sorted[tool_name] = tool_pred_val
        except:
            continue
    pred_tool_ids_sorted = dict(sorted(pred_tool_ids_sorted.items(), key=lambda kv: kv[1], reverse=True))
    
    cls_wt = dict()
    usg_wt = dict()
    inv_wt = dict()
    ids_tools = dict()
    keys = list(pred_tool_ids_sorted.keys())
    for k in keys:
        try:
            cls_wt[k] = np.round(class_weights[str(data_dict[k])], 2)
            usg_wt[k] = np.round(usage_weights[k], 2)
            inv_wt[k] = np.round(inverted_weights[str(data_dict[k])], 2)
        except:
            continue
    print("Predicted tools: \n")
    print(pred_tool_ids_sorted)
    print()
    print("Class weights: \n")
    cls_wt = dict(sorted(cls_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(cls_wt)
    print()
    print("Usage weights: \n")
    usg_wt = dict(sorted(usg_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(usg_wt)
    print()
    total_usage_wt = np.mean(list(usg_wt.values()))
    print("Mean usage wt: %0.4f" % (total_usage_wt))
    print()
    print("Inverted weights: \n")
    inv_wt = dict(sorted(inv_wt.items(), key=lambda kv: kv[1], reverse=True))
    print(inv_wt)
    for key in pred_tool_ids_sorted:
        ids_tools[key] = dictionary[key]
    print()
    print("Tool ids")
    print(ids_tools)
    print("======================================")
    return cls_wt, usg_wt, inv_wt, pred_tool_ids_sorted

#base_path = "data/rnn_custom_loss/complete_training/"
base_path = "data/models/"

#model_path = base_path + "trained_model_19_09_1.h5"
model_path = base_path + "model_rnn_custom_loss_19_03.hdf5"

trained_model = h5py.File(model_path, 'r')
model_config = json.loads(trained_model.get('model_config').value)
class_weights = json.loads(trained_model.get('class_weights').value)
    
loaded_model = model_from_json(model_config)
dictionary = json.loads(trained_model.get('data_dictionary').value)
compatibile_tools = json.loads(trained_model.get('compatible_tools').value)
best_params = json.loads(trained_model.get('best_parameters').value)

model, dictionary, reverse_dictionary, compatibile_tools = create_model(model_path)

print(reverse_dictionary)

{'1': 'gemini_comp_hets', '2': 'nanoplot', '3': 'mothur_screen_seqs', '4': 'rbc_mirdeep2_quantifier', '5': 'deeptools_bamCoverage', '6': 'abims_xcms_fillPeaks', '7': 'fasta2tab', '8': 'cshl_grep_tool', '9': 'wc_gnu', '10': 'Show tail1', '11': 'sklearn_build_pipeline', '12': 'bamCompare_deepTools', '13': 'fastq_to_fasta_python', '14': 'tp_sort_header_tool', '15': 'plotly_regression_performance_plots', '16': 'IDScoreSwitcher', '17': 'varscan_mpileup', '18': 'bedtools_multiintersectbed', '19': 'gemini_recessive_and_dominant', '20': 'ggplot2_heatmap2', '21': 'varscan_somatic', '22': 'snpsift_vartype', '23': 'blastxml_to_tabular_selectable', '24': 'ctb_remSmall', '25': 'egsea', '26': 'tp_split_on_column', '27': 'Heatmap', '28': 'gafa', '29': 'enhanced_bowtie_wrapper', '30': 'deeptools_correct_gc_bias', '31': 'samtools_slice_bam', '32': 'glimmer_build-icm', '33': 'eden_vectorizer', '34': 'proteomics_search_protein_prophet_1', '35': 'Cut1', '36': 'sam_bw_filter', '37': 'rseqc_junction_saturat

In [19]:
topk = 30
tool_seq = "587"
class_wt, usage_wt, inverse_wt, pred_tools = verify_model(model, tool_seq, "", dictionary, reverse_dictionary, compatibile_tools, class_weights, topk)

Predicted tools: 

{'Count1': 100, 'bedtools_mergebed': 100, 'datamash_ops': 100, 'comp1': 100, 'tp_replace_in_line': 100, 'addValue': 100, 'collapse_dataset': 100, 'deeptools_compute_matrix': 100, 'dexseq_annotate': 100, 'samtools_flagstat': 100, 'trimmer': 100, 'venn_list': 100, 'cat1': 100, 'Paste1': 100, 'tabular_to_fastq': 100, 'smooth_running_window': 100, 'vegan_rarefaction': 100, 'wig_to_bigWig': 100, 'snpEff': 100, 'sort1': 100, 'gtf_filter_by_attribute_values_list': 100, 'Cut1': 100, 'random_lines1': 100, 'bedtools_sortbed': 100, 'bedtools_coveragebed': 100, 'cardinal_mz_images': 100, 'gops_intersect_1': 100, 'mergeCols1': 100, 'tp_replace_in_column': 100, 'Fetch Taxonomic Ranks': 100}

Class weights: 

{}

Usage weights: 

{}

Mean usage wt: nan

Inverted weights: 

{}

Tool ids
{'Count1': 258, 'bedtools_mergebed': 898, 'datamash_ops': 50, 'comp1': 528, 'tp_replace_in_line': 525, 'addValue': 475, 'collapse_dataset': 520, 'deeptools_compute_matrix': 865, 'dexseq_annotate': 27

In [20]:
class_weights

{'0': 0.0,
 '1': 0.034741137125159235,
 '2': 3.763522662770899,
 '3': 3.71265533419441,
 '4': 0.07102052374333781,
 '5': 0.09390644047766762,
 '6': 0.3189445052146787,
 '7': 2.692372388608714,
 '8': 0.10751963108964303,
 '9': 0.0,
 '10': 3.7062259531342217,
 '11': 0.9405350918191971,
 '12': 0.0,
 '13': 4.357141792561724,
 '14': 6.625259907325613,
 '15': 0.4526208770636268,
 '16': 0.756977440001906,
 '17': 2.7494567255515547,
 '18': 2.419090246729231,
 '19': 0.24595253342506312,
 '20': 4.756222165541598,
 '21': 0.7869009543379938,
 '22': 0.3643105652442822,
 '23': 0.0,
 '24': 0.0,
 '25': 0.8080419692437555,
 '26': 2.4292177439274116,
 '27': 0.6717404732149909,
 '28': 0.09523834995189549,
 '29': 0.0,
 '30': 0.09210774839937601,
 '31': 2.7305837659262564,
 '32': 0.09530040702929865,
 '33': 0.0,
 '34': 0.09528700680856035,
 '35': 7.805980004273164,
 '36': 0.09528385218502887,
 '37': 0.5145711823951903,
 '38': 0.0,
 '39': 0.0,
 '40': 0.09399472406258753,
 '41': 0.0,
 '42': 0.095310179804324

In [11]:
'fastqc': 100, 'Cut1': 87, 'Filter1': 83, 'tp_sort_header_tool': 74, 'rna_star': 72, 'addValue': 69, 'join1': 68, 'tp_cut_tool': 63, 'bedtools_sortbed': 49, 'mergeCols1': 47, 'Paste1': 45, 'fastq_filter': 44, 'Convert characters1': 43, 'CONVERTER_bed_gff_or_vcf_to_bigwig_0': 43, 'sam_to_bam': 42, 'filter_by_fasta_ids': 37, 'datamash_transpose': 32, 'fastq_to_tabular': 30, 'fasta2tab': 30, 'ggplot2_histogram': 29, 'mass_spectrometry_imaging_combine': 21, 'trim_galore': 16, 'proteomics_moff': 8, 'checkFormat': 7, 'column_order_header_sort': 6, 'w4mclassfilter': 4, 'cardinal_combine': 3, 'barchart_gnuplot': 2, 'hgv_linkToGProfile': 1, 'mothur_phylotype': 1}

SyntaxError: invalid syntax (<ipython-input-11-d41b346ade61>, line 1)