In [1]:
import lightning as L
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.loggers import WandbLogger
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

from dbtk.data.datasets import SequenceDataset
from dbtk.nn.layers import TransformerEncoder, TransformerEncoderBlock, RelativeMultiHeadAttention
from dnabert import DnaBertModel, DnaBertPretrainingModel

In [2]:
global_batch_size = 512
batch_size = global_batch_size // 4
min_length = 65
max_length = 250
mask_ratio = 0.15
embed_dim = 768
num_heads = 12
feedforward_dim = 2048
stack = 8
kmer = 6
kmer_stride = 1

In [3]:
bert = DnaBertPretrainingModel(
    DnaBertModel(
        TransformerEncoder(
            TransformerEncoderBlock(
                RelativeMultiHeadAttention(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    max_length=max_length,
                ),
                feedforward_dim=feedforward_dim,
            ),
            num_layers=stack
        ),
        kmer=kmer,
        kmer_stride=kmer_stride
    )
)

In [4]:
num_workers = 8

def transform(entry):
    # Sequence
    sequence = entry.sequence
    minlen, maxlen = min(min_length, len(sequence)), min(max_length, len(sequence))
    length = torch.randint(minlen, maxlen, size=(1,)).item()
    offset = torch.randint(0, len(sequence) - length + 1, size=(1,)).item()
    sequence = torch.tensor(list(bert.base.vocabulary(bert.base.tokenizer(sequence[offset:offset+length]))))

    # Masking
    mask_length = torch.randint(1, int(len(sequence)*mask_ratio) + 1, size=(1,)).item()
    mask_offset = torch.randint(0, len(sequence) - mask_length + 1, size=(1,)).item()
    masked_tokens = sequence[mask_offset:mask_offset+mask_length].clone()
    sequence[mask_offset:mask_offset+mask_length] = bert.base.vocabulary["[MASK]"]

    # Padding
    sequence = F.pad(sequence, (0, max_length - len(sequence) + 1), value=bert.base.vocabulary["[PAD]"])
    return sequence, masked_tokens

def collate(entries):
    sequences, masked_tokens = zip(*entries)
    return (torch.stack(sequences), torch.cat(masked_tokens)), None, None


train_dataset = torch.utils.data.ConcatDataset([
    SequenceDataset(
        "/home/data2/deepdna/datasets/silva_nr99_filtered_515f_806r/sequences.fasta.db/",
        transform=transform
    )
])

train_sampler = torch.utils.data.RandomSampler(train_dataset, num_samples=num_workers*batch_size, replacement=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, collate_fn=collate)

In [5]:
import time
it = iter(train_loader)

t = time.time()
batch = next(it)
time.time() - t

0.03195476531982422

In [6]:
batch

((tensor([[   4,    4,    4,  ...,    0,    0,    0],
          [2716, 2657, 2421,  ...,    0,    0,    0],
          [3232,  625, 2488,  ...,    0,    0,    0],
          ...,
          [3207,  528, 2098,  ...,    0,    0,    0],
          [2779, 2909, 3431,  ...,    0,    0,    0],
          [   1,    1,    1,  ...,    0,    0,    0]]),
  tensor([2720, 2673, 2485,  ..., 2722, 2681, 2517])),
 None,
 None)

In [8]:
logger = None
run = wandb.init(project="dnabert-768")
logger = WandbLogger(project="dnabert-768", log_model=False)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msirdavidludwig[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
trainer = L.Trainer(
    max_steps=100000,
    callbacks=[RichProgressBar(refresh_rate=10)],
    logger=logger,
    log_every_n_steps=50,
    accumulate_grad_batches=global_batch_size//batch_size)
trainer.fit(
    model=bert,
    train_dataloaders=train_loader)

In [None]:
torch.save(bert, "./model.pt")

In [None]:
a = wandb.Artifact(f"dnabert.{embed_dim}d.silva", type="model")
a.add_file("./model.pt")
run.log_artifact(a)

In [None]:
run.finish()