---
# Transformer XL
---

---
# Import Libraries

In [1]:
import os
import sys

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")
sys.path.append("../../../deep-learning-dna/common")

In [4]:
import wandb

In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
import math
import string

from Attention import Set_Transformer
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

from core.custom_objects import CustomObject

import tf_utils as tfu

In [6]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [7]:
#Import pretrained model
api = wandb.Api()
model_path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-64dim:latest").download()
pretrained_model = dnabert.DnaBertModel.load(model_path)
pretrained_model.load_weights(model_path + "/model.h5")
pretrained_model

<common.models.dnabert.DnaBertPretrainModel at 0x7f84c384f700>

In [8]:
#Load datafiles
dataset_path = api.artifact("sirdavidludwig/nachusa-dna/nachusa-dna:latest").download('/data/dna_samples:v1')
samples = find_dbs(dataset_path + '/train')
samples[13]

[34m[1mwandb[0m: Downloading large artifact nachusa-dna:latest, 4079.09MB. 420 files... Done. 0:0:0.1


'/data/dna_samples:v1/train/WS-CCW-Jul2015_S82_L001_R1_001.db'

---
# Create Dataset

In [9]:
split_ratios = [0.8, 0.2]
set_len = 1000
sequence_len = 150
kmer = 3
batch_size = [20,5]
batches_per_epoch = 20
augument = True
labels = DnaLabelType.SampleIds
seed = 0
rng = np.random.default_rng(seed)
random_samples = samples.copy()

In [10]:
rng.shuffle(random_samples)

In [11]:
trimmed_samples, (train_dataset, val_dataset) = DnaSampleGenerator.split(samples=random_samples[0:5], split_ratios=split_ratios, subsample_length=set_len, sequence_length=sequence_len, kmer=kmer,batch_size=batch_size,batches_per_epoch=batches_per_epoch,augment=augument,labels=labels, rng=rng)

In [12]:
random_samples[0:5]

['/data/dna_samples:v1/train/WS-CCE-Apr2016_S6_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes52-10-TC_S53_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-WH-Jul2016_S46_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes41-10-HN_S42_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wesley026-Ag-072820_S165_L001_R1_001.db']

--- 
# Batch Parameters

In [13]:
seg_size = 200
max_set_len = set_len
max_files = len(train_dataset.samples)
max_files

5

In [14]:
if seg_size-2 > set_len:
    raise ValueError("Segment size should not be bigger than sequence length")

In [15]:
print(max_set_len)

1000


---
# Create Embeddings

In [16]:
#Create 8 dimensional embeddings
pretrained_encoder = dnabert.DnaBertEncoderModel(pretrained_model.base)
pretrained_encoder.trainable = False

In [17]:
from common.utils import subbatch_predict

In [18]:
class Create_Embeddings():
    def __init__(self, encoder):
        super(Create_Embeddings, self).__init__()
        self.encoder = encoder
        
    def subbatch_predict(self, model, batch, subbatch_size, concat=lambda old, new: tf.concat((old, new), axis=0)):
        batch_size = tf.shape(batch)[0]
        
        result = tf.zeros((batch_size, 64))
        
        for i in range(0, tf.shape(batch)[0], subbatch_size):
            subbatch = batch[i:i+subbatch_size]
            clamp = tf.minimum(subbatch_size, batch_size-i)
            encoded = self.encoder(subbatch)
            result = tf.tensor_scatter_nd_update(result, tf.expand_dims(tf.range(i, i+clamp), 1), encoded)
        return result
    
    def modify_data_for_input(self, data):
        batch_size = tf.shape(data)[0]
        subsample_size = tf.shape(data)[1]
        flat_data = tf.reshape(data, (batch_size*subsample_size, -1))
        encoded = self.subbatch_predict(self.encoder, flat_data, 128)
        result = tf.reshape(encoded, (batch_size, subsample_size, -1))
        return result
    
    def __call__(self, data):
        embeddings = self.modify_data_for_input(data)
        return embeddings

---
# Cache Memory

In [19]:
def Cache_Memory(current_state, previous_state, memory_length):
    if memory_length is None or memory_length == 0:
        return None
    else:
        if previous_state is None:
            new_mem = current_state[:, -memory_length:, :]
        else:
            new_mem = tf.concat(
                    [previous_state, current_state], 1)[:, -memory_length:, :]

    return tf.stop_gradient(new_mem)

---
# Attention

In [20]:
class Attention(keras.layers.Layer):
    def __init__(self, num_induce, embed_dim, num_heads, use_layernorm, pre_layernorm, use_keras_mha):
        super(Attention, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        if self.num_induce == 0:       
            self.attention = (Set_Transformer.SetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))
        else:
            self.attention = Set_Transformer.InducedSetAttentionBlock(embed_dim=self.embed_dim, num_heads=self.num_heads, num_induce=self.num_induce, use_layernorm=self.use_layernorm, pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha)
    
    
    def call(self, data, mems):
                
            attention = self.attention([data, mems])
                
            return attention

---
# XL Block

In [21]:
class TransformerXLBlock(tf.keras.layers.Layer):
    def __init__(self,
                 num_induce, 
                 embed_dim,
                 num_heads,
                 use_layernorm,
                 pre_layernorm,
                 use_keras_mha,):

        super(TransformerXLBlock, self).__init__()
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.attention = Attention
        
        self.attention_layer = self.attention(self.num_induce, self.embed_dim, self.num_heads, self.use_layernorm, self.pre_layernorm, self.use_keras_mha)

   
    def call(self,
             content_stream,
             state=None):
        
        attention_output = self.attention_layer(content_stream, state)

        return attention_output

---
# Transformer XL

In [22]:
class TransformerXL(keras.layers.Layer):
    def __init__(self,
                 mem_switched, 
                 num_layers,
                 num_induce,
                 embed_dim,
                 num_heads,
                 dropout_rate,
                 mem_len=None,
                 use_layernorm=True,
                 pre_layernorm=True, 
                 use_keras_mha=True):
        
        super(TransformerXL, self).__init__()

        self.mem_switched = mem_switched
        self.num_layers = num_layers
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.mem_len = mem_len
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha

        self.transformer_xl_layers = []
        
        for i in range(self.num_layers):
            self.transformer_xl_layers.append(
                    TransformerXLBlock(self.num_induce,
                                        self.embed_dim,
                                        self.num_heads,
                                        self.use_layernorm,
                                        self.pre_layernorm, 
                                        self.use_keras_mha))

        self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)

    def call(self,
             content_stream,
             state=None):
        
        new_mems = []

        if state is None:
            state = [None] * self.num_layers
            
        for i in range(self.num_layers):
            if self.mem_switched == False:
                new_mems.append(Cache_Memory(content_stream, state[i], self.mem_len))
            
            transformer_xl_layer = self.transformer_xl_layers[i]
            
            transformer_xl_output = transformer_xl_layer(content_stream=content_stream,
                                                        state=state[i])
            
            content_stream = self.output_dropout(transformer_xl_output)
            
            if self.mem_switched == True:
                new_mems.append(Cache_Memory(content_stream, state[i], self.mem_len))
                
        output_stream = content_stream
        return output_stream, new_mems

---
# Xl Model Class

In [23]:
class XlModel(keras.Model):
    def __init__(self, mem_switched, max_files, seg_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, mem_len, dropout_rate, num_seeds, use_layernorm, pre_layernorm, use_keras_mha):
        super(XlModel, self).__init__()
        
        self.mem_switched = mem_switched
        self.max_files = max_files
        self.seg_size = seg_size
        self.max_set_len = max_set_len
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mem_len = mem_len
        self.dropout_rate = dropout_rate
        self.num_seeds = num_seeds
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha

        self.linear_layer = keras.layers.Dense(self.embed_dim)
        
        self.transformer_xl = TransformerXL(self.mem_switched,
                                            self.num_layers,
                                             self.num_induce,
                                             self.embed_dim,
                                             self.num_heads,
                                             self.dropout_rate,
                                             self.mem_len,
                                             self.use_layernorm,
                                             self.pre_layernorm,
                                             self.use_keras_mha)
        

        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds,embed_dim=self.embed_dim,num_heads=self.num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm, use_keras_mha=self.use_keras_mha, is_final_block=True)
    
        self.reshape_layer = keras.layers.Reshape((self.embed_dim,))
   
        self.output_layer = keras.layers.Dense(self.max_files, activation=keras.activations.softmax)
        
    
    def call(self, embeddings, mems, index, training=None):        
        
        linear_transform = self.linear_layer(embeddings)
        
        segment = linear_transform[:, index:index+self.seg_size]
        
        output, mems = self.transformer_xl(content_stream=segment, state=mems)
                
        pooling = self.pooling_layer(output)

        reshape = self.reshape_layer(pooling)

        output = self.output_layer(reshape)          
        
        return output, mems

---
# Xl Parameters

In [24]:
#Xl Parameters#
mem_switched=False
num_induce = 0
embed_dim = 64
num_layers = 8
num_heads = 8
mem_len = 200
dropout_rate = 0.01
num_seeds = 1
use_layernorm = True
pre_layernorm = True
use_keras_mha = True

---
# Create Models

In [25]:
model = XlModel(mem_switched, max_files, seg_size, max_set_len, num_induce, embed_dim, num_layers, num_heads, mem_len, dropout_rate, num_seeds, use_layernorm, pre_layernorm, use_keras_mha)
model.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False), optimizer = keras.optimizers.Adam(1e-3), metrics = keras.metrics.SparseCategoricalAccuracy())

