In [3]:
"""
Create transformers
"""
import os
import time
import numpy as np
import subprocess
import h5py
import matplotlib.pyplot as plt
import sys
import random
import json

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, Dropout, Embedding, SpatialDropout1D
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dropout, Layer
from tensorflow.keras import backend as K

from tensorflow.keras.layers import Embedding, Input, GlobalAveragePooling1D, Dense
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import EarlyStopping

import numpy as np
from sklearn.manifold import TSNE
import pandas as pd
import seaborn as sns
from matplotlib.pyplot import figure


fig_size = (15, 15)
font = {'family': 'serif', 'size': 8}
plt.rc('font', **font)

batch_size = 100
test_batches = 1
n_topk = 1
max_seq_len = 25

embed_dim = 128 # Embedding size for each token d_model
num_heads = 4 # Number of attention heads
ff_dim = 128 # Hidden layer size in feed forward network inside transformer # dff
dropout = 0.1
seq_len = 25

predict_rnn = False

if predict_rnn is True:
    base_path = "../models/rnn/"
    #"log_19_09_22_GPU_RNN_full_data/"
    #"/media/anupkumar/b1ea0d39-97af-4ba5-983f-cd3ff76cf7a6/tool_prediction_datasets/computed_results/aug_22 data/rnn/run2/" #"log_19_09_22_GPU_RNN_full_data/" #"log_22_08_22_rnn/" #"log_08_08_22_rnn/"
else:
    base_path = "../models/transformer/"
    #"log_19_09_22_GPU_transformer_full_data/" #"log_12_09_22_GPU/" #"log_19_09_22_GPU_transformer_full_data/" 

model_number = 40000
model_path = base_path + "saved_model/" + str(model_number) + "/tf_model/"
model_path_h5 = base_path + "saved_model/" + str(model_number) + "/tf_model_h5/"


import tensorflow as tf
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, Dropout, Layer
from tensorflow.keras.layers import Dense, Embedding
from tensorflow.keras.models import Sequential


class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate)
        self.ffn = Sequential(
            [Dense(ff_dim, activation="relu"), Dense(embed_dim)]
        )
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training):
        attn_output, attention_scores = self.att(inputs, inputs, inputs, return_attention_scores=True, training=training)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output), attention_scores


class TokenAndPositionEmbedding(Layer):
    def __init__(self, maxlen, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True)
        self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions


def read_file(file_path):
    """
    Read a file
    """
    with open(file_path, "r") as json_file:
        file_content = json.loads(json_file.read())
    return file_content


def write_file(file_path, content):
    """
    Write a file
    """
    remove_file(file_path)
    with open(file_path, "w") as json_file:
        json_file.write(json.dumps(content))

def create_rnn_model(seq_len, vocab_size):

    seq_inputs = tf.keras.Input(batch_shape=(None, seq_len))

    gen_embedding = tf.keras.layers.Embedding(vocab_size, embed_dim, mask_zero=True)
    in_gru = tf.keras.layers.GRU(ff_dim, return_sequences=True, return_state=False)
    out_gru = tf.keras.layers.GRU(ff_dim, return_sequences=False, return_state=True)
    enc_fc = tf.keras.layers.Dense(vocab_size, activation='sigmoid', kernel_regularizer="l2")

    embed = gen_embedding(seq_inputs)

    embed = tf.keras.layers.Dropout(dropout)(embed)

    gru_output = in_gru(embed)

    gru_output = tf.keras.layers.Dropout(dropout)(gru_output)

    gru_output, hidden_state = out_gru(gru_output)

    gru_output = tf.keras.layers.Dropout(dropout)(gru_output)

    fc_output = enc_fc(gru_output)

    return Model(inputs=[seq_inputs], outputs=[fc_output])

def create_transformer_model(maxlen, vocab_size):
    inputs = Input(shape=(maxlen,))
    #a_mask = Input(shape=(maxlen, maxlen))
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
    x, weights = transformer_block(x)
    x = GlobalAveragePooling1D()(x)
    x = Dropout(dropout)(x)
    x = Dense(ff_dim, activation="relu")(x)
    x = Dropout(dropout)(x)
    outputs = Dense(vocab_size, activation="sigmoid")(x)
    return Model(inputs=inputs, outputs=[x, outputs, weights])


