# Tool recommendation 
## (Gated recurrent units neural network with weighted cross-entropy loss)

In [112]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K

warnings.filterwarnings("ignore")


def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data


def create_model(model_path):
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    for index, item in enumerate(trained_model.keys()):
        if "weight_" in item:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary


def get_published_pred(standard_connections, predictions, last_t_name):
    t_published = list()
    print(last_t_name)
    print()
    if last_t_name in standard_connections:
        t_published = standard_connections[last_t_name]
    pub_pred = list(set(predictions).intersection(set(t_published)))
    print("Published tools:::")
    print(t_published)
    print()
    print("Published and pred:::")
    print()
    print(pub_pred)
    return pub_pred

def get_normal_pred(compatible_tools, predictions, last_t_name):
    t_normal = list()
    #print(last_t_name)
    #print()
    #print("Predictions:::")
    #print(predictions)
    print()
    if last_t_name in compatible_tools:
        compatible_t = compatible_tools[last_t_name].split(",")
    compatible_pred = list(set(predictions).intersection(set(compatible_t)))
    print("Compatible tools:::")
    print(compatible_t)
    print()
    print("Compatible and pred:::")
    print(compatible_pred)
    print()
    return compatible_pred
    
def sort_by_usage(t_list, class_weights, d_dict):
    tool_dict = dict()
    for tool in t_list:
        t_id = d_dict[tool]
        tool_dict[tool] = class_weights[str(t_id)]
    tool_dict = dict(sorted(tool_dict.items(), key=lambda kv: kv[1], reverse=True))
    #print(tool_dict)
    return list(tool_dict.keys())

def compute_recommendations(model, tool_sequence, labels, dictionary, reverse_dictionary, class_weights, optimize=True, topk=20, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    last_tool_name = reverse_dictionary[str(tl_seq[-1])]
    sample = np.zeros(max_seq_len)
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    for idx, tool_id in enumerate(tl_seq):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))
    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tool_sequence.split(",")]
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    nw_dimension = prediction.shape[1]
    prediction = np.reshape(prediction, (nw_dimension,))
    
    half_len = int(nw_dimension / 2)
    
    standard_pred = prediction[:half_len]
    normal_pred = prediction[half_len:]
    
    standard_pred = standard_pred * weight_val
    normal_pred = normal_pred * weight_val
    
    standard_prediction_pos = np.argsort(standard_pred, axis=-1)
    standard_topk_prediction_pos = standard_prediction_pos[-topk:]

    normal_prediction_pos = np.argsort(normal_pred, axis=-1)
    normal_topk_prediction_pos = normal_prediction_pos[-topk:]
    
    standard_pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in standard_topk_prediction_pos]
    pub_p = get_published_pred(standard_connections, standard_pred_tool_ids, last_tool_name)
    sorted_pub_p = sort_by_usage(pub_p, class_weights, dictionary)
    print(sorted_pub_p)
    
    normal_pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in normal_topk_prediction_pos]
    c_pred = get_normal_pred(compatible_tools, normal_pred_tool_ids, last_tool_name)
    sorted_c_pred = sort_by_usage(c_pred, class_weights, dictionary)
    print(sorted_c_pred)
    
    print()
    tool_seq_name = ",".join(tool_sequence_names)
    print("Current tool sequence: ")
    print()
    print(tool_seq_name)
    print()
    print("Overall recommendations: ")
    print()
    sorted_pub_p.extend(sorted_c_pred)
    #sorted_pub_p = list(set(sorted_pub_p))
    sorted_pub_p = list(dict.fromkeys(sorted_pub_p))
    print(sorted_pub_p)
    
    ids_tools = dict()
    for key in sorted_pub_p:
        ids_tools[key] = dictionary[key]
    print()
    print("Recommended tool ids:")
    print()
    for i in ids_tools:
        print(i + "(" + str(ids_tools[i]) + ")")

## Unpack trained model for prediction

In [113]:
model_path = "data/tool_recommendation_model.hdf5"
trained_model = h5py.File(model_path, 'r')
model_config = json.loads(trained_model.get('model_config').value)
dictionary = json.loads(trained_model.get('data_dictionary').value)
class_weights = json.loads(trained_model.get('class_weights').value)
standard_connections = json.loads(trained_model.get('standard_connections').value)
compatible_tools = json.loads(trained_model.get('compatible_tools').value)
loaded_model = model_from_json(model_config)
model, dictionary, reverse_dictionary = create_model(model_path)

