# SetBERT Demo

This notebook serves to demonstrate how to pull a pretrained SetBERT model from W&B and fine tune it to a specific down-stream task.

## Dependencies

In [None]:
!pip3 install dnadb tqdm tf-settransformer tf-utilities wandb

## Configuration

In [1]:
# Disable Tensorflow info logging
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'

# Add the deepdna source code to the path
import sys
sys.path.append("../src")

In [2]:
from dnadb import dna, fasta
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import tensorflow as tf
import tf_utilities as tfu
import time
from tqdm.auto import tqdm
import wandb

## Configuration

In [3]:
tfu.devices.select_gpu(0)
# tfu.devices.select_cpu()

([PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')],
 [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')])

In [8]:
# User's home directory
HOME_PATH = Path('~').expanduser()

# Path to the dataset
DATASET_PATH = HOME_PATH / "work/Datasets/Walker_Reed_OTU"

## Loading the Dataset

In [9]:
from deepdna.data.dataset import Dataset
from deepdna.data.otu import OtuSampleDb

In [22]:
dataset = Dataset(DATASET_PATH)

fasta_dbs = list(map(fasta.FastaDb, dataset.fasta_dbs(Dataset.Split.Train)))
otu_dbs = list(map(OtuSampleDb, dataset.otu_dbs(Dataset.Split.Train)))

In [23]:
fasta_dbs[0][0]

FastaEntry(identifier='M03064_59_000000000-JCVRG_1_2104_6980_18698', sequence='TACGTATGGGGCGAGCGTTGTTCGGAGTTATTGGGCGTAAAGCGCGTGTAGGCGGTTTTTTAAGTCTGATGTGAAAGCCCCGGGCTCAACCTGGGAAGTGCATTGGATACTGGAAGACTTGAGTACGGGAGAGGGTAGTGGAATTCCTAGTGTAGGAGTGAAATCCGTAGATATTAGGAGGAACACCGGTGGCGAAGGCGGCTGCCTGGACCGATACTGACGCTGAGACGCGAAAGCGTGGGGAGCAAACAGG', extra='wet710t\tOtu0000093\tNumRep=4')

## Loading a Pretrained SetBERT Model

Pretrained models are stored in as Weights & Biases artifacts. Here we pull it using the W&B API.

The available pretrained SetBERT artifacts are available [here](https://wandb.ai/sirdavidludwig/setbert-pretrain/artifacts/model/setbert-pretrain-reed-128d-250l).

In [24]:
import wandb

In [25]:
api = wandb.Api()

### Downloading the Artifact

Here we use a pretrained SetBERT model trained using sequences of length 250. The artifact page is available [here](https://wandb.ai/sirdavidludwig/setbert-pretrain/artifacts/model/setbert-pretrain-reed-128d-250l/v0/overview).

In [26]:
SETBERT_ARTIFACT = "sirdavidludwig/setbert-pretrain/setbert-pretrain-reed-abund-128d-250l:latest"

In [27]:
pretrained_setbert_path = api.artifact(SETBERT_ARTIFACT).download()

[34m[1mwandb[0m: Downloading large artifact setbert-pretrain-reed-abund-128d-250l:latest, 91.00MB. 4 files... 
[34m[1mwandb[0m:   4 of 4 files downloaded.  
Done. 0:0:0.1


### Loading the Model

In [28]:
# Import our own DNABERT and SetBERT modules
from deepdna.nn.models import load_model, setbert

In [29]:
pretrained_setbert_model = load_model(
    pretrained_setbert_path,
    setbert.SetBertPretrainModel # optionally specify the model type
)

In [30]:
pretrained_setbert_model

<deepdna.nn.models.setbert.SetBertPretrainModel at 0x7f35bf3640a0>

The DNABERT is embedded in the SetBERT model and can be extracted like so:

In [31]:
dnabert_encoder = pretrained_setbert_model.base.dnabert_encoder
dnabert_encoder

<deepdna.nn.models.dnabert.DnaBertEncoderModel at 0x7f35bf3642e0>

---

## Constructing a Downstream SetBERT Task

For a simple downstream task, we will look at predicting the sample that the subsample comes from.

### Data Generator

First we need to create the data generator that can provide the model with the appropriate information. The available data generators are located in `src/deepdna/nn/data_generators.py` for reference.

In [32]:
from deepdna.nn.data_generators import OtuSequenceGenerator

class OtuSampleGenerator(OtuSequenceGenerator):
    def __init__(
        self,
        sample: tuple[fasta.FastaDb, OtuSampleDb],
        sequence_length: int,
        kmer: int,
        use_presence_absence: bool = False,
        batch_size: int = 16,
        batches_per_epoch: int = 100,
        subsample_size: int = 0,
        augment_slide: bool = True,
        augment_ambiguous_bases: bool = True,
        encoder_batch_size: int = 1,
        rng: np.random.Generator = np.random.default_rng()
    ):
        super().__init__(
            samples=[sample],
            sequence_length=sequence_length,
            kmer=kmer,
            use_presence_absence=use_presence_absence,
            batch_size=batch_size,
            batches_per_epoch=batches_per_epoch,
            subsample_size=subsample_size,
            augment_slide=augment_slide,
            augment_ambiguous_bases=augment_ambiguous_bases,
            rng=rng
        )
        self.encoder_batch_size = encoder_batch_size
        self.encoder = dnabert_encoder

    def generate_batch(self, rng: np.random.Generator):
        (_, sample_indices), entries = self.sampler.random_entries(
            self.batch_size, max(1, self.subsample_size), rng)
        sequences = self.sampler.sequences(entries, rng)
        if self.kmer > 1:
            sequences = dna.encode_kmers(
                sequences, # type: ignore
                self.kmer,
                not self.sampler.augment_ambiguous_bases) # type: ignore
        if self.subsample_size == 0:
            sequences = np.squeeze(sequences, axis=1)
        return sequences, sample_indices

### Create the Data Generator Instances for Training and Testing

In [47]:
train_data = OtuSampleGenerator(
    (fasta_dbs[0], otu_dbs[0]),
    sequence_length=dnabert_encoder.base.sequence_length,
    kmer=dnabert_encoder.base.kmer,
    use_presence_absence=False,
    subsample_size=1000,
    batch_size=8,
    batches_per_epoch=100
)

# test_data = OtuSampleGenerator(
#     (fasta_dbs[0], otu_dbs[0]),
#     sequence_length=dnabert_encoder.base.sequence_length,
#     kmer=dnabert_encoder.base.kmer,
#     use_presence_absence=False,
#     subsample_size=1000,
#     shuffle=False,
#     batch_size=8,
#     batches_per_epoch=100
# )

### Model Architecture

In [34]:
num_samples_to_predict = len(otu_dbs[0])
num_samples_to_predict

768

Since we want to use the set embeddings that come from the SetBERT model, we'll employ a SetBERT encoder model. This model simply extracts and returns the class token representing the embedded set.

In [35]:
setbert_encoder = setbert.SetBertEncoderModel(pretrained_setbert_model.base)

In [36]:
setbert_encoder.input_shape

(None, 1000, 128)

In [37]:
# Outputs 8 128d embeddings, one for each sample
setbert_encoder(train_data[0][0])

<tf.Tensor: shape=(8, 128), dtype=float32, numpy=
array([[ 0.5711318 ,  4.1046515 , -1.8908697 , ..., -6.3806534 ,
         8.200887  ,  0.8826097 ],
       [-0.82003796,  4.9176493 , -1.9321549 , ..., -6.8010664 ,
         7.000064  ,  1.519652  ],
       [-0.14095283,  4.4001226 , -2.5025759 , ..., -6.810854  ,
         7.4469137 ,  1.351651  ],
       ...,
       [ 0.1294086 ,  4.2874637 , -2.6301215 , ..., -6.940845  ,
         7.9330893 ,  2.1062567 ],
       [-0.6797569 ,  4.781932  , -2.0586112 , ..., -6.6996155 ,
         7.3001323 ,  1.1681182 ],
       [ 0.4470997 ,  3.7609084 , -2.4623172 , ..., -7.1823206 ,
         8.035939  ,  1.4849747 ]], dtype=float32)>

Then we create the full model by encoding the input and passing the set embeddings through a single dense layer with a softmax activation function to compute predict which sample the input originated.

In [42]:
subsample_length = setbert_encoder.base.max_set_len
kmer_sequence_length = dnabert_encoder.input_shape[1]

subsample_length, kmer_sequence_length

(1000, 248)

In [44]:
inputs = tf.keras.layers.Input((subsample_length, kmer_sequence_length))
class_tokens = setbert_encoder(inputs)
sample_predictions = tf.keras.layers.Dense(num_samples_to_predict, activation="softmax")(class_tokens)
model = tf.keras.Model(inputs, sample_predictions)

Compile it using the appropriate loss function/metrics/optimizer

In [45]:
# Compile the model
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    optimizer=tf.keras.optimizers.Adam(1e-4)
)

In [46]:
# Grab a batch from the training data
batch = train_data[0]

# Get the inputs and targets from the batch
inputs = batch[0]
targets = batch[1]

predictions = model(inputs)

# Test predictions of untrained model
print("Predictions", tf.argmax(predictions, axis=1))
print("Targets", targets)

Predictions tf.Tensor([619 619 619 619 619 619 619 619], shape=(8,), dtype=int64)
Targets [272 276 596 356 232 352 577 186]


Train the model

In [None]:
model.fit(
    train_data, 
    # validation_data=test_data, 
    epochs=1
)