def sample_balanced(x_seqs, y_labels, ulabels_tr_dict):
    batch_tools = list(ulabels_tr_dict.keys())
    random.shuffle(batch_tools)
    last_tools = batch_tools[:batch_size]
    rand_batch_indices = list()
    for l_tool in last_tools:
        seq_indices = ulabels_tr_depo_tr_batch_lossict[l_tool]
        random.shuffle(seq_indices)
        rand_batch_indices.append(seq_indices[0])

    x_batch_train = x_seqs[rand_batch_indices]
    y_batch_train = y_labels[rand_batch_indices]
    unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64)
    unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64)
    return unrolled_x, unrolled_y


def get_u_tr_labels(y_tr):
    labels = list()
    labels_pos_dict = dict()
    for i, item in enumerate(y_tr):
        label_pos = np.where(item > 0)[0]
        labels.extend(label_pos)
        for label in label_pos:
            if label not in labels_pos_dict:
                labels_pos_dict[label] = list()
            labels_pos_dict[label].append(i)

    u_labels = list(set(labels))
    
    for item in labels_pos_dict:
        labels_pos_dict[item] = list(set(labels_pos_dict[item]))
    return u_labels, labels_pos_dict


def sample_balanced_tr_y(x_seqs, y_labels, ulabels_tr_y_dict):
    batch_y_tools = list(ulabels_tr_y_dict.keys())
    random.shuffle(batch_y_tools)
    label_tools = list()
    rand_batch_indices = list()

    for l_tool in batch_y_tools:
        seq_indices = ulabels_tr_y_dict[l_tool]
        random.shuffle(seq_indices)
        
        if seq_indices[0] not in rand_batch_indices:
            rand_batch_indices.append(seq_indices[0])
            label_tools.append(l_tool)
        if len(rand_batch_indices) == batch_size:
            break
    
    x_batch_train = x_seqs[rand_batch_indices]
    y_batch_train = y_labels[rand_batch_indices]

    unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64)
    unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64)
    return unrolled_x, unrolled_y, label_tools, rand_batch_indices


def verify_tool_in_tr(r_dict):
    all_sel_tool_ids = read_file(base_path + "data/all_sel_tool_ids.txt").split(",")

    freq_dict = dict()
    freq_dict_names = dict()

    for tool_id in all_sel_tool_ids:
        if tool_id not in freq_dict:
            freq_dict[tool_id] = 0

        if tool_id not in freq_dict_names:
            freq_dict_names[r_dict[str(int(tool_id))]] = 0

        freq_dict[tool_id] += 1
        freq_dict_names[r_dict[str(int(tool_id))]] += 1

    s_freq = dict(sorted(freq_dict.items(), key=lambda kv: kv[1], reverse=True))
    s_freq_names = dict(sorted(freq_dict_names.items(), key=lambda kv: kv[1], reverse=True))

    write_file(base_path + "data/s_freq_names.txt", s_freq_names)
    write_file(base_path + "data/s_freq.txt", s_freq)

    return s_freq


def read_h5_model():
    print(model_path_h5)
    h5_path = model_path_h5 + "model.h5"
    model_h5 = h5py.File(h5_path, 'r')

    r_dict = json.loads(model_h5["reverse_dict"][()].decode("utf-8"))
    #print(r_dict)
    m_load_s_time = time.time()
    #tf_loaded_model = create_transformer_model(seq_len, len(r_dict) + 1)
    tf_loaded_model = create_transformer_model(seq_len, len(r_dict) + 1)
    tf_loaded_model.load_weights(h5_path)
    m_load_e_time = time.time()
    model_loading_time = m_load_e_time - m_load_s_time

    f_dict = dict((v, k) for k, v in r_dict.items())
    c_weights = json.loads(model_h5["class_weights"][()].decode("utf-8"))
    c_tools = json.loads(model_h5["compatible_tools"][()].decode("utf-8"))
    s_conn = json.loads(model_h5["standard_connections"][()].decode("utf-8"))

    model_h5.close()

    return tf_loaded_model, f_dict, r_dict, c_weights, c_tools, s_conn, model_loading_time