In [26]:
epochs = 50

In [27]:
loss_function = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
accuracy_function = keras.metrics.SparseCategoricalAccuracy()

In [28]:
subbatch_size = 1

In [29]:
embedder = Create_Embeddings(pretrained_encoder)

In [30]:
e = embedder(train_dataset[0][0])
mems = tf.zeros((num_layers, tf.shape(e)[0], mem_len, embed_dim))
e = model(e, mems, 0)

In [31]:
@tf.function()
def train_step(inputs):
    batch, max_set_len, seg_size = inputs
    
    #Iterate through subbatches
    #Pull out one set at a time
    for i in range (batch_size[0]):
        n = i + subbatch_size
        one_set = (batch[0][i:n], batch[1][i:n]) 
        x, y = one_set
        i += 1
        
        #Initialize mems
        mems = tf.zeros((num_layers, tf.shape(x)[0], mem_len, embed_dim))
        
        #Initialize embeddings
        embeddings = embedder(x)
        
        #Initialize gradients
        accum_grads = [tf.zeros_like(w) for w in model.trainable_weights]
    
        total_loss = 0.0
        total_accuracy = 0.0
    
        #Split set into segments
        for index in range(0, max_set_len, seg_size):
            
            #Pass entire set (for embeddings) and memories into model
            with tf.GradientTape() as tape:
                
                segment_output, mems = model(embeddings, mems, index, True)

                loss = loss_function(y, segment_output)
                
                #Set loss
                total_loss += loss

            #Compute segment level gradients
            grads = tape.gradient(loss, model.trainable_weights)   

            accum_grads = [(gs + ags) for gs, ags in zip(grads, accum_grads)]    

        total_accuracy += accuracy_function(y, segment_output)
        
        #Apply gradients
        model.optimizer.apply_gradients(zip(accum_grads, model.trainable_weights))

    return total_loss, total_accuracy

