In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import time
import pyracular
import random
from biobeaker.utils import get_angles, positional_encoding
from biobeaker import BEAKER

from tensorflow.data import Dataset

from umap import UMAP
import plotly.express as px
from tensorflow.keras.layers import (
    Dense,
    Embedding,
    Flatten,
    Lambda,
    Subtract,
    Input,
    Concatenate,
    AveragePooling1D,
    Reshape,
    GRU,
    Bidirectional,
    Dropout,
    LSTM,
    Conv1D,
    Conv2D,
    LocallyConnected1D,
    Conv1DTranspose,
)
from tensorflow.keras.models import Model, Sequential

from lib.useful_windows import (
    calc_kmer_numeric_tuple,
    convert_tuple_to_string,
    calc_distance,
    convert_tuple_to_np,
    cos_sim,
    convert_string_to_nparray,
    convert_string_to_nparray_tuple,
)
from lib.bert_inspired import (
    get_angles,
    positional_encoding,
    fasta_generator,
    discriminator_layer,
    point_wise_feed_forward_network,
    gff3_type_classification_layer,
)
from TransformerModelPosConcat_12Oct import Transformer, CustomSchedule, ffn
from collections import defaultdict

tfk = tf.keras
tfkl = tf.keras.layers

In [2]:
pyracular.get_headers_from_sfasta("Vvulg_chr2.sfasta")

['Chr2']

In [3]:
# Hyper parameters
k = 21
window_size = 32
num_layers = 8
embedding_dims = 32
output_dims = 128 # Output dims are also internal dims!
intermediate_dims = 256
num_heads = 8
dropout_rate = 0.15
max_positions = 512
batch_size = 64

In [4]:
transformer = BEAKER(num_layers, embedding_dims, output_dims, num_heads, intermediate_dims, max_positions,
                          dropout=dropout_rate, attention_dropout=dropout_rate, activation=tfa.activations.mish)

# Magic embeddings 
# 
# Kmer -> DNA Embedding
# Where kmer1 (k1) and kmer2 (k2)
# manhattan_distance(k1, k2) =~ alignment_distance(k1, k2)

magic = Dense(embedding_dims, 
                activation=tf.nn.swish, 
                name="Magic", 
                use_bias=False, 
                trainable=False,
                dtype=tf.float32)

magic.build((window_size+1,k*5))

#Load up the weights
weights = np.load("weights/weights_wide_singlelayer_k21_3Aug2020model_21_dims_32_epochs256.npy", allow_pickle=True)
magic.set_weights([weights[0][0]])

transformer.load_weights("beaker_medium_tripleloss")

cls = np.asarray([[1] * 105])

In [5]:
cls = np.asarray([[1] * 105])
gene_token = np.asarray([[1,0,1] * 35])
end_of_query_token = np.asarray([[0, 1, 1] * 35])

#Classes are:
# 5 long
# 0 is query start (for decoder) [1, 0, 0, 0, 0, 0]
# 1 is not a gene [0, 1, 0, 0, 0, 0]
# 2 is gene start [0, 0, 1, 0, 0, 0]
# 3 is gene [0, 0, 0, 1, 0, 0]
# 4 is end of gene [0, 0, 0, 0, 1, 0]
# 5 is end of decoder output [0, 0, 0, 0, 0, 1]

n_classes = 5

