In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

In [2]:
import abc
from collections import defaultdict
import json
from itertools import chain, repeat
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
from tqdm.auto import tqdm
from typing import Iterable, Generator, Optional
import time
import wandb

from dnadb.datasets import Greengenes, Silva
from dnadb import dna, fasta, sample, taxonomy

from deepdna.data.dataset import Dataset
from deepdna.nn.data_generators import SequenceGenerator
from deepdna.nn.models import custom_model, dnabert, load_model, setbert, taxonomy as tax_models
from deepdna.nn.utils import encapsulate_model
from deepdna.nn import layers, functional, utils


In [3]:
tfu.devices.select_gpu(0)

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

## Dataset

In [4]:
dataset = Dataset("/home/dwl2x/work/Datasets/synthetic")
tax_db_path = list(dataset.taxonomy_dbs(Dataset.Split.Train))[0]
tax_db = taxonomy.TaxonomyDb(tax_db_path)

fasta_path = next(dataset.fasta_dbs(Dataset.Split.Train))
fasta_mapping_path = fasta_path.with_suffix(".mapping.db")
samples = tuple(sample.load_multiplexed_fasta(fasta_path, fasta_mapping_path))
len(samples)

210

In [5]:
tax_db

<dnadb.taxonomy.TaxonomyDb at 0x7f7ff504e6c0>

In [6]:
sample_tax_pairs = list(zip(samples, repeat(tax_db)))

In [7]:
from deepdna.data.samplers import SampleSampler, SequenceSampler
from deepdna.nn.data_generators import _encode_sequences, BatchGenerator
from typing import Any, cast

class SequenceTaxonomyGenerator(BatchGenerator):
    def __init__(
        self,
        fasta_taxonomy_pairs: Iterable[tuple[sample.FastaSample, taxonomy.TaxonomyDb]],
        sequence_length: int,
        taxonomy_id_map: dict[str, int],
        kmer: int = 1,
        subsample_size: int|None = None,
        batch_size: int = 32,
        batches_per_epoch: int = 100,
        augment_slide: bool = True,
        augment_ambiguous_bases: bool = True,
        balance: bool = False,
        shuffle: bool = True,
        rng: np.random.Generator = np.random.default_rng()
    ):
        super().__init__(
            batch_size=batch_size,
            batches_per_epoch=batches_per_epoch,
            shuffle=shuffle,
            rng=rng
        )
        fasta_samples, taxonomy_dbs = zip(*fasta_taxonomy_pairs)
        self.sample_sampler = SampleSampler(cast(tuple[sample.FastaSample, ...], fasta_samples))
        self.sequence_sampler = SequenceSampler(sequence_length, augment_slide)
        self.taxonomy_dbs: tuple[taxonomy.TaxonomyDb, ...] = cast(Any, taxonomy_dbs)
        self.kmer = kmer
        self.taxonomy_id_map = taxonomy_id_map
        self.subsample_size = subsample_size
        self.augment_ambiguous_bases = augment_ambiguous_bases
        self.balance = balance

    @property
    def sequence_length(self) -> int:
        return self.sequence_sampler.sequence_length

    def generate_batch(
        self,
        rng: np.random.Generator
    ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
        subsample_size = self.subsample_size or 1
        sequences = np.empty((self.batch_size, subsample_size), dtype=f"<U{self.sequence_length}")
        sample_ids = np.empty(self.batch_size, dtype=np.int32)
        sequence_ids = [None] * self.batch_size
        label_ids = np.empty((self.batch_size, subsample_size), dtype=np.int32)
        samples = self.sample_sampler.sample_with_ids(self.batch_size, self.balance, rng)
        for i, (sample_id, sample) in enumerate(samples):
            tax_db = self.taxonomy_dbs[sample_id]
            sequence_info = tuple(self.sequence_sampler.sample_with_ids(sample, subsample_size, rng))
            sequence_ids[i], sequences[i] = zip(*sequence_info)
            sample_ids[i] = sample_id
            label_ids[i] = [self.taxonomy_id_map[tax_db.fasta_id_to_label(fasta_id)] for fasta_id in sequence_ids[i]]
        sequences = _encode_sequences(sequences, self.augment_ambiguous_bases, self.rng)
        if self.subsample_size is None:
            sequences = np.squeeze(sequences, axis=1)
        sequences = sequences.astype(np.int32)
        if self.kmer > 1:
            sequences = dna.encode_kmers(sequences, self.kmer, not self.augment_ambiguous_bases).astype(np.int32) # type: ignore
        return sample_ids, sequence_ids, sequences, label_ids

    def reduce_batch(self, batch):
        # remove sample IDs and sequence IDs
        return batch[2:]

## Model

In [8]:
# api = wandb.Api()
run = wandb.init(project="setbert-taxonomy-naive", name="64d-150l")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msirdavidludwig[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-silva-64:v3").download()
path = "/home/dwl2x/.cache/wandb/wandb/run-20230812_235347-en0l0i7m/files/model"
setbert_model = load_model(path, setbert.SetBertPretrainModel).base



In [10]:
# class NaiveTaxonomyClassificationModel(custom_model.ModelWrapper, custom_model.CustomModel[tf.Tensor, tuple[tf.Tensor, ...]]):
#     def __init__(
#         self,
#         base: tf.keras.Model,
#         taxonomies: Iterable[str],
#         input_shape: Optional[tuple[int, ...]] = None,
#         **kwargs
#     ):
#         super().__init__(**kwargs)
#         self.base = base
#         self.taxonomy_id_map = {}
#         self._model_input_shape = input_shape
#         for tax in taxonomies:
#             if tax not in self.taxonomy_id_map:
#                 assert isinstance(tax, str), "Taxonomy label must be a string."
#                 self.taxonomy_id_map[tax] = len(self.taxonomy_id_map)
#         self.model = self.build_model()

#     def default_loss(self):
#         return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

#     def default_metrics(self):
#         return [
#             tf.keras.metrics.SparseCategoricalAccuracy()
#         ]

#     def build_model(self):
#         if self._model_input_shape is not None:
#             x = tf.keras.layers.Input(self._model_input_shape)
#             y = self.base(x)
#         else:
#             x, y = encapsulate_model(self.base)
#         y = tf.keras.layers.Dense(len(self.taxonomy_id_map))(y)
#         model = tf.keras.Model(x, y)
#         return model

#     def get_config(self):
#         return super().get_config() | {
#             "base": self.base,
#             "taxonomies": list(self.taxonomy_id_map.keys()),
#             "input_shape": self._model_input_shape
#         }

In [11]:
encoder = setbert.SetBertEncoderModel(setbert_model, compute_sequence_embeddings=True, output_class=False, output_sequences=True)
model = tax_models.NaiveTaxonomyClassificationModel(encoder, tax_db)
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4))