In [32]:
@tf.function()
def test_step(inputs):
    batch, max_set_len, seg_size = inputs
    
    #Iterate through subbatches
    #Pull out one set at a time
    for i in range (batch_size[1]):
        n = i + subbatch_size
        one_set = (batch[0][i:n], batch[1][i:n]) 
        x, y = one_set
        i += 1

        #Initialize mems
        mems = tf.zeros((num_layers, tf.shape(x)[0], mem_len, embed_dim))

        #Initialize embeddings
        embeddings = embedder(x)

        total_loss = 0.0
        total_accuracy = 0.0

        #Split set into segments
        for index in range(0, max_set_len, seg_size):

            #Pass entire set (for embeddings) and memories into model
            segment_output, mems = model(embeddings, mems, index, True)

            loss = loss_function(y, segment_output)

            #Set loss
            total_loss += loss

        total_accuracy += accuracy_function(y, segment_output)

    return total_loss, total_accuracy

In [33]:
for epoch in range(epochs):    
    loss = 0.0
    accuracy = 0.0
    i = 0
    
    #Iterate through batches
    for batch in train_dataset:
        
        i += 1
        #Pass one batch intro train_step
        loss, accuracy = train_step([batch, max_set_len, seg_size])
        
        print(f"\r{epoch+1}/{epochs} training batch: {i}/{len(train_dataset)} Train Loss: {loss} Train Accuracy = {accuracy}", end="")
    
    loss = 0.0
    accuracy = 0.0
    i = 0

    #Iterate through batches
    for batch in val_dataset:
        
        i += 1
        #Pass one batch intro train_step
        loss, accuracy = test_step([batch, max_set_len, seg_size])
        
        print(f"\r{epoch+1}/{epochs} testing batch: {i}/{len(val_dataset)} Val Loss: {loss} Val Accuracy = {accuracy}", end="")