def gen():
    fasta = pyracular.Gff3KmerGenerator(
        k, "Dmel.sfasta", window_size, "Dmel.gff3", True
    )
    types = fasta.types()
    gene_index = types.index("gene_Plus")
    minus_gene_index = types.index("gene_Minus")
    for i in fasta:
        if i["rc"]:
            gi = minus_gene_index
        else:
            gi = gene_index
        gene_classifications = []
        #gene_classifications.append([1,0,0,0,0,0]) # Start token
        gene_classifications.append(0)
        gc_count = []
        
        for x in i["classifications"]:
            gc_count.append(x[gi])
            if x[gi] == 1:
                #gene_classifications.append([0,0,0,1,0,0]) # Is a gene
                gene_classifications.append(3)
            else:
                #gene_classifications.append([0,1,0,0,0,0]) # Not a gene
                gene_classifications.append(1)
        total = np.sum(gc_count)
        
        #gene_classifications.append([0,0,0,0,0,1]) # End token
        gene_classifications.append(5)

        # TODO: Bad attempt at balancing
        if total == 0 or total == window_size:
            continue

        sample_weight = (0.5 - (np.abs((total / window_size) - 0.5))) / 0.5
        #        if sample_weight <= 0.5:
        #            continue
        
        for n in range(1, window_size+1):
            prev = gene_classifications[n-1]
            future = gene_classifications[n+1]
            cur = gene_classifications[n]
            
            if cur == 3 and prev == 1:
                gene_classifications[n] = 2; # Gene start
            elif prev == 3 and future == 1:
                gene_classifications[n] = 4
        
        #kmers = np.concatenate([gene_token, i['kmers'], end_of_query_token])
        kmers = np.concatenate([cls, i['kmers']])
        gene_classifications = tf.one_hot(gene_classifications, n_classes)
        yield kmers, gene_classifications[1:-1]
        #yield (kmers, gene_classifications[:-1]), gene_classifications[1:]
        #yield i["kmers"], gene_classifications
        #yield i["kmers"], [gene_classifications], sample_weight
        # yield i['kmers'], [np.average(gene_classifications)]

def validation_gen():
    fasta = pyracular.Gff3KmerGenerator(
        k, "Vvulg_chr2.sfasta", window_size, "Vvulg_chr2.gff3", True
    )
    types = fasta.types()
    gene_index = types.index("gene_Plus")
    minus_gene_index = types.index("gene_Minus")
    for i in fasta:
        if i["rc"]:
            gi = minus_gene_index
        else:
            gi = gene_index
        gene_classifications = []
        #gene_classifications.append([1,0,0,0,0,0]) # Start token
        gene_classifications.append(0)
        gc_count = []
        
        for x in i["classifications"]:
            gc_count.append(x[gi])
            if x[gi] == 1:
                #gene_classifications.append([0,0,0,1,0,0]) # Is a gene
                gene_classifications.append(3)
            else:
                #gene_classifications.append([0,1,0,0,0,0]) # Not a gene
                gene_classifications.append(1)
        total = np.sum(gc_count)
        
        #gene_classifications.append([0,0,0,0,0,1]) # End token
        gene_classifications.append(5)

        # TODO: Bad attempt at balancing
        #if total == 0 or total == window_size:
        #    continue

        sample_weight = (0.5 - (np.abs((total / window_size) - 0.5))) / 0.5
        #        if sample_weight <= 0.5:
        #            continue
        
        for n in range(1, window_size+1):
            prev = gene_classifications[n-1]
            future = gene_classifications[n+1]
            cur = gene_classifications[n]
            
            if cur == 3 and prev == 1:
                gene_classifications[n] = 2; # Gene start
            elif prev == 3 and future == 1:
                gene_classifications[n] = 4
        
        #kmers = np.concatenate([gene_token, i['kmers'], end_of_query_token])
        kmers = np.concatenate([cls, i['kmers']])
        gene_classifications = tf.one_hot(gene_classifications, n_classes)
        yield kmers, gene_classifications[1:-1]
        #yield (kmers, gene_classifications[:-1]), gene_classifications[1:]
        #yield i["kmers"], gene_classifications
        #yield i["kmers"], [gene_classifications], sample_weight
        # yield i['kmers'], [np.average(gene_classifications)]


In [6]:
g = validation_gen()
next(g)[0].shape

(33, 105)

In [81]:
#d0 = tfkl.Dense(128, activation="relu")
bn0 = tfkl.BatchNormalization()
bn1 = tfkl.BatchNormalization()
bn2 = tfkl.BatchNormalization()