def read_model():
    print(model_path)
    m_load_s_time = time.time()
    tf_loaded_model = tf.saved_model.load(model_path)
    m_load_e_time = time.time()
    m_l_time = m_load_e_time - m_load_s_time
    r_dict = read_file(base_path + "data/rev_dict.txt")
    f_dict = read_file(base_path + "data/f_dict.txt")
    c_weights = read_file(base_path + "data/class_weights.txt")
    c_tools = read_file(base_path + "data/compatible_tools.txt")
    s_conn = read_file(base_path + "data/published_connections.txt")

    return tf_loaded_model, f_dict, r_dict, c_weights, c_tools, s_conn, m_l_time
    
    
def plot_TSNE(embed, labels):
    print("Plotting embedding...")
    print(labels)

    #perplexity = 50
    n_colors = 10
    figsize = (8, 8)

    figure(figsize=figsize, dpi=150)

    z = TSNE(n_components=2).fit_transform(embed)

    df = pd.DataFrame()
    df["comp-1"] = z[:,0]
    df["comp-2"] = z[:,1]

    sns.scatterplot(x="comp-1", y="comp-2", hue=labels, data=df).set(title="T-SNE projection") #palette=sns.color_palette("hls", n_colors)
    plt.show()
    

def predict_seq():

    #visualize_loss_acc()  

    #sys.exit()

    #plot_model_usage_time()

    #tool_tr_freq = read_file(base_path + "data/all_sel_tool_ids.txt")
    #verify_training_sampling(tool_tr_freq, r_dict)  

    path_test_data = base_path + "saved_data/test.h5"
    
    print(path_test_data)

    file_obj = h5py.File(path_test_data, 'r')

    #test_target = tf.convert_to_tensor(np.array(file_obj["target"]), dtype=tf.int64)
    test_input = np.array(file_obj["input"])
    test_target = np.array(file_obj["target"])

    print(test_input.shape, test_target.shape)

    if predict_rnn is True:
        print(model_path)
        m_load_s_time = time.time()
        tf_loaded_model = tf.saved_model.load(model_path)
        m_load_e_time = time.time()
        model_loading_time = m_load_e_time - m_load_s_time
        r_dict = read_file(base_path + "data/rev_dict.txt")
        f_dict = read_file(base_path + "data/f_dict.txt")
        class_weights = read_file(base_path + "data/class_weights.txt")
        compatible_tools = read_file(base_path + "data/compatible_tools.txt")
        published_connections = read_file(base_path + "data/published_connections.txt")
    else:
        #tf_loaded_model, f_dict, r_dict, class_weights, compatible_tools, published_connections, model_loading_time = read_model()
        tf_loaded_model, f_dict, r_dict, class_weights, compatible_tools, published_connections, model_loading_time = read_h5_model()

    c_weights = list(class_weights.values())

    c_weights = tf.convert_to_tensor(c_weights, dtype=tf.float32)

    u_te_y_labels, u_te_y_labels_dict = get_u_tr_labels(test_target)

    precision = list()
    pub_prec_list = list()
    error_label_tools = list()
    batch_pred_time = list()
    for j in range(test_batches):

        te_x_batch, y_train_batch, selected_label_tools, bat_ind = sample_balanced_tr_y(test_input, test_target, u_te_y_labels_dict)

        #print(j * batch_size, j * batch_size + batch_size)
        #te_x_batch = test_input[j * batch_size : j * batch_size + batch_size, :]
        #y_train_batch = test_target[j * batch_size : j * batch_size + batch_size, :]

        te_x_batch = tf.cast(te_x_batch, dtype=tf.float32, name="input_2")
        
        pred_s_time = time.time()
        
        if predict_rnn is True:
            te_prediction = tf_loaded_model(te_x_batch, training=False)
        else:
            embed, te_prediction, att_weights = tf_loaded_model(te_x_batch, training=False)
            print(embed.shape, te_prediction.shape, att_weights.shape)
           
        pred_e_time = time.time()
        diff_time = (pred_e_time - pred_s_time) / float(batch_size)
        batch_pred_time.append(diff_time)
        filter_embed = list()
        filter_embed_label = list()
        filter_embed_label_names = list()
        for i, (inp, tar) in enumerate(zip(te_x_batch, y_train_batch)):

            t_ip = te_x_batch[i]
            tar = y_train_batch[i]
            prediction = te_prediction[i]
            if len(np.where(inp > 0)[0]) <= max_seq_len:
            #if len(np.where(tar > 0)[0]) == 1:
                print(tar, len(np.where(tar > 0)[0]))
                real_prediction = np.where(tar > 0)[0]
                target_pos = real_prediction #list(set(all_tr_label_tools_ids).intersection(set(real_prediction)))

                prediction_wts = tf.math.multiply(c_weights, prediction)
                filter_embed.append(embed[i])
                n_topk = len(target_pos)
                top_k = tf.math.top_k(prediction, k=n_topk, sorted=True)
                print(i, top_k.indices.numpy())
                top_k_wts = tf.math.top_k(prediction_wts, k=n_topk, sorted=True)

                t_ip = t_ip.numpy()
                label_pos = np.where(t_ip > 0)[0]
                
                one_target_pos = target_pos[np.random.randint(len(target_pos))]
                filter_embed_label_names.append(r_dict[str(one_target_pos)])
                filter_embed_label.append(str(one_target_pos))
                
                i_names = ",".join([r_dict[str(int(item))] for item in t_ip[label_pos]  if item not in [0, "0"]])
                t_names = ",".join([r_dict[str(int(item))] for item in target_pos  if item not in [0, "0"]])

                last_i_tool = [r_dict[str(int(item))] for item in t_ip[label_pos]][-1]

                true_tools = [r_dict[str(int(item))] for item in target_pos]

                pred_tools = [r_dict[str(int(item))] for item in top_k.indices.numpy()  if item not in [0, "0"]]
                pred_tools_wts = [r_dict[str(int(item))] for item in top_k_wts.indices.numpy()  if item not in [0, "0"]]

                intersection = list(set(true_tools).intersection(set(pred_tools)))

                pub_prec = 0.0
                pub_prec_wt = 0.0

                if last_i_tool in published_connections:
                    true_pub_conn = published_connections[last_i_tool]

                    if len(pred_tools) > 0:
                        intersection_pub = list(set(true_pub_conn).intersection(set(pred_tools)))
                        intersection_pub_wt = list(set(true_pub_conn).intersection(set(pred_tools_wts)))
                        pub_prec = float(len(intersection_pub)) / len(pred_tools)
                        pub_prec_list.append(pub_prec)
                        pub_prec_wt = float(len(intersection_pub_wt)) / len(pred_tools)
                    else:
                        pub_prec = False
                        pub_prec_wt = False

                if len(pred_tools) > 0:
                    pred_precision = float(len(intersection)) / len(pred_tools)
                    precision.append(pred_precision)

                if pred_precision < 2.0:
            
                    print("Test batch {}, Tool sequence: {}".format(j+1, [r_dict[str(int(item))] for item in t_ip[label_pos]]))
                    print()
                    print("Test batch {}, True tools: {}".format(j+1, true_tools))
                    print()
                    print("Test batch {}, Predicted top {} tools: {}".format(j+1, n_topk, pred_tools))
                    print()
                    print("Test batch {}, Predicted top {} tools with weights: {}".format(j+1, n_topk, pred_tools_wts))
                    print()
                    print("Test batch {}, Precision: {}".format(j+1, pred_precision)) 
                    print()
                    print("Test batch {}, Published precision: {}".format(j+1, pub_prec))
                    print()
                    print("Test batch {}, Published precision with weights: {}".format(j+1, pub_prec_wt))
                    print()
                    print("Time taken to predict tools: {} seconds".format(diff_time))
                    print("=========================")
                print("--------------------------")
                generated_attention(att_weights[i], i_names, f_dict, r_dict)
                print("Batch {} prediction finished ...".format(j+1))

    plot_TSNE(np.array(filter_embed), filter_embed_label_names)
    sys.exit()
    
    te_lowest_t_ids = read_file(base_path + "data/te_lowest_t_ids.txt")
    lowest_t_ids = [int(item) for item in te_lowest_t_ids.split(",")]
    print(lowest_t_ids)
    lowest_t_ids = lowest_t_ids[:1]
    
    low_te_data = test_input[lowest_t_ids]
    low_te_labels = test_target[lowest_t_ids]
    low_te_data = tf.cast(low_te_data, dtype=tf.float32)
    low_topk = 20
    low_te_precision = list()
    low_te_pred_time = list()

    pred_s_time = time.time()
    if predict_rnn is True:
        bat_low_prediction = tf_loaded_model(low_te_data, training=False)
    else:
        bat_embed_low, bat_low_prediction, att_weights = tf_loaded_model(low_te_data, training=False)
        print(bat_embed_low.shape, bat_low_prediction.shape, att_weights.shape)
    pred_e_time = time.time()
    low_diff_pred_t = (pred_e_time - pred_s_time) / float(len(lowest_t_ids))
    low_te_pred_time.append(low_diff_pred_t)
    print("Time taken to predict tools: {} seconds".format(low_diff_pred_t))

    for i, (low_inp, low_tar) in enumerate(zip(low_te_data, low_te_labels)):

        low_prediction = bat_low_prediction[i]
        low_tar = low_te_labels[i]
        low_label_pos = np.where(low_tar > 0)[0]

        low_topk = len(low_label_pos)
        low_topk_pred = tf.math.top_k(low_prediction, k=low_topk, sorted=True)
        low_topk_pred = low_topk_pred.indices.numpy()
        
        low_label_pos_tools = [r_dict[str(int(item))] for item in low_label_pos if item not in [0, "0"]]
        low_pred_label_pos_tools = [r_dict[str(int(item))] for item in low_topk_pred if item not in [0, "0"]]

        low_intersection = list(set(low_label_pos_tools).intersection(set(low_pred_label_pos_tools)))
        low_pred_precision = float(len(low_intersection)) / len(low_label_pos)
        low_te_precision.append(low_pred_precision)

        low_inp_pos = np.where(low_inp > 0)[0]
        low_inp = low_inp.numpy()
        print(low_inp, low_inp_pos)
        print("{}, Low: test tool sequence: {}".format(i, [r_dict[str(int(item))] for item in low_inp[low_inp_pos]]))
        print()
        print("{},Low: True labels: {}".format(i, low_label_pos_tools))
        print()
        print("{},Low: Predicted labels: {}, Precision: {}".format(i, low_pred_label_pos_tools, low_pred_precision))
       
        print("-----------------")
        print()

    if test_batches > 0:
        print("Batch Precision@{}: {}".format(n_topk, np.mean(precision)))
        print("Batch Published Precision@{}: {}".format(n_topk, np.mean(pub_prec_list)))
        print("Batch Trained model loading time: {} seconds".format(model_loading_time))
        print("Batch average seq pred time: {} seconds".format(np.mean(batch_pred_time)))
        print("Batch total model loading and pred time: {} seconds".format(model_loading_time + np.mean(batch_pred_time)))
        print()
        
    print("----------------------------")
    print()
    print("Predicting for individual sequences...")
    print()
    print("Predicting for individual tools or sequences")
    n_topk_ind = 20
    t_ip = np.zeros((1, 25))

    t_ip[0, 0] = int(f_dict["bowtie2"])
    t_ip[0, 1] = int(f_dict["hicexplorer_hicbuildmatrix"])
    t_ip[0, 2] = int(f_dict["hicexplorer_chicqualitycontrol"])
    t_ip[0, 3] = int(f_dict["hicexplorer_chicviewpointbackgroundmodel"])
    t_ip[0, 4] = int(f_dict["hicexplorer_chicviewpoint"])
    
    last_tool_name = "hicexplorer_chicviewpoint"
    
    t_ip = tf.convert_to_tensor(t_ip, dtype=tf.int64)
    t_ip = tf.cast(t_ip, dtype=tf.float32)
    
    pred_s_time = time.time()
    if predict_rnn is True:
        prediction = tf_loaded_model(t_ip, training=False)
    else:
        indi_embed, prediction, att_weights = tf_loaded_model(t_ip, training=False)
        print(indi_embed.shape, prediction.shape, att_weights.shape)
    pred_e_time = time.time()
    print("Time taken to predict tools: {} seconds".format(pred_e_time - pred_s_time))
    prediction_cwts = tf.math.multiply(c_weights, prediction)

    top_k = tf.math.top_k(prediction, k=n_topk_ind, sorted=True)
    top_k_wts = tf.math.top_k(prediction_cwts, k=n_topk_ind, sorted=True)

    t_ip = t_ip.numpy()[0]
    label_pos = np.where(t_ip > 0)[0]
    print(t_ip)
    print(t_ip.shape, t_ip[label_pos])
    i_names = ",".join([r_dict[str(int(item))] for item in t_ip[label_pos] if item not in [0, "0"]])

    pred_tools = [r_dict[str(int(item))] for item in top_k.indices.numpy()[0] if item not in [0, "0"]]
    pred_tools_wts = [r_dict[str(int(item))] for item in top_k_wts.indices.numpy()[0] if item not in [0, "0"]]

    c_tools = []
    if str(f_dict[last_tool_name]) in compatible_tools:
        c_tools = [r_dict[str(item)] for item in compatible_tools[str(f_dict[last_tool_name])]]

    pred_intersection = list(set(pred_tools).intersection(set(c_tools)))
    prd_te_prec = len(pred_intersection) / float(n_topk_ind)

    print("Tool sequence: {}".format([r_dict[str(int(item))] for item in t_ip[label_pos]]))
    print()
    print("Compatible true tools: {}, size: {}".format(c_tools, len(c_tools)))
    print()
    print("Predicted top {} tools: {}".format(n_topk_ind, pred_tools))
    print()
    print("Predicted precision: {}".format(prd_te_prec))
    print()
    print("Correctly predicted tools: {}".format(pred_intersection))
    print()
    print("Predicted top {} tools with weights: {}".format(n_topk_ind, pred_tools_wts))
    print()
    if predict_rnn is False:
        generated_attention(att_weights, i_names, f_dict, r_dict)


