In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

In [2]:
from dataclasses import replace
from dnadb.datasets import Silva
from dnadb import dna, fasta, fastq, sample, taxonomy
import numpy as np
from pathlib import Path
import tf_utilities as tfu
import wandb

from deepdna.nn.models import load_model
from deepdna.nn.models.taxonomy import \
    AbstractTaxonomyClassificationModel, \
    NaiveTaxonomyClassificationModel, \
    BertaxTaxonomyClassificationModel, \
    TopDownTaxonomyClassificationModel

In [3]:
tfu.devices.select_gpu(0)

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

In [4]:
root = Path("/home/dwl2x/work/Datasets/Synthetic")

## SILVA Sequences and Taxonomies

In [5]:
silva = Silva()
# with fasta.FastaDbFactory(root / f"Silva_{silva.version}.fasta.db") as db:
#     db.write_entries(silva.sequences())
# with taxonomy.TaxonomyDbFactory(root / f"Silva_{silva.version}.tax.tsv.db") as db:
#     for entry in silva.taxonomies():
#         db.write_entry(replace(entry, label=entry.label.replace("uncultured", "")))

In [6]:
# with open(root / f"Silva_{silva.version}.fasta", 'w') as f:
#     for entry in silva.sequences():
#         fasta.write(f, [entry])

In [7]:
# with open(root / f"Silva_{silva.version}.tax.tsv", 'w') as f:
#     for entry in silva.taxonomies():
#         taxonomy.write(f, [replace(entry, label=entry.label.replace("uncultured", ""))])

In [8]:
# with fasta.FastaDb(root / f"Silva_{silva.version}.fasta.db") as fasta_db:
#     with fasta.FastaIndexDbFactory(root / f"Silva_{silva.version}.fasta.index.db") as index:
#         index.write_entries(fasta_db)

In [9]:
silva_fasta = fasta.FastaDb(root / f"Silva_{silva.version}.fasta.db")
silva_index = fasta.FastaIndexDb(root / f"Silva_{silva.version}.fasta.index.db")
silva_tax = taxonomy.TaxonomyDb(root / f"Silva_{silva.version}.tax.tsv.db")

In [10]:
tax_to_fasta_ids = {}
for label in silva_tax:
    tax_to_fasta_ids[label] = list(silva_tax.fasta_ids_with_label(silva_tax.label_to_index(label)))

---

## Creating Synthetic Datasets

In [34]:
uniform_distribution = True

In [11]:
rng = np.random.default_rng()

In [12]:
class FastaEntryWriter:
    def __init__(self, sequence_length: int, prefix: str = ""):
        if len(prefix) > 0:
            prefix += "."
        self.prefix = prefix
        self.sequence_length = sequence_length
        self.count = 0

    def __call__(self, entry: fasta.FastaEntry):
        offset = rng.integers(len(entry.sequence) - self.sequence_length)
        entry = replace(
            entry,
            identifier=f"{self.prefix}{self.count:08d}",
            sequence=entry.sequence[offset:offset + self.sequence_length],
            extra="")
        self.count += 1
        return entry

In [13]:
def create_test_samples(
    samples,
    output_path: str|Path,
    sequence_length: int,
    subsample_size: int,
    n_subsamples: int,
):
    output_path = Path(output_path)
    output_path.mkdir(exist_ok=True)
    for index, s in enumerate(samples):
        print(f"\r{index+1}/{len(samples)}", end="")
        base_name = s.name.replace(".fastq.gz", "").replace(".fastq", "").replace(" ", "_")
        for i in range(n_subsamples):
            name = f"{base_name}.{i+1:03d}"
            fasta_writer = FastaEntryWriter(sequence_length, name)
            out_fasta = open(output_path / (name + ".fasta"), 'w')
            out_tax = open(output_path / (name + ".tax.tsv"), 'w')
            for entry in s.sample(subsample_size):
                label = silva_tax.fasta_id_to_label(entry.identifier)
                entry = fasta_writer(entry)
                fasta.write(out_fasta, (entry,))
                taxonomy.write(out_tax, (taxonomy.TaxonomyEntry(identifier=entry.identifier, label=label),))
            out_fasta.close()
            out_tax.close()

