In [1]:
import os
import sys

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")

In [4]:
import wandb

In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display
import math
import string

from Attention import BigBird, Set_Transformer
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utils as tfu

In [6]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [7]:
#Import pretrained model
api = wandb.Api()
model_path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-8dim:latest").download()
pretrained_model = dnabert.DnaBertModel.load(model_path)
pretrained_model.load_weights(model_path + "/model.h5")
pretrained_model

<common.models.dnabert.DnaBertPretrainModel at 0x7f8671173160>

In [8]:
#Load datafiles
dataset_path = api.artifact("sirdavidludwig/nachusa-dna/dnasamples-complete:latest").download('/data/dna_samples:v1')
samples = find_dbs(dataset_path + '/train')
samples[13]

[34m[1mwandb[0m: Downloading large artifact dnasamples-complete:latest, 4079.09MB. 420 files... Done. 0:0:0.1


'/data/dna_samples:v1/train/WS-CCW-Jul2015_S82_L001_R1_001.db'

---
# Create Dataset

In [9]:
#Generate batches
split_ratios = [0.8, 0.2]
subsample_length = 1000
sequence_length = 150
kmer = 3
batch_size = [20,5]
batches_per_epoch = 20
augument = True
labels = DnaLabelType.SampleIds
seed = 0
rng = np.random.default_rng(seed)
random_samples = samples.copy()

In [10]:
rng.shuffle(random_samples)

In [11]:
trimmed_samples, (train_dataset, val_dataset) = DnaSampleGenerator.split(samples=random_samples[0:10], split_ratios=split_ratios, subsample_length=subsample_length, sequence_length=sequence_length,kmer=kmer,batch_size=batch_size,batches_per_epoch=batches_per_epoch,augment=augument,labels=labels, rng=rng)

In [12]:
random_samples[0:50]

['/data/dna_samples:v1/train/WS-CCE-Apr2016_S6_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes52-10-TC_S53_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-WH-Jul2016_S46_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes41-10-HN_S42_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wesley026-Ag-072820_S165_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-MU-Apr2016_S84_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes5-5-CCE_S6_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-MR-Apr2016_S13_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-MU-Sep2015_S43_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wesley012-HN-051120_S151_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-HPN-Sep2015_S91_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-TCR-Sep2015_S52_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes26-8-AG_S27_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-SB-Jul2016_S22_L001_R1_001.db',
 '/data/dna_samples:v1/train/Wes25-8-MU_S26_L001_R1_001.db',
 '/data/dna_samples:v1/train/WS-SB-Oct

In [13]:
max_files = len(train_dataset.samples)
max_files

10

---
# Create Embeddings

In [14]:
#Create 8 dimensional embeddings
pretrained_encoder= dnabert.DnaBertEncoderModel(pretrained_model.base)
pretrained_encoder.trainable = False

In [15]:
class Create_Embeddings(keras.layers.Layer):
    def __init__(self, encoder):
        super(Create_Embeddings, self).__init__()
        self.encoder = encoder
        
    
    def subbatch_predict(self, model, batch, subbatch_size, concat=lambda old, new: tf.concat((old, new), axis=0)):
        def predict(i, result=None):
            n = i + subbatch_size
            pred = tf.stop_gradient(model(batch[i:n]))
            if result is None:
                return [n, pred]
            return [n, concat(result, pred)]
        i, result = predict(0)
        batch_size = tf.shape(batch)[0]
        i, result = tf.while_loop(
            cond=lambda i, _: i < batch_size,
            body=predict,
            loop_vars=[i, result],
            parallel_iterations=1)

        return result
    
    def modify_data_for_input(self, data):
        batch_size = tf.shape(data)[0]
        subsample_size = tf.shape(data)[1]
        flat_data = tf.reshape(data, (batch_size*subsample_size, -1))
        encoded = self.subbatch_predict(self.encoder, flat_data, 128)
        return tf.reshape(encoded, (batch_size, subsample_size, -1))
    
    def call(self, data):
        return  self.modify_data_for_input(data)

---
# Create Big Bird Masks

In [16]:
class Create_BigBird_Masks(keras.layers.Layer):
    def __init__(self, attention_block_size):
        super(Create_BigBird_Masks, self).__init__()
            
        self.mask_layer = BigBird.BigBirdMasks(block_size=attention_block_size)
        
    def call(self, one_batch):

        mask = tf.ones(tf.shape(one_batch)[:-1])
                       
        masks = self.mask_layer(one_batch, mask)      
        
        return masks

---
# Big Bird Attention

In [17]:
class Big_Bird_Attention(keras.layers.Layer):
    def __init__(self, dropout, inner_size, num_heads, key_dim, num_rand_blocks,from_block_size,to_block_size,max_rand_mask_length):
        super(Big_Bird_Attention, self).__init__()
        
        self.dropout = dropout
        self.inner_size = inner_size
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.num_rand_blocks = num_rand_blocks
        self.from_block_size = from_block_size
        self.to_block_size = to_block_size
        self.max_rand_mask_length = max_rand_mask_length
        
        self.attention_layer = BigBird.BigBirdAttention(num_heads=self.num_heads, key_dim=self.key_dim, num_rand_blocks=self.num_rand_blocks,from_block_size=self.from_block_size,to_block_size=self.to_block_size,max_rand_mask_length=self.max_rand_mask_length)
        self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
        self.attention_layer_norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12, dtype=tf.float32)
        self.inner_dense = tf.keras.layers.experimental.EinsumDense("abc,cd->abd", output_shape=(None, self.inner_size), bias_axes="d", kernel_initializer=keras.initializers.RandomNormal(stddev=0.1))
        self.inner_activation_layer = tf.keras.layers.Activation("relu")
        self.inner_dropout_layer = tf.keras.layers.Dropout(rate=self.dropout)
        self.output_dense = tf.keras.layers.experimental.EinsumDense("abc,cd->abd", output_shape=(None, inner_size), bias_axes="d", kernel_initializer=keras.initializers.RandomNormal(stddev=0.1))
        self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
        self.output_layer_norm = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12)
        
    def call(self, content_stream, mask):

        attention_output = self.attention_layer(content_stream, content_stream, content_stream, mask)

        attention_stream = attention_output
        input_stream = content_stream

        attention_stream = self.attention_dropout(attention_stream)
        attention_stream = self.attention_layer_norm(attention_stream + input_stream)
        inner_output = self.inner_dense(attention_stream)
        inner_output = self.inner_activation_layer(inner_output)
        inner_output = self.inner_dropout_layer(inner_output)
        layer_output = self.output_dense(inner_output)
        layer_output = self.output_dropout(layer_output)
        layer_output = self.output_layer_norm(layer_output + attention_stream)
        attention_output = layer_output
        
        return attention_output

---
# Set BigBird Transformer Class

In [18]:
class Set_Big_Bird_Model(keras.Model):
    def __init__(self, embed_dim, stack, encoder, max_files, num_seeds, pooling_num_heads, attention_block_size, dropout, inner_size, attention_num_heads, key_dim, num_rand_blocks, from_block_size, to_block_size, max_rand_mask_length, use_layernorm, pre_layernorm, use_keras_mha):
        super(Set_Big_Bird_Model, self).__init__()
        
        self.embed_dim = embed_dim
        self.stack = stack
        self.encoder = encoder
        self.max_files = max_files
        self.num_seeds = num_seeds 
        self.pooling_num_heads = pooling_num_heads   
        
        self.attention_block_size = attention_block_size
        self.dropout = dropout
        self.inner_size = inner_size
        self.attention_num_heads = attention_num_heads
        self.key_dim = key_dim
        self.num_rand_blocks = num_rand_blocks
        self.from_block_size = from_block_size
        self.to_block_size = to_block_size
        self.max_rand_mask_length = max_rand_mask_length
        
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        
        self.embedding_layer = Create_Embeddings(self.encoder)
        self.linear_layer = keras.layers.Dense(self.embed_dim)
        self.attention_blocks = []
        
        self.mask_layer = Create_BigBird_Masks(attention_block_size)
        
        self.attention_layer = Big_Bird_Attention(self.dropout, self.inner_size, self.attention_num_heads, self.key_dim, self.num_rand_blocks, self.from_block_size, self.to_block_size, self.max_rand_mask_length)
        
        for i in range(self.stack):
            self.attention_blocks.append(self.attention_layer)
                
        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds,embed_dim=self.embed_dim,num_heads=self.pooling_num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha,is_final_block=True)
    
        self.reshape_layer = keras.layers.Reshape((self.embed_dim,))
        
        self.output_layer = keras.layers.Dense(self.max_files)
    
    def call(self, data):
        
            embeddings = self.embedding_layer(data)
            
            linear_transform = self.linear_layer(embeddings)
            
            mask = self.mask_layer(linear_transform)
            
            attention = linear_transform
            
            for attention_block in self.attention_blocks:
                attention = attention_block(attention, mask)
                
            pooling = self.pooling_layer(attention)
        
            reshape = self.reshape_layer(pooling)
            
            output = self.output_layer(reshape)    
            
            return output

---
# Create Model

In [50]:
#Hyperparameters
embed_dim = 64
stack = 4
encoder = pretrained_encoder
num_seeds = 1
pooling_num_heads = 1
attention_block_size = 10
dropout = 0.01
inner_size = 64
attention_num_heads = 8
key_dim = 32
num_rand_blocks = 6
from_block_size = 10
to_block_size = 10
max_rand_mask_length = subsample_length
use_layernorm = True
pre_layernorm = True
use_keras_mha = True

In [56]:
Parameters = dict(
    embed_dim = 32,
    stack = 4,
    num_seeds = 1,
    pooling_num_heads = 1,
    attention_block_size = 10,
    dropout = 0.01,
    inner_size = 32,
    attention_num_heads = 10,
    key_dim = 64,
    num_rand_blocks = 5,
    from_block_size = 10,
    to_block_size = 10,
    max_rand_mask_length = max_files,
    use_layernorm = True,
    pre_layernorm = True,
    use_keras_mha = True)

In [57]:
run = wandb.init(project="Set_Big_Bird", config=Parameters)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [58]:
model = Set_Big_Bird_Model(embed_dim, stack, encoder, max_files, num_seeds, pooling_num_heads, attention_block_size, dropout, inner_size, attention_num_heads, key_dim, num_rand_blocks, from_block_size, to_block_size, max_rand_mask_length, use_layernorm, pre_layernorm, use_keras_mha)
model.compile(optimizer=keras.optimizers.Adam(1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = [keras.metrics.sparse_categorical_accuracy])

In [59]:
epochs = 10000

In [60]:
history = model.fit(x=train_dataset, validation_data=val_dataset, epochs=epochs, verbose=1, callbacks=[wandb.keras.WandbCallback(save_weights_only=True)])



Epoch 1/10000


KeyboardInterrupt: 

In [None]:
run.finish()