# Tool recommendation 
## (Gated recurrent units neural network with weighted cross-entropy loss)

In [143]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K

warnings.filterwarnings("ignore")


def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data


def create_model(model_path):
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    for index, item in enumerate(trained_model.keys()):
        if "weight_" in item:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary


def get_published_pred(standard_connections, predictions, last_t_name):
    t_published = list()
    print(last_t_name)
    print()
    print("Predictions:::")
    print(predictions)
    print()
    if last_t_name in standard_connections:
        t_published = standard_connections[last_t_name]
    pub_pred = list(set(predictions).intersection(set(t_published)))
    unpub_pred = list(set(predictions).difference(set(t_published)))
    print("Published tools:::")
    print(t_published)
    print()
    print("Published and pred:::")
    print(pub_pred)
    print()
    print("Unpublished and pred:::")
    print(unpub_pred)
    return pub_pred, unpub_pred
    
def sort_by_usage(t_list, class_weights, d_dict):
    tool_dict = dict()
    for tool in t_list:
        t_id = d_dict[tool]
        tool_dict[tool] = class_weights[str(t_id)]
    tool_dict = dict(sorted(tool_dict.items(), key=lambda kv: kv[1], reverse=True))
    print(tool_dict)
    return list(tool_dict.keys())

def compute_recommendations(model, tool_sequence, labels, dictionary, reverse_dictionary, class_weights, optimize=True, topk=20, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    last_tool_name = reverse_dictionary[str(tl_seq[-1])]
    sample = np.zeros(max_seq_len)
    for idx, tool_id in enumerate(tl_seq):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))
    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tool_sequence.split(",")]
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    prediction = np.reshape(prediction, (prediction.shape[1],))
    prediction = prediction * weight_val
    prediction = prediction / float(np.max(prediction))
    prediction_pos = np.argsort(prediction, axis=-1)
    # get topk prediction
    topk_prediction_pos = prediction_pos[-topk:]
    topk_prediction_val = [int(prediction[pos] * 100) for pos in topk_prediction_pos]
    # read tool names using reverse dictionary
    pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
    
    pub_p, unpub_p = get_published_pred(standard_connections, pred_tool_ids, last_tool_name)
    sorted_pub_p = sort_by_usage(pub_p, class_weights, dictionary)
    sorted_unpub_p = sort_by_usage(unpub_p, class_weights, dictionary)
    
    pred_tool_ids_sorted = dict()
    for (tool_pos, tool_pred_val) in zip(topk_prediction_pos, topk_prediction_val):
        tool_name = reverse_dictionary[str(tool_pos)]
        pred_tool_ids_sorted[tool_name] = tool_pred_val
    pred_tool_ids_sorted = dict(sorted(pred_tool_ids_sorted.items(), key=lambda kv: kv[1], reverse=True))
    ids_tools = dict()
    keys = list(pred_tool_ids_sorted.keys())
    tool_seq_name = ",".join(tool_sequence_names)
    print("Current tool sequence: ")
    print()
    print(tool_seq_name)
    print()
    print("Recommended tools for the tool sequence '%s' with their scores in decreasing order:" % tool_seq_name)
    print()
    for i in pred_tool_ids_sorted:
        print(i + "(" + str(pred_tool_ids_sorted[i]) + "%)")
    for key in pred_tool_ids_sorted:
        ids_tools[key] = dictionary[key]
    print()
    print("Predictions published:")
    for t in sorted_pub_p:
        print(t)
    print()
    print("Predictions unpublished:")
    for t in sorted_unpub_p:
        print(t)
    print()
    print("Tool ids:")
    print()
    for i in ids_tools:
        print(i + "(" + str(ids_tools[i]) + ")")

## Unpack trained model for prediction

In [144]:
model_path = "data/tool_recommendation_model.hdf5"
trained_model = h5py.File(model_path, 'r')
model_config = json.loads(trained_model.get('model_config').value)
dictionary = json.loads(trained_model.get('data_dictionary').value)
class_weights = json.loads(trained_model.get('class_weights').value)
standard_connections = json.loads(trained_model.get('standard_connections').value)
loaded_model = model_from_json(model_config)
model, dictionary, reverse_dictionary = create_model(model_path)

## Indices of tools

In [145]:
#print(reverse_dictionary)

## Recommended tools

In [160]:
optimize = True
topk = 20 # set the maximum number of recommendations
tool_seq = "1118,332,1117,1257,1281" # give tools ids in a sequence and see the recommendations. To know all the tool ids, 
                     # please print the variable 'reverse_dictionary'
