---
# Import Libraries

In [1]:
import os
import sys

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")
sys.path.append("../../../deep-learning-dna/common")

In [4]:
import wandb

In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
import math
import string

from Attention import Set_Transformer
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utils as tfu

In [6]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [7]:
def gen_data(batch_size, length=5):
    x = np.random.randint(1, 100, (batch_size, length))
    y = np.max(x, axis=1)
    return x, y # (batch_size, length)

In [8]:
x, y = gen_data(3, 5)

In [9]:
print(x.shape, y.shape)

(3, 5) (3,)


In [10]:
x

array([[99, 85, 49, 44,  2],
       [81, 40, 24, 44, 92],
       [13, 74, 34, 36, 65]])

In [11]:
y

array([99, 92, 74])

--- 
# Batch Parameters

---
# Compress Memory

In [17]:
class Compress_Memory(keras.layers.Layer):
    def __init__(self, num_seeds, embed_dim, num_heads, use_layernorm, pre_layernorm, use_keras_mha):
        super(Compress_Memory, self).__init__()
        
        self.num_seeds = num_seeds
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.compresser = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds,embed_dim=self.embed_dim,num_heads=self.num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha, is_final_block=True)
        
    def call(self, current_state, previous_state):
        if previous_state is None:
            new_mem = self.compresser(current_state)
        else:
            new_mem = self.compresser(tf.concat([previous_state, current_state], 1))

        return tf.stop_gradient(new_mem)

---
# Attention

In [18]:
class Attention(keras.layers.Layer):
    def __init__(self, num_induce, embed_dim, num_heads, use_layernorm, pre_layernorm, use_keras_mha):
        super(Attention, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        if self.num_induce == 0:       
            self.attention = (Set_Transformer.SetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))
        else:
            self.attention = Set_Transformer.InducedSetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, num_induce=self.num_induce, use_layernorm=self.use_layernorm, pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha)

    def call(self, data, mems=None):
            attention = self.attention([data, mems])
                
            return attention

---
# XL Block

In [19]:
class TransformerXLBlock(tf.keras.layers.Layer):
    def __init__(self,
                 num_induce, 
                 embed_dim,
                 num_heads,
                 use_layernorm,
                 pre_layernorm,
                 use_keras_mha,):

        super(TransformerXLBlock, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.attention = Attention
        
        self.attention_layer = self.attention(self.num_induce, self.embed_dim, self.num_heads, self.use_layernorm, self.pre_layernorm, self.use_keras_mha)

   
    def call(self,
             content_stream,
             state=None):
        
        attention_output = self.attention_layer(content_stream, state)

        return attention_output

---
# Transformer XL

In [20]:
class TransformerXL(tf.keras.layers.Layer):
    def __init__(self,
                 mem_switched, 
                 num_seeds,
                 num_layers,
                 num_induce,
                 embed_dim,
                 num_heads,
                 dropout_rate,
                 use_layernorm=True,
                 pre_layernorm=True, 
                 use_keras_mha=True):
        
        super(TransformerXL, self).__init__()

        self.mem_switched = mem_switched
        self.num_seeds = num_seeds
        self.num_layers = num_layers
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha

        self.compresser = Compress_Memory
        
        self.compress_mems = self.compresser(self.num_seeds,
                                        self.embed_dim,
                                        self.num_heads,
                                        self.use_layernorm,
                                        self.pre_layernorm, 
                                        self.use_keras_mha)
        
        self.transformer_xl_layers = []
        
        for i in range(self.num_layers):
            self.transformer_xl_layers.append(
                    TransformerXLBlock(self.num_induce,
                                        self.embed_dim,
                                        self.num_heads,
                                        self.use_layernorm,
                                        self.pre_layernorm, 
                                        self.use_keras_mha))

        self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)

    def call(self,
             content_stream,
             state=None):
        
        new_mems = []

        if state is None:
            state = [None] * self.num_layers
            
        for i in range(self.num_layers):
            if self.mem_switched == False:
                new_mems.append(self.compress_mems(content_stream, state[i]))
            
            transformer_xl_layer = self.transformer_xl_layers[i]
            
            transformer_xl_output = transformer_xl_layer(content_stream=content_stream,
                                                        state=state[i])
            
            content_stream = self.output_dropout(transformer_xl_output)
            
            if self.mem_switched == True:
                new_mems.append(self.compress_mems(content_stream, state[i]))

        output_stream = content_stream
        return output_stream, new_mems

---
# Xl Model Class

In [39]:
class XlModel(keras.Model):
    def __init__(self, mem_switched, num_seeds_mems, max_files, encoder, block_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, dropout_rate, num_seeds_output, use_layernorm, pre_layernorm, use_keras_mha):
        super(XlModel, self).__init__()
        
        self.mem_switched = mem_switched
        self.num_seeds_mems = num_seeds_mems
        self.max_files = max_files
        self.encoder = encoder
        self.block_size = block_size
        self.max_set_len = max_set_len
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.num_seeds_output = num_seeds_output
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        

        self.linear_layer = keras.layers.Dense(self.embed_dim)
        
        self.transformer_xl = TransformerXL(self.mem_switched,
                                            self.num_seeds_mems,
                                            self.num_layers,
                                             self.num_induce,
                                             self.embed_dim,
                                             self.num_heads,
                                             self.dropout_rate,
                                             self.use_layernorm,
                                             self.pre_layernorm,
                                             self.use_keras_mha)
        

        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds_output,embed_dim=self.embed_dim,num_heads=self.num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha, is_final_block=True)
    
        self.dropout_layer = keras.layers.Dropout(.5)
    
        self.dense_layer = keras.layers.Dense(1) 
        
    
    def call(self, x, training=None):        
 
        mems = tf.zeros((self.num_layers, tf.shape(x)[0], self.num_seeds_mems, self.embed_dim))
    
        x = tf.expand_dims(x, axis=2)
        
        linear_transform = self.linear_layer(x)

        for i in range(0, self.max_set_len, self.block_size):
            block = linear_transform[:,i:i+self.block_size]
            
            output, mems = self.transformer_xl(content_stream=block, state=mems)
            
        pooling = self.pooling_layer(output)
        
        dropout = self.dropout_layer(pooling)

        dense = self.dense_layer(dropout)

        output = tf.reshape(dense, tf.shape(dense)[:2])    
        
        return output

