In [45]:
from collections import defaultdict
from dataclasses import replace
from dnadb import dna, fasta, fastq, sample, taxonomy
from dnadb.datasets import Silva
from functools import cache
import numpy as np
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
import time
from tqdm.auto import tqdm
import wandb

from deepdna.nn.models import load_model
from deepdna.nn.models.taxonomy import TopDownTaxonomyClassificationModel, TopDownConcatTaxonomyClassificationModel

In [2]:
tfu.devices.select_gpu(1)

2023-06-21 08:22:44.662716: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08:22:44.662930: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08:22:44.668929: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08:22:44.669149: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08:22:44.669323: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from S

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')])

In [3]:
api = wandb.Api()

In [4]:
rng = np.random.default_rng()

## Taxonomy Classification Model

In [5]:
path = api.artifact("sirdavidludwig/dnabert-taxonomy-complete/dnabert-taxonomy-topdown-concat.silva.64d.150l:latest").download()
tax_model = load_model(path, TopDownConcatTaxonomyClassificationModel)
tax_model

[34m[1mwandb[0m: Downloading large artifact dnabert-taxonomy-topdown-concat.silva.64d.150l:latest, 438.34MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:0.6
2023-06-21 08:22:45.788510: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-21 08:22:45.790462: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08:22:45.790705: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2023-06-21 08

<deepdna.nn.models.taxonomy.TopDownConcatTaxonomyClassificationModel at 0x7f1ea65015a0>

In [6]:
tax_model.input_shape

(None, 148)

## SILVA Dataset

In [7]:
dataset_root = Path("~/work/Datasets/synthetic").expanduser()
dataset_root.mkdir(exist_ok=True)
dataset_root

PosixPath('/home/dwl2x/work/Datasets/synthetic')

In [8]:
silva = Silva()
silva.path

PosixPath('/tmp/Silva')

Write the SILVA sequences to a FASTA database.

In [9]:
# db = fasta.FastaDbFactory(dataset_root / "silva.fasta.db")
# db.write_entries(tqdm(silva.sequences()))
# db.close()

Create a index-to-sequenceID mapping.

In [10]:
# db = fasta.FastaIndexDbFactory(dataset_root / "silva.fasta.index.db")
# db.write_entries(silva.sequences())
# db.close()

Write the SILVA taxonomies to a taxonomy TSV database.

In [11]:
# db = taxonomy.TaxonomyDbFactory(dataset_root / "silva.tax.tsv.db")
# for entry in tqdm(silva.taxonomies()):
#     entry = replace(entry, label=entry.label + "; s__")
#     db.write_entry(entry)
# db.close()

In [90]:
uncultured_families = set()
cultured_families = set()
for entry in silva.taxonomies():
    taxons = taxonomy.split_taxonomy(entry.label)
    if taxons[-1] == "uncultured":
        uncultured_families.add(taxons[-2])
    else:
        cultured_families.add(taxons[-2])
    taxonomy.join_taxonomy(taxons, depth=7)
    # if "uncultured" in
    # if "__uncultured" in entry.label:
    #     print(entry)

In [95]:
uncultured_families.intersection(cultured_families)

{'Acetobacteraceae',
 'Acidaminococcaceae',
 'Acidiferrobacteraceae',
 'Acidimicrobiaceae',
 'Acidobacteriaceae_(Subgroup_1)',
 'Actinomycetaceae',
 'Aerococcaceae',
 'Alcaligenaceae',
 'Alteromonadaceae',
 'Ammonificaceae',
 'Amoebophilaceae',
 'Anaerolineaceae',
 'Anaerovoracaceae',
 'Anaplasmataceae',
 'Aphelidea',
 'Aquificaceae',
 'Archaeoglobaceae',
 'Arcobacteraceae',
 'Ardenticatenaceae',
 'Arenicellaceae',
 'Arthracanthida',
 'Ascobolaceae',
 'Atopobiaceae',
 'Azospirillaceae',
 'Bacillaceae',
 'Bacillariophyceae',
 'Bacteriovoracaceae',
 'Balneolaceae',
 'Barnesiellaceae',
 'Bdellovibrionaceae',
 'Beggiatoaceae',
 'Beijerinckiaceae',
 'Bernardetiaceae',
 'Bifidobacteriaceae',
 'Blastocatellaceae',
 'Blattabacteriaceae',
 'Brocadiaceae',
 'Burkholderiaceae',
 'Butyricicoccaceae',
 'Caedibacteraceae',
 'Caldilineaceae',
 'Caloramatoraceae',
 'Cardiobacteriaceae',
 'Carnobacteriaceae',
 'Caulobacteraceae',
 'Cellvibrionaceae',
 'Cercomonadidae',
 'Chitinophagaceae',
 'Chlamydiac

In [89]:
(1, 2, 3).find(0)

AttributeError: 'tuple' object has no attribute 'find'

In [12]:
silva_index = fasta.FastaIndexDb(dataset_root / "silva.fasta.index.db")

Create a taxonomy-to-sequenceID mapping.

In [13]:
tax_to_fasta_id = {}
for entry in silva.taxonomies():
    label = entry.label + '; s__'
    if label not in tax_to_fasta_id:
        tax_to_fasta_id[label] = []
    tax_to_fasta_id[label].append(entry.identifier)
len(tax_to_fasta_id)

11070

## Taxonomy

In [67]:
def genus_id_to_tax_label(taxon_id, depth=7):
    taxon = tax_model.hierarchy.id_to_taxon_map[5][taxon_id]
    taxons = []
    while taxon.rank > -1:
        taxons.append(taxon.name)
        taxon = taxon.parent
    return taxonomy.join_taxonomy(taxons[::-1], depth=depth)
genus_to_tax_label_map = [genus_id_to_tax_label(t) for t in range(len(tax_model.hierarchy.taxons[5]))]

In [64]:
def trim_and_encode_sequence(sequence: str, length=150, rng=np.random.default_rng()):
    # Trim the sequence
    offset = rng.integers(0, len(sequence) - length + 1)
    sequence = sequence[offset:offset+length]
    assert len(sequence) == length
    return dna.encode_sequence(sequence)

In [65]:
taxon_counts_by_level = []
for i, taxons in enumerate(tax_model.hierarchy.taxons[:-1]):
    taxon_counts_by_level.append([1])
    for taxon in taxons:
        taxon_counts_by_level[i].append(len(taxon.children))
taxon_counts_by_level[0]

@tf.function()
def sequences_to_taxons(sequences, top_k=1):
    pred = tax_model(sequences)
    probabilities = tf.cast(pred[0], tf.float64)
    for i, taxon_counts in enumerate(taxon_counts_by_level, start=1):
        gate_indices = [j for j, count in enumerate(taxon_counts) for _ in range(count)]
        gate = tf.gather(probabilities, gate_indices, axis=-1)
        probabilities = gate*tf.cast(pred[i], tf.float64)
    probabilities = probabilities[:,1:]
    indices = tf.math.top_k(probabilities, k=1).indices
    return indices

## Nachusa Sequences

In [14]:
sample_folders = [
    "nachusa-2015-soil16S-sequences",
    "nachusa-2016-soil16S-sequences",
    "nachusa-2017-soil16S-sequences",
    "nachusa-2018-soil16S-sequences",
    # "nachusa-2019-soil16S-sequences", # missing
    "nachusa-2020-soil16S-sequences",
]

In [15]:
nachusa_fastqs = []
for sample_folder in sample_folders:
    path = Path("/home/shared/prism-data/Nachusa Sequences") / sample_folder
    nachusa_fastqs += list(path.iterdir())
len(nachusa_fastqs)

210

In [17]:
def trim_and_encode_sequence(sequence: str, length=150, rng=np.random.default_rng()):
    # Trim the sequence
    offset = rng.integers(0, len(sequence) - length + 1)
    sequence = sequence[offset:offset+length]
    assert len(sequence) == length
    return dna.encode_sequence(sequence)

In [33]:
taxon_counts_by_level = []
for i, taxons in enumerate(tax_model.hierarchy.taxons[:-1]):
    taxon_counts_by_level.append([1])
    for taxon in taxons:
        taxon_counts_by_level[i].append(len(taxon.children))
taxon_counts_by_level[0]

@tf.function()
def sequences_to_genus_taxons(sequences, top_k=1):
    pred = tax_model(sequences)
    probabilities = tf.cast(pred[0], tf.float64)
    outputs = []
    for i, taxon_counts in enumerate(taxon_counts_by_level, start=1):
        gate_indices = [j for j, count in enumerate(taxon_counts) for _ in range(count)]
        gate = tf.gather(probabilities, gate_indices, axis=-1)
        probabilities = gate*tf.cast(pred[i], tf.float64)
        normalized_probabilities = probabilities / tf.reduce_sum(probabilities)
        outputs.append(tf.math.top_k(normalized_probabilities, k=top_k).indices)
    # probabilities = probabilities[:,1:]
    indices = tf.math.top_k(probabilities, k=1).indices
    return indices

In [60]:
def genus_id_to_tax_label(taxon_id, depth=7):
    taxon = tax_model.hierarchy.id_to_taxon_map[5][taxon_id]
    taxons = []
    while taxon.rank > -1:
        taxons.append(taxon.name)
        taxon = taxon.parent
    return taxonomy.join_taxonomy(taxons[::-1], depth=depth)
genus_to_tax_label_map = [genus_id_to_tax_label(t) for t in range(len(tax_model.hierarchy.taxons[5]))]

In [61]:
batch_size = 256

t = time.time()
for file in nachusa_fastqs:
    # Fetch and encode the sequences
    sequences = np.array(list(map(trim_and_encode_sequence, map(lambda e: e.sequence, fastq.entries(file)))))
    sequences = dna.replace_ambiguous_encoded_bases(sequences)
    sequences = dna.encode_kmers(sequences, tax_model.base.base.kmer)

    # Classify the sequences
    abundances = defaultdict(int)
    for i in range(0, len(sequences), batch_size):
        batch = sequences[i:i+batch_size]
        pred = sequences_to_genus_taxons(batch)
        for row in pred:
            abundances[genus_to_tax_label_map[row[0]]] += 1
    break
time.time() - t

32.78842353820801

In [63]:
len(abundances)

1262

In [21]:
# result_logits = tax_model.model.predict(sequences, batch_size=256)

In [22]:
result_f64 = [r.astype(np.float64) for r in result]

In [24]:
taxon_counts_by_level = []
for i, taxons in enumerate(tax_model.hierarchy.taxons[:-1]):
    taxon_counts_by_level.append([1])
    for taxon in taxons:
        taxon_counts_by_level[i].append(len(taxon.children))
taxon_counts_by_level[0]

[1, 15, 89, 110]

In [25]:
probabilities = result_f64[0]
for i, taxon_counts in enumerate(taxon_counts_by_level, start=1):
    gate_indices = np.array([j for j, count in enumerate(taxon_counts) for _ in range(count)])
    gate = tf.gather(probabilities, gate_indices, axis=-1)
    probabilities = gate*result_f64[i]
probabilities = probabilities[:,1:]

In [30]:
indices = tf.math.top_k(probabilities, 5).indices
indices

<tf.Tensor: shape=(256, 5), dtype=int32, numpy=
array([[ 284,  286,  282,  280,  292],
       [ 977,  981,  980,  982,  976],
       [1482, 1491, 1432, 1517, 1434],
       ...,
       [ 965,  968,  940,  964,  952],
       [ 606,  605,  617,  596,  620],
       [2772, 2622, 2588, 1482, 1491]], dtype=int32)>

In [None]:
class TaxonomyBayesianClassifier(tf.keras.models.Model):
    def __init__(self, taxonomy_classifier, **kwargs):
        super().__init__(**kwargs)
        self.taxonomy_classifier = taxonomy_classifier

    def call(self, inputs, top_k=1):


In [32]:
for row in indices:
    assert genus_id_to_tax_label(row[0]) in tax_to_fasta_id

In [24]:
probabilities = probabilities.numpy()

In [25]:
batch = probabilities[:256,1:]

In [26]:
np.argpartition(batch, -5, axis=1)

array([[ 4404,  5281, 10561, ...,   282,   286,   284],
       [ 3481, 10562, 10559, ...,   982,   981,   980],
       [ 4634,  5281,     2, ...,  1481,  1482,  1491],
       ...,
       [ 4489, 10562, 10561, ...,   940,   968,   965],
       [ 1524,     0, 10546, ...,   606,   605,   617],
       [ 2936,  5281,     2, ...,  2588,  2622,  2772]])

In [None]:
batch[np.argpartition(batch, -5, axis=0)]

In [None]:
4

In [177]:
batch.shape

(256, 10564)

In [128]:
genus_id_to_tax_label(1)

'k__Archaea; p__Aenigmarchaeota; c__Aenigmarchaeia; o__Aenigmarchaeales; f__Aenigmarchaeales; g__Candidatus_Aenigmarchaeum; s__'

In [174]:
probabilities[:256,1:].numpy()[np.argpartition(probabilities[:256,1:], -5, axis=1)[:,-5:]]

AttributeError: 'numpy.ndarray' object has no attribute 'numpy'

In [36]:
a = np.random.normal(size=(2, 5))
a

array([[-0.40246809, -1.31822226,  1.12585387,  0.53307294,  0.52990556],
       [ 0.81157848,  1.08686566, -1.43370008,  0.11677023, -1.24866921]])

In [39]:
np.take_along_axis(a, np.argpartition(a, 2, axis=1), 1)

array([[-1.31822226, -0.40246809,  0.52990556,  0.53307294,  1.12585387],
       [-1.43370008, -1.24866921,  0.11677023,  0.81157848,  1.08686566]])

In [154]:
np.argsort(probabilities[:256,1:], axis=1)

array([[ 7271,  7273,  7233, ...,   282,   286,   284],
       [ 8795,  8802,  8806, ...,   980,   977,   981],
       [  168,  5610,  5559, ...,  1481,  1491,  1482],
       ...,
       [ 4655,  8803,  8504, ...,   940,   968,   965],
       [ 7572, 10297,   237, ...,   605,   617,   606],
       [ 8249,  8258,   118, ...,  2588,  2622,  2772]])

In [130]:
np.argmax(probabilities[:256,1:], axis=1)

array([ 284,  981, 1482,  336, 4012, 3801,  968, 2446, 2730, 2446, 2962,
       3008, 1482,  322,  533, 2966,  946, 3732, 2730, 2025,  292,  322,
       2469,  777, 1432, 3492, 2966,  322, 1482, 2002,  286,  617, 1345,
       2610, 2730,  784, 2772, 2730, 2806, 3941, 2795, 1540, 3297,  216,
       4391, 3507,  957,  981, 3471,  292, 3492, 4385, 2806, 3763,  286,
       3284, 3002, 2025, 3307,  850,  322, 1404, 3822, 2815, 3541, 4581,
       1285, 3822, 2815, 1734,  292, 3878, 3008, 3149, 1345,  322, 4355,
       2914, 3881, 4433, 3212, 2837, 2806, 2806,  249,   27,  940, 1308,
       4355,   99, 3284, 1482, 2490, 4581, 3192, 1418,  377,  313, 3161,
       3065, 4298, 2986, 2469,  946, 4382, 2730,  956, 3782,  313,  292,
       1054, 1060, 4546, 3833,  322, 1279, 2806, 3492, 3256,  322,  825,
       2987, 2986, 3603, 3522, 7196, 1060, 1824, 1412, 3941, 3492, 4581,
       2730, 3291, 3627, 2535,  322, 2806, 3070,   27, 3941, 2765, 2824,
        313,  944, 2733,  779,  977, 3665, 1377,  9

In [91]:
result_f64[1].shape

(71553, 215)

In [50]:
for i in range(len(result[0])):
    break

71553

In [None]:
[0][]

In [46]:
len(abundances)

13855

In [30]:
taxon_ids = np.array([np.argmax(r, axis=1) for r in result]).T
taxon_ids, taxon_ids.shape

(array([[   1,   18,   40,   80,  126,  284],
        [   1,   27,   77,  158,  302, 1266],
        [   1,   35,  100,  236,  461, 1490],
        ...,
        [   1,   83,  244,  554,  958, 3066],
        [   1,   19,   64,  140,  234,  777],
        [   1,   72,  201,  421,  791, 2772]]),
 (71553, 6))

In [22]:
result[0]

NameError: name 'result' is not defined

In [32]:
entry_factory = sample.SampleMappingEntryFactory(file.stem, silva_index)


In [33]:
for label, abundance in abundances.items():
    fasta_id = rng.choice(tax_to_fasta_id[label])
    entry_factory.add_fasta_id(fasta_id, abundance)
    print(label, abundance)
# tax_model.hierarchy.detokenize(taxon_ids)

k__Bacteria; p__Acidobacteriota; c__Blastocatellia; o__Blastocatellales; f__Blastocatellaceae; g__JGI_0001001-H03; s__ 393


KeyError: 'k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Saprospiraceae; g__Lentimicrobium; s__'

In [38]:
tax_to_fasta_id["k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Saprospiraceae; g__; s__"]

KeyError: 'k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Saprospiraceae; g__; s__'

In [36]:
tax_model.hierarchy.reduce_taxonomy("k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Saprospiraceae; g__Lentimicrobium; s__")

'k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Saprospiraceae; g__; s__'

In [43]:
[len(t.children) for t in tax_model.hierarchy.id_to_taxon_map[0].values()]

[15, 89, 110]

In [21]:
taxonomy.join_taxonomy(tax_model.hierarchy.detokenize_taxons(np.array([0, 1, 1, 1, 1, 1])))

'k__Archaea; p__Altiarchaeota; c__Aenigmarchaeota; o__Aenigmarchaeota; f__Aenigmarchaeota; g__Candidatus_Aenigmarchaeum; s__'

In [40]:
np.cumsum([len(t.children) for t in tax_model.hierarchy.id_to_taxon_map[0].values()])

array([ 15, 104, 214])

In [42]:
result[1].shape

(71553, 215)

In [33]:
np.argmax(result[0], axis=1)

array([1, 1, 1, ..., 1, 1, 1])

In [27]:
# Reduce the tax label to a valid hierarchy
tax_model.hierarchy.reduce_taxonomy("k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__Sphingobacteriaceae; g__Sphingobacterium; s__")

'k__Bacteria; p__Bacteroidota; c__Bacteroidia; o__Chitinophagales; f__; g__; s__'

## Hopland Sequences

In [31]:
hopland_fastqs = [
    p for p in Path("/home/shared/hopland/fastq").iterdir()
    if not p.name.startswith("Blank")]
len(hopland_fastqs)

256