# SetBERT Demo

This notebook serves to demonstrate how to pull a pretrained SetBERT model from W&B and fine tune it to a specific down-stream task.

## Dependencies

In [None]:
!pip3 install dnadb tqdm tf-settransformer tf-utilities

## Configuration

In [1]:
# Disable Tensorflow info logging
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'

# Add the deepdna source code to the path
import sys
sys.path.append("../src")

In [2]:
from dnadb import dna, fasta
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
import time
from tqdm.auto import tqdm
import wandb

## Configuration

In [3]:
tfu.devices.select_gpu(0)

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

In [4]:
# User's home directory
HOME_PATH = Path('~').expanduser()

## Building and Loading the Dataset

In [5]:
# Our Dataset class with make it easier to load the FASTA/OTU maps
from deepdna.data.dataset import Dataset
from deepdna.data.otu import OtuSampleDb

In [6]:
# The path where to store the processed dataset
DATASET_PATH = HOME_PATH / "Datasets/Walker_Reed"
DATASET_PATH.mkdir(parents=False, exist_ok=True)

### Converting the FASTA files

We first need to construct the dataset by converting the raw data files into the appropriate formats compatible with our models. This usually involves converting files to their .db equivalent (i.e. sequences.fasta -> sequences.fasta.db, etc.).

In [7]:
# The path to the FASTA file to use
# Note: the sequences in the FASTA must be clean and only 
#       contain valid base characters
FASTA_PATH = Path("/home/shared/walker_lab/reed/reed_clean.fasta")

Now we convert the FASTA file to a fasta.db database. This script will create a training and testing split of the data.

Note: the `prepare_local_dataset.py` script can be run with the -h flag to display argument information

In [None]:
%%bash -s "$DATASET_PATH" "$FASTA_PATH"
python3 ../scripts/prepare_local_dataset.py \
    $1 \
    --seed 0 \
    --output-db \
    --num-splits 1 \
    --min-sequence-length 250 \
    --test-split 0.2 \
    $2

In [34]:
%%bash -s "$DATASET_PATH"
ls $1

0


### Building the OTU Map

Next we need to build the OTU map. This process does not currently have a convenient pipeline, so for now this process must be done by hand.

In [35]:
# The paths OTU matrix files
OTU_LIST_PATH = Path("/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.list")
OTU_SHARED_PATH = Path("/home/shared/walker_lab/digitalocean/Reed_NRCS/shared_list/201201_wet_libs1_8.trim.contigs.pcr.good.unique.good.filter.unique.precluster.pick.pick.asv.shared")

In [36]:
with open(OTU_LIST_PATH) as f:
    f.readline() # discard first line
    otu_to_sequence_id_map = f.readline().rstrip().split('\t')[2:]

In [37]:
# Map OTU index to identifier
otu_to_sequence_id_map[0]

'M03064_50_000000000-CY83Y_1_1109_29403_15697'

In [38]:
def create_otu_db(
    otu_db_path: str|Path,
    fasta_db_path: str|Path,
    otu_shared_path: str|Path, 
    otu_to_sequence_id_map: list[str]
):
    from deepdna.data.otu import OtuSampleDbFactory, OtuSampleEntry
    fasta_db = fasta.FastaDb(fasta_db_path)
    otus = np.array([i for i, otu in enumerate(otu_to_sequence_id_map) if otu in fasta_db])
    factory = OtuSampleDbFactory(otu_db_path)
    factory.write_identifiers(
        (i, otu_to_sequence_id_map[otu_index]) 
        for i, otu_index in enumerate(otus))
    with open(otu_shared_path) as f:
        f.readline() # header
        for line in tqdm(f, desc="Writing OTU Entries"):
            parts = line.rstrip().split('\t')
            sample_name = parts[1]
            counts_by_otu = list(map(int, parts[3:]))
            counts_by_otu = [counts_by_otu[i] for i in otus]
            factory.write_entry(OtuSampleEntry.from_counts(sample_name, counts_by_otu))
        factory.close()

In [39]:
# Training OTU map
create_otu_db(
    otu_db_path=(DATASET_PATH / "0/train" / FASTA_PATH.name).with_suffix(".otu.db"),
    fasta_db_path=(DATASET_PATH / "0/train" / FASTA_PATH.name).with_suffix(".fasta.db"),
    otu_shared_path=OTU_SHARED_PATH,
    otu_to_sequence_id_map=otu_to_sequence_id_map
)