In [14]:
def trim_and_encode(entry: fasta.FastaEntry|fastq.FastqEntry, length: int = 150):
    # Trim the sequence
    offset = rng.integers(0, len(entry) - length + 1)
    sequence = entry.sequence[offset:offset+length]
    assert len(sequence) == length
    return dna.encode_sequence(sequence)

In [15]:
def chunk(iterable, chunk_size):
    result = []
    for item in iterable:
        result.append(item)
        if len(result) == chunk_size:
            yield result
            result = []
    if len(result) > 0:
        yield result

In [16]:
from dataclasses import replace
import re

def clean_entry(entry: fasta.FastaEntry):
    """
    Clean the sequence in the given entry by removing all non-nucleotide-base characters.
    """
    sequence = re.sub(r"[^" + dna.ALL_BASES + r"]", "", entry.sequence)
    return replace(entry, sequence=sequence)

### Classification Model

In [17]:
api = wandb.Api()

In [18]:
model_type = "Topdown"

if model_type == "Naive":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy-naive/dnabert-taxonomy-naive-64d-150l:latest").download()
    model = load_model(path, NaiveTaxonomyClassificationModel)
elif model_type == "Bertax":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy/dnabert-taxonomy-bertax-64d-150l:latest").download()
    model = load_model(path, BertaxTaxonomyClassificationModel)
elif model_type == "Topdown":
    path = api.artifact("sirdavidludwig/dnabert-taxonomy/dnabert-taxonomy-topdown-64d-150l:latest").download()
    model = load_model(path, TopDownTaxonomyClassificationModel)
model