def generated_attention(attention_weights, i_names, f_dict, r_dict):
    try:
        attention_heads = tf.squeeze(attention_weights, 0)
    except:
        attention_heads = attention_weights
    n_heads = attention_heads.shape[1]
    i_names = i_names.split(",")
    in_tokens = i_names
    out_tokens = i_names
    
    mean_att = np.mean(attention_heads, axis=0)
    for h, head in enumerate(attention_heads):
      plot_attention_head(in_tokens, out_tokens, head)
      break


def plot_attention_head(in_tokens, out_tokens, attention):
  fig = plt.figure(figsize=(16, 8))
  ax = plt.gca()
  cax = ax.matshow(attention[:len(in_tokens), :len(out_tokens)], interpolation='nearest')
  ax.set_xlabel(f'Head')

  ax.set_xticks(range(len(in_tokens)))
  ax.set_xticklabels(in_tokens, rotation=90)

  ax.set_yticks(range(len(out_tokens)))
  ax.set_yticklabels(out_tokens)
  fig.colorbar(cax)
  plt.tight_layout()
  plt.show()


def plot_attention_head_axes(att_weights):
    seq_len = 25
    n_heads = 4
    attention_heads = tf.squeeze(att_weights, 0)
    attention_heads = attention_heads.numpy()
    print(attention_heads.shape)
    #print(attention_heads[0:, 0:])
    att_flatten = attention_heads.flatten()
    print(att_flatten.shape)
    block1 = att_flatten[0: seq_len * n_heads]
    block1 = block1.reshape((seq_len, n_heads))
    print(block1.shape)
    
    block2 = att_flatten[attention_heads.shape[0] * attention_heads.shape[1]: ]
    block2 = block2.reshape((attention_heads.shape[0], attention_heads.shape[1]))
    print(block2.shape)
    plt.matshow(block1)
    plt.matshow(block2)
    plt.show()



