In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

In [2]:
import abc
from collections import defaultdict
import json
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
from tqdm.auto import tqdm
from typing import Iterable, Generator, Optional
import time
import wandb

from dnadb.datasets import Greengenes, Silva
from dnadb import dna, fasta, sample, taxonomy

from deepdna.data.dataset import Dataset
from deepdna.nn.models import custom_model, dnabert, load_model
from deepdna.nn.models.utils import encapsulate_model
from deepdna.nn import layers, functional, utils


In [3]:
tfu.devices.select_gpu(0)

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

## Dataset

In [4]:
dataset = Dataset("/home/dwl2x/work/Datasets/Silva2/0")
train_fastas = tuple(map(sample.load_fasta, dataset.fasta_dbs(Dataset.Split.Train)))
train_tax = tuple(map(taxonomy.TaxonomyDb, dataset.taxonomy_dbs(Dataset.Split.Train)))
test_fastas = tuple(map(sample.load_fasta, dataset.fasta_dbs(Dataset.Split.Test)))
test_tax = tuple(map(taxonomy.TaxonomyDb, dataset.taxonomy_dbs(Dataset.Split.Test)))

In [20]:
n = 0
for label in train_tax[0]:
    n += len(list(train_tax[0].fasta_ids_with_label(train_tax[0].label_to_index(label))))
n

2113503

In [21]:
n = 0
for label in test_tax[0]:
    n += len(list(test_tax[0].fasta_ids_with_label(test_tax[0].label_to_index(label))))
n

111237

In [5]:
for label in test_tax[0]:
    assert train_tax[0].contains_label(label)

In [6]:
from deepdna.data.samplers import SampleSampler, SequenceSampler
from deepdna.nn.data_generators import _encode_sequences, BatchGenerator
from typing import Any, cast

class SequenceTaxonomyGenerator(BatchGenerator):
    def __init__(
        self,
        fasta_taxonomy_pairs: Iterable[tuple[sample.FastaSample, taxonomy.TaxonomyDb]],
        sequence_length: int,
        taxonomy_id_map: dict[str, int],
        kmer: int = 1,
        subsample_size: int|None = None,
        batch_size: int = 32,
        batches_per_epoch: int = 100,
        augment_slide: bool = True,
        augment_ambiguous_bases: bool = True,
        balance: bool = False,
        shuffle: bool = True,
        rng: np.random.Generator = np.random.default_rng()
    ):
        super().__init__(
            batch_size=batch_size,
            batches_per_epoch=batches_per_epoch,
            shuffle=shuffle,
            rng=rng
        )
        fasta_samples, taxonomy_dbs = zip(*fasta_taxonomy_pairs)
        self.sample_sampler = SampleSampler(cast(tuple[sample.FastaSample, ...], fasta_samples))
        self.sequence_sampler = SequenceSampler(sequence_length, augment_slide)
        self.taxonomy_dbs: tuple[taxonomy.TaxonomyDb, ...] = cast(Any, taxonomy_dbs)
        self.kmer = kmer
        self.taxonomy_id_map = taxonomy_id_map
        self.subsample_size = subsample_size
        self.augment_ambiguous_bases = augment_ambiguous_bases
        self.balance = balance

    @property
    def sequence_length(self) -> int:
        return self.sequence_sampler.sequence_length

    def generate_batch(
        self,
        rng: np.random.Generator
    ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
        subsample_size = self.subsample_size or 1
        sequences = np.empty((self.batch_size, subsample_size), dtype=f"<U{self.sequence_length}")
        sample_ids = np.empty(self.batch_size, dtype=np.int32)
        sequence_ids = [None] * self.batch_size
        label_ids = np.empty((self.batch_size, subsample_size), dtype=np.int32)
        samples = self.sample_sampler.sample_with_ids(self.batch_size, self.balance, rng)
        for i, (sample_id, sample) in enumerate(samples):
            tax_db = self.taxonomy_dbs[sample_id]
            sequence_info = tuple(self.sequence_sampler.sample_with_ids(sample, subsample_size, rng))
            sequence_ids[i], sequences[i] = zip(*sequence_info)
            sample_ids[i] = sample_id
            label_ids[i] = [self.taxonomy_id_map[tax_db.fasta_id_to_label(fasta_id)] for fasta_id in sequence_ids[i]]
        sequences = _encode_sequences(sequences, self.augment_ambiguous_bases, self.rng)
        if self.subsample_size is None:
            sequences = np.squeeze(sequences, axis=1)
        sequences = sequences.astype(np.int32)
        if self.kmer > 1:
            sequences = dna.encode_kmers(sequences, self.kmer, not self.augment_ambiguous_bases).astype(np.int32) # type: ignore
        return sample_ids, sequence_ids, sequences, label_ids

    def reduce_batch(self, batch):
        # remove sample IDs and sequence IDs
        return batch[2:]

## Model

In [7]:
# api = wandb.Api()
run = wandb.init(project="dnabert-taxonomy-naive", name="64d-150l")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msirdavidludwig[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
# path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-silva-64:v3").download()
path = run.use_artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-silva-64:v3").download()
dnabert_model = load_model(path, dnabert.DnaBertPretrainModel).base

[34m[1mwandb[0m:   4 of 4 files downloaded.  


In [9]:
class NaiveTaxonomyClassificationModel(custom_model.ModelWrapper, custom_model.CustomModel[tf.Tensor, tuple[tf.Tensor, ...]]):
    def __init__(
        self,
        base: tf.keras.Model,
        taxonomies: Iterable[str],
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base = base
        self.taxonomy_id_map = {}
        for tax in taxonomies:
            if tax not in self.taxonomy_id_map:
                assert isinstance(tax, str), "Taxonomy label must be a string."
                self.taxonomy_id_map[tax] = len(self.taxonomy_id_map)
        self.model = self.build_model()

    def default_loss(self):
        return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    def default_metrics(self):
        return [
            tf.keras.metrics.SparseCategoricalAccuracy()
        ]

    def build_model(self):
        x, y = encapsulate_model(self.base)
        y = tf.keras.layers.Dense(len(self.taxonomy_id_map))(y)
        model = tf.keras.Model(x, y)
        return model

    def get_config(self):
        return super().get_config() | {
            "base": self.base,
            "taxonomies": list(self.taxonomy_id_map.keys())
        }

In [10]:
encoder = dnabert.DnaBertEncoderModel(dnabert_model, 256)
model = NaiveTaxonomyClassificationModel(encoder, chain(*map(iter, train_tax)))
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4))

In [11]:
model.summary()

Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 148)]             0         
                                                                 
 dna_bert_encoder_model (Dna  (None, 64)               1210496   
 BertEncoderModel)                                               
                                                                 
 dense_18 (Dense)            (None, 11070)             719550    
                                                                 
Total params: 1,930,046
Trainable params: 1,930,046
Non-trainable params: 0
_________________________________________________________________


In [12]:
common_args = dict(
    sequence_length = 150,
    kmer = 3,
    taxonomy_id_map = model.taxonomy_id_map,
    batch_size = 256,
)

train_data = SequenceTaxonomyGenerator(
    zip(train_fastas, train_tax),
    batches_per_epoch=100,
    **common_args)
test_data = SequenceTaxonomyGenerator(
    zip(test_fastas, test_tax),
    batches_per_epoch=20,
    **common_args)

In [13]:
train_data[0][1]

array([[   30],
       [  226],
       [  258],
       [ 2033],
       [ 2105],
       [ 1636],
       [ 3122],
       [ 1258],
       [   69],
       [   98],
       [ 2827],
       [ 3489],
       [  224],
       [  643],
       [  360],
       [   50],
       [ 5017],
       [  554],
       [ 2286],
       [   13],
       [  491],
       [  374],
       [ 6266],
       [  623],
       [  239],
       [   50],
       [ 6610],
       [ 6126],
       [ 3802],
       [    5],
       [  984],
       [  325],
       [  258],
       [  239],
       [  623],
       [  716],
       [  987],
       [  223],
       [  647],
       [ 3650],
       [  223],
       [ 1788],
       [ 4724],
       [  775],
       [   48],
       [  239],
       [ 3663],
       [  239],
       [ 2001],
       [   46],
       [ 6102],
       [ 1753],
       [  623],
       [   48],
       [  760],
       [  239],
       [  258],
       [ 4619],
       [ 1666],
       [ 7781],
       [ 5735],
       [  239],
       [

In [14]:
test_data[0][1]

array([[ 239],
       [1753],
       [  11],
       [ 239],
       [1133],
       [ 215],
       [   8],
       [  43],
       [8298],
       [ 787],
       [2944],
       [ 623],
       [ 698],
       [ 699],
       [1130],
       [ 169],
       [ 239],
       [4663],
       [  13],
       [ 340],
       [ 282],
       [1842],
       [4064],
       [2894],
       [4221],
       [ 716],
       [ 159],
       [ 141],
       [ 159],
       [ 686],
       [  46],
       [ 488],
       [  56],
       [  48],
       [ 423],
       [ 196],
       [  48],
       [   8],
       [  50],
       [ 225],
       [8863],
       [5767],
       [  17],
       [3489],
       [ 239],
       [7562],
       [5329],
       [ 196],
       [ 615],
       [6661],
       [ 471],
       [  82],
       [   8],
       [ 170],
       [3309],
       [ 226],
       [  75],
       [1065],
       [ 202],
       [   4],
       [ 683],
       [ 615],
       [  48],
       [ 239],
       [ 239],
       [  79],
       [ 9

In [15]:
model(train_data[0][0])

<tf.Tensor: shape=(256, 11070), dtype=float32, numpy=
array([[ 0.00583833, -0.08855611,  0.18546222, ..., -0.016127  ,
        -0.2977479 , -0.18200122],
       [-0.3335538 , -0.1786217 ,  0.09431156, ..., -0.13177276,
        -0.19796431,  0.03931923],
       [ 0.32255176, -0.0689109 ,  0.10260673, ..., -0.05669966,
        -0.14712423, -0.1045833 ],
       ...,
       [-0.13978443, -0.44564387, -0.08463553, ..., -0.13566524,
         0.00341579,  0.09858178],
       [-0.23548141, -0.1444789 , -0.02358536, ...,  0.2709653 ,
        -0.11341795, -0.24591446],
       [-0.08897561, -0.0622659 ,  0.06874194, ..., -0.00876656,
        -0.19912681, -0.20852236]], dtype=float32)>

In [16]:
wandb_callback = wandb.keras.WandbCallback(save_model=False)
wandb_callback.save_model_as_artifact = False
checkpoint = tf.keras.callbacks.ModelCheckpoint("logs/models/dnabert_taxonomy_naive", save_best=True)

In [None]:
model.fit(train_data, validation_data=test_data, epochs=3000, initial_epoch=2500, callbacks=[wandb_callback, checkpoint])

In [23]:
model.save("logs/models/dnabert_taxonomy_naive")



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive/assets
  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)
