In [1]:
import time
import os
import subprocess
import sys
import re
import argparse
import collections
import gzip
import math
import shutil
import matplotlib.pyplot as plt
import wandb
import numpy as np
import time
from datetime import datetime
import random

import seaborn as sns
%matplotlib inline
import logging
from silence_tensorflow import silence_tensorflow
#silence_tensorflow()
os.environ['TPU_LOAD_LIBRARY']='0'
os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'
import tensorflow as tf


import tensorflow.experimental.numpy as tnp
import tensorflow_addons as tfa
from tensorflow import strings as tfs
from tensorflow.keras import mixed_precision
from scipy.stats.stats import pearsonr  
from scipy.stats.stats import spearmanr  
## custom modules
import src.aformer_TF_gc_separated as aformer
#import src.aformer_TF as aformer
from src.layers.layers import *
import src.metrics as metrics
from src.optimizers import *
import src.schedulers as schedulers
import src.utils as utils

import training_utils_aformer_TF_genecentered_separated as training_utils


from scipy import stats



In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='node-15')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.FILE
    options.deterministic=False
    #options.experimental_threading.max_intra_op_parallelism = 1
    mixed_precision.set_global_policy('mixed_bfloat16')
    tf.config.optimizer.set_jit(True)
    #options.num_devices = 64

    BATCH_SIZE_PER_REPLICA = 1
    NUM_REPLICAS = strategy.num_replicas_in_sync
    GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * NUM_REPLICAS

In [None]:
with strategy.scope():
    heads_dict = {}
    orgs = ["hg","mm"]
    for k, org in enumerate(orgs):
        heads_dict[org] = int(k)
    model = aformer.aformer(kernel_transformation="relu_kernel_transformation",
                                dropout_rate=0.35,
                                input_length=16384,
                                num_heads=4,
                                numerical_stabilizer=0.0000001,
                                nb_random_features=128,
                                hidden_size=128,
                                d_model=128,
                                norm=True,
                                dim=32,
                                max_seq_length = 128,
                                rel_pos_bins=512,
                                widening = 2, ## ratio between first and second dense layer units in transformer block
                                conv_filter_size_1_seq=15,
                                conv_filter_size_2_seq=5,
                                conv_filter_size_1_atac=15,
                                conv_filter_size_2_atac=5,
                                positional_dropout_rate=0.1,
                                transformer_depth=1,
                                momentum=0.90,
                                channels_list=[48,48,56,56,64,64], 
                                kernel_regularizer=0.0000001,
                                bottleneck_units=32,
                            bottleneck_units_tf=32,
                                use_mask_pos=False,
                                use_rot_emb=True,
                                heads_dict=heads_dict)
    model.load_weights("gs://picard-testing-176520/16k_genecentered_blacklist0.50_atacnormalized/models/aformer_TF_gene_centered_test/final/saved_model")

In [None]:
with strategy.scope():
    gcs_path = "gs://picard-testing-176520/test_gene/test_gene.tfr"
    val_data = training_utils.return_dataset_interpret(gcs_path,
                                                       strategy,
                                                         1,
                                                         16384,
                                                         "logTPM",
                                                         4,
                                                         10,
                                                         1637)