31/50 testing batch: 10/20 Val Loss: 0.43860530853271484 Val Accuracy = 0.5901618003845215956328

KeyboardInterrupt: 

In [34]:
def model_predict(inputs):
    batch, max_set_len, seg_size = inputs
    for i in range (batch_size[0]):
        n = i + subbatch_size
        one_set = (batch[0][i:n], batch[1][i:n]) 
        x, y = one_set
        i += 1

        #Initialize mems
        mems = tf.zeros((num_layers, tf.shape(x)[0], mem_len, embed_dim))

        #Initialize embeddings
        embeddings = embedder(x)

        total_loss = 0.0
        total_accuracy = 0.0

        #Split set into segments
        for index in range(0, max_set_len, seg_size):

            #Pass entire set (for embeddings) and memories into model
            segment_output, mems = model(embeddings, mems, index, True)
            print("Index:", index, "Predicted:", tf.argmax(segment_output, -1), "Correct:", y)
            
            loss = loss_function(y, segment_output)

            #Set loss
            total_loss += loss

        total_accuracy += accuracy_function(y, segment_output)

    return total_loss, total_accuracy

In [35]:
def model_predict_val(inputs):
    batch, max_set_len, seg_size = inputs
    for i in range (batch_size[1]):
        n = i + subbatch_size
        one_set = (batch[0][i:n], batch[1][i:n]) 
        x, y = one_set
        i += 1

        #Initialize mems
        mems = tf.zeros((num_layers, tf.shape(x)[0], mem_len, embed_dim))

        #Initialize embeddings
        embeddings = embedder(x)

        total_loss = 0.0
        total_accuracy = 0.0

        #Split set into segments
        for index in range(0, max_set_len, seg_size):

            #Pass entire set (for embeddings) and memories into model
            segment_output, mems = model(embeddings, mems, index, True)
            print("Index:", index, "Predicted:", tf.argmax(segment_output, -1), "Correct:", y)
            
            loss = loss_function(y, segment_output)

            #Set loss
            total_loss += loss

        total_accuracy += accuracy_function(y, segment_output)

    return total_loss, total_accuracy

In [36]:
model_predict([train_dataset[1], max_set_len, seg_size])

Index: 0 Predicted: tf.Tensor([0], shape=(1,), dtype=int64) Correct: [2]
Index: 200 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [2]
Index: 400 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [2]
Index: 600 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [2]
Index: 800 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [2]
Index: 0 Predicted: tf.Tensor([1], shape=(1,), dtype=int64) Correct: [4]
Index: 200 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 400 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 600 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 800 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 0 Predicted: tf.Tensor([0], shape=(1,), dtype=int64) Correct: [0]
Index: 200 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [0]
Index: 400 Predicted: tf.Tensor([2], shape=(1,), dtype=int64) Correct: [0]
Index: 600 Predicted: tf.Tensor

(<tf.Tensor: shape=(), dtype=float32, numpy=2.512153>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.59030694>)

In [37]:
model_predict_val([val_dataset[2], max_set_len, seg_size])

Index: 0 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 200 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 400 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 600 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 800 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 0 Predicted: tf.Tensor([1], shape=(1,), dtype=int64) Correct: [4]
Index: 200 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 400 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 600 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 800 Predicted: tf.Tensor([4], shape=(1,), dtype=int64) Correct: [4]
Index: 0 Predicted: tf.Tensor([1], shape=(1,), dtype=int64) Correct: [3]
Index: 200 Predicted: tf.Tensor([1], shape=(1,), dtype=int64) Correct: [3]
Index: 400 Predicted: tf.Tensor([1], shape=(1,), dtype=int64) Correct: [3]
Index: 600 Predicted: tf.Tensor

(<tf.Tensor: shape=(), dtype=float32, numpy=0.9745901>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.59037465>)