---
# Xl Parameters

In [40]:
#Xl Parameters
encoder = 0
max_files = 0
mem_switched = False
num_seeds_mem = 200
num_induce = 0
embed_dim = 64
num_layers = 4
num_heads = 4
dropout_rate = 0.01
num_seeds_output = 1
use_layernorm = True
pre_layernorm = True
use_keras_mha = True

---
# Create Models

In [41]:
block_size = 20
length = 20

In [42]:
x, y = gen_data(batch_size=10, length=length)
vx, vy = gen_data(batch_size=10, length=length)

In [43]:
max_set_len = length
set_len = length

In [44]:
model = XlModel(mem_switched, num_seeds_mem, max_files, encoder, block_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, dropout_rate, num_seeds_output, use_layernorm, pre_layernorm, use_keras_mha)
model.compile(optimizer=keras.optimizers.Adam(1e-3),loss=keras.losses.MeanAbsoluteError())

In [45]:
epochs=1000

In [46]:
y[:5]

array([95, 99, 87, 97, 99])

In [61]:
(model.predict(x[:5]))
#tf.argmax(model.predict(x[:5]), axis = 1)

array([[90.72172],
       [90.75025],
       [90.72575],
       [90.72146],
       [90.72817]], dtype=float32)

In [48]:
epochs = 1500

In [49]:
history = model.fit(x, y, validation_data=(vx, vy), epochs=epochs, verbose=1)

Epoch 1/1500
Epoch 2/1500
Epoch 3/1500
Epoch 4/1500
Epoch 5/1500
Epoch 6/1500
Epoch 7/1500
Epoch 8/1500
Epoch 9/1500
Epoch 10/1500
Epoch 11/1500
Epoch 12/1500
Epoch 13/1500
Epoch 14/1500
Epoch 15/1500
Epoch 16/1500
Epoch 17/1500
Epoch 18/1500
Epoch 19/1500
Epoch 20/1500
Epoch 21/1500
Epoch 22/1500
Epoch 23/1500
Epoch 24/1500
Epoch 25/1500
Epoch 26/1500
Epoch 27/1500
Epoch 28/1500
Epoch 29/1500
Epoch 30/1500
Epoch 31/1500
Epoch 32/1500
Epoch 33/1500
Epoch 34/1500
Epoch 35/1500
Epoch 36/1500
Epoch 37/1500
Epoch 38/1500
Epoch 39/1500
Epoch 40/1500
Epoch 41/1500
Epoch 42/1500
Epoch 43/1500
Epoch 44/1500
Epoch 45/1500
Epoch 46/1500
Epoch 47/1500
Epoch 48/1500
Epoch 49/1500
Epoch 50/1500
Epoch 51/1500
Epoch 52/1500
Epoch 53/1500
Epoch 54/1500
Epoch 55/1500
Epoch 56/1500
Epoch 57/1500
Epoch 58/1500
Epoch 59/1500
Epoch 60/1500
Epoch 61/1500
Epoch 62/1500
Epoch 63/1500
Epoch 64/1500
Epoch 65/1500
Epoch 66/1500
Epoch 67/1500
Epoch 68/1500
Epoch 69/1500
Epoch 70/1500
Epoch 71/1500
Epoch 72/1500
E

In [50]:
#Training

In [51]:
x[:5]

array([[20, 66, 42, 29, 48, 48, 32, 56, 95, 16, 36, 26, 41, 90, 87,  1,
        12, 57, 37, 88],
       [63, 23, 22, 24, 89, 67, 76, 99, 61, 66, 55, 93, 77, 88, 99, 74,
        35, 40, 72, 81],
       [45,  4, 46, 33, 74, 66, 87, 85, 42, 60, 77, 70, 63, 63, 51,  9,
        63, 15, 41, 41],
       [68, 55, 11,  8, 77, 73, 30, 62, 80, 51,  8,  7, 29, 88, 54,  7,
        97, 46, 37, 66],
       [13, 16, 27, 92, 56, 65, 42, 99, 81, 84, 83, 41, 17, 52, 33, 43,
        83, 68, 94, 23]])

In [52]:
y[:5]

array([95, 99, 87, 97, 99])

In [56]:
(model.predict(x[:5]))

array([[90.72583],
       [90.73942],
       [90.72425],
       [90.72194],
       [90.73026]], dtype=float32)

In [57]:
#Validation

In [58]:
x, y = gen_data(batch_size=30, length=10)

In [59]:
y[:5]

array([82, 97, 88, 88, 96])

In [60]:
(model.predict(x[:5]))

array([[90.72172],
       [90.75025],
       [90.72575],
       [90.72146],
       [90.72817]], dtype=float32)