In [None]:
import os
import sys

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [None]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")

In [None]:
import wandb

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display
import math
import string

from Attention import Set_Transformer 
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utils as tfu

In [None]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [None]:
def gen_data(batch_size, length=5):
    x = np.random.randint(1, 100, (batch_size, length))
    y = np.max(x, axis=1)
    return x, y # (batch_size, length)

In [None]:
x, y = gen_data(3, 5)

In [None]:
print(x.shape, y.shape)

(3, 5) (3,)


In [None]:
x

array([[47, 66, 31, 73,  6],
       [95, 76, 68, 60, 31],
       [12, 14, 22, 55, 43]])

In [None]:
y

array([3, 0, 3])

---
# Set Transformer Class

In [None]:
class Set_Transformer_Model(keras.Model):
    def __init__(self, num_induce, embed_dim, attention_num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, num_seeds, pooling_num_heads):
        super(Set_Transformer_Model, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.attention_num_heads = attention_num_heads
        self.stack = stack
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        self.num_seeds = num_seeds 
        self.pooling_num_heads = pooling_num_heads
        
        self.linear_layer = keras.layers.Dense(self.embed_dim)
        self.attention_blocks = []
        
        if self.num_induce == 0:
            for i in range(self.stack):
                self.attention_blocks.append(Set_Transformer.SetAttentionBlock(embed_dim=self.embed_dim,num_heads=self.attention_num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))
        else:
            for i in range(self.stack):
                self.attention_blocks.append(Set_Transformer.InducedSetAttentionBlock(embed_dim=self.embed_dim,num_heads=self.attention_num_heads, num_induce=self.num_induce, use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))

        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds,embed_dim=self.embed_dim,num_heads=self.pooling_num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha,is_final_block=True)
        
        self.dense_layer = keras.layers.Dense(1)                                       
    
    def call(self, data):
        
            data = tf.expand_dims(data, axis=2)
        
            linear_transform = self.linear_layer(data)
            
            attention = linear_transform
            
            for attention_block in self.attention_blocks:
                attention = attention_block([attention, None])
                
            pooling = self.pooling_layer(attention)
        
            dense = self.dense_layer(pooling)
            
            output = tf.reshape(dense, tf.shape(dense)[:2])
            
            return output

---
# Create Model

In [None]:
x, y = gen_data(batch_size=2 ** 10, length=20)

In [None]:
#Hyperparameters
num_induce = 0
embed_dim = 64
attention_num_heads = 8
stack = 4
use_layernorm = True
pre_layernorm = True
use_keras_mha = True
num_seeds = 1
pooling_num_heads = 1

In [None]:
model = Set_Transformer_Model(num_induce, embed_dim, attention_num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, num_seeds, pooling_num_heads)
model.compile(optimizer=keras.optimizers.Adam(1e-3),loss=keras.losses.MeanAbsoluteError())

In [None]:
x

array([[40, 28, 69, 80, 75, 16, 94, 77, 10,  4, 65, 15, 52,  4, 63, 63,
        53, 67, 21, 61],
       [32, 60,  6, 31, 24, 81, 30, 48, 62, 33, 34, 87,  3, 33,  2, 96,
        85, 13, 47, 57],
       [35, 39, 86, 98, 41, 43, 43, 50, 64, 83, 30, 15, 11, 74, 72, 67,
        37,  3, 41, 96],
       [60, 43, 41, 31, 81, 85, 79, 81, 84,  4, 41, 81, 69, 46, 97, 76,
        26,  4, 65, 46],
       [93, 49, 11, 26,  2, 91, 67, 59, 85, 23, 27, 52,  9, 63, 73, 40,
        12, 11, 52, 44],
       [81, 15, 19, 90, 36,  1, 49, 42, 21, 90, 79, 67, 93, 67, 58, 97,
         5, 78, 31, 45],
       [34, 92,  8,  3, 15, 75, 55, 73, 47, 47, 75, 77, 19,  8, 82, 81,
        29, 63, 23, 13],
       [31, 79, 33, 17, 59, 74, 67, 60, 74, 53, 53,  5, 48, 15, 94, 57,
        96, 98, 30, 19],
       [69, 85, 96, 53, 45, 13, 61, 22,  9, 27, 18, 65, 71, 92,  8, 94,
        64, 77, 78, 83],
       [70, 31, 49, 33, 84, 89, 64, 32, 28, 42, 11, 15, 36, 45, 14, 16,
        19, 74, 29, 83]])

In [None]:
model.predict(x[:10])

array([[1.157297 ],
       [1.1036431],
       [1.2188632],
       [1.2527688],
       [1.1129992],
       [1.1849427],
       [1.1257389],
       [1.2237768],
       [1.21368  ],
       [1.1609969]], dtype=float32)

In [None]:
epochs = 500

In [None]:
history = model.fit(x, y, epochs=epochs, verbose=1,)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78

KeyboardInterrupt: 

In [None]:
x[:3]

In [None]:
model.predict(x[:10])