In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

In [44]:
from dataclasses import replace
from dnadb.datasets import Silva
from dnadb import dna, fasta, fastq, sample, taxonomy
import numpy as np
from pathlib import Path
import wandb

from deepdna.nn.models import load_model
from deepdna.nn.models.taxonomy import \
    AbstractTaxonomyClassificationModel, \
    NaiveTaxonomyClassificationModel, \
    BertaxTaxonomyClassificationModel, \
    TopDownTaxonomyClassificationModel

In [14]:
root = Path("/home/dwl2x/work/Datasets/Synthetic")

## SILVA Sequences and Taxonomies

In [15]:
# silva = Silva()
# with fasta.FastaDbFactory(root / f"Silva_{silva.version}.fasta.db") as db:
#     db.write_entries(silva.sequences())
# with taxonomy.TaxonomyDbFactory(root / f"Silva_{silva.version}.tax.tsv.db") as db:
#     for entry in silva.taxonomies():
#         db.write_entry(replace(entry, label=entry.label.replace("uncultured", "")))

In [16]:
# with open(root / f"Silva_{silva.version}.fasta", 'w') as f:
#     for entry in silva.sequences():
#         fasta.write(f, [entry])

In [17]:
# with open(root / f"Silva_{silva.version}.tax.tsv", 'w') as f:
#     for entry in silva.taxonomies():
#         taxonomy.write(f, [replace(entry, label=entry.label.replace("uncultured", ""))])

In [19]:
# with fasta.FastaDb(root / f"Silva_{silva.version}.fasta.db") as fasta_db:
#     with fasta.FastaIndexDbFactory(root / f"Silva_{silva.version}.fasta.index.db") as index:
#         index.write_entries(fasta_db)

In [57]:
silva_fasta = fasta.FastaDb(root / f"Silva_{silva.version}.fasta.db")
silva_index = fasta.FastaIndexDb(root / f"Silva_{silva.version}.fasta.index.db")
tax_db = taxonomy.TaxonomyDb(root / f"Silva_{silva.version}.tax.tsv.db")

In [58]:
tax_to_fasta_ids = {}
for label in tax_db:
    tax_to_fasta_ids[label] = list(tax_db.fasta_ids_with_label(tax_db.label_to_index(label)))

---

## Creating Synthetic Datasets

In [39]:
rng = np.random.default_rng()

In [40]:
def trim_and_encode(entry: fasta.FastaEntry|fastq.FastqEntry, length: int = 150):
    # Trim the sequence
    offset = rng.integers(0, len(entry) - length + 1)
    sequence = entry.sequence[offset:offset+length]
    assert len(sequence) == length
    return dna.encode_sequence(sequence)

In [42]:
def chunk(iterable, chunk_size):
    result = []
    for item in iterable:
        result.append(item)
        if len(result) == chunk_size:
            yield result
            result = []
    if len(result) > 0:
        yield result

In [77]:
from dataclasses import replace
import re

def clean_entry(entry: fasta.FastaEntry):
    """
    Clean the sequence in the given entry by removing all non-nucleotide-base characters.
    """
    sequence = re.sub(r"[^" + dna.ALL_BASES + r"]", "", entry.sequence)
    return replace(entry, sequence=sequence)

### Classification Model

In [26]:
api = wandb.Api()

In [27]:
model = "topdown"

if model == "naive":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy-naive/dnabert-taxonomy-naive-64d-150l:latest").download()
    model = load_model(path, NaiveTaxonomyClassificationModel)
elif model == "bertax":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy/dnabert-taxonomy-bertax-64d-150l:latest").download()
    model = load_model(path, BertaxTaxonomyClassificationModel)
elif model == "topdown":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy/dnabert-taxonomy-topdown-64d-150l:latest").download()
    model = load_model(path, TopDownTaxonomyClassificationModel)
model