[34m[1mwandb[0m:   4 of 4 files downloaded.  


<deepdna.nn.models.taxonomy.TopDownTaxonomyClassificationModel at 0x7f1f5c2d6c80>

In [19]:
batch_size = 256
kmer = model.base.base.kmer
subsample_size = 1000
n_subsamples = 10
sequence_length = 150

---

### Nachusa (Soil)

In [94]:
dataset_name = "Nachusa"

In [95]:
dataset_root = root / f"{dataset_name}/{model_type}"
dataset_root.mkdir(exist_ok=True, parents=True)

In [96]:
sample_folders = [
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2015-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2016-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2017-soil16S-sequences",
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2018-soil16S-sequences",
    # "/home/shared/prism-data/Nachusa Sequences/nachusa-2019-soil16S-sequences", # missing
    "/home/shared/prism-data/Nachusa Sequences/nachusa-2020-soil16S-sequences",
]

In [None]:
fasta_mapping_factory = sample.SampleMappingDbFactory(dataset_root / f"{dataset_name}.fasta.mapping.db")
for folder in sample_folders:
    for f in Path(folder).iterdir():
        print(f"{f.name}\r\n", end="")
        name = f.name
        entries = fastq.entries(f)
        mapping_entry = sample.SampleMappingEntryFactory(name, silva_index)
        for sequences in chunk(map(trim_and_encode, entries), batch_size):
            sequences = dna.encode_kmers(np.array(sequences), kmer)
            for label in model.classify(sequences, batch_size=batch_size, verbose=0):
                fasta_id = rng.choice(tax_to_fasta_ids[label])
                mapping_entry.add_fasta_id(fasta_id)
        fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

#### Test Subsamples

In [97]:
samples = sample.load_multiplexed_fasta(
    silva_fasta,
    dataset_root / f"{dataset_name}.fasta.mapping.db",
    silva_index
)
len(samples)

210

In [98]:
create_test_samples(
    samples,
    dataset_root / "test_uniform",
    sequence_length,
    subsample_size,
    n_subsamples,
    uniform
)

210/210

---

### Hopland (Soil)

In [20]:
dataset_name = "Hopland"

In [21]:
dataset_root = root / f"{dataset_name}/{model_type}"
dataset_root.mkdir(exist_ok=True, parents=True)

In [22]:
sample_folder = "/home/shared/hopland/fastq"

In [23]:
fasta_mapping_factory = sample.SampleMappingDbFactory(dataset_root / f"{dataset_name}.fasta.mapping.db")
for f in Path(sample_folder).iterdir():
    if f.name.startswith("Blank"):
        continue
    print(f"{f.name}\r\n", end="")
    name = f.name
    entries = fastq.entries(f)
    mapping_entry = sample.SampleMappingEntryFactory(name, silva_index)
    for sequences in chunk(map(trim_and_encode, entries), batch_size):
        sequences = dna.encode_kmers(np.array(sequences), kmer)
        for label in model.classify(sequences, batch_size=batch_size, verbose=0):
            fasta_id = rng.choice(tax_to_fasta_ids[label])
            mapping_entry.add_fasta_id(fasta_id)
    fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

Ur32-B-16S_S260_L001_R2_001.fastq
Ur48-R-16S_S194_L001_R2_001.fastq
Ur22-R-16S_S175_L001_R1_001.fastq
Ur43-B-16S_S220_L001_R2_001.fastq
Ur15-R-16S_S182_L001_R1_001.fastq
Ur35-B-16S_S219_L001_R2_001.fastq
Ur14-R-16S_S174_L001_R2_001.fastq
Ur25-R-16S_S136_L001_R2_001.fastq
Ur40-R-16S_S193_L001_R2_001.fastq
Ur1-R-16S_S133_L001_R1_001.fastq
Ur35-B-16S_S219_L001_R1_001.fastq
Ur39-B-16S_S253_L001_R2_001.fastq
Ur24-B-16S_S259_L001_R2_001.fastq
Ur8-B-16S_S257_L001_R2_001.fastq
Ur62-B-16S_S248_L001_R2_001.fastq
Ur57-R-16S_S140_L001_R1_001.fastq
Ur55-R-16S_S187_L001_R2_001.fastq
Ur38-B-16S_S245_L001_R2_001.fastq
Ur24-R-16S_S191_L001_R2_001.fastq
Ur4-R-16S_S157_L001_R1_001.fastq
Ur38-B-16S_S245_L001_R1_001.fastq
Ur3-B-16S_S215_L001_R1_001.fastq
Ur13-R-16S_S166_L001_R1_001.fastq
Ur12-B-16S_S225_L001_R1_001.fastq
Ur26-R-16S_S144_L001_R1_001.fastq
Ur21-R-16S_S167_L001_R2_001.fastq
Ur31-R-16S_S184_L001_R1_001.fastq
Ur49-R-16S_S139_L001_R1_001.fastq
Ur19-B-16S_S217_L001_R1_001.fastq
Ur61-R-16S_S172_L0


KeyboardInterrupt



#### Test Subsamples

In [25]:
samples = sample.load_multiplexed_fasta(
    silva_fasta,
    dataset_root / f"{dataset_name}.fasta.mapping.db",
    silva_index
)
len(samples)

128

In [31]:
s = [s for s in samples if s.name == "Ur10-B-16S_S207_L001_R1_001.fastq"]

In [None]:
Ur10-B-16S_S207_L001_R1_001.007.tax.tsv

In [33]:
create_test_samples(
    s,
    dataset_root / "test",
    sequence_length,
    subsample_size,
    n_subsamples
)

1/1

---

### Wetland (Soil)

In [56]:
dataset_name = "Wetland"

In [57]:
dataset_root = root / f"{dataset_name}/{model_type}"
dataset_root.mkdir(exist_ok=True, parents=True)

In [22]:
wetland_fasta_file = "/home/shared/walker_lab/reed/P_A_201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.agc.0.03.pick.0.03.abund.0.03.pick.fasta.new.fasta"
otu_list_path = "/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.list"
otu_shared_path = "/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.shared"

In [23]:
# FASTA ID to OTU Index
with open(otu_list_path) as f:
    f.readline()
    line = f.readline().strip().split('\t')[2:]
    fasta_id_to_index = {fasta_id: i for i, fasta_id in enumerate(line)}

In [24]:
with taxonomy.TaxonomyDbFactory(f"/tmp/wetland_{model_type}.tax.db") as factory:
    progress = 0
    for entries in chunk(map(clean_entry, fasta.entries(wetland_fasta_file)), batch_size):
        progress += len(entries)
        print(f"\r{progress}", end="")
        sequences = np.array(list(map(trim_and_encode, entries)))
        sequences = dna.encode_kmers(sequences, kmer)
        labels = model.predictions_to_labels(model(sequences))
        for entry, label in zip(entries, labels):
            factory.write_entry(taxonomy.TaxonomyEntry(str(fasta_id_to_index[entry.identifier]), label))
tax_db = taxonomy.TaxonomyDb(f"/tmp/wetland_{model_type}.tax.db")

3351738

In [25]:
fasta_mapping_factory = sample.SampleMappingDbFactory(dataset_root / f"{dataset_name}.fasta.mapping.db")
with open(otu_shared_path) as f:
    f.readline()
    for index, line in enumerate(f):
        _, name, _, *abundances = line.strip().split('\t')
        mapping_entry = sample.SampleMappingEntryFactory(name, silva_index)
        print(f"\r{index+1}: {name:<50}", end="")
        for i, abundance in enumerate(abundances):
            if abundance == '0' or not tax_db.contains_fasta_id(str(i)):
                continue
            label = tax_db.fasta_id_to_label(str(i))
            for fasta_id in rng.choice(tax_to_fasta_ids[label], int(abundance), replace=True):
                mapping_entry.add_fasta_id(fasta_id)
        fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

768: wetcc4                                            

#### Test Subsamples

In [58]:
samples = sample.load_multiplexed_fasta(
    silva_fasta,
    dataset_root / f"{dataset_name}.fasta.mapping.db",
    silva_index
)
len(samples)

768

In [59]:
create_test_samples(
    samples,
    dataset_root / "test",
    sequence_length,
    subsample_size,
    n_subsamples
)

768/768

---

### Snake Fungal Disease (Gut)

In [60]:
dataset_name = "SFD"

In [61]:
dataset_root = root / f"{dataset_name}/{model_type}"
dataset_root.mkdir(exist_ok=True, parents=True)

In [24]:
sfd_fasta_file = "/home/shared/walker_lab/alex/P_A_221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.pick.opti_mcc.0.03.pick.0.03.abund.0.03.pick.fasta"
otu_list_path = "/home/shared/walker_lab/digitalocean/Alex_SFD/shared_list/221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.asv.list"
otu_shared_path = "/home/shared/walker_lab/digitalocean/Alex_SFD/shared_list/221205_cmfp.trim.contigs.pcr.good.unique.good.filter.unique.precluster.denovo.vsearch.asv.shared"

In [25]:
# FASTA ID to OTU Index
with open(otu_list_path) as f:
    f.readline()
    line = f.readline().strip().split('\t')[2:]
    fasta_id_to_index = {fasta_id: i for i, fasta_id in enumerate(line)}

In [26]:
with taxonomy.TaxonomyDbFactory(f"/tmp/sfd_{model_type}.tax.db") as factory:
    for entries in chunk(map(clean_entry, fasta.entries(sfd_fasta_file)), batch_size):
        sequences = np.array(list(map(trim_and_encode, entries)))
        sequences = dna.encode_kmers(sequences, kmer)
        labels = model.predictions_to_labels(model(sequences))
        for entry, label in zip(entries, labels):
            factory.write_entry(taxonomy.TaxonomyEntry(str(fasta_id_to_index[entry.identifier]), label))
tax_db = taxonomy.TaxonomyDb(f"/tmp/sfd_{model_type}.tax.db")

In [27]:
fasta_mapping_factory = sample.SampleMappingDbFactory(dataset_root / f"{dataset_name}.fasta.mapping.db")
with open(otu_shared_path) as f:
    f.readline()
    for index, line in enumerate(f):
        _, name, _, *abundances = line.strip().split('\t')
        mapping_entry = sample.SampleMappingEntryFactory(name, silva_index)
        print(f"\r{index+1}: {name:<50}", end="")
        for i, abundance in enumerate(abundances):
            if abundance == '0' or not tax_db.contains_fasta_id(str(i)):
                continue
            label = tax_db.fasta_id_to_label(str(i))
            for fasta_id in rng.choice(tax_to_fasta_ids[label], int(abundance), replace=True):
                mapping_entry.add_fasta_id(fasta_id)
        fasta_mapping_factory.write_entry(mapping_entry.build())
fasta_mapping_factory.close()

887: TST1                                              

#### Test Subsamples

In [62]:
samples = sample.load_multiplexed_fasta(
    silva_fasta,
    dataset_root / f"{dataset_name}.fasta.mapping.db",
    silva_index
)
len(samples)

887

In [63]:
create_test_samples(
    samples,
    dataset_root / "test",
    sequence_length,
    subsample_size,
    n_subsamples
)

887/887