# Tool recommendation 
## (Gated recurrent units neural network with weighted cross-entropy loss)

In [66]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K

warnings.filterwarnings("ignore")


def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data


def create_model(model_path):
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    cu_wt = list()
    for index, item in enumerate(trained_model.keys()):
        if "weight_" in item:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model[d_key][()] #trained_model.get(d_key).value
            print(weights)
            mean = np.mean(weights)
            cu_wt.append(mean)
            model_weights.append(weights)
            weight_ctr += 1
    print("Overall mean of model weights: %.6f" % np.mean(cu_wt))
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary


def get_predicted_tools(base_tools, predictions, topk):
    """
    Get predicted tools. If predicted tools are less in number, combine them with published tools
    """
    intersection = list(set(predictions).intersection(set(base_tools)))
    return intersection[:topk]


def sort_by_usage(t_list, class_weights, d_dict):
    """
    Sort predictions by usage/class weights
    """
    tool_dict = dict()
    for tool in t_list:
        t_id = d_dict[tool]
        tool_dict[tool] = class_weights[str(t_id)]
    tool_dict = dict(sorted(tool_dict.items(), key=lambda kv: kv[1], reverse=True))
    return list(tool_dict.keys()), list(tool_dict.values())


def separate_predictions(base_tools, predictions, last_tool_name, weight_values, topk):
    """
    Get predictions from published and normal workflows
    """
    last_base_tools = list()
    predictions = predictions * weight_values
    prediction_pos = np.argsort(predictions, axis=-1)
    topk_prediction_pos = prediction_pos[-topk:]
    # get tool ids
    pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
    if last_tool_name in base_tools:
        last_base_tools = base_tools[last_tool_name]
        if type(last_base_tools).__name__ == "str":
            # get published or compatible tools for the last tool in a sequence of tools
            last_base_tools = last_base_tools.split(",")
    # get predicted tools
    p_tools = get_predicted_tools(last_base_tools, pred_tool_ids, topk)
    sorted_c_t, sorted_c_v = sort_by_usage(p_tools, class_weights, dictionary)
    return sorted_c_t, sorted_c_v