In [None]:
with strategy.scope():

    model = aformer.aformer(kernel_transformation="relu_kernel_transformation",
                                dropout_rate=0.35,
                                input_length=16384,
                                num_heads=4,
                                numerical_stabilizer=0.0000001,
                                nb_random_features=128,
                                hidden_size=128,
                                d_model=128,
                                norm=True,
                                dim=32,
                                max_seq_length = 128,
                                rel_pos_bins=512,
                                widening = 2, ## ratio between first and second dense layer units in transformer block
                                conv_filter_size_1_seq=15,
                                conv_filter_size_2_seq=5,
                                conv_filter_size_1_atac=15,
                                conv_filter_size_2_atac=5,
                                positional_dropout_rate=0.1,
                                transformer_depth=1,
                                momentum=0.90,
                                channels_list=[48,48,56,56,64,64], 
                                kernel_regularizer=0.0000001,
                                bottleneck_units=32,
                            bottleneck_units_tf=32,
                                use_mask_pos=False,
                                use_rot_emb=True,
                                heads_dict=heads_dict)
    model.load_weights("gs://picard-testing-176520/16k_genecentered_blacklist0.50_atacnormalized/models/aformer_TF_gene_centered_test/final/saved_model")


    def predict_on_batch(model, inputs):
        return model.predict_on_batch(inputs)

    @tf.function
    def contribution_input_grad(model, model_inputs,output_head='hg'):
        seq, atac, tf_acc=model_inputs

        with tf.GradientTape() as input_grad_tape:
            input_grad_tape.watch(seq)
            input_grad_tape.watch(atac)
            input_grad_tape.watch(tf_acc)
            inputs = seq,atac,tf_acc
            prediction = model.predict_on_batch(inputs)[0][output_head]


        input_grads = input_grad_tape.gradient(prediction,inputs)
        

        input_grads_seq = input_grads[0] 
        input_grads_atac = input_grads[1]
        input_grads_tf_acc = input_grads[2]
        
        
        seq_grads = tf.reduce_sum(input_grads_seq[0,:,1:] * test_input['inputs'].values[0][0,:,1:],
                                  axis=1)
        
        tss_grads = input_grads_seq[0,:,0]
        
        atac_grads = input_grads_atac[0,:,] * test_input['atac'].values[0][0,:,]
        
        tf_acc_grads = input_grads[2][0,:]
        
        return seq_grads, tss_grads, atac_grads, tf_acc_grads,prediction


In [None]:
with strategy.scope():
    test_input = next(val_data)
    inputs = test_input['inputs'].values[0],test_input['atac'].values[0], test_input['TF_acc'].values[0]
    scores = contribution_input_grad(model,inputs)

In [None]:
def plot_track_seq(tss_tokens, input_arr, start, stop, height=1.5):
    length=stop-start


    # Set y-limit, making neg y-values not show in plot
    #plt.ylim(start, stop)
    x_vals = np.linspace(start,stop,num=length)
    baseline = np.zeros_like(x_vals)
    # Filling between line y3 and line y4
    plt.fill_between(x_vals, baseline, input_arr[start:stop],alpha=0.9)

    plt.show()
    

In [None]:
tss_tokens = test_input['tss_tokens'].values[0]
plot_track_seq(tss_tokens,scores[0],0,16384)

In [None]:
plot_track_seq(tss_tokens,tf.abs(scores[0]),0,16384)

In [None]:
plt.imshow((tf.abs(scores[1][tf.newaxis])), aspect = "auto", cmap="viridis")
plt.gca().set_yticks([])
plt.show()

In [None]:



#plot_track_seq(tss_tokens,test_input['atac'].values[0][0,:,0],0,16384)


plt.imshow(tf.transpose(tf.abs(scores[2])), aspect = "auto", cmap="viridis")
plt.gca().set_yticks([])
plt.show()

In [None]:
#plot_track_seq(tss_tokens,test_input['atac'].values[0][0,:,0],0,16384)


plt.imshow(tf.transpose(test_input['atac'].values[0][0,:]), aspect = "auto", cmap="viridis")
plt.gca().set_yticks([])
plt.show()

In [None]:
plt.imshow(tss_tokens[np.newaxis], aspect = "auto", cmap="viridis")
plt.gca().set_yticks([])
plt.show()

In [None]:
inputs = test_input['inputs'].values[0],test_input['atac'].values[0], test_input['TF_acc'].values[0]
att_matrices = model.predict_on_batch(inputs)[1]
k_1,q_1 = att_matrices['layer_0']
mat = tf.nn.softmax((k_1[:,0,:,:] * tf.transpose(q_1[:,0,:,:])) / tf.math.sqrt(128.0))

mat_ave = sum([mat[:,0,k,:] for k in range(4)]) / 4.0

plt.matshow(mat_ave)
plt.colorbar()
plt.show()