In [None]:
predict_seq()

../models/transformer/saved_data/test.h5


NameError: name 'f_dict' is not defined

In [None]:
'''

Tool seqs for good attention plots:
'schicexplorer_schicqualitycontrol', 'schicexplorer_schicnormalize', 'schicexplorer_schicclustersvl'


# Tested tools: porechop, schicexplorer_schicqualitycontrol, schicexplorer_schicclustersvl, snpeff_sars_cov_2
    # sarscov2genomes, ivar_covid_aries_consensus, remove_nucleotide_deletions, pangolin
    # bowtie2,lofreq_call
    # dropletutils_read_10x
    # 'bowtie2', 'hicexplorer_hicbuildmatrix'
    # 'mtbls520_04_preparations', 'mtbls520_05a_import_maf', 'mtbls520_06_import_traits', 'mtbls520_07_species_diversity'
    # ctsm_fates: 'xarray_metadata_info', 'interactive_tool_panoply', 'xarray_select', '__EXTRACT_DATASET__'
    # msnbase_readmsdata: 'abims_xcms_xcmsSet', 'xcms_export_samplemetadata', 'xcms_plot_chromatogram'
    # ncbi_eutils_esearch: ncbi_eutils_elink
    # 1_create_conf: '5_calc_stat', '4_filter_sam', '2_map', 'conf4circos', '3_filter_single_pair'
    # pdaug_peptide_data_access: pdaug_tsvtofasta
    # 'pdaug_peptide_data_access', 'pdaug_tsvtofasta': 'pdaug_peptide_sequence_analysis', 'pdaug_fishers_plot', 'pdaug_sequence_property_based_descriptors'
    # 'rankprodthree', 'Remove beginning1', 'cat1', 'Cut1', 'interactions': 'biotranslator', 'awkscript'
    # rpExtractSink: rpCompletion', 'retropath2'
    # 'EMBOSS: transeq101', 'ncbi_makeblastdb', 'ncbi_blastp_wrapper', 'blast_parser', 'hcluster_sg'
    # 'Remove beginning1', 'Cut1', 'param_value_from_file', 'kc-align', 'sarscov2formatter', 'hyphy_fel'
    # abims_CAMERA_annotateDiffreport
    # cooler_csort_pairix
    # mycrobiota-split-multi-otutable_ensembl_gtf2gene_list
    # XY_Plot_1
    # mycrobiota-qc-report
    # 1_create_conf
    # RNAlien
    # ont_fast5_api_multi_to_single_fast5 
    # ctb_remIons

    Incorrect predictions
    # scpipe, 
    # 'delly_call', 'delly_merge'ivar_covid_aries_consensus
    # 'gmap_build', 'gsnap', 'sam_to_bam', 'filter', 'assign', 'polyA'
    # 'bioext_bealign', 'tn93_filter', 'hyphy_cfel'
    # sklearn_build_pipeline
    # split_file_to_collection', 'rdock_rbdock', 'xchem_pose_scoring', 'sucos_max_score'
    # 'rmcontamination', 'scaffold2fasta'  
    # 'rmcontamination', 'scaffold2fasta'
    # cat1', 'fastq_filter', 'cshl_fastq_to_fasta', 'filter_16s_wrapper_script 1'
    # 'TrimPrimer', 'Flash', 'Btrim64', 'uparse'
    # 'cshl_fastq_to_fasta', 'cshl_fastx_trimmer', 'fasta_tabular_converter
    # CryptoGenotyper
    # cooler_makebins
    # 'PeakPickerHiRes', 'FileFilter', 'xcms-find-peaks', 'xcms-collect-peaks'
    # 'TrimPrimer', 'Flash', 'Btrim64'
    # cryptotyperanndata_import
    # ip_spot_detection_2d
    # 'picard_FastqToSam', 'TagBamWithReadSequenceExtended', 'FilterBAM', 'BAMTagHistogram'
    # 'basic_illumination', 'ashlar'
    # 'cghub_genetorrent', 'gatk_indel'
    # 'FeatureFinderMultiplex', 'HighResPrecursorMassCorrector', 'MSGFPlusAdapter', 'PeptideIndexer', 'IDMerger', 'ConsensusID'
    # 'PeakPickerHiRes', 'FileFilter', 'xcms-find-peaks', 'xcms-collect-peaks'
    # 'PeakPickerHiRes', 'FileFilter', 'xcms-find-peaks', 'xcms-collect-peaks', 'xcms-group-peaks', 'xcms-blankfilter', 'xcms-dilutionfilter', 'camera-annotate-peaks', 'camera-group-fwhm', 'camera-find-adducts', 'camera-find-isotopes'
    # 'minfi_read450k', 'minfi_mset'
    # 'msnbase_readmsdata', 'abims_xcms_xcmsSet', 'abims_xcms_refine'
    # # 'snpEff_build_gb', 'bwa_mem', 'samtools_view',

'''

