In [1]:
import bootstrap

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import tensorflow as tf
import tensorflow.keras as keras
import tf_utils as tfu
# from settransformer import ConditionedSetAttentionBlock, spectral_dense, InducedSetEncoder, InducedSetAttentionBlock
import wandb

from common.models.dnabert import DnaBertPretrainModel, DnaBertEncoderModel
from common.models.gan import Gan, ConditionalGan
from common.models.gast import GastGenerator, GastDiscriminator
from common.data import find_dbs, DnaSequenceGenerator, DnaSampleGenerator
from common.models import gast
from common.layers import SampleSet

In [3]:
strategy = tfu.strategy.gpu(0)

2022-05-16 17:46:27.985634: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 17:46:27.985855: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 17:46:27.991326: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 17:46:27.991534: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-05-16 17:46:27.991701: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from S

In [4]:
api = wandb.Api()

In [5]:
prefix = "./artifacts/dnasamples:v2/train"
samples = [os.path.join(prefix, f) for f in sorted(os.listdir(prefix))]
samples

['./artifacts/dnasamples:v2/train/fall_2016-10-07.db',
 './artifacts/dnasamples:v2/train/fall_2017-10-13.db',
 './artifacts/dnasamples:v2/train/spring_2016-04-22.db',
 './artifacts/dnasamples:v2/train/spring_2017-05-02.db',
 './artifacts/dnasamples:v2/train/spring_2018-04-23.db',
 './artifacts/dnasamples:v2/train/spring_2019-05-14.db',
 './artifacts/dnasamples:v2/train/spring_2020-05-11.db']

In [13]:
dataset = DnaSampleGenerator(samples, 100, 150, kmer=3, batch_size=16, batches_per_epoch=100, include_labels=False, use_batch_as_labels=False)

In [14]:
dataset[0][1].dtype

dtype('int32')

In [15]:
path = api.artifact("deep-learning-dna/dnabert-pretrain:latest").download()
with strategy.scope():
	encoder = DnaBertEncoderModel(DnaBertPretrainModel.load(path).base)

In [20]:
class DnaSampleGan(Gan):
	def __init__(self, generator, discriminator, encoder, batch_size, subsample_size, encoder_batch_size=512):
		super().__init__(generator, discriminator)
		self.encoder = encoder
		self.encoder.trainable = False
		self.encoder_batch_size = encoder_batch_size

		# Can't obtain these dynamically in TF...
		self.batch_size = batch_size
		self.subsample_size = subsample_size

	def modify_data_for_input(self, data):
		flat_data = tf.reshape(data, (self.batch_size*self.subsample_size, -1))
		encoded = []
		for i in range(0, self.batch_size*self.subsample_size, self.encoder_batch_size):
			encoded.append(tf.stop_gradient(self.encoder(flat_data[i:i+self.encoder_batch_size])))
		return tf.reshape(tf.concat(encoded, axis=0), (self.batch_size, self.subsample_size, -1))

	def get_config(self):
		config = super().get_config()
		config.update({
			"encoder": self.encoder,
			"batch_size": self.batch_size,
			"subsample_size": self.subsample_size,
			"encoder_batch_size": self.encoder_batch_size
		})
		return config

In [21]:
tf.repeat(tf.range(7), 10)

<tf.Tensor: shape=(70,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
       6, 6, 6, 6], dtype=int32)>

In [22]:
with strategy.scope():
	g = GastGenerator(100, 64, 128, 128, 4, 4, 32, num_classes=1)
	d = GastDiscriminator(128, 128, 4, 4, 32, num_classes=1)
	gan = DnaSampleGan(g, d, encoder, 16, 100, encoder_batch_size=512)
	gan.compile(
		keras.losses.BinaryCrossentropy(from_logits=True, reduction="sum"),
		keras.optimizers.Adam(1e-4),
		keras.optimizers.Adam(1e-4),
		[keras.metrics.BinaryAccuracy()],
		[keras.metrics.BinaryAccuracy()]
	)
	gan.summary()

Model: "dna_sample_gan_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 gast_generator_2 (GastGener  multiple                 5080704   
 ator)                                                           
                                                                 
 gast_discriminator_2 (GastD  multiple                 1125633   
 iscriminator)                                                   
                                                                 
 dna_bert_encoder_model_1 (D  multiple                 2547584   
 naBertEncoderModel)                                             
                                                                 
Total params: 8,753,929
Trainable params: 6,179,073
Non-trainable params: 2,574,856
_________________________________________________________________


In [23]:
with strategy.scope():
	history = gan.fit(dataset, initial_epoch=0, epochs=1)

2022-05-16 17:47:35.413985: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:766] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_2"
op: "FlatMapDataset"
input: "TensorDataset/_1"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_flat_map_fn_124396"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\021FlatMapDataset:27"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
        dim {
          size: -1
        }
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT32
    }
  }
}
. Consider either turning off auto-

 15/100 [===>..........................] - ETA: 39s - generator_loss: 3.5661 - discriminator_loss: 31.0797 - binary_accuracy: 0.4875

KeyboardInterrupt: 

In [165]:
gan.save("./testmodel", save_format="savedmodel")

INFO:tensorflow:Assets written to: ./testmodel/assets
