---
# Import Libraries

In [1]:
import os
import sys

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")
sys.path.append("../../../deep-learning-dna/common")

In [4]:
import wandb

In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
import math
import string

from Attention import Set_Transformer
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utilities as tfu

In [6]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [7]:
#Import pretrained model
api = wandb.Api()
model_path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-64dim:latest").download()
pretrained_model = dnabert.DnaBertModel.load(model_path)
pretrained_model.load_weights(model_path + "/model.h5")
pretrained_model

[34m[1mwandb[0m:   5 of 5 files downloaded.  


<common.models.dnabert.DnaBertPretrainModel at 0x7f146bf721d0>

In [8]:
#Load datafiles
dataset_path = api.artifact("sirdavidludwig/nachusa-dna/dnasamples-complete:latest").download()
samples = find_dbs(dataset_path + '/train')
samples[13]

[34m[1mwandb[0m: Downloading large artifact dnasamples-complete:latest, 4079.09MB. 420 files... 
[34m[1mwandb[0m:   420 of 420 files downloaded.  
Done. 0:0:3.4


'./artifacts/dnasamples-complete:v0/train/WS-CCW-Jul2015_S82_L001_R1_001.db'

---
# Create Dataset

In [9]:
split_ratios = [0.8, 0.2]
set_len = 1000
sequence_len = 150
kmer = 3
batch_size = [1,1]
batches_per_epoch = 20
augument = True
labels = DnaLabelType.SampleIds
seed = 0
rng = np.random.default_rng(seed)
random_samples = samples.copy()

In [10]:
rng.shuffle(random_samples)

In [11]:
trimmed_samples, (train_dataset, val_dataset) = DnaSampleGenerator.split(samples=random_samples[0:5], split_ratios=split_ratios, subsample_length=set_len, sequence_length=sequence_len, kmer=kmer,batch_size=batch_size,batches_per_epoch=batches_per_epoch,augment=augument,labels=labels, rng=rng)

--- 
# Batch Parameters

In [12]:
block_size = 250
max_set_len = set_len
max_files = len(train_dataset.samples)
max_files

5

In [13]:
if block_size-2 > set_len:
    raise ValueError("Block size should not be bigger than sequence length")

In [14]:
print(max_set_len)

1000


---
# Create Embeddings

In [15]:
#Create 8 dimensional embeddings
pretrained_encoder = dnabert.DnaBertEncoderModel(pretrained_model.base)
pretrained_encoder.trainable = False

In [133]:
class Create_Embeddings(keras.layers.Layer):
    def __init__(self, encoder):
        super(Create_Embeddings, self).__init__()
        self.encoder = encoder
        
    def subbatch_predict(self, model, batch, subbatch_size, concat=lambda old, new: tf.concat((old, new), axis=0)):
        def predict(i, result=None):
            n = i + subbatch_size
            pred = tf.stop_gradient(model(batch[i:n]))
            if result is None:
                return [n, pred]
            return [n, concat(result, pred)]
        i, result = predict(0)
        batch_size = tf.shape(batch)[0]
        i, result = tf.while_loop(
            cond=lambda i, _: i < batch_size,
            body=predict,
            loop_vars=[i, result],
            parallel_iterations=1)
        return result
    
    def modify_data_for_input(self, data):
        batch_size = tf.shape(data)[0]
        subsample_size = tf.shape(data)[1]
        flat_data = tf.reshape(data, (batch_size*subsample_size, -1))
        encoded = self.subbatch_predict(self.encoder, flat_data, 128)
        result = tf.reshape(encoded, (batch_size, subsample_size, -1))
        return result
    
    def call(self, data):
        return  self.modify_data_for_input(data)

# Cache Memory

In [134]:
def Cache_Memory(current_state, previous_state, memory_length):
    if memory_length is None or memory_length == 0:
        return None, None
    else:
        if previous_state is None:
            new_mem = current_state[:, -memory_length:, :]
            excess = current_state[:, :-memory_length, :]
        else:
            concatanted =  tf.concat([previous_state, current_state], 1)
            new_mem = concatanted[:, -memory_length:, :]
            excess = concatanted[:,:-memory_length,:]
            
    return tf.stop_gradient(new_mem), tf.stop_gradient(excess)

---
# Attention

In [135]:
class Attention(keras.layers.Layer):
    def __init__(self, num_induce, embed_dim, num_heads, use_layernorm, pre_layernorm, use_keras_mha):
        super(Attention, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        if self.num_induce == 0:       
            self.attention = (Set_Transformer.SetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))
        else:
            self.attention = Set_Transformer.InducedSetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, num_induce=self.num_induce, use_layernorm=self.use_layernorm, pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha)

    def call(self, data, mems=None):
            attention = self.attention([data, mems])
                
            return attention

---
# XL Block

In [136]:
class TransformerXLBlock(tf.keras.layers.Layer):
    def __init__(self,
                 num_compressed_seeds,
                 num_induce, 
                 embed_dim,
                 num_heads,
                 use_layernorm,
                 pre_layernorm,
                 use_keras_mha,):

        super(TransformerXLBlock, self).__init__()
        
        self.num_compressed_seeds = num_compressed_seeds
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.attention = Attention
        
        self.compress = Set_Transformer.CompressedPoolingByMultiHeadAttention(
            num_seeds=self.num_compressed_seeds,
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            use_layernorm=self.use_layernorm,
            pre_layernorm=self.pre_layernorm, 
            use_keras_mha=self.use_keras_mha, 
            is_final_block=True)
        
        self.attention_layer = self.attention(self.num_induce, self.embed_dim, self.num_heads, self.use_layernorm, self.pre_layernorm, self.use_keras_mha)
    
    # def compress(self, x):
    #         return 1 + x[:,self.num_compressed_seeds:,:]
    
    def call(self,
             content_stream,
             state=None,
             compressed=None):
        
        memories = tf.concat((state, compressed), axis=1)
        
        attention_output = self.attention_layer(content_stream, memories)

        return attention_output       

---
# Transformer XL

In [137]:
class TransformerXL(tf.keras.layers.Layer):
    def __init__(self,
                 mem_switched, 
                 num_compressed_seeds,
                 compressed_len,
                 mem_len,
                 num_layers,
                 num_induce,
                 embed_dim,
                 num_heads,
                 dropout_rate,
                 use_layernorm=True,
                 pre_layernorm=True, 
                 use_keras_mha=True):
        
        super(TransformerXL, self).__init__()

        self.mem_switched = mem_switched
        self.num_compressed_seeds = num_compressed_seeds
        self.compressed_len = compressed_len
        self.mem_len = mem_len
        self.num_layers = num_layers
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.transformer_xl_layers = []
        
        for i in range(self.num_layers):
            self.transformer_xl_layers.append(
                    TransformerXLBlock(self.num_compressed_seeds,
                                        self.num_induce,
                                        self.embed_dim,
                                        self.num_heads,
                                        self.use_layernorm,
                                        self.pre_layernorm, 
                                        self.use_keras_mha))

        self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)

    def call(self,
             content_stream,
             state=None,
             compressed=None):
        
        new_mems = []
        new_compressed = []

        if state is None:
            state = [None] * self.num_layers
            
        if new_compressed is None:
            new_compressed = [None] * self.num_layers
            
        for i, transformer_xl_layer in enumerate(self.transformer_xl_layers):
            if self.mem_switched == False:
                mems_append, mems_excess = Cache_Memory(content_stream, state[i], self.mem_len)
                new_mems.append(mems_append)
                
                #Perform attention between current segment and uncompressed trimmed memory
                #uncompressed_attention = transformer_xl_layer.attention_layer(tf.stop_gradient(content_stream), tf.stop_gradient(mems_excess))
                
                compressed_mems = transformer_xl_layer.compress(mems_excess)
                
                compressed_append, _ = Cache_Memory(compressed_mems, compressed[i], self.compressed_len)
                new_compressed.append(compressed_append)
            
                #Perform attention between current segment and compressed trimmed memory
                #compressed_attention = transformer_xl_layer.attention_layer( tf.stop_gradient(content_stream), tf.stop_gradient(compressed_mems))
                
                #loss = tf.linalg.norm(uncompressed_attention-compressed_attention)
                
            transformer_xl_output = transformer_xl_layer(content_stream=content_stream,
                                                        state=state[i], compressed=compressed[i])
            
            content_stream = self.output_dropout(transformer_xl_output)

        loss = 0
        output_stream = content_stream
        return output_stream, new_mems, new_compressed, loss

---
# Xl Model Class

In [139]:
class XlModel(keras.Model):
    def __init__(self, mem_switched, num_compressed_seeds, compressed_len, mem_len, max_files, encoder, block_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, dropout_rate, num_seeds_output, use_layernorm, pre_layernorm, use_keras_mha):
        super(XlModel, self).__init__()
        
        self.mem_switched = mem_switched
        self.num_compressed_seeds = num_compressed_seeds
        self.compressed_len = compressed_len
        self.mem_len = mem_len
        self.max_files = max_files
        self.encoder = encoder
        self.block_size = block_size
        self.max_set_len = max_set_len
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.num_seeds_output = num_seeds_output
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.embedding_layer = Create_Embeddings(self.encoder)

        self.linear_layer = keras.layers.Dense(self.embed_dim)
        
        self.transformer_xl = TransformerXL(self.mem_switched,
                                            self.num_compressed_seeds,
                                            self.compressed_len,
                                            self.mem_len,
                                            self.num_layers,
                                             self.num_induce,
                                             self.embed_dim,
                                             self.num_heads,
                                             self.dropout_rate,
                                             self.use_layernorm,
                                             self.pre_layernorm,
                                             self.use_keras_mha)
        

        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds_output,embed_dim=self.embed_dim,num_heads=self.num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha, is_final_block=True)
    
        self.reshape_layer = keras.layers.Reshape((self.embed_dim,))
   
        self.output_layer = keras.layers.Dense(self.max_files, activation=keras.activations.softmax)
    
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred, loss_compressed = self(x, return_loss=True, training=True) 
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

            trainable_vars = self.trainable_variables
            gradients = tape.gradient(loss+loss_compressed, trainable_vars)
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))
            self.compiled_metrics.update_state(y, y_pred)

            return {m.name: m.result() for m in self.metrics}
        
    def test_step(self, data):
        x, y = data

        y_pred, loss_compressed = self(x, return_loss=True, training=False)

        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def call(self, x, return_loss=False, training=None):  
        
        mems = tf.zeros((self.num_layers, tf.shape(x)[0], self.mem_len, self.embed_dim))
        compressed = tf.zeros((self.num_layers, tf.shape(x)[0], self.compressed_len, self.embed_dim))

        embeddings = self.embedding_layer(x)

        linear_transform = self.linear_layer(embeddings)

        losses = 0
        
        for i in range(0, self.max_set_len, self.block_size):
            block = linear_transform[:,i:i+self.block_size]
            
            output, mems, compressed, loss = self.transformer_xl(content_stream=block, state=mems, compressed=compressed)
            losses = losses + loss
            
        pooling = self.pooling_layer(output)

        reshape = self.reshape_layer(pooling)

        output = self.output_layer(reshape)          
        
        if return_loss:
            return output, losses
        
        return output

---
# Xl Parameters

In [140]:
#Xl Parameters
encoder = pretrained_encoder
mem_switched = False
num_seeds_compressed = 200
compressed_len = 250
mem_len = 250
num_induce = 0
embed_dim = 64
num_layers = 4
num_heads = 4
dropout_rate = 0.01
num_seeds_output = 1
use_layernorm = True
pre_layernorm = True
use_keras_mha = True

---
# Create Models

In [144]:
model = XlModel(mem_switched, num_seeds_compressed, compressed_len, mem_len, max_files, encoder, block_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, dropout_rate, num_seeds_output, use_layernorm, pre_layernorm, use_keras_mha)
model.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False), optimizer = keras.optimizers.Adam(1e-3), metrics = keras.metrics.SparseCategoricalAccuracy())

In [145]:
model(train_dataset[0][0])

<tf.Tensor: shape=(1, 5), dtype=float32, numpy=
array([[0.02401159, 0.067814  , 0.8353462 , 0.0686931 , 0.00413507]],
      dtype=float32)>

In [146]:
epochs=1

In [147]:
history = model.fit(x=train_dataset, validation_data=val_dataset, epochs=epochs, verbose=1)



In [129]:
model.transformer_xl.transformer_xl_layers[0].compress.trainable_weights

[<tf.Variable 'CSeeds:0' shape=(1, 200, 64) dtype=float32, numpy=
 array([[[-0.01044076,  0.01633035,  0.03292162, ..., -0.02150482,
           0.00901913,  0.0025207 ],
         [-0.093173  ,  0.03271655, -0.13451068, ..., -0.07813554,
          -0.09983643,  0.0357213 ],
         [ 0.01803841, -0.00130155,  0.03728583, ..., -0.11020125,
           0.00694254,  0.04405535],
         ...,
         [-0.04518931, -0.04866203,  0.08861157, ..., -0.0316733 ,
           0.00328898, -0.0001425 ],
         [-0.04779753,  0.03958122,  0.01364814, ..., -0.03591426,
           0.04815786,  0.03601424],
         [ 0.00156752, -0.00309683,  0.03725267, ..., -0.10404078,
          -0.06174248, -0.00969997]]], dtype=float32)>,
 <tf.Variable 'xl_model_12/transformer_xl_12/compressed_pooling_by_multi_head_attention_36/multi_head_attention_block_48/multi_head_attention_84/query/kernel:0' shape=(64, 4, 64) dtype=float32, numpy=
 array([[[ 0.02207458,  0.01448111, -0.03452677, ..., -0.0108335 ,
         

In [122]:
model.transformer_xl.transformer_xl_layers[0].compress.trainable_weights

[<tf.Variable 'CSeeds:0' shape=(1, 200, 64) dtype=float32, numpy=
 array([[[-0.01044076,  0.01633035,  0.03292162, ..., -0.02150482,
           0.00901913,  0.0025207 ],
         [-0.093173  ,  0.03271655, -0.13451068, ..., -0.07813554,
          -0.09983643,  0.0357213 ],
         [ 0.01803841, -0.00130155,  0.03728583, ..., -0.11020125,
           0.00694254,  0.04405535],
         ...,
         [-0.04518931, -0.04866203,  0.08861157, ..., -0.0316733 ,
           0.00328898, -0.0001425 ],
         [-0.04779753,  0.03958122,  0.01364814, ..., -0.03591426,
           0.04815786,  0.03601424],
         [ 0.00156752, -0.00309683,  0.03725267, ..., -0.10404078,
          -0.06174248, -0.00969997]]], dtype=float32)>,
 <tf.Variable 'xl_model_12/transformer_xl_12/compressed_pooling_by_multi_head_attention_36/multi_head_attention_block_48/multi_head_attention_84/query/kernel:0' shape=(64, 4, 64) dtype=float32, numpy=
 array([[[ 0.02207458,  0.01448111, -0.03452677, ..., -0.0108335 ,
         

In [115]:
model.save_weights("Test.h5")

ValueError: Unable to create dataset (name already exists)