'''
def predict_seq():


    #sys.exit()
    # read test sequences
    r_dict = read_file(base_path + "data/rev_dict.txt")
    f_dict = read_file(base_path + "data/f_dict.txt")
    
    tf_loaded_model = tf.saved_model.load(model_path)
    #predictor = predict_sequences.PredictSequence(tf_loaded_model)

    #predictor(test_input, test_target, f_dict, r_dict)

    #tool_name = "cutadapt"
    #print("Prediction for {}...".format(tool_name))
    bowtie_output = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    bowtie_output = bowtie_output.write(0, [tf.constant(index_start_token, dtype=tf.int64)])
    #bowtie_output = bowtie_output.write(1, [tf.constant(295, dtype=tf.int64)])
    bowtie_o = tf.transpose(bowtie_output.stack())
    #tool_id = f_dict[tool_name]
    #print(tool_name, tool_id)
    tool_list = ["ctb_filter"]
    bowtie_input = np.zeros([1, 25])
    bowtie_input[:, 0] = index_start_token
    bowtie_input[:, 1] = f_dict[tool_list[0]]
    #bowtie_input[:, 2] = f_dict[tool_list[1]]
    #bowtie_input[:, 3] = f_dict["featurecounts"]
    #bowtie_input[:, 4] = f_dict["deseq2"]
    bowtie_input = tf.constant(bowtie_input, dtype=tf.int64)
    print(bowtie_input, bowtie_output, bowtie_o)
    bowtie_pred, _ = tf_loaded_model([bowtie_input, bowtie_o], training=False)
    print(bowtie_pred.shape)
    top_k = tf.math.top_k(bowtie_pred, k=10)
    print("Top k: ", bowtie_pred.shape, top_k, top_k.indices)
    print(np.all(top_k.indices.numpy(), axis=-1))
    print("Predicted tools for {}: {}".format( ",".join(tool_list), [r_dict[str(item)] for item in top_k.indices.numpy()[0][0]]))
    print()
    #print("Generating predictions...")
    #generated_attention(tf_loaded_model, f_dict, r_dict)


def generated_attention(trained_model, f_dict, r_dict):

    np_output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
    np_output_array = np_output_array.write(0, [tf.constant(index_start_token, dtype=tf.int64)])

    n_target_items = 5
    n_input = np.zeros([1, 25])
    n_input[:, 0] = index_start_token
    n_input[:, 1] = f_dict["hicexplorer_hicadjustmatrix"]
    #n_input[:, 2] = f_dict["hicexplorer_hicbuildmatrix"]
    #n_input[:, 3] = f_dict["hicexplorer_hicfindtads"]
    #n_input[:, 4] = f_dict["deseq2"]
    #n_input[:, 5] = f_dict["Add_a_column1"]
    #n_input[:, 6] = f_dict["table_compute"]
    a_input = n_input
    n_input = tf.constant(n_input, dtype=tf.int64)
   
    for i in range(n_target_items):
        #print(i, index)
        output = tf.transpose(np_output_array.stack())
        print("decoder input: ", n_input, output, output.shape)
        orig_predictions, _ = trained_model([n_input, output], training=False)
        #print(orig_predictions.shape)trimmomatic
        #print("true target seq real: ", te_tar_real)
        #print("Pred seq argmax: ", tf.argmax(orig_predictions, axis=-1))
        predictions = orig_predictions[:, -1:, :]
        predicted_id = tf.argmax(predictions, axis=-1)
        np_output_array = np_output_array.write(i+1, predicted_id[0])
    print(output, np_output_array.stack(), output.numpy())
    print("----------")
    last_decoder_layer = "decoder_layer4_block2"
    _, attention_weights = trained_model([n_input, output[:,:-1]], training=False)
    pred_attention = attention_weights[last_decoder_layer]

    print(attention_weights[last_decoder_layer].shape)
    head = 0
    attention_heads = tf.squeeze(attention_weights[last_decoder_layer], 0)
    pred_attention = attention_heads[head]
    print(pred_attention)

    #print(attention_weights)
    in_tokens = [r_dict[str(int(item))] for itscanpy_read_10xem in a_input[0] if item > 0]
    out_tokens = [r_dict[str(int(item))] for item in output.numpy()[0]]
    out_tokens = out_tokens[1:]
    print(in_tokens)
    print(out_tokens)
    pred_attention = pred_attention[:,:len(in_tokens)]
    print(pred_attention)
    plot_attention_head(in_tokens, out_tokens, pred_attention)

scanpy_read_10x
def plot_attention_head(in_tokens, out_tokens, attention):
  # The plot is of the attention when a token walog_19_09_22_GPU_transformer_full_datas generated.
  # The model didn't generate `<START>` in the output. Skip it.

  fig = plt.figure()
  ax = fig.add_subplot(111)
  cax = ax.matshow(attention, interpolation='nearest')ctb_chemfp_nxn_clustering
  fig.colorbar(cax)

  #ax = plt.gca()
  #ax.matshow(attention)

  ax.set_xticks(range(len(in_tokens)))
  ax.set_yticks(range(len(out_tokens)))

  ax.set_xticklabels(in_tokens, rotation=90)
  ax.set_yticklabels(out_tokens)

  plt.show()

'''