Writing OTU Identifiers: 0it [00:00, ?it/s]

Writing OTU Entries: 0it [00:00, ?it/s]

In [40]:
# Testing OTU map
create_otu_db(
    otu_db_path=(DATASET_PATH / "0/test" / FASTA_PATH.name).with_suffix(".otu.db"),
    fasta_db_path=(DATASET_PATH / "0/test" / FASTA_PATH.name).with_suffix(".fasta.db"),
    otu_shared_path=OTU_SHARED_PATH,
    otu_to_sequence_id_map=otu_to_sequence_id_map
)

Writing OTU Identifiers: 0it [00:00, ?it/s]

Writing OTU Entries: 0it [00:00, ?it/s]

### Loading the Dataset

In [8]:
dataset = Dataset(DATASET_PATH / "0")

train_fasta_dbs = list(map(fasta.FastaDb, dataset.fasta_dbs(Dataset.Split.Train)))
train_otu_dbs = list(map(OtuSampleDb, dataset.otu_dbs(Dataset.Split.Train)))

test_fasta_dbs = list(map(fasta.FastaDb, dataset.fasta_dbs(Dataset.Split.Test)))
test_otu_dbs = list(map(OtuSampleDb, dataset.otu_dbs(Dataset.Split.Test)))

In [9]:
train_fasta_dbs[0][0]

>M03064_47_000000000-CRMYC_1_1102_22182_6612 wet335t	Otu0000849	NumRep=1
TACGGAGGTTGCGAGCGTTATCCGGAGTTACTGGGCGTAAAGGGCGGGCAGGCGGAGGCGTAAGATGGGTGTGAAATCTCTCGGCTCAACCGGGAGGGGCCACTCGTGACTGCGCATCTGGAGGGCAGCAGAGGAGCGTGGAATTCCGGGTGGAGTGGTGAAATGCGTAGAGATCCGGAGGAACACCAGAGGCGAAGGCGGCGCTCTGGGCTGCGACTGACGCTGAACCGCGAAAGCCAGGGGAGCAAACGGG

## Loading a Pretrained DNABERT Model

Pretrained models are stored in as Weights & Biases artifacts. Here we pull it using the W&B API.