drop0 = tfkl.Dropout(0.15)
drop1 = tfkl.Dropout(0.15)
drop2 = tfkl.Dropout(0.15)
drop3 = tfkl.Dropout(0.15)

mha0 = tfkl.MultiHeadAttention(8, 256, 256, dropout=0.15)

#lstm0 = tfkl.Bidirectional(tfkl.LSTM(128, return_sequences=True))
#lstm1 = tfkl.Bidirectional(tfkl.LSTM(128, return_sequences=True))
#conv1d = tf.keras.layers.Conv1D(256, 5, padding='same')
d0 = tfkl.Dense(128, activation="relu")
d1 = tfkl.Dense(5, activation="softmax")

batch_input = Input(shape=(window_size+1, k * 5), dtype="float32", name="BatchInput")
contexts = magic(batch_input)
enc_outputs, _, _ = transformer(contexts, True)
#x = tf.concat([enc_outputs[:,1:], contexts[:,1:]], axis=2)

x = drop1(d0(bn0(drop0(enc_outputs))))
x = bn1(drop2(mha0(x, x, x)))

y = d1(bn2(x[:,1:]))

#y = d1(d0(conv1d(x)))

#y = d1(lstm0(d0(enc_outputs[:,1:])))
#y = d1(l(d0(enc_outputs)))
#y = classification(classifier(tf.concat([enc_outputs, gfc_outputs], axis=-1)))
#y = tf.squeeze(y)

model = Model(inputs=[batch_input], outputs=[y])

transformer.trainable = False
#optimizer = tfa.optimizers.RectifiedAdam(lr=0.00001)
#optimizer = tfa.optimizers.Lookahead(optimizer)
optimizer = tfk.optimizers.Adam(1e-3)
model = Model(inputs=[batch_input], outputs=[y])
print(model.summary())
model.compile(
    metrics=["mae", "mse", "binary_accuracy", "categorical_accuracy", tfk.metrics.TruePositives(), tfk.metrics.FalsePositives()],
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    optimizer=optimizer,
)