## Indices of tools

In [114]:
#print(reverse_dictionary)

## Recommended tools

In [118]:
optimize = True
topk = 20 # set the maximum number of recommendations
tool_seq = "980,1300,465,937,977" # give tools ids in a sequence and see the recommendations. To know all the tool ids, 
                     # please print the variable 'reverse_dictionary'
compute_recommendations(model, tool_seq, "", dictionary, reverse_dictionary, class_weights, topk)

snpSift_filter

Published tools:::
['snpSift_annotate']

Published and pred:::

['snpSift_annotate']
['snpSift_annotate']

Compatible tools:::
['snpSift_geneSets', 'vcfcombine', 'snpEff', 'gemini_load', 'cloudmap_variant_discovery_mapping', 'snpSift_extractFields', 'cshl_word_list_grep', 'table_annovar', 'snpSift_annotate']

Compatible and pred:::
['gemini_load', 'vcfcombine', 'snpEff', 'snpSift_extractFields', 'snpSift_annotate', 'cshl_word_list_grep']

['snpEff', 'gemini_load', 'snpSift_extractFields', 'vcfcombine', 'snpSift_annotate', 'cshl_word_list_grep']

Current tool sequence: 

bowtie2,freebayes,vcfallelicprimitives,vt_normalize,snpSift_filter

Overall recommendations: 

['snpSift_annotate', 'snpEff', 'gemini_load', 'snpSift_extractFields', 'vcfcombine', 'cshl_word_list_grep']

Recommended tool ids:

snpSift_annotate(524)
snpEff(82)
gemini_load(634)
snpSift_extractFields(1512)
vcfcombine(459)
cshl_word_list_grep(90)


In [116]:
print(reverse_dictionary)

{'1': 'vigiechiro_idvalid', '2': 'rawtools', '3': 'qiime_validate_mapping_file', '4': 'varscan_mpileup', '5': 'bed2gff1', '6': 'bedtools_genomecoveragebed_bedgraph', '7': 'metaphlan2', '8': 'enasearch_search_data', '9': 'CONVERTER_gz_to_uncompressed', '10': 'goslimmer', '11': 'ambertools_acpype', '12': 'gdal_gdalinfo', '13': 'sccaf_regress_out', '14': 'CONVERTER_bed_gff_or_vcf_to_bigwig_0', '15': 'cp_cellprofiler', '16': 'seqtk_trimfq', '17': 'velvet', '18': 'prinseq', '19': 'kraken', '20': 'lofreq_viterbi', '21': 'scater_plot_pca', '22': 'CONVERTER_fasta_to_2bit', '23': 'sm_mirdeep2core_without_Randfold', '24': 'ip_landmark_registration', '25': 'samtools_bedcov', '26': 'mothur_homova', '27': 'pyprophet_export', '28': 'gatk2_variant_annotator', '29': 'ip_projective_transformation', '30': 'histogram_rpy', '31': 'edger', '32': 'xpath', '33': 'addName', '34': 'velvetoptimiser', '35': 'EMBOSS: geecee41', '36': 'gops_subtract_1', '37': 'ExtractFASTAfromFASTQ', '38': 'chira_map', '39': 'musc

In [117]:
print(class_weights)

{'0': 0.0, '1': 3.6181228618462042, '2': 2.7308467871257793, '3': 1.5752098302117428, '4': 5.294696963139428, '5': 2.1416233915635003, '6': 3.8975097341640623, '7': 5.026759126183539, '8': 2.365559892155434, '9': 7.994103785051117, '10': 1.9778600169107303, '11': 0.8440503199744372, '12': 1.2878362605579323, '13': 0.057805432811355945, '14': 4.287563318416876, '15': 0.1926821830024244, '16': 1.0603356007161215, '17': 0.0, '18': 3.820552503337477, '19': 0.038443978880996196, '20': 0.19268434382950114, '21': 1.3229497669677621, '22': 2.1853963971827404, '23': 0.0, '24': 2.4543968576905786, '25': 3.1724006260575903, '26': 1.3420891029370807, '27': 0.1926822324870661, '28': 0.0953441246827365, '29': 0.2131345317788946, '30': 3.0909001436873202, '31': 5.7062793548547885, '32': 2.725978764941211, '33': 4.178600501143271, '34': 4.3118212938000475, '35': 2.9590110665908615, '36': 2.727012059364866, '37': 0.0, '38': 0.19268434382950114, '39': 3.162852968767856, '40': 4.917161529222698, '41': 0.