def compute_recommendations(model, tool_sequence, labels, dictionary, reverse_dictionary, class_weights, topk=10, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    tl_seq_ids = [str(dictionary[t]) for t in tl_seq]
    last_tool_name = tl_seq[-1]
    sample = np.zeros(max_seq_len)
    weight_val = list(class_weights.values())
    weight_val = np.reshape(weight_val, (len(weight_val),))
    for idx, tool_id in enumerate(tl_seq_ids):
        sample[idx] = int(tool_id)
    sample_reshaped = np.reshape(sample, (1, max_seq_len))
    tool_sequence_names = [reverse_dictionary[str(tool_pos)] for tool_pos in tl_seq_ids]
    # predict next tools for a test path
    prediction = model.predict(sample_reshaped, verbose=0)
    nw_dimension = prediction.shape[1]
    prediction = np.reshape(prediction, (nw_dimension,))
    
    half_len = int(nw_dimension / 2)
    
    pub_t, pub_v = separate_predictions(standard_connections, prediction[:half_len], last_tool_name, weight_val, topk)
    # get recommended tools from normal workflows
    c_t, c_v = separate_predictions(compatible_tools, prediction[half_len:], last_tool_name, weight_val, topk)
    # combine predictions coming from different workflows
    # promote recommended tools coming from published workflows
    # to the top and then show other recommendations
    print()
    tool_seq_name = ",".join(tool_sequence_names)
    print("Current tool sequence: ")
    print()
    print(tool_seq_name)
    print()
    print("Overall recommendations: ")
    print()
    pub_t.extend(c_t)
    pub_v.extend(c_v)
    # remove duplicates if any
    pub_t = list(dict.fromkeys(pub_t))
    pub_v = list(dict.fromkeys(pub_v))
    print(pub_t)
    ids_tools = dict()
    for key in pub_t:
        ids_tools[key] = dictionary[key]
    print()
    print("Recommended tool ids:")
    print()
    for i in ids_tools:
        rev_id = dictionary[i]
        wt = class_weights[str(rev_id)]
        print(i + "(" + str(ids_tools[i]) + ")" + "(" + str(wt) + ")")

## Unpack trained model for prediction

In [67]:
model_path = "../data/feb_22/tool_recommendation_model.hdf5"
with h5py.File(model_path, 'r') as trained_model:
    model_config = json.loads(trained_model["model_config"][()])
    dictionary = json.loads(trained_model["data_dictionary"][()])
    class_weights = json.loads(trained_model["class_weights"][()])
    standard_connections = json.loads(trained_model["standard_connections"][()])
    compatible_tools = json.loads(trained_model["compatible_tools"][()])
    loaded_model = model_from_json(model_config)
    model, dictionary, reverse_dictionary = create_model(model_path)

[[ 2.5616992e-02 -9.5330365e-03  2.5758710e-02 ...  4.8993718e-02
  -2.2291685e-02  3.0981805e-02]
 [ 3.1450633e-02  2.6231024e-02  1.9229297e-02 ... -2.6189530e-02
  -4.6460070e-02  1.7058577e-02]
 [-6.9594055e-01 -1.3811614e-01 -5.5850995e-01 ... -2.8869215e-01
   7.7565575e-01 -2.0640020e-01]
 ...
 [-9.5944655e-01 -2.8117830e-01 -6.4869680e-02 ...  5.9618433e-03
  -3.1834251e-01  1.6874691e+00]
 [-7.2706378e-01 -9.6543670e-02 -2.5350434e-01 ...  3.1502196e-01
   2.4037234e-01  4.2095359e-02]
 [-3.4828998e-02  6.8374500e-03 -2.4526609e-02 ...  3.1785283e-02
  -5.0432459e-03  5.1588938e-04]]
[[ 0.22047408 -0.68096167 -0.24587005 ...  0.24888049 -0.6327199
   0.6910393 ]
 [-0.05191123  0.73034334  0.0808339  ... -0.20938997 -0.33471814
   0.07782038]
 [-0.03099243  0.14367282  0.18357219 ... -0.5525934   0.23153993
   0.3233636 ]
 ...
 [ 0.58498645 -0.01972618  0.22199841 ...  0.42214638 -0.3336556
   0.13377893]
 [-0.3953597  -0.30302262 -0.10988069 ... -0.5793919   0.66111535
  -1.07

## Example tools

In [68]:
# Assembly: 
# (https://training.galaxyproject.org/training-material/topics/assembly/tutorials/debruijn-graph-assembly/tutorial.html)
# spades -> 'bandage_info', 'fasta-stats', 'bandage_image', 'fasta_filter_by_length', 'abricate', 'quast', 'mlst' ... 
# velveth -> velvetg
# (https://training.galaxyproject.org/training-material/topics/assembly/tutorials/unicycler-assembly/tutorial.html)
# unicycler -> 'bandage_info', 'glimmer_build-icm', 'glimmer_knowlegde-based', 'bandage_image', 'transdecoder', 'minimap2', 'antismash', 'fasta_filter_by_length' ...


## Computational chemistry
# ctb_remDuplicates -> ctb_remIons 
# ctb_remDuplicates,ctb_remIons -> 'ctb_chemfp_mol2fps', 'ctb_compound_convert'
# ctb_remDuplicates,ctb_remIons,ctb_chemfp_mol2fps -> 'ctb_chemfp_butina_clustering', 'ctb_simsearch', 'ctb_chemfp_nxn_clustering', 'comp1'
# (https://training.galaxyproject.org/training-material/topics/computational-chemistry/tutorials/cheminformatics/tutorial.html)


## RAD-seq
# (https://training.galaxyproject.org/training-material/topics/ecology/tutorials/ref-based-rad-seq/tutorial.html)
# stacks_procrad -> 'bwa', 'bwa_wrapper', 'Grep1', 'stacks_denovomap', 'fastqc', 'fastq_filter'


## Epigenetics
# (https://training.galaxyproject.org/training-material/topics/epigenetics/tutorials/atac-seq/tutorial.html)
# cutadapt,bowtie2 => samtools_flagstat', 'picard_MarkDuplicates', 'picard_AddOrReplaceReadGroups', 'macs2_callpeak','bg_sortmerna', 'multiqc', 'hisat2', 'trim_galore', 'bowtie2' ...
# cutadapt,bowtie2,picard_MarkDuplicates -> 'picard_ReorderSam', 'gatk4_mutect2', 'samtools_rmdup'
# cutadapt,bowtie2,picard_MarkDuplicates,genrich -> 'pygenomeTracks'
#
# (https://training.galaxyproject.org/training-material/topics/epigenetics/tutorials/methylation-seq/tutorial.html)
# bwameth -> 'samtools_rmdup', 'samtools_sort', 'bam_to_sam', 'pileometh' ..
# bwameth,pileometh -> 'deeptools_compute_matrix', 'tp_sed_tool', 'Filter1', 'metilene', 'Remove beginning1', 'wig_to_bigWig'
#
# bowtie2 -> samtools_flagstat', 'picard_MarkDuplicates', 'picard_AddOrReplaceReadGroups ... 
# bowtie2,deeptools_multi_bam_summary -> 'deeptools_plot_pca', 'r_correlation_matrix' ...
# (https://training.galaxyproject.org/training-material/topics/epigenetics/tutorials/formation_of_super-structures_on_xi/tutorial.html)

# bowtie2,hicexplorer_hicbuildmatrix -> 'hicexplorer_hicsummatrices', 'hicexplorer_hicplotviewpoint', 'tp_sed_tool', 'hicexplorer_hiccorrectmatrix', 'hicexplorer_hicmergematrixbins', 'hicexplorer_hicpca' ..
# bowtie2,hicexplorer_hicbuildmatrix,hicexplorer_hicmergematrixbins -> 'hicexplorer_hiccorrectmatrix', 'hicexplorer_hicplottads', 'hicexplorer_hicplotmatrix'
# (https://training.galaxyproject.org/training-material/topics/epigenetics/tutorials/hicexplorer/tutorial.html)

# minfi_read450k -> 'minfi_getbeta'


## Genome annotation
# (https://training.galaxyproject.org/training-material/topics/genome-annotation/tutorials/annotation-with-maker/tutorial.html)
# maker -> 'gffread', 'maker_map_ids', 'jcvi_gff_stats'
# (https://training.galaxyproject.org/training-material/topics/genome-annotation/tutorials/annotation-with-prokka/tutorial.html)
# prokka -> 'mlst', 'jbrowse', 'taxonomy_krona_chart' ...


## Imaging
# (https://training.galaxyproject.org/training-material/topics/imaging/tutorials/hela-screen-analysis/tutorial.html)
# ip_filter_standard -> 'ip_histogram_equalization', 'ip_threshold', 'ip_count_objects
# ip_filter_standard,ip_threshold -> ip_binary_to_labelimage', 'ip_2d_split_binaryimage_by_watershed', 'ip_count_objects', 'ip_convertimage'
# ip_filter_standard,ip_threshold,ip_2d_split_binaryimage_by_watershed -> 'ip_2d_filter_segmentation_by_features', 'ip_2d_feature_extraction'


## Mass spectrometry
# mass_spectrometry_imaging_preprocessing -> 'mass_spectrometry_imaging_combine', 'mass_spectrometry_imaging_preprocessing'
# mass_spectrometry_imaging_preprocessing,mass_spectrometry_imaging_combine -> 'maldi_quant_preprocessing', 'mass_spectrometry_imaging_preprocessing', 'mass_spectrometry_imaging_qc' ...
# search_gui -> peptide_shaker
# search_gui,peptide_shaker -> 'mz_to_sqlite', 'Remove beginning1', 'tp_replace_in_column', 'unipept', proteomics_moff ...


## Single cell
# raceid_main -> 'seurat','raceid_trajectory'
# raceid_main,raceid_trajectory -> 'raceid_inspecttrajectory'
# raceid_inspectclusters,__BUILD_LIST__ -> 'picard_MarkDuplicates', 'hisat2', 'stringtie_merge', 'cutadapt', 'tp_cat'
# raceid_filtnormconf -> 'raceid_clustering', '__BUILD_LIST__'
# raceid_filtnormconf,raceid_clustering -> 'raceid_trajectory', 'raceid_inspectclusters'
# scanpy_regress_variable -> scanpy_scale_data 
# scanpy_regress_variable,scanpy_scale_data -> 'scanpy_run_pca', 'scanpy_find_variable_genes'
# scanpy_regress_variable,scanpy_scale_data,scanpy_run_pca -> 'scanpy_compute_graph', 'scanpy_plot', 'scanpy_run_tsne', 'scanpy_plot_embed'

# (https://training.galaxyproject.org/training-material/topics/transcriptomics/tutorials/scrna-preprocessing-tenx/tutorial.html)
# rna_starsolo -> 'dropletutils', 'multiqc'
# rna_starsolo,dropletutils -> 'scanpy_read_10x', 'scanpy_cluster_reduce_dimension', 'seurat_read10x', 'anndata_import', 'scanpy_plot', 'raceid_filtnormconf'


## Variant calling
# (https://training.galaxyproject.org/training-material/topics/variant-analysis/tutorials/microbial-variants/tutorial.html)
# snippy -> 'bedtools_intersectbed', 'vcfvcfintersect', 'Remove beginning1', 'snippy_core', 'qualimap_bamqc', 'freebayes', 'jbrowse', 'vcfcombine'

# https://training.galaxyproject.org/training-material/topics/variant-analysis/tutorials/somatic-variants/tutorial.html
# trimmomatic,bwa_mem,samtools_rmdup,bamleftalign -> 'samtool_filter2', 'ivar_variants', 'freebayes', 'varscan_somatic', 'deeptools_bam_coverage', 'fastqc', 'ngsutils_bam_filter', 'bamFilter', 'rgPicFixMate', 'samtools_calmd'
# trimmomatic,bwa_mem,samtools_rmdup,bamleftalign,varscan_somatic -> 'bcftools_norm', 'gemini_annotate', 'vt_normalize', 'vcffilter2', 'vcfallelicprimitives', 'snpEff'

# (https://training.galaxyproject.org/training-material/topics/variant-analysis/tutorials/dip/tutorial.html)
# freebayes -> 'vcfallelicprimitives', 'bcftools_norm', 'custom_pro_db', 'vcfvcfintersect'
# freebayes,vcfallelicprimitives -> 'vt_normalize', 'snpSift_filter', 'snpSift_annotate', 'snpEff'
# freebayes,vcfallelicprimitives,snpEff -> 'gemini_load', 'mimodd_varreport', 'snpSift_extractFields


# Transcriptomics
#(https://training.galaxyproject.org/training-material/topics/transcriptomics/tutorials/small_ncrna_clustering/tutorial.html)
# samtools_sort,blockclust,sort1 -> 'cshl_awk_tool', 'Show beginning1', 'tp_awk_tool', 'blockbuster'

# (https://training.galaxyproject.org/training-material/topics/transcriptomics/tutorials/ref-based/tutorial.html)
# cutadapt -> 'umi_tools_extract', 'fastq_paired_end_interlacer', 'chira_collapse', 'rna_star',
# cutadapt,rna_star -> 'featurecounts', 'multiqc', 'htseq_count', 'rseqc_infer_experiment', 'samtools_stats', 'bamFilter',
# cutadapt,rna_star,featurecounts -> 'multiqc', 'deseq2', 'tp_sort_header_tool', 'bamFilter', 'collection_column_join ...


# Single-cell HiC
# schicexplorer_schicqualitycontrol -> 'schicexplorer_schicnormalize'
# schicexplorer_schicnormalize -> 'schicexplorer_schicclustersvl', 'schicexplorer_schicconsensusmatrices', 'schicexplorer_schicplotclusterprofiles'
# schicexplorer_schicnormalize,schicexplorer_schicclustersvl -> 'schicexplorer_schicplotclusterprofiles', 'schicexplorer_schicconsensusmatrices'


# Animal detection on acoustic recording
# vigiechiro_idvalid -> 'vigiechiro_bilanenrichipf', 'vigiechiro_bilanenrichirp'

## Recommended tools

In [69]:
topk = 10 # set the maximum number of recommendations
# Give tools ids in a sequence and see the recommendations. # To know all the tool ids, 
# please print the variable 'reverse_dictionary'
tool_seq = "trimmomatic"
compute_recommendations(model, tool_seq, "", dictionary, reverse_dictionary, class_weights, topk)

KeyError: 'trimmomatic'