## Load model and compute recommended tools

In [121]:
import os
import numpy as np
import json
import h5py

def load_model(model_path):
    model = h5py.File(model_path, 'r')
    dictionary = json.loads(model.get('data_dictionary').value)
    paths = json.loads(model.get('multilabels_paths').value)
    c_tools = json.loads(model.get('compatible_next_tools').value)
    class_weights = json.loads(model.get('class_weights').value)
    standard_connections = json.loads(model.get('standard_connections').value)
    rev_dict = dict((str(v), k) for k, v in dictionary.items())
    return paths, dictionary, rev_dict, c_tools, class_weights, standard_connections

def predict_tools(dict_paths, d_dict, c_tools, class_weights, test_path="bowtie2"):
    p_num = list()
    for t in test_path.split(","):
            p_num.append(str(d_dict[t]))
    p_num = ",".join(p_num)
    predicted_tools = list()
    for k in dict_paths:
        if k == p_num:
            predicted_tools = dict_paths[k].split(",")
            break
    pred_names = list()
    for tool in predicted_tools:
        pred_names.append(rev_dict[tool])
    return predicted_tools, pred_names

In [122]:
model_path = "data/tool_recommendation_model_statistical_model.hdf5"
test_path = "umi_tools_extract"
dict_paths, d_dict, rev_dict, c_tools, class_weights, standard_connections = load_model(model_path)
pred_ids, pred_names = predict_tools(dict_paths, d_dict, c_tools, class_weights, test_path)

print(pred_names)

['bwa_mem', 'umi_tools_group', 'fastqc', 'je_markdupes', 'bowtie2', 'rna_star', 'je_demultiplex', 'bwa', 'hisat2']


## Fetch top recommended tools (sorted in descending order based on their usage)

In [123]:
topk = 10 # this specifies how many top recommended tools are computed

s_pred_sorted = list()

def sort_by_wt(t_list):
    c_wt_names = dict()
    
    for t_name in t_list:
        t_id = d_dict[t_name]
        c_wt_names[t_name] = class_weights[str(t_id)]
    sorted_pred_tools = sorted(c_wt_names.items(), key=lambda item: item[1], reverse=True)
    sorted_names = list()
    for k, v in sorted_pred_tools:
        sorted_names.append(k)
    return sorted_names[:topk]

if test_path in standard_connections:
    s_conn = standard_connections[test_path]
    s_pred = list(set(s_conn).intersection(set(pred_names)))
    n_pred = list(set(pred_names).difference(set(s_conn)))
    
    s_pred_sorted = sort_by_wt(s_pred)
    n_pred_sorted = sort_by_wt(n_pred)
    print(s_pred_sorted)
    print()
    print(n_pred_sorted)
    print()
else:
    n_pred_sorted = sort_by_wt(pred_names)
s_pred_sorted.extend(n_pred_sorted)

['rna_star', 'je_demultiplex']

['fastqc', 'bowtie2', 'hisat2', 'bwa_mem', 'bwa', 'je_markdupes', 'umi_tools_group']



## Top recommended tools

In [124]:
print(s_pred_sorted)

['rna_star', 'je_demultiplex', 'fastqc', 'bowtie2', 'hisat2', 'bwa_mem', 'bwa', 'je_markdupes', 'umi_tools_group']


In [125]:
print(d_dict)

{'antismash': 1, 'megahit': 2, 'smf_utils_extract-boxed-sequences': 3, 'IDMapper': 4, 'deeptools_compute_matrix_operations': 5, 'vcfcheck': 6, 'extract_bcs.py': 7, 'deeptools_bam_compare': 8, 'samtools_markdup': 9, 'xcms_export_samplemetadata': 10, 'bedtools_genomecoveragebed': 11, 'ctb_osra': 12, 'sam_bw_filter': 13, 'rmcontamination': 14, 'idba_hybrid': 15, 'fastq_quality_trimmer': 16, 'gd_calc_freq': 17, 'tp_uniq_tool': 18, 'w4mclstrpeakpics': 19, 'rbc_rnacode': 20, 'qualimap_rnaseq': 21, 'ssake': 22, 'chira_collapse': 23, 'bedtools_unionbedgraph': 24, 'vcffixup': 25, 'Show beginning1': 26, 'ctb_filter': 27, 'mothur_summary_tax': 28, 'regionalgam_ab_index': 29, 'cp_mask_image': 30, 'deeptools_bamCorrelate': 31, 'mothur_pairwise_seqs': 32, 'charts': 33, 'column_remove_by_header': 34, 'gemini_comp_hets': 35, 'trinity_gene_to_trans_map': 36, 'cshl_fasta_nucleotides_changer': 37, 'ip_2d_feature_extraction': 38, 'bam_to_sam': 39, 'cshl_fastx_clipper': 40, 'picard_MarkDuplicatesWithMateCi