In [115]:
import csv
import random
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K

warnings.filterwarnings("ignore")


def format_tool_id(tool_link):
    """
    Extract tool id from tool link
    """
    tool_id_split = tool_link.split("/")
    tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link
    return tool_id

def read_workflow(wf_id, workflow_rows):
    """
    Read all connections for a workflow
    """
    tool_parents = dict()
    for connection in workflow_rows:
        in_tool = connection[0]
        out_tool = connection[1]
        if out_tool not in tool_parents:
            tool_parents[out_tool] = list()
        if in_tool not in tool_parents[out_tool]:
            tool_parents[out_tool].append(in_tool)
    return tool_parents

def get_roots_leaves(graph):
    roots = list()
    leaves = list()
    all_parents = list()
    for item in graph:
        all_parents.extend(graph[item])
    all_parents = list(set(all_parents))
    children = graph.keys()
    roots = list(set(all_parents).difference(set(children)))
    leaves = list(set(children).difference(set(all_parents)))
    return roots, leaves

def find_tool_paths_workflow(graph, start, end, path=[]):
    path = path + [end]
    if start == end:
        return [path]
    path_list = list()
    if end in graph:
        for node in graph[end]:
            if node not in path:
                new_tools_paths = find_tool_paths_workflow(graph, start, node, path)
                for tool_path in new_tools_paths:
                    path_list.append(tool_path)
    return path_list

def read_tabular_file(raw_file_path):
    """
    Read tabular file and extract workflow connections
    """
    print("Reading workflows...")
    workflows = {}
    workflow_paths_dup = ""
    workflow_parents = dict()
    workflow_paths = list()
    unique_paths = dict()
    tool_frequency = dict()
    standard_connections = dict()
    with open(raw_file_path, 'rt') as workflow_connections_file:
        workflow_connections = csv.reader(workflow_connections_file, delimiter='\t')
        for index, row in enumerate(workflow_connections):
            wf_id = str(row[0])
            in_tool = format_tool_id(row[3])
            out_tool = format_tool_id(row[6])
            if wf_id not in workflows:
                workflows[wf_id] = list()
            if out_tool and in_tool and out_tool != in_tool:
                workflows[wf_id].append((out_tool, in_tool))
            if out_tool != "":
                if out_tool not in tool_frequency:
                    tool_frequency[out_tool] = 0
                tool_frequency[out_tool] += 1
            if in_tool != "":
                if in_tool not in tool_frequency:
                    tool_frequency[in_tool] = 0
                
                tool_frequency[in_tool] += 1
    print("Reading workflows finished")
    print("Processing workflows...")
    wf_ctr = 0
    for wf_id in workflows:
        wf_ctr += 1
        workflow_parents[wf_id] = read_workflow(wf_id, workflows[wf_id])

    for wf_id in workflow_parents:
        flow_paths = list()
        parents_graph = workflow_parents[wf_id]
        roots, leaves = get_roots_leaves(parents_graph)
        for root in roots:
            for leaf in leaves:
                paths = find_tool_paths_workflow(parents_graph, root, leaf)
                # reverse the paths as they are computed from leaves to roots leaf
                paths = [tool_path for tool_path in paths]
                if len(paths) > 0:
                    flow_paths.extend(paths)
        workflow_paths.extend(flow_paths)
    print("Workflows processed: %d" % wf_ctr)

    # remove slashes from the tool ids
    wf_paths_no_slash = list()
    for path in workflow_paths:
        path_no_slash = [format_tool_id(tool_id) for tool_id in path]
        wf_paths_no_slash.append(path_no_slash)

    # collect duplicate paths
    for path in wf_paths_no_slash:
        workflow_paths_dup += ",".join(path) + "\n"
    unique_paths = list(workflow_paths_dup.split("\n"))
    unique_paths = list(filter(None, unique_paths))
    random.shuffle(unique_paths)
    paths_freq = dict()
    for path in unique_paths:
        if path not in paths_freq:
            paths_freq[path] = 0
        paths_freq[path] += 1
        
    return tool_frequency, paths_freq


wf_path = "../data/worflow-connection-20-04.tsv"
t_freq, p_freq = read_tabular_file(wf_path)

Reading workflows...
Reading workflows finished
Processing workflows...
Workflows processed: 18659


In [129]:
last_tool_freq = dict()
for path in p_freq:
    last_t = path.split(",")[-1]
    if last_t not in last_tool_freq:
        last_tool_freq[last_t] = 0
    last_tool_freq[last_t] += 1

In [130]:
t_freq  = dict(sorted(t_freq.items(), key=lambda kv: kv[1], reverse=False))
p_freq = dict(sorted(p_freq.items(), key=lambda kv: kv[1], reverse=False))
last_tool_freq = dict(sorted(last_tool_freq.items(), key=lambda kv: kv[1], reverse=False))

In [119]:
def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data


def create_model(model_path):
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    for index, item in enumerate(trained_model.keys()):
        if "weight_" in item:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary


def get_predicted_tools(base_tools, predictions, topk):
    """
    Get predicted tools. If predicted tools are less in number, combine them with published tools
    """
    precision = np.nan
    intersection = list()
    if len(base_tools) > 0:
        intersection = list(set(predictions).intersection(set(base_tools)))
        precision = len(intersection) / float(len(predictions))
    #print(base_tools)
    #print(intersection)
    #print(precision)
    #print()
    return intersection[:topk], precision


def sort_by_usage(t_list, class_weights, d_dict):
    """
    Sort predictions by usage/class weights
    """
    tool_dict = dict()
    for tool in t_list:
        t_id = d_dict[tool]
        tool_dict[tool] = class_weights[str(t_id)]
    #tool_dict = dict(sorted(tool_dict.items(), key=lambda kv: kv[1], reverse=True))
    return list(tool_dict.keys()), list(tool_dict.values())