In [12]:
model.summary()

Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, 1000, 148)]       0         
                                                                 
 set_bert_encoder_model (Set  ((None, 1000, 64),       2340672   
 BertEncoderModel)            [[(None, 8, 1001, 1001)            
                             , (None, 8, 1001, 1001),            
                              (None, 8, 1001, 1001),             
                              (None, 8, 1001, 1001),             
                              (None, 8, 1001, 1001),             
                              (None, 8, 1001, 1001),             
                              (None, 8, 1001, 1001),             
                              (None, 8, 1001, 1001)]]            
                             )                                   
                                                           

In [13]:
common_args = dict(
    sequence_length = 150,
    kmer = 3,
    taxonomy_id_map = model.taxonomy_id_map,
    batch_size = 16,
    subsample_size=1000,
)

train_data = SequenceTaxonomyGenerator(
    sample_tax_pairs,
    batches_per_epoch=100,
    **common_args)
test_data = SequenceTaxonomyGenerator(
    sample_tax_pairs,
    batches_per_epoch=20,
    shuffle=False,
    **common_args)

In [14]:
batch = train_data[0]

In [15]:
wandb_callback = wandb.keras.WandbCallback(save_model=False)
wandb_callback.save_model_as_artifact = False
checkpoint = tf.keras.callbacks.ModelCheckpoint("logs/models/dnabert_taxonomy_naive", save_best=True)

In [None]:
model.fit(train_data, validation_data=test_data, epochs=300, initial_epoch=200, callbacks=[wandb_callback, checkpoint])

Epoch 201/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 202/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 203/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 204/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 205/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 206/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 207/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 208/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 209/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 210/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 211/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 212/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 213/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 214/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 215/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 216/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 217/300



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


Epoch 218/300

In [23]:
model.save("logs/models/dnabert_taxonomy_naive")



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets
  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


In [17]:
1

1