The available pretrained DNABERT artifacts are available [here](https://wandb.ai/sirdavidludwig/dnabert-pretrain/artifacts/model/dnabert-pretrain-128d-250l-silva).

In [10]:
import wandb

In [11]:
api = wandb.Api()



### Downloading the Artifact

Here we use a pretrained DNABERT model trained using sequences of length 250. The artifact page is available [here](https://wandb.ai/sirdavidludwig/dnabert-pretrain/artifacts/model/dnabert-pretrain-128d-250l-silva/v0/overview).

In [12]:
DNABERT_ARTIFACT = "sirdavidludwig/dnabert-pretrain/dnabert-pretrain-128d-250l-silva:v0"

In [13]:
pretrained_dnabert_path = api.artifact(DNABERT_ARTIFACT).download()

[34m[1mwandb[0m: Downloading large artifact dnabert-pretrain-128d-250l-silva:v0, 61.32MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:0.1


### Loading the Model

In [14]:
from deepdna.nn.models import dnabert, load_model

In [15]:
pretrained_dnabert_model = load_model(
    pretrained_dnabert_path, 
    dnabert.DnaBertPretrainModel # optionally specify the model type
)

In [16]:
pretrained_dnabert_model

<deepdna.nn.models.dnabert.DnaBertPretrainModel at 0x7f2545bc9bd0>

## Loading a Pretrained SetBERT Model

In [17]:
# Import our own DNABERT and SetBERT modules
from deepdna.nn.models import setbert

In [18]:
SETBERT_ARTIFACT = "sirdavidludwig/setbert-pretrain/setbert-pretrain-reed-128d-250l:v0"

In [19]:
pretrained_setbert_path = api.artifact(SETBERT_ARTIFACT).download()

[34m[1mwandb[0m: Downloading large artifact setbert-pretrain-reed-128d-250l:v0, 58.04MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:0.1


In [33]:
pretrained_setbert_model = load_model(
    pretrained_setbert_path, 
    setbert.SetBertPretrainModel # optionally specify the model type
)

In [34]:
pretrained_setbert_model

<deepdna.nn.models.setbert.SetBertPretrainModel at 0x7f22a041fb80>

---

## Constructing a Downstream SetBERT Task

For a simple downstream task, we will look at predicting the sample that the subsample comes from.

### Data Generator

First we need to create the data generator that can provide the model with the appropriate information. The available data generators are located in `src/deepdna/nn/data_generators.py` for reference.

In [35]:
from deepdna.nn.data_generators import OtuSequenceGenerator

class OtuSampleEmbeddingGenerator(OtuSequenceGenerator):
    def __init__(
        self,
        sample: tuple[fasta.FastaDb, OtuSampleDb],
        dnabert_encoder: dnabert.DnaBertEncoderModel,
        use_presence_absence: bool = False,
        batch_size: int = 16,
        batches_per_epoch: int = 100,
        subsample_size: int = 0,
        augment_slide: bool = True,
        augment_ambiguous_bases: bool = True,
        encoder_batch_size: int = 1,
        rng: np.random.Generator = np.random.default_rng()
    ):
        super().__init__(
            samples=[sample],
            sequence_length=dnabert_encoder.base.sequence_length,
            kmer=dnabert_encoder.base.kmer,
            use_presence_absence=use_presence_absence,
            batch_size=batch_size,
            batches_per_epoch=batches_per_epoch,
            subsample_size=subsample_size,
            augment_slide=augment_slide,
            augment_ambiguous_bases=augment_ambiguous_bases,
            rng=rng
        )
        self.encoder_batch_size = encoder_batch_size
        self.encoder = dnabert_encoder

    def generate_batch(self, rng: np.random.Generator):
        (_, sample_indices), entries = self.sampler.random_entries(
            self.batch_size, max(1, self.subsample_size), rng)
        sequences = self.sampler.sequences(entries, rng)
        if self.kmer > 1:
            sequences = dna.encode_kmers(
                sequences, # type: ignore
                self.kmer,
                not self.sampler.augment_ambiguous_bases) # type: ignore
        if self.subsample_size == 0:
            sequences = np.squeeze(sequences, axis=1)
        sequences = self.encoder.encode(sequences) # type: ignore
        return sequences, sample_indices

### Create the Data Generator Instances for Training and Testing

In [36]:
train_data = OtuSampleEmbeddingGenerator(
    (train_fasta_dbs[0], train_otu_dbs[0]),
    dnabert.DnaBertEncoderModel(pretrained_dnabert_model.base, chunk_size=256),
    use_presence_absence=False,
    subsample_size=1000,
    batch_size=8,
    batches_per_epoch=100
)

test_data = OtuSampleEmbeddingGenerator(
    (test_fasta_dbs[0], test_otu_dbs[0]),
    dnabert.DnaBertEncoderModel(pretrained_dnabert_model.base, chunk_size=256),
    use_presence_absence=False,
    subsample_size=1000,
    batch_size=8,
    batches_per_epoch=100
)

### Model Architecture

In [37]:
num_samples_to_predict = len(train_otu_dbs[0])
num_samples_to_predict

768

Since we want to use the set embeddings that come from the SetBERT model, we'll employ a SetBERT encoder model. This model simply extracts and returns the class token representing the embedded set.

In [38]:
setbert_encoder = setbert.SetBertEncoderModel(pretrained_setbert_model.base)

Then we create the full model by encoding the input and passing the set embeddings through a single dense layer with a softmax activation function to compute predict which sample the input originated.

In [42]:
# Slice off the batch dimension using [1:]
input_shape = setbert_encoder.input_shape[1:]

y = x = tf.keras.layers.Input(input_shape)
y = setbert_encoder(y)
y = tf.keras.layers.Dense(len(train_otu_dbs[0]), activation="softmax")(y)
model = tf.keras.Model(x, y)

Compile it using the appropriate loss function/metrics/optimizer

In [48]:
# Compile the model
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    optimizer=tf.keras.optimizers.Adam(1e-4)
)

In [49]:
# Test predictions of untrained model
tf.argmax(model(train_data[0][0]), axis=1)

<tf.Tensor: shape=(8,), dtype=int64, numpy=array([220, 220, 220, 220, 220, 220, 220, 220])>

Train the model

In [50]:
model.fit(train_data, validation_data=test_data, epochs=1)

 16/100 [===>..........................] - ETA: 9:54 - loss: 8.4164 - sparse_categorical_accuracy: 0.0000e+00 

KeyboardInterrupt: 