In [16]:
import numpy as np
import json
import warnings
import operator

import h5py
from keras.models import model_from_json
from keras import backend as K

from matplotlib import pyplot as plt

warnings.filterwarnings("ignore")

size_title = 18
size_label = 14
n_pred = 2

base_path = "data/remote_sqrt/"

path_data_dict = base_path + "data_dict.txt"
path_inverted_wt = base_path + "inverted_weights.txt"
path_usage_wt = base_path + "usage_prediction.txt"
path_class_wt = base_path + "class_weights.txt"
path_test_data = base_path + "test_data.txt"
model_path = base_path + "trained_model.hdf5"

def read_file(file_path):
    with open(file_path, 'r') as data_file:
        data = json.loads(data_file.read())
    return data

class_weights = read_file(path_class_wt)
usage_weights = read_file(path_usage_wt)
inverted_weights = read_file(path_inverted_wt)
data_dict = read_file(path_data_dict)

def create_model(model_path):
    trained_model = h5py.File(model_path, 'r')
    model_config = json.loads(trained_model.get('model_config').value)
    loaded_model = model_from_json(model_config)
    dictionary = json.loads(trained_model.get('data_dictionary').value)
    compatibile_tools = json.loads(trained_model.get('compatible_tools').value)
    reverse_dictionary = dict((str(v), k) for k, v in dictionary.items())
    model_weights = list()
    weight_ctr = 0
    while True:
        try:
            d_key = "weight_" + str(weight_ctr)
            weights = trained_model.get(d_key).value
            model_weights.append(weights)
            weight_ctr += 1
        except Exception as exception:
            break
    # set the model weights
    loaded_model.set_weights(model_weights)
    return loaded_model, dictionary, reverse_dictionary, compatibile_tools

model, dictionary, reverse_dictionary, compatibile_tools = create_model(model_path)

In [17]:
reverse_dictionary

