# Taxonomic Classification Demo
## Featuring: DNABERT and SetBERT
Source repository: https://github.com/Phillips-Lab-MTSU/DL-Workshop.git

### Preliminary Package Installation

You will probably need to restart the notebook kernel after running the next cell - doesn't hurt to run it again after you have done this at least one time for this session.

In [None]:
! git clone -b taxonomy-demo https://github.com/jlphillipsphd/deep-dna.git
! cd deep-dna && git checkout taxonomy-demo && pip install -e .

### Let's grab some FASTQ soil samples...

In [None]:
! curl --remote-name https://data.phillips-lab.org/fastq/Wesley010-FC-042318_S10_L001_R1_001.fastq
! curl --remote-name https://data.phillips-lab.org/fastq/Wesley011-FC-071818_S11_L001_R1_001.fastq
! curl --remote-name https://data.phillips-lab.org/fastq/Wesley012-FC-100818_S12_L001_R1_001.fastq

Take a look at one of the files here...

In [None]:
! head Wesley010-FC-042318_S10_L001_R1_001.fastq

### Prep Tools and Data

In [None]:
from dnadb import fasta, taxonomy
import deepctx as dcs
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import wandb

from deepdna.nn import data_generators as dg
from deepdna.nn.models import load_model

In [None]:
# Import forward reads from FASTQs
!dnadb fasta import-multiplexed \
    --output-sequences-path ./nachusa.fasta.db \
    --output-mapping-path ./nachusa.fasta.mapping.db \
    Wesley*.fastq

In [None]:
# Load samples
nachusa_sequences = fasta.FastaDb("./nachusa.fasta.db")
nachusa_samples = nachusa_sequences.mappings("./nachusa.fasta.mapping.db")

In [None]:
print(f"Total sequences: {len(nachusa_sequences):,}")

In [None]:
# Number of samples loaded.
len(nachusa_samples)

### Load the DNABERT and SetBERT Models

In [None]:
wandb.login()

In [None]:
# Only if needed...
# wandb.login(relogin=True)

In [None]:
api = wandb.Api()

In [None]:
path = api.artifact("sirdavidludwig/dnabert-taxonomy/dnabert-taxonomy-topdown-64d-150l:v0").download()
dnabert_tax_model = load_model(path)

In [None]:
path = api.artifact("sirdavidludwig/model-registry/setbert-taxonomy-topdown-64d-150l:v0").download()
setbert_tax_model = load_model(path)
setbert_tax_model.base.chunk_size = 256 # sequence encoding chunk size

### Prep Data for Model Analysis

In [None]:
nachusa = dg.BatchGenerator(batch_size=3, batches_per_epoch=1, pipeline=[
    dg.random_samples(nachusa_samples), # The samples to choose from (uniformly)
    dg.random_sequence_entries(1000),   # Sample random FASTA entries from chosen samples
    dg.sequences(150),                  # Get the sequences from the FASTA entries and trim to length
    dg.augment_ambiguous_bases(),       # Augment any ambiguous bases present in the sequence
    dg.encode_sequences(),              # Encode to integers,
    dg.encode_kmers(3),                 # Encode kmer integers
    lambda samples, encoded_kmer_sequences: (samples, encoded_kmer_sequences)
])

In [None]:
samples, encoded_kmer_sequences = nachusa[0]

In [None]:
encoded_kmer_sequences.shape

In [None]:
len(samples)

In [None]:
# Print the corresponding samples
for sample in samples:
    print(sample.name)

### DNABERT Predictions (Single-Sequence)

In [None]:
encoded_kmer_sequences.reshape((-1,encoded_kmer_sequences.shape[-1])).shape

In [None]:
dnabert_predictions = dnabert_tax_model.classify(encoded_kmer_sequences.reshape((-1,encoded_kmer_sequences.shape[-1])))

In [None]:
dnabert_predictions[:10]

In [None]:
setbert_predictions = setbert_tax_model.classify(encoded_kmer_sequences)

In [None]:
setbert_predictions[0,:10]

### DNABERT (Top-Down Hierarchy) Architecture

In [None]:
tf.keras.utils.plot_model(dnabert_tax_model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(dnabert_tax_model.model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(dnabert_tax_model.model.layers[1].model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(dnabert_tax_model.model.layers[1].model.layers[2].model,expand_nested=True,show_shapes=True)

### SetBERT Architecture

In [None]:
tf.keras.utils.plot_model(setbert_tax_model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(setbert_tax_model.model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(setbert_tax_model.model.layers[1].model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(setbert_tax_model.model.layers[1].model.layers[2].model,expand_nested=True,show_shapes=True)

In [None]:
tf.keras.utils.plot_model(setbert_tax_model.model.layers[1].model.layers[2].model.layers[-1].model,expand_nested=True,show_shapes=True)