def separate_predictions(base_tools, predictions, last_tool_name, weight_values, topk):
    """
    Get predictions from published and normal workflows
    """
    last_base_tools = list()
    predictions = predictions * weight_values
    prediction_pos = np.argsort(predictions, axis=-1)
    topk_prediction_pos = prediction_pos[-topk:]
    # get tool ids
    pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
    if last_tool_name in base_tools:
        last_base_tools = base_tools[last_tool_name]
        if type(last_base_tools).__name__ == "str":
            # get published or compatible tools for the last tool in a sequence of tools
            last_base_tools = last_base_tools.split(",")
    # get predicted tools
    p_tools, precision = get_predicted_tools(last_base_tools, pred_tool_ids, topk)
    sorted_c_t, sorted_c_v = sort_by_usage(p_tools, class_weights, dictionary)
    return sorted_c_t, sorted_c_v, precision


def compute_recommendations(model, tool_sequence, labels, dictionary, reverse_dictionary, class_weights, topk=10, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    tl_seq_ids = [str(dictionary[t]) for t in tl_seq]
    last_tool_name = tl_seq[-1]
    sample = np.zeros(max_seq_len)
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    for idx, tool_id in enumerate(tl_seq_ids):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))
    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tl_seq_ids]
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    nw_dimension = prediction.shape[1]
    prediction = np.reshape(prediction, (nw_dimension,))
    
    half_len = int(nw_dimension / 2)
    
    pub_t, pub_v, pub_prec = separate_predictions(standard_connections, prediction[:half_len], last_tool_name, weight_val, topk)
    # get recommended tools from normal workflows
    c_t, c_v, c_prec = separate_predictions(compatible_tools, prediction[half_len:], last_tool_name, weight_val, topk)
    
    return pub_prec, c_prec

In [120]:
model_path = "../data/tool_recommendation_model_20_04.hdf5"
trained_model = h5py.File(model_path, 'r')
model_config = json.loads(trained_model.get('model_config').value)
dictionary = json.loads(trained_model.get('data_dictionary').value)
class_weights = json.loads(trained_model.get('class_weights').value)
standard_connections = json.loads(trained_model.get('standard_connections').value)
compatible_tools = json.loads(trained_model.get('compatible_tools').value)
loaded_model = model_from_json(model_config)
model, dictionary, reverse_dictionary = create_model(model_path)


In [162]:
# get tools with lowest frequency
t_lowest_freq = dict()
for t in last_tool_freq:
    if last_tool_freq[t] < 10:
        t_lowest_freq[t] = last_tool_freq[t]

In [163]:
last_tool_freq

{'mtbls520_23_seasons_rda': 1,
 'mtbls520_08c_species_variability': 1,
 'keras_batch_models': 1,
 'metagene_annotator': 1,
 'rseqc_RNA_fragment_size': 1,
 'mtbls520_19b_seasons_unique': 1,
 'get_sequences': 1,
 'minfi_getM': 1,
 'mummer_nucmer': 1,
 'fpocket': 1,
 'ambertools_antechamber': 1,
 'tag_stat2': 1,
 'ip_scale_image': 1,
 'mlocarna': 1,
 'cp_image_math': 1,
 'aggregate': 1,
 'collection_element_identifiers': 1,
 'bed_to_bigBed': 1,
 'ggplot2_pca': 1,
 'mtbls520_08d_concentration': 1,
 'ip_histogram_equalization': 1,
 'rseqc_clipping_profile': 1,
 'cp_relate_objects': 1,
 'gops_basecoverage_1': 1,
 'Nucleosome': 1,
 'gemini_de_novo': 1,
 'vcfaddinfo': 1,
 'mtbls520_09_species_venn': 1,
 'pileup_interval': 1,
 'mothur_phylotype': 1,
 'vsearch_sorting': 1,
 'tp_uniq_tool': 1,
 'interactive_tool_phinch': 1,
 'p_clip_peaks': 1,
 'mtbls520_08b_species_unique': 1,
 'scater_plot_tsne': 1,
 'deeptools_alignmentsieve': 1,
 'ctb_filter': 1,
 'ip_binary_to_labelimage': 1,
 'picard_SortSa

In [161]:
topk = 1
complete_p_prec = list()
complete_c_prec = list()

freq = list()
print("Number of tool sequences: %d" % len(p_freq))
ctr = 0
for t_seq in p_freq:
    last_tool = t_seq.split(",")[-1]
    if last_tool in t_lowest_freq.keys():
        p_prec, c_prec = compute_recommendations(model, t_seq, "", dictionary, reverse_dictionary, class_weights, topk)
        complete_p_prec.append(p_prec)
        complete_c_prec.append(c_prec)
        freq.append(p_freq[t_seq])
        ctr+= 1

for t_seq in t_lowest_freq:
    p_prec, c_prec = compute_recommendations(model, t_seq, "", dictionary, reverse_dictionary, class_weights, topk)
    complete_p_prec.append(p_prec)
    complete_c_prec.append(c_prec)
    freq.append(t_lowest_freq[t_seq])
    ctr+= 1
    
print("Published precision: %s" % str(np.nanmean(complete_p_prec)))
print("Normal precision: %s" % str(np.nanmean(complete_c_prec)))
print("Number of paths used : %s" % ctr)
print("Mean frequency: %s" % str(np.mean(freq)))

Number of tool sequences: 198459


KeyboardInterrupt: 

In [149]:
complete_p_prec

[nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 0.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 0.0,
 nan,
 nan,
 nan,
 1.0,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 1.0,
 nan,
 1.0,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 1.0,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 nan,
 1.0,
 nan,
 nan,
 0.0,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan,
 nan