{'1': 'gatk_realigner_target_creator',
 '2': 'seq_filter_by_id',
 '3': 'w4mclstrpeakpics',
 '4': 'tp_awk_tool',
 '5': 'mothur_classify_seqs',
 '6': 'gatk2_print_reads',
 '7': 'circgraph',
 '8': 'bowtie2',
 '9': 'mimodd_varreport',
 '10': 'NSPDK_candidateClust',
 '11': 'gatk2_variant_filtration',
 '12': 'rmcontamination',
 '13': 'metaspades',
 '14': 'Draw_phylogram',
 '15': 'sequence_content_trimmer',
 '16': 'vcfallelicprimitives',
 '17': 'CONVERTER_Bam_Bai_0',
 '18': 'mothur_parsimony',
 '19': 'sklearn_feature_selection',
 '20': 'align_families',
 '21': 'cast',
 '22': 'velvet',
 '23': 'Show tail1',
 '24': 'fragmenter',
 '25': 'samtools_stats',
 '26': 'idr-embl',
 '27': 'rbc_mafft',
 '28': 'vcf2pgSnp',
 '29': 'qiime_core_diversity',
 '30': 'vcf_filter',
 '31': 'edger',
 '32': 't2t_report',
 '33': 'bcftools_view',
 '34': 'bamFilter',
 '35': 'ncbi_makeblastdb',
 '36': 'gtf_filter_by_attribute_values_list',
 '37': 'fastq_paired_end_deinterlacer',
 '38': 'EMBOSS: shuffleseq87',
 '39': 'bgch

In [20]:
def verify_model(model, tl_seq, labels, dictionary, reverse_dictionary, compatible_tools, topk=5, max_seq_len=25):
    tl_seq = tool_sequence.split(",")
    last_tool_name = reverse_dictionary[str(tl_seq[-1])]
    print(last_tool_name)
    last_compatible_tools = compatible_tools[last_tool_name]
    sample = np.zeros(max_seq_len)
    for idx, tool_id in enumerate(tl_seq):
        print(tool_id)
        sample[-(idx + 1)] = int(tool_id)
        sample_reshaped = np.reshape(sample, (1, max_seq_len))

        # predict next tools for a test path
        prediction = model.predict(sample_reshaped, verbose=0)
        prediction = np.reshape(prediction, (prediction.shape[1],))
        prediction_pos = np.argsort(prediction, axis=-1)

        # get topk prediction
        topk_prediction_pos = prediction_pos[-topk:]
        topk_prediction_val = [np.round(prediction[pos] * 100, 2) for pos in topk_prediction_pos]        

        # read tool names using reverse dictionary
        pred_tool_ids = [reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos]
        actual_next_tool_ids = list(set(pred_tool_ids).intersection(set(last_compatible_tools.split(","))))      
        
        print("Actual tools: %s" % ",".join(actual_next_tool_ids))
        #print("Predicted tools: %s" % ",".join(pred_tool_ids))
        print()
        pred_tool_ids_sorted = dict()
        for (tool_pos, tool_pred_val) in zip(topk_prediction_pos, topk_prediction_val):
            tool_name = reverse_dictionary[str(tool_pos)]
            #if tool_name in actual_next_tool_ids:
            pred_tool_ids_sorted[tool_name] = tool_pred_val
        pred_tool_ids_sorted = dict(sorted(pred_tool_ids_sorted.items(), key=lambda kv: kv[1], reverse=True))
        cls_wt = dict()
        usg_wt = dict()
        inv_wt = dict()
        keys = list(pred_tool_ids_sorted.keys())
        for k in keys:
            cls_wt[k] = np.round(class_weights[str(data_dict[k])], 2)
            usg_wt[k] = np.round(usage_weights[k], 2)
            inv_wt[k] = np.round(inverted_weights[str(data_dict[k])], 2)
        print("Predicted tools: \n")
        print(pred_tool_ids_sorted)
        print()
        print("Class weights: \n")
        cls_wt = dict(sorted(cls_wt.items(), key=lambda kv: kv[1], reverse=True))
        print(cls_wt)
        print()
        print("Usage weights: \n")
        usg_wt = dict(sorted(usg_wt.items(), key=lambda kv: kv[1], reverse=True))
        print(usg_wt)
        print()
        print("Inverted weights: \n")
        inv_wt = dict(sorted(inv_wt.items(), key=lambda kv: kv[1], reverse=True))
        print(inv_wt)
        print("======================================")
        # find false positives
        #false_positives = [ tool_name for tool_name in top_predicted_next_tool_names if tool_name not in actual_next_tool_names ]
        #absolute_precision = 1 - ( len( false_positives ) / float( len( actual_classes_pos ) ) )
        #ave_abs_precision.append(absolute_precision)
    #mean_precision = np.mean(ave_abs_precision)
    #print("Absolute precision on test data using current model is: %0.6f" % mean_precision)
    #return mean_precision

'''def get_predictions(model, dictionary, reverse_dictionary, compatibile_tools):
    t_data = read_file(path_test_data)
    ctr = 1
    for ph, cl in t_data.items():
        verify_model(model, ph, cl, dictionary, reverse_dictionary, compatibile_tools)
        ctr += 1
        if ctr == 400:
            break'''
tool_seq = "8,250"
verify_model(model, tool_seq, "", dictionary, reverse_dictionary, compatibile_tools)
#get_predictions(model, dictionary, reverse_dictionary, compatibile_tools)

samtools_flagstat
8
Actual tools: multiqc

Predicted tools: 

{'samtools_flagstat': 85.43, 'multiqc': 77.63, 'picard_MarkDuplicates': 75.86, 'macs2_callpeak': 75.37, 'samtools_rmdup': 70.4}

Class weights: 

{'macs2_callpeak': 1422.38, 'picard_MarkDuplicates': 1129.38, 'samtools_rmdup': 641.38, 'multiqc': 596.61, 'samtools_flagstat': 468.53}

Usage weights: 

{'macs2_callpeak': 576.96, 'multiqc': 434.46, 'picard_MarkDuplicates': 392.84, 'samtools_flagstat': 295.48, 'samtools_rmdup': 253.4}

Inverted weights: 

{'macs2_callpeak': 3506.6, 'picard_MarkDuplicates': 3246.85, 'samtools_rmdup': 1623.43, 'multiqc': 819.3, 'samtools_flagstat': 742.92}
250
Actual tools: multiqc

Predicted tools: 

{'sam_to_bam': 20.55, 'tp_cat': 19.1, 'multiqc': 17.68, 'samtools_rmdup': 15.43, 'Cut1': 12.31}

Class weights: 

{'samtools_rmdup': 641.38, 'multiqc': 596.61, 'tp_cat': 293.59, 'sam_to_bam': 236.4, 'Cut1': 93.84}

Usage weights: 

{'Cut1': 1425.46, 'multiqc': 434.46, 'samtools_rmdup': 253.4, 'tp_cat':