In [46]:
import os
import numpy as np
import json
import h5py

def load_model(model_path):
    model = h5py.File(model_path, 'r')
    dictionary = json.loads(model.get('data_dictionary').value)
    paths = json.loads(model.get('multilabels_paths').value)
    c_tools = json.loads(model.get('compatible_next_tools').value)
    class_weights = json.loads(model.get('class_weights').value)
    rev_dict = dict((str(v), k) for k, v in dictionary.items())
    return paths, dictionary, rev_dict, c_tools, class_weights

def predict_tools(dict_paths, d_dict, c_tools, class_weights, test_path="bowtie2"):
    p_num = list()
    for t in test_path.split(","):
            p_num.append(str(d_dict[t]))
    p_num = ",".join(p_num)
    predicted_tools = list()
    for k in dict_paths:
        if k == p_num:
            predicted_tools = dict_paths[k].split(",")
            break
    pred_names = list()
    for tool in predicted_tools:
        pred_names.append(rev_dict[tool])
    #print("Test path: %s" % test_path)
    #print("Predicted tools: %s" % ",".join(pred_names))
    return predicted_tools, pred_names

In [47]:
model_path = "data/tool_recommendation_model_statistical_model.hdf5"
test_path = "trimmomatic,rna_star"
dict_paths, d_dict, rev_dict, c_tools, class_weights = load_model(model_path)
pred_ids, pred_names = predict_tools(dict_paths, d_dict, c_tools, class_weights, test_path)

In [50]:
c_wt_names = dict()
topk = 5
for t_id in pred_ids:
    t_name = rev_dict[t_id]
    c_wt_names[t_name] = class_weights[t_id]
sorted_pred_tools = sorted(c_wt_names.items(), key=lambda item: item[1], reverse=True)
sorted_names = list()
for k, v in sorted_pred_tools:
    sorted_names.append(k)

In [51]:
print(sorted_names[:topk])

['featurecounts', 'deeptools_bam_coverage', 'htseq_count', 'multiqc', 'deseq2']
