In [1]:
import os
import sys

In [2]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

In [3]:
sys.path.append("../../../deep-learning-dna")
sys.path.append("../")

In [4]:
import wandb

In [5]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display
import math
import string

from Attention import Set_Transformer 
from common.models import dnabert
from common import dna
from lmdbm import Lmdb
from common.data import DnaSequenceGenerator, DnaLabelType, DnaSampleGenerator, find_dbs
import wandb

import tf_utils as tfu

In [6]:
strategy = tfu.devices.select_gpu(0, use_dynamic_memory=True)

---
# Load Data

In [7]:
#Import pretrained model
api = wandb.Api()
model_path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-8dim:latest").download()
pretrained_model = dnabert.DnaBertModel.load(model_path)
pretrained_model.load_weights(model_path + "/model.h5")
pretrained_model

<common.models.dnabert.DnaBertPretrainModel at 0x7f0b1e946f80>

In [8]:
#Load datafiles
dataset_path = api.artifact("sirdavidludwig/nachusa-dna/dnasamples-complete:latest").download('/data/dna_samples:v1')
samples = find_dbs(dataset_path + '/train')
samples[13]

[34m[1mwandb[0m: Downloading large artifact dnasamples-complete:latest, 4079.09MB. 420 files... Done. 0:0:0.1


'/data/dna_samples:v1/train/WS-CCW-Jul2015_S82_L001_R1_001.db'

---
# Create Dataset

In [9]:
#Generate batches
split_ratios = [0.8, 0.2]
subsample_length = 1000
sequence_length = 150
kmer = 3
batch_size = [20,5]
batches_per_epoch = 20
augument = True
labels = DnaLabelType.SampleIds
seed = 0
rng = np.random.default_rng(seed)
random_samples = samples.copy()

In [10]:
rng.shuffle(random_samples)

In [11]:
trimmed_samples, (train_dataset, val_dataset) = DnaSampleGenerator.split(samples=random_samples[0:5], split_ratios=split_ratios, subsample_length=subsample_length, sequence_length=sequence_length,kmer=kmer,batch_size=batch_size,batches_per_epoch=batches_per_epoch,augment=augument,labels=labels, rng=rng)

In [12]:
max_files = len(train_dataset.samples)
max_files

5

---
# Create Embeddings

In [13]:
#Create 8 dimensional embeddings
pretrained_encoder= dnabert.DnaBertEncoderModel(pretrained_model.base)
pretrained_encoder.trainable = False

In [14]:
class Create_Embeddings(keras.layers.Layer):
    def __init__(self, encoder):
        super(Create_Embeddings, self).__init__()
        self.encoder = encoder
        
    
    def subbatch_predict(self, model, batch, subbatch_size, concat=lambda old, new: tf.concat((old, new), axis=0)):
        def predict(i, result=None):
            n = i + subbatch_size
            pred = tf.stop_gradient(model(batch[i:n]))
            if result is None:
                return [n, pred]
            return [n, concat(result, pred)]
        i, result = predict(0)
        batch_size = tf.shape(batch)[0]
        i, result = tf.while_loop(
            cond=lambda i, _: i < batch_size,
            body=predict,
            loop_vars=[i, result], 
            shape_invariants = [tf.TensorShape([]), tf.TensorShape([None, self.encoder.base.embed_dim])],
            parallel_iterations=1)

        return result
    
    def modify_data_for_input(self, data):
        batch_size = tf.shape(data)[0]
        subsample_size = tf.shape(data)[1]
        flat_data = tf.reshape(data, (batch_size*subsample_size, -1))
        encoded = self.subbatch_predict(self.encoder, flat_data, 128)
        return tf.reshape(encoded, (batch_size, subsample_size, -1))
    
    def call(self, data):
        embeddings = self.modify_data_for_input(data)
        return embeddings

---
# Set Transformer Class

In [15]:
class Set_Transformer_Model(keras.Model):
    def __init__(self, num_induce, embed_dim, attention_num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, encoder, max_files, num_seeds, pooling_num_heads):
        super(Set_Transformer_Model, self).__init__()
        
        
        self.num_induce = num_induce
        self.embed_dim = embed_dim
        self.attention_num_heads = attention_num_heads
        self.stack = stack
        self.use_layernorm = use_layernorm
        self.pre_layernorm = pre_layernorm
        self.use_keras_mha = use_keras_mha
        self.encoder = encoder
        self.max_files = max_files
        self.num_seeds = num_seeds 
        self.pooling_num_heads = pooling_num_heads
        
        
        self.embedding_layer = Create_Embeddings(self.encoder)
        self.linear_layer = keras.layers.Dense(self.embed_dim)
        self.attention_blocks = []
        
        if self.num_induce == 0:
            for i in range(self.stack):
                self.attention_blocks.append(Set_Transformer.SetAttentionBlock(embed_dim=self.embed_dim,num_heads=self.attention_num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))
        else:
            for i in range(self.stack):
                self.attention_blocks.append(Set_Transformer.InducedSetAttentionBlock(embed_dim=self.embed_dim,num_heads=self.attention_num_heads, num_induce=self.num_induce, use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha))

        self.pooling_layer = Set_Transformer.PoolingByMultiHeadAttention(num_seeds=self.num_seeds,embed_dim=self.embed_dim,num_heads=self.pooling_num_heads,use_layernorm=self.use_layernorm,pre_layernorm=self.pre_layernorm,use_keras_mha=self.use_keras_mha,is_final_block=True)
    
        self.reshape_layer = keras.layers.Reshape((self.embed_dim,))
        
        self.output_layer = keras.layers.Dense(self.max_files)
    
    def call(self, data):
        
            embeddings = self.embedding_layer(data)
            
            linear_transform = self.linear_layer(embeddings)
            
            attention = linear_transform
            
            for attention_block in self.attention_blocks:
                attention = attention_block([attention, None])
                
            pooling = self.pooling_layer(attention)
        
            reshape = self.reshape_layer(pooling)
            
            output = self.output_layer(reshape)    
            
            return output

---
# Create Model

In [16]:
#Hyperparameters
num_induce = 0
embed_dim = 64
attention_num_heads = 8
stack = 4
use_layernorm = True
pre_layernorm = True
use_keras_mha = True
encoder = pretrained_encoder
num_seeds = 1
pooling_num_heads = 1

In [17]:
model = Set_Transformer_Model(num_induce, embed_dim, attention_num_heads, stack, use_layernorm, pre_layernorm, use_keras_mha, encoder, max_files, num_seeds, pooling_num_heads)
model.compile(optimizer=keras.optimizers.Adam(1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = [keras.metrics.sparse_categorical_accuracy])

In [18]:
epochs = 10000

In [19]:
history = model.fit(x=train_dataset, validation_data=val_dataset, epochs=epochs, verbose=1,)

Epoch 1/10000
Epoch 2/10000
Epoch 3/10000
Epoch 4/10000
Epoch 5/10000
Epoch 6/10000
Epoch 7/10000
Epoch 8/10000
Epoch 9/10000
Epoch 10/10000
Epoch 11/10000
Epoch 12/10000
Epoch 13/10000

KeyboardInterrupt: 