compute_recommendations(model, tool_seq, "", dictionary, reverse_dictionary, class_weights, topk)

IDMapper

Predictions:::
['mothur_unifrac_weighted', 'flexbardsc', 'mimodd_header', 'mimodd_varcall', 'cshl_fastq_to_fasta', 'melt', 'rbc_mirdeep2', 'cshl_cut_tool', 'ip_coordinates_of_roi', 'hicexplorer_hicsummatrices', 'IDMapper', 'QCCalculator', 'FeatureLinkerUnlabeledQT', 'FileMerger', 'MetaProSIP', 'MzTabExporter', 'IDConflictResolver', 'FileFilter', 'FileInfo', 'ProteinQuantifier']

Published tools:::
['ProteinQuantifier', 'IDConflictResolver']

Published and pred:::
['IDConflictResolver', 'ProteinQuantifier']

Unpublished and pred:::
['FeatureLinkerUnlabeledQT', 'mothur_unifrac_weighted', 'cshl_fastq_to_fasta', 'IDMapper', 'hicexplorer_hicsummatrices', 'melt', 'MetaProSIP', 'mimodd_header', 'FileFilter', 'flexbardsc', 'mimodd_varcall', 'cshl_cut_tool', 'rbc_mirdeep2', 'FileInfo', 'QCCalculator', 'FileMerger', 'ip_coordinates_of_roi', 'MzTabExporter']
{'ProteinQuantifier': 5.374815302782456, 'IDConflictResolver': 4.071363882392451}
{'FileInfo': 5.2753944625929305, 'FileFilter': 5

In [151]:
print(reverse_dictionary)

{'1': 'velveth', '2': 'Grouping1', '3': 'scanpy_find_cluster', '4': 'EMBOSS: textsearch98', '5': 'cds_essential_variability', '6': 'seurat_read10x', '7': 'sklearn_searchcv', '8': 'mothur_get_oturep', '9': 'mycrobiota-make-multi-otutable', '10': 'ctb_opsin', '11': 'fasttree', '12': 'Psortb', '13': 'ip_2d_split_binaryimage_by_watershed', '14': 'methtools_plot', '15': 'cshl_fastx_reverse_complement', '16': 'CometAdapter', '17': 'camera-find-isotopes', '18': 'bedtools_mergebed', '19': 'mycrobiota-qc-report', '20': 'vcfallelicprimitives', '21': 'param_value_from_file', '22': 'mothur_chimera_vsearch', '23': 'gmx_merge_topology_files', '24': 'feelnc', '25': 'datamash_reverse', '26': 'varscan_copynumber', '27': 'venn_list', '28': 'mothur_remove_groups', '29': 'mothur_make_group', '30': 'iuc_pear', '31': 'pieplots_macs', '32': 'fastq_paired_end_interlacer', '33': 'deg_annotate', '34': 'Show beginning1', '35': 'miranda', '36': 'passatutto', '37': 'vcfflatten2', '38': 'flexbar_split_RR_bcs', '39'

{'0': 0.0,
 '1': 4.8693318950069004,
 '2': 4.710430790975013,
 '3': 0.1823215567939546,
 '4': 0.07246826928972329,
 '5': 0.26236426446749106,
 '6': 0.1505728584793744,
 '7': 3.2210969531355067,
 '8': 0.5395959443637816,
 '9': 0.0,
 '10': 0.08971682978290577,
 '11': 2.607037830397792,
 '12': 0.09048175077058784,
 '13': 0.0825580615169461,
 '14': 0.09473615144520271,
 '15': 3.884504883758919,
 '16': 0.03820023110008294,
 '17': 0.0,
 '18': 4.387828804608451,
 '19': 0.0,
 '20': 5.913932166319434,
 '21': 0.051575139665438495,
 '22': 3.7864597824528,
 '23': 0.09531017980432493,
 '24': 1.1518698432994303,
 '25': 0.09531017980432493,
 '26': 0.0,
 '27': 3.7820731095235285,
 '28': 3.5786714472209065,
 '29': 3.095577608523707,
 '30': 0.5097738248193585,
 '31': 0.0,
 '32': 4.250239751222284,
 '33': 4.07916355023045,
 '34': 4.566142639960839,
 '35': 0.8034952362362733,
 '36': 0.0,
 '37': 0.0,
 '38': 0.0,
 '39': 5.755483261770645,
 '40': 0.09531017980432493,
 '41': 0.0,
 '42': 3.2531696608703236,
 '