Model: "model_38"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
BatchInput (InputLayer)         [(None, 33, 105)]    0                                            
__________________________________________________________________________________________________
Magic (Dense)                   (None, 33, 32)       3360        BatchInput[0][0]                 
__________________________________________________________________________________________________
beaker (BEAKER)                 ((None, 33, 256), {' 17848048    Magic[14][0]                     
__________________________________________________________________________________________________
dropout_21 (Dropout)            (None, 33, 256)      0           beaker[14][0]                    
___________________________________________________________________________________________

In [82]:
ds = (
    Dataset.from_generator(gen, (tf.float32, tf.float32)) #, tf.float32))
    .cache("dmel_genes_ds")
    .repeat(8192*64)
    .shuffle(256)
    .batch(batch_size)
    .prefetch(8)
)

vds = (
    Dataset.from_generator(validation_gen, (tf.float32, tf.float32)) #, tf.float32))
    .cache("vvulg_chr2_genes_ds")
    .shuffle(256)
    .batch(batch_size)
    .prefetch(8)
)

In [83]:
import tensorflow.keras.backend as K

def categorical_focal_loss(alpha, gamma=2.):
    """
    Softmax version of focal loss.
    When there is a skew between different categories/labels in your data set, you can try to apply this function as a
    loss.
           m
      FL = ∑  -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
          c=1
      where m = number of classes, c = class and o = observation
    Parameters:
      alpha -- the same as weighing factor in balanced cross entropy. Alpha is used to specify the weight of different
      categories/labels, the size of the array needs to be consistent with the number of classes.
      gamma -- focusing parameter for modulating factor (1-p)
    Default value:
      gamma -- 2.0 as mentioned in the paper
      alpha -- 0.25 as mentioned in the paper
    References:
        Official paper: https://arxiv.org/pdf/1708.02002.pdf
        https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
    Usage:
     model.compile(loss=[categorical_focal_loss(alpha=[[.25, .25, .25]], gamma=2)], metrics=["accuracy"], optimizer=adam)
    """

    alpha = np.array(alpha, dtype=np.float32)

    def categorical_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred: A tensor resulting from a softmax
        :return: Output tensor.
        """

        # Clip the prediction value to prevent NaN's and Inf's
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        # Calculate Cross Entropy
        cross_entropy = -y_true * K.log(y_pred)

        # Calculate Focal Loss
        loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

        # Compute mean loss in mini_batch
        return K.mean(K.sum(loss, axis=-1))

    return categorical_focal_loss_fixed

In [84]:
csvlog = tf.keras.callbacks.CSVLogger(
   "TransformerGenePredictionLSTM.tsv" , separator=',', append=False
)

In [None]:
optimizer = tfk.optimizers.Adam(1e-5) # 1e-4
model = Model(inputs=[batch_input], outputs=[y])
print(model.summary())
model.compile(
    metrics=["mae", "mse", "binary_accuracy", "categorical_accuracy", tfk.metrics.TruePositives(), tfk.metrics.FalsePositives()],
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    #loss=[categorical_focal_loss(alpha=[[.01, .25, .25, .25, .25]], gamma=2)],
    optimizer=optimizer,
)

model.fit(
    ds,
    use_multiprocessing=True,
    shuffle=False,
    steps_per_epoch=128,
    epochs=64)
    #callbacks=[csvlog])

Model: "model_41"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
BatchInput (InputLayer)         [(None, 33, 105)]    0                                            
__________________________________________________________________________________________________
Magic (Dense)                   (None, 33, 32)       3360        BatchInput[0][0]                 
__________________________________________________________________________________________________
beaker (BEAKER)                 ((None, 33, 256), {' 17848048    Magic[14][0]                     
__________________________________________________________________________________________________
dropout_21 (Dropout)            (None, 33, 256)      0           beaker[14][0]                    
___________________________________________________________________________________________

In [None]:
optimizer = tfk.optimizers.Adam(2e-5)
model.compile(
    metrics=["mae", "mse", "binary_accuracy", "categorical_accuracy", tfk.metrics.TruePositives(), tfk.metrics.FalsePositives()],
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    optimizer=optimizer,
)
model.fit(
    ds,
    use_multiprocessing=True,
    shuffle=False,
    steps_per_epoch=128,
    epochs=128)

In [87]:
vds = (
    Dataset.from_generator(validation_gen, (tf.float32, tf.float32)) #, tf.float32))
    .cache("vvulg_chr2_genes_ds")
    .batch(batch_size)
    .prefetch(8)
)

model.evaluate(vds)

     21/Unknown - 24s 1s/step - loss: 2.4802 - mae: 0.1841 - mse: 0.1440 - binary_accuracy: 0.8256 - categorical_accuracy: 0.5641 - true_positives_28: 23554.0000 - false_positives_28: 18047.0000

KeyboardInterrupt: 

In [88]:
tf.argmax(model.predict(next(iter(vds))[0]), axis=-1)[0]

<tf.Tensor: shape=(32,), dtype=int64, numpy=
array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1])>

In [27]:
transformer.trainable = True
optimizer = tfk.optimizers.Adam(2e-5)
print(model.summary())
model.compile(
    metrics=["mae", "mse", "binary_accuracy", "categorical_accuracy", tfk.metrics.TruePositives(), tfk.metrics.FalsePositives()],
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
    optimizer=optimizer,
)

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
BatchInput (InputLayer)      [(None, 33, 105)]         0         
_________________________________________________________________
Magic (Dense)                (None, 33, 32)            3360      
_________________________________________________________________
beaker (BEAKER)              ((None, 33, 256), {'encod 17848048  
_________________________________________________________________
tf.__operators__.getitem_2 ( (None, 32, 256)           0         
_________________________________________________________________
dense_21 (Dense)             (None, 32, 256)           65792     
_________________________________________________________________
lstm_2 (LSTM)                (None, 32, 256)           525312    
_________________________________________________________________
dense_22 (Dense)             (None, 32, 5)             1285

In [86]:
model.fit(
    ds,
    use_multiprocessing=True,
    shuffle=False,
    steps_per_epoch=128,
    epochs=16)
    #callbacks=[csvlog])

Epoch 1/16
  2/128 [..............................] - ETA: 2:23 - loss: 0.2452 - mae: 0.0490 - mse: 0.0250 - binary_accuracy: 0.9665 - categorical_accuracy: 0.9141 - true_positives_28: 3695.0000 - false_positives_28: 286.0000

KeyboardInterrupt: 

In [None]:
checkpoint_path = "bertish_16_Feb_2021_nt_genedecoder" # /model_{epoch:04d}.ckpt"
latest = tf.train.latest_checkpoint(checkpoint_path)
if latest:
  print("Loading checkpoint")
  print(latest)
  model.load_weights(latest).expect_partial()
  print("Checkpoint loaded");


In [None]:
#fasta = pyracular.FastaKmersGenerator(k, "Dmel.sfasta", 512, True, False)

In [None]:
#for _ in range(10):
#    x = next(fasta)
#x = next(fasta)
#np.shape(np.asarray(x['kmers']))
#[x['coords'][0][0], x['coords'][-1][1]]

In [None]:
#x['ids']

In [None]:
#kmers = np.concatenate([gene_token, x['kmers']])
kmers = np.asarray(x['kmers'])
ctx = magic(kmers)
enc_outputs, _, _ = transformer(np.asarray([ctx]), False)
lam = create_look_ahead_mask(window_size+1)

#query_start = np.concatenate([[[1.0,0.0,0.0,0.0,0.0,0.0]]])
#query_start = tf.expand_dims(query_start, 0)
#np.shape(query_start) # (65, 4)

#raw_out = list()
#dec_output = query_start
#dec_output = tf.cast(dec_output, tf.float32)
#for i in range(len(x['kmers'])+1):
lam = create_look_ahead_mask(tf.shape([ctx])[1])
y, dec_out, dec_attn = the_decoder(np.asarray([ctx]), enc_outputs, lam, training=False)
    #dec_output = dec_output.numpy()
    #print(y[0][-1])
z = tf.argmax(y[0])
    #dec_output = tf.stack([dec_output, [[tf.one_hot(z, 4)]]], axis=2)
    #dec_output = tf.stack([dec_output, [[tf.one_hot(z, 4)]]], axis=1)
    
    #dec_output = np.asarray([np.vstack([dec_output[0], y[0][-1]])])
#dec_output[0] = np.append(dec_output, y[0][-1], axis=0)

#vals = dec_output[:, :, 3][0]
#np.shape(vals)
#px.scatter(y=vals)

In [None]:
y

In [None]:
tf.argmax(y, axis=2)

In [None]:
y = y[0]

In [None]:
px.scatter(y=[np.asarray(y)[:, 0], np.asarray(y)[:, 1], np.asarray(y)[:, 2], np.asarray(y)[:, 3], np.asarray(y)[:, 4]])

In [None]:
# 0 is query start (for decoder) [1, 0, 0, 0, 0, 0]
# 1 is not a gene [0, 1, 0, 0, 0, 0]
# 2 is gene start [0, 0, 1, 0, 0, 0]
# 3 is gene [0, 0, 0, 1, 0, 0]
# 4 is end of gene [0, 0, 0, 0, 1, 0]
# 5 is end of decoder output [0, 0, 0, 0, 0, 1]

In [None]:
kmers = np.concatenate([gene_token, x['kmers']])
ctx = magic(kmers)
enc_outputs, _, _ = transformer(np.asarray([ctx]), False)
query_start = np.concatenate([[[1.0,0.0,0.0,0.0,0.0,0.0]]])
query_start = tf.expand_dims(query_start, 0)
dec_output = query_start
dec_output = tf.cast(dec_output, tf.float32)
lam = create_look_ahead_mask(tf.shape(dec_output)[1])
y, dec_out, dec_attn = the_decoder(y, enc_outputs, lam, training=False)
z = tf.argmax(y[0][-1])
dec_output = tf.concat([dec_output, [[tf.one_hot(z, n_classes)]]], axis=1)

In [None]:
lam = create_look_ahead_mask(tf.shape(dec_output)[1])
y, dec_out, dec_attn = the_decoder(dec_output, enc_outputs, lam, training=False)
z = tf.argmax(y[0][-1])
dec_output = tf.concat([dec_output, [[tf.one_hot(z, n_classes)]]], axis=1)

In [None]:
dec_output

In [None]:
import umap
reducer = umap.UMAP()
reduced = reducer.fit_transform(dec_out[-1][0].numpy())

In [None]:
px.scatter(x=reduced[:, 0], y=reduced[:, 1])

In [None]:
px.scatter(y=[np.asarray(raw_out)[:, 0], np.asarray(raw_out)[:, 1], np.asarray(raw_out)[:, 2], np.asarray(raw_out)[:, 3], np.asarray(raw_out)[:, 4], np.asarray(raw_out)[:, 5]])

In [None]:
x['coords'].index((127449, 127469))

In [None]:
# 0 is query start (for decoder) [1, 0, 0, 0, 0, 0]
# 1 is not a gene [0, 1, 0, 0, 0, 0]
# 2 is gene start [0, 0, 1, 0, 0, 0]
# 3 is gene [0, 0, 0, 1, 0, 0]
# 4 is end of gene [0, 0, 0, 0, 1, 0]
# 5 is end of decoder output [0, 0, 0, 0, 0, 1]

In [None]:
ctx = magic(np.asarray(x['kmers']))
enc_outputs, enc_attn, _ = transformer(np.asarray([ctx]), False)
lam = create_look_ahead_mask(65)
#lam = tf.expand_dims(lam, 0)
y, dec_out, dec_attn = the_decoder(query_start, enc_outputs, lam, training=False)

In [None]:
mygen = validation_gen()
# mygen = gen()
# types

In [None]:
i = 0
testing = list()
answers = list()
for x in mygen:
    testing.append(x[0])
    answers.append(x[1])
    i = i + 1
    if i == batch_size:
        print(np.sum(answers))
        if np.sum(answers) < 10:
            testing = list()
            answers = list()
            i = 0
        else:
            break

In [None]:
predictions = model.predict(testing)
# model.predict([tb[0]])

In [None]:
confusion_matrix = dict()
confusion_matrix[(0, 0)] = 0
confusion_matrix[(0, 1)] = 0
confusion_matrix[(1, 0)] = 0
confusion_matrix[(1, 1)] = 0

for i, j in enumerate(predictions):
    # print("--")
    for z, y in enumerate(j):
        a = y[0]  # Predicted
        b = answers[i][0][z]  # Actual

        if a < 0.5:
            a = 0
        else:
            a = 1
        confusion_matrix[(b, a)] += 1

        # print(str(y[0]) + "\t" + str(answers[i][0][z]))

In [None]:
confusion_matrix

In [None]:
test_input = np.asarray([testing[0]])
contexts = magic(test_input)
enc_outputs, attn, _ = transformer(contexts)

oout = orf_layer5(orf_layer4(orf_layer3(orf_layer2(orf_layer1(orf_layer(test_input))))))

decoded, dattn = decoder_layer(tf.concat([contexts, oout], axis=-1), enc_outputs)

# , attn = decoder_layer(contexts, enc_outputs)
# decoder, attn = decoder_layer(contexts, enc_outputs)

In [None]:
pe = positional_encoding(256, 16)

In [None]:
pe[:, :32, :]

In [None]:
dec_attn.keys()

In [None]:
plot_attention_weights(dec_attn, convert_all_kmers(x[0][0]), "decoder_layer1_block1_")

In [None]:
plot_attention_weights(dec_attn, convert_all_kmers(x[0][0]), "decoder_layer1_block2_")

In [None]:
plot_attention_weights(dec_attn, convert_all_kmers(x[0][0]), "decoder_layer6_block1_")

In [None]:
plot_attention_weights(dec_attn, convert_all_kmers(x[0][0]), "decoder_layer6_block2_")

In [None]:
dattn.keys()

In [None]:
plot_attention_weights(dattn, convert_all_kmers(testing[0]), "decoder_layer1_block1_")

In [None]:
plot_attention_weights(dattn, convert_all_kmers(testing[0]), "decoder_layer4_block1_")

In [None]:
plot_attention_weights(dattn, convert_all_kmers(testing[0]), "decoder_layer4_block2_")

In [None]:
attn.keys()

In [None]:
import matplotlib.pyplot as plt


def plot_attention_weights(attention, sentence, layer):
    fig = plt.figure(figsize=(22, 12))

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0])[0:8]:
        ax = fig.add_subplot(2, 4, head + 1)

        # plot the attention weights
        ax.matshow(attention[head][:-1, :], cmap="viridis")

        fontdict = {"fontsize": 10}

        ax.set_xticks(range(len(sentence)))
        ax.set_yticks(range(len(sentence)))

        ax.set_ylim(len(sentence) - 1.5, -0.5)

        ax.set_xticklabels(
            [sentence[i] for i in range(len(sentence))], fontdict=fontdict, rotation=90
        )

        ax.set_yticklabels(
            [sentence[i] for i in range(len(sentence))], fontdict=fontdict
        )

        ax.set_xlabel("Head {}".format(head + 1))

    plt.tight_layout()
    plt.show()


def convert_all_kmers(kmers):
    kmers_as_str = list()
    for x in kmers:
        y = "".join(list(map(convert_letter_to_string, np.array_split(x, k))))
        kmers_as_str.append(y)
    return kmers_as_str


def convert_letter_to_string(x):
    y = np.nonzero(x)[0][0]
    if y == 0:
        return "A"
    elif y == 1:
        return "T"
    elif y == 2:
        return "N"
    elif y == 3:
        return "C"
    elif y == 4:
        return "G"

In [None]:
generator = gen()
for x in generator:
    if np.sum(x[1]) > 0:
        plottable = x
        break

In [None]:
# for p in plottable:
#    np.dot(p, weights[0][0])

calculated = np.dot(plottable[0], weights[0][0])

In [None]:
calculated.shape

In [None]:
reducer = UMAP()

In [None]:
reduced = reducer.fit_transform(calculated)

In [None]:
px.scatter(x=reduced[:, 0], y=reduced[:, 1], color=plottable[1])

In [None]:
from sklearn.decomposition import PCA

In [None]:
pca = PCA()
reduced = pca.fit_transform(calculated)

In [None]:
reduced.shape

In [None]:
reduced

In [None]:
tf.keras.backend.set_floatx("float64")
input_layer = Input(shape=k * 5, dtype="float64")  # shape=(k*5),
# input_reshape = Reshape((k,5))(input_layer)
# input_flat = Flatten(dtype="float64")(input_layer)
# layer1 = Dense(2048, activation="sigmoid", dtype="float64")
# layer2 = Dense(2048, activation=tf.nn.swish, dtype="float64")
# layer3 = Dense(1024, activation="relu", dtype="float64")
layer1 = Bidirectional(GRU(1024, dtype="float64"))
classifier = Dense(window_size, activation="softmax", dtype="float64")

output = classifier(layer1(magic(input_layer)))