In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

In [2]:
import abc
from collections import defaultdict
import json
from itertools import chain
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
from tqdm.auto import tqdm
from typing import Iterable, Generator, Optional
import time
import wandb

from dnadb.datasets import Greengenes, Silva
from dnadb import dna, fasta, sample, taxonomy

from deepdna.data.dataset import Dataset
from deepdna.data.tokenizers import AbstractTaxonomyTokenizer, NaiveTaxonomyTokenizer
from deepdna.nn.models import custom_model, dnabert, load_model, taxonomy as tax_models
from deepdna.nn.utils import encapsulate_model
from deepdna.nn import layers, functional, utils

In [3]:
tfu.devices.select_gpu(0)

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

## Dataset

In [4]:
dataset = Dataset("/home/dwl2x/work/Datasets/Silva2/0")
train_fastas = tuple(map(sample.load_fasta, dataset.fasta_dbs(Dataset.Split.Train)))
train_tax = tuple(map(taxonomy.TaxonomyDb, dataset.taxonomy_dbs(Dataset.Split.Train)))
test_fastas = tuple(map(sample.load_fasta, dataset.fasta_dbs(Dataset.Split.Test)))
test_tax = tuple(map(taxonomy.TaxonomyDb, dataset.taxonomy_dbs(Dataset.Split.Test)))

In [5]:
for label in test_tax[0]:
    assert train_tax[0].contains_label(label)

In [6]:
tokenizer = NaiveTaxonomyTokenizer(depth=6)
for db in train_tax:
    tokenizer.add_labels(db)
tokenizer.build()

In [7]:
from deepdna.data.samplers import SampleSampler, SequenceSampler
from deepdna.nn.data_generators import _encode_sequences, BatchGenerator
from typing import Any, cast

class SequenceTaxonomyGenerator(BatchGenerator):
    def __init__(
        self,
        fasta_taxonomy_pairs: Iterable[tuple[sample.FastaSample, taxonomy.TaxonomyDb]],
        sequence_length: int,
        taxonomy_tokenizer: AbstractTaxonomyTokenizer,
        kmer: int = 1,
        subsample_size: int|None = None,
        batch_size: int = 32,
        batches_per_epoch: int = 100,
        augment_slide: bool = True,
        augment_ambiguous_bases: bool = True,
        balance: bool = False,
        shuffle: bool = True,
        rng: np.random.Generator = np.random.default_rng()
    ):
        super().__init__(
            batch_size=batch_size,
            batches_per_epoch=batches_per_epoch,
            shuffle=shuffle,
            rng=rng
        )
        fasta_samples, taxonomy_dbs = zip(*fasta_taxonomy_pairs)
        self.sample_sampler = SampleSampler(cast(tuple[sample.FastaSample, ...], fasta_samples))
        self.sequence_sampler = SequenceSampler(sequence_length, augment_slide)
        self.taxonomy_dbs: tuple[taxonomy.TaxonomyDb, ...] = cast(Any, taxonomy_dbs)
        self.kmer = kmer
        self.taxonomy_tokenizer = taxonomy_tokenizer
        self.subsample_size = subsample_size
        self.augment_ambiguous_bases = augment_ambiguous_bases
        self.balance = balance

    @property
    def sequence_length(self) -> int:
        return self.sequence_sampler.sequence_length

    def generate_batch(
        self,
        rng: np.random.Generator
    ) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]:
        subsample_size = self.subsample_size or 1
        sequences = np.empty((self.batch_size, subsample_size), dtype=f"<U{self.sequence_length}")
        sample_ids = np.empty(self.batch_size, dtype=np.int32)
        sequence_ids = [None] * self.batch_size
        label_ids = np.empty((self.batch_size, subsample_size, self.taxonomy_tokenizer.depth), dtype=np.int32)
        samples = self.sample_sampler.sample_with_ids(self.batch_size, self.balance, rng)
        for i, (sample_id, sample) in enumerate(samples):
            tax_db = self.taxonomy_dbs[sample_id]
            sequence_info = tuple(self.sequence_sampler.sample_with_ids(sample, subsample_size, rng))
            sequence_ids[i], sequences[i] = zip(*sequence_info)
            sample_ids[i] = sample_id
            label_ids[i] = [self.taxonomy_tokenizer.tokenize_label(tax_db.fasta_id_to_label(fasta_id)) for fasta_id in sequence_ids[i]]
        sequences = _encode_sequences(sequences, self.augment_ambiguous_bases, self.rng)
        if self.subsample_size is None:
            sequences = np.squeeze(sequences, axis=1)
            label_ids = np.squeeze(label_ids, axis=1)
        sequences = sequences.astype(np.int32)
        if self.kmer > 1:
            sequences = dna.encode_kmers(sequences, self.kmer, not self.augment_ambiguous_bases).astype(np.int32) # type: ignore
        return sample_ids, sequence_ids, sequences, tuple(label_ids.T)

    def reduce_batch(self, batch):
        # remove sample IDs and sequence IDs
        return batch[2:]

## Model

In [8]:
api = wandb.Api()
# run = wandb.init(project="dnabert-taxonomy-topdown", name="64d-150l")

In [9]:
path = api.artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-silva-64:v3").download()
# path = run.use_artifact("sirdavidludwig/dnabert-pretrain/dnabert-pretrain-silva-64:v3").download()
dnabert_model = load_model(path, dnabert.DnaBertPretrainModel).base

[34m[1mwandb[0m:   4 of 4 files downloaded.  


In [10]:
encoder = dnabert.DnaBertEncoderModel(dnabert_model, 256)
model = tax_models.BertaxTaxonomyClassificationModel(encoder, tokenizer)
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4))

In [9]:
common_args = dict(
    sequence_length = 150,
    kmer = 3,
    taxonomy_tokenizer = model.taxonomy_tokenizer,
    subsample_size=None,
    batch_size = 256,
)

train_data = SequenceTaxonomyGenerator(
    zip(train_fastas, train_tax),
    batches_per_epoch=100,
    **common_args)
test_data = SequenceTaxonomyGenerator(
    zip(test_fastas, test_tax),
    batches_per_epoch=20,
    **common_args)

NameError: name 'model' is not defined

In [12]:
# wandb_callback = wandb.keras.WandbCallback(save_model=False)
# wandb_callback.save_model_as_artifact = False
checkpoint = tf.keras.callbacks.ModelCheckpoint("logs/models/dnabert_taxonomy_bertax", save_best=False)

In [13]:
model.fit(train_data, validation_data=test_data, epochs=1, initial_epoch=0, callbacks=[checkpoint])



2023-09-08 20:21:54.261820: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.
2023-09-08 20:21:54.410859: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.
2023-09-08 20:21:54.551404: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_bertax/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_bertax/assets




  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


<keras.callbacks.History at 0x7f09ec395b10>

In [14]:
model.save("logs/models/dnabert_taxonomy_bertax")

2023-09-08 20:22:12.860141: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.
2023-09-08 20:22:13.009454: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_bertax/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_bertax/assets
  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


In [11]:
model = load_model("logs/models/dnabert_taxonomy_bertax")

2023-09-08 20:23:26.786024: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.
2023-09-08 20:23:26.786059: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 374475792 exceeds 10% of free system memory.


In [414]:
# [np.taxonomy_tokenizermax(row, axis=1) for row in model.predict(test_data)]

In [16]:
model.save("logs/models/dnabert_taxonomy_bertax")



INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive_bak/assets


INFO:tensorflow:Assets written to: logs/models/dnabert_taxonomy_naive_bak/assets
  layer_config = serialize_layer_fn(layer)
  return generic_utils.serialize_keras_object(obj)


In [18]:
a = wandb.Artifact(name="dnabert-taxonomy-bertax-64d-150l", type="model")

In [19]:
a.add_dir("logs/models/dnabert_taxonomy_bertax")

[34m[1mwandb[0m: Adding directory to artifact (./logs/models/dnabert_taxonomy_naive)... Done. 0.1s


In [20]:
run.log_artifact(a)

<Artifact dnabert-taxonomy-naive-64d-150l>

In [21]:
run.finish()

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,█▆▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparse_categorical_accuracy,▁▃▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇█████████████████
val_loss,█▆▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_sparse_categorical_accuracy,▁▃▄▄▅▅▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇████████████

0,1
best_epoch,2693.0
best_val_loss,0.84462
epoch,2999.0
loss,1.0782
sparse_categorical_accuracy,0.75941
val_loss,0.89255
val_sparse_categorical_accuracy,0.79707