[34m[1mwandb[0m:   4 of 4 files downloaded.  


<deepdna.nn.models.taxonomy.TopDownTaxonomyClassificationModel at 0x7fb5585d4b20>

In [52]:
batch_size = 256
kmer = model.base.base.kmer

---

### Nachusa (Soil)

In [66]:
sample_folders = [
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2015-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2016-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2017-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2018-soil16S-sequences",
    # "/home/shared/prism-data/Nachusa Sequences/nachusa-2019-soil16S-sequences", # missing
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2020-soil16S-sequences",
]

In [68]:
fasta_mapping_factory = sample.SampleMappingDbFactory(root / "Nachusa.fasta.mapping.db")
for folder in sample_folders:
    for f in Path(folder).iterdir():
        print(f"{f.name}\r\n", end="")
        name = f.name
        entries = fastq.entries(f)
        mapping_entry = sample.SampleMappingEntryFactory(name, fasta_index)
        for sequences in chunk(map(trim_and_encode, entries), batch_size):
            sequences = dna.encode_kmers(np.array(sequences), kmer)
            for label in model.classify(sequences, batch_size=batch_size, verbose=0):
                fasta_id = rng.choice(tax_to_fasta_ids[label])
                mapping_entry.add_fasta_id(fasta_id)
        fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

WS-AG-May2015_S65_L001_R1_001.fastq


KeyboardInterrupt: 

---

### Hopland (Soil)

In [75]:
sample_folder = "/home/shared/hopland/fastq"

In [74]:
fasta_mapping_factory = sample.SampleMappingDbFactory(root / "Hopland.fasta.mapping.db")
for f in Path(sample_folder).iterdir():
    if f.name.startswith("Blank"):
        continue
    print(f"{f.name}\r\n", end="")
    name = f.name
    entries = fastq.entries(f)
    mapping_entry = sample.SampleMappingEntryFactory(name, fasta_index)
    for sequences in chunk(map(trim_and_encode, entries), batch_size):
        sequences = dna.encode_kmers(np.array(sequences), kmer)
        for label in model.classify(sequences, batch_size=batch_size, verbose=0):
            fasta_id = rng.choice(tax_to_fasta_ids[label])
            mapping_entry.add_fasta_id(fasta_id)
    fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

Ur32-B-16S_S260_L001_R2_001.fastq


KeyboardInterrupt: 

### Wetland (Soil)

In [149]:
wetland_fasta_file = "/home/shared/walker_lab/reed/P_A_201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.agc.0.03.pick.0.03.abund.0.03.pick.fasta.new.fasta"
otu_list_path = "/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.list"
otu_shared_path = "/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.shared"

In [150]:
# ASV Abundance Count
with open(otu_shared_path) as f:
    header = f.readline().strip().split('\t')
    rows = [line.strip().split('\t') for line in f]
len(rows) # number of samples

768

In [142]:
otu_to_index = {header[i]: i for i in range(3, len(header))}

In [143]:
# FASTA ID to OTU Index
with open(otu_list_path) as f:
    values = f.readline().strip().split('\t')
    keys = f.readline().strip().split('\t')
otu_index = dict(zip(keys[2:], (otu_to_index[v] for v in values[2:])))

In [144]:
next(iter(otu_index.items()))

('M03064_50_000000000-CY83Y_1_1109_29403_15697', 3)

In [151]:
fasta_mapping_factory = sample.SampleMappingDbFactory(root / "Wetland.fasta.mapping.db")
mapping_entries = [sample.SampleMappingEntryFactory(row[1], fasta_index) for row in rows]
progress = 0
for entries in chunk(map(clean_entry, fasta.entries(wetland_fasta_file)), batch_size):
    progress += len(entries)
    print(f"\r{progress}", end="")
    sequences = np.array(list(map(trim_and_encode, entries)))
    sequences = dna.encode_kmers(sequences, kmer)
    labels = model.predictions_to_labels(model(sequences))
    for entry, label in zip(entries, labels):
        fasta_id = rng.choice(tax_to_fasta_ids[label])
        for row, mapping_entry in zip(rows, mapping_entries):
            abundance = int(row[otu_index[entry.identifier]])
            if abundance > 0:
                mapping_entry.add_fasta_id(fasta_id, abundance=abundance)
fasta_mapping_factory.write_entries(mapping_entries)
fasta_mapping_factory.close()

9728

KeyboardInterrupt: 

### Snake Fungal Disease (Gut)

In [76]:
sfd_fasta_file = "/home/shared/walker_lab/alex/P_A_221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.pick.opti_mcc.0.03.pick.0.03.abund.0.03.pick.fasta"
otu_list_path = "/home/shared/walker_lab/digitalocean/Alex_SFD/shared_list/221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.asv.list"
otu_shared_path = "/home/shared/walker_lab/digitalocean/Alex_SFD/shared_list/221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.asv.shared"

In [107]:
# ASV Abundance Count
with open(otu_shared_path) as f:
    header = f.readline().strip().split('\t')
    rows = [line.strip().split('\t') for line in f]
len(rows) # number of samples

887

In [111]:
otu_to_index = {header[i]: i for i in range(3, len(header))}

In [112]:
# FASTA ID to OTU Index
with open(otu_list_path) as f:
    values = f.readline().strip().split('\t')
    keys = f.readline().strip().split('\t')
otu_index = dict(zip(keys[2:], (otu_to_index[v] for v in values[2:])))

In [114]:
next(iter(otu_index.items()))

('M03064_63_000000000-JHTY5_1_1110_16981_26249', 3)

In [139]:
fasta_mapping_factory = sample.SampleMappingDbFactory(root / "SFD.fasta.mapping.db")
mapping_entries = [sample.SampleMappingEntryFactory(row[1], fasta_index) for row in rows]
progress = 0
for entries in chunk(map(clean_entry, fasta.entries(sfd_fasta_file)), batch_size):
    progress += len(entries)
    print(f"\r{progress}", end="")
    sequences = np.array(list(map(trim_and_encode, entries)))
    sequences = dna.encode_kmers(sequences, kmer)
    labels = model.predictions_to_labels(model(sequences))
    for entry, label in zip(entries, labels):
        fasta_id = rng.choice(tax_to_fasta_ids[label])
        for row, mapping_entry in zip(rows, mapping_entries):
            abundance = int(row[otu_index[entry.identifier]])
            if abundance > 0:
                mapping_entry.add_fasta_id(fasta_id, abundance=abundance)
fasta_mapping_factory.write_entries(mapping_entries)
fasta_mapping_factory.close()

2304

KeyboardInterrupt: 

---

## Generating Synthetic Subsamples for Testing

In [33]:
subsample_size = 1000
n_subsamples = 10
sequence_length = 150

In [34]:
rng = np.random.default_rng()

In [119]:
class FastaEntryWriter:
    def __init__(self, prefix: str = ""):
        if len(prefix) > 0:
            prefix += "."
        self.prefix = prefix
        self.count = 0

    def __call__(self, entry: fasta.FastaEntry):
        offset = rng.integers(len(entry.sequence) - sequence_length)
        entry = replace(
            entry,
            identifier=f"{self.prefix}{self.count:08d}",
            sequence=entry.sequence[offset:offset + sequence_length],
            extra="")
        self.count += 1
        return entry

### Silva

In [131]:
sequences_db = fasta.FastaDb(root / f"Silva/Silva_{silva.version}.fasta.db")

In [132]:
taxonomy_db = taxonomy.TaxonomyDb(root / f"Silva/Silva_{silva.version}.tax.tsv.db")

### Nachusa

In [133]:
samples = sample.load_multiplexed_fasta(
    sequences_db,
    root / f"synthetic/synthetic.fasta.mapping.db",
    root / f"synthetic/synthetic.fasta.index.db"
)

In [134]:
len(samples)

210

In [153]:
for index, s in enumerate(samples):
    print(f"\r{index+1}/{len(samples)}", end="")
    base_name = s.name.replace(".fastq.gz", "").replace(".fastq", "").replace(" ", "_")
    for i in range(n_subsamples):
        name = f"{base_name}.{i+1:03d}"
        fasta_writer = FastaEntryWriter(name)
        out_fasta = open(root / "Synthetic/Nachusa/test" / (name + ".fasta"), 'w')
        out_tax = open(root / "Synthetic/Nachusa/test" / (name + ".tax.tsv"), 'w')
        for entry in s.sample(subsample_size):
            label = taxonomy_db.fasta_id_to_label(entry.identifier)
            entry = fasta_writer(entry)
            fasta.write(out_fasta, (entry,))
            taxonomy.write(out_tax, (taxonomy.TaxonomyEntry(identifier=entry.identifier, label=label),))
        out_fasta.close()
        out_tax.close()

210/210

In [70]:
taxonomy_db.fasta_id_to_label("HQ119724.1.1499")

'k__Bacteria; p__Gemmatimonadota; c__Gemmatimonadetes; o__Gemmatimonadales; f__Gemmatimonadaceae; g__'