# DNABERT-like Promoter Classifier (PyTorch) - 5 Component Prediction

A minimal, self-contained notebook that trains a DNABERT-style classifier on promoter sequences:
- Tokenises DNA into overlapping k-mers (k=6 by default)
- [CLS] + tokens + [SEP]
- Token + positional embeddings
- Transformer encoder stack
- Classification head on [CLS] to predict 5-component probabilities


In [1]:
# Imports
import math
from itertools import product
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
print(f"MPS: {getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available()}")



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.11/site-p

AttributeError: _ARRAY_API not found


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.11/site-p

AttributeError: _ARRAY_API not found


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.11/site-p

AttributeError: _ARRAY_API not found


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.11/site-p

AttributeError: _ARRAY_API not found


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/anaconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/opt/anaconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 701, in start
    self.io_loop.start()
  File "/opt/anaconda3/lib/python3.11/site-p

AttributeError: _ARRAY_API not found

ImportError: numpy.core.multiarray failed to import

In [8]:
# Data utilities (inline to keep notebook self-contained)
class PromoterDataset(Dataset):
    def __init__(self, sequences: list, targets: np.ndarray, max_length: int = 600):
        self.sequences = sequences
        self.targets = targets
        self.max_length = max_length
        self.dna_dict = {"A": 0, "T": 1, "G": 2, "C": 3, "N": 4}
    def __len__(self):
        return len(self.sequences)
    def encode_sequence(self, sequence: str) -> np.ndarray:
        seq = sequence
        if len(seq) > self.max_length:
            seq = seq[: self.max_length]
        else:
            seq = seq + "N" * (self.max_length - len(seq))
        encoded = np.array([self.dna_dict.get(base.upper(), 4) for base in seq])
        one_hot = np.zeros((self.max_length, 5), dtype=np.float32)
        one_hot[np.arange(self.max_length), encoded] = 1.0
        return one_hot.T
    def __getitem__(self, idx: int):
        sequence = self.encode_sequence(self.sequences[idx])
        target = self.targets[idx].astype(np.float32)
        total = float(np.sum(target))
        if total <= 0:
            target = np.ones_like(target, dtype=np.float32) / target.shape[0]
        else:
            target = target / total
        return {"sequence": torch.FloatTensor(sequence), "target": torch.FloatTensor(target)}

def load_and_prepare_data(file_path: str):
    df = pd.read_csv(file_path)
    prob_cols = ["Component_1_Probability", "Component_2_Probability", "Component_3_Probability", "Component_4_Probability"]
    df = df.dropna(subset=["ProSeq"]).dropna(subset=prob_cols)
    sequences = df["ProSeq"].tolist()
    targets = df[prob_cols].values
    valid_sequences = []
    valid_targets = []
    for i, seq in enumerate(sequences):
        if isinstance(seq, str) and len(seq) > 0:
            valid_sequences.append(seq)
            valid_targets.append(targets[i])
    return valid_sequences, np.array(valid_targets, dtype=np.float32)


In [9]:
# DNABERT-like model and tokenisation
SPECIALS = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"]

def build_kmer_vocab(k: int):
    bases = ["A", "C", "G", "T"]
    kmers = ["".join(p) for p in product(bases, repeat=k)]
    vocab = SPECIALS + kmers
    stoi = {t: i for i, t in enumerate(vocab)}
    itos = {i: t for t, i in stoi.items()}
    return vocab, stoi, itos

def seq_to_kmers(seq: str, k: int) -> List[str]:
    seq = seq.upper()
    toks: List[str] = []
    for i in range(len(seq) - k + 1):
        kmer = seq[i:i+k]
        if any(c not in "ACGT" for c in kmer):
            toks.append("[UNK]")
        else:
            toks.append(kmer)
    return ["[CLS]"] + toks + ["[SEP]"]

def encode_batch(seqs: List[str], k: int, stoi: dict, max_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    tokenised = [seq_to_kmers(s, k) for s in seqs]
    if max_len is None:
        max_len = max(len(t) for t in tokenised)
    pad_id = stoi["[PAD]"]
    unk_id = stoi["[UNK]"]
    input_ids = []
    attn = []
    for toks in tokenised:
        ids = [stoi.get(t, unk_id) for t in toks[:max_len]]
        mask = [1] * len(ids)
        if len(ids) < max_len:
            pad_n = max_len - len(ids)
            ids += [pad_id] * pad_n
            mask += [0] * pad_n
        input_ids.append(ids)
        attn.append(mask)
    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(attn, dtype=torch.bool)

def onehot5_to_strings(x: torch.Tensor) -> List[str]:
    assert x.ndim == 3 and x.size(1) == 5, "expected (B,5,L)"
    idx = x.argmax(dim=1)
    lut = {0: "A", 1: "T", 2: "G", 3: "C", 4: "N"}
    return ["".join(lut[int(i)] for i in row) for row in idx]

class DNABertClassifier(nn.Module):
    def __init__(
        self,
        k: int = 6,
        num_labels: int = 4,
        hidden_size: int = 256,
        num_layers: int = 6,
        num_heads: int = 8,
        ffn_size: Optional[int] = None,
        dropout: float = 0.1,
        max_position_embeddings: int = 1024,
        vocab: Optional[List[str]] = None,
        stoi: Optional[dict] = None,
    ):
        super().__init__()
        self.k = k
        self.num_labels = num_labels
        if vocab is None or stoi is None:
            vocab, stoi, _ = build_kmer_vocab(k)
        self.stoi = stoi
        self.vocab_size = len(vocab)
        if ffn_size is None:
            ffn_size = 4 * hidden_size
        self.token_embeddings = nn.Embedding(self.vocab_size, hidden_size)
        self.pos_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.emb_layer_norm = nn.LayerNorm(hidden_size)
        self.emb_dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=ffn_size,
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=False,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels),
        )
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        bsz, seqlen = input_ids.shape
        device = input_ids.device
        if attention_mask is None:
            attention_mask = input_ids.ne(self.stoi["[PAD]"])
        pos_ids = torch.arange(seqlen, device=device).unsqueeze(0).expand(bsz, seqlen)
        x = self.token_embeddings(input_ids) + self.pos_embeddings(pos_ids)
        x = self.emb_layer_norm(self.emb_dropout(x))
        x = self.encoder(x, src_key_padding_mask=~attention_mask.bool())
        cls = x[:, 0]
        logits = self.classifier(cls)
        return logits

def prepare_inputs_from_onehot(onehot_batch: torch.Tensor, k: int, stoi: dict, max_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    seqs = onehot5_to_strings(onehot_batch)
    input_ids, attn = encode_batch(seqs, k, stoi, max_len=max_len)
    return input_ids.to(onehot_batch.device), attn.to(onehot_batch.device)


In [10]:
# Training helpers

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total = 0.0
    for batch in train_loader:
        onehot = batch['sequence'].to(device)
        targets = batch['target'].to(device)
        input_ids, attn = prepare_inputs_from_onehot(onehot, k=model.k, stoi=model.stoi, max_len=(onehot.shape[-1]-model.k+1)+2)
        optimizer.zero_grad()
        logits = model(input_ids, attn)
        log_probs = F.log_softmax(logits, dim=1)
        loss = criterion(log_probs, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total += loss.item()
    return total / max(1, len(train_loader))

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total = 0.0
    with torch.no_grad():
        for batch in val_loader:
            onehot = batch['sequence'].to(device)
            targets = batch['target'].to(device)
            input_ids, attn = prepare_inputs_from_onehot(onehot, k=model.k, stoi=model.stoi, max_len=(onehot.shape[-1]-model.k+1)+2)
            logits = model(input_ids, attn)
            log_probs = F.log_softmax(logits, dim=1)
            loss = criterion(log_probs, targets)
            total += loss.item()
    return total / max(1, len(val_loader))

def evaluate_model(model, test_loader, device):
    model.eval()
    preds, targs = [], []
    with torch.no_grad():
        for batch in test_loader:
            onehot = batch['sequence'].to(device)
            targets = batch['target'].to(device)
            input_ids, attn = prepare_inputs_from_onehot(onehot, k=model.k, stoi=model.stoi, max_len=(onehot.shape[-1]-model.k+1)+2)
            logits = model(input_ids, attn)
            probs = torch.softmax(logits, dim=1)
            preds.append(probs.cpu().numpy())
            targs.append(targets.cpu().numpy())
    return np.vstack(preds), np.vstack(targs)


In [None]:
# Load data and build loaders
from sklearn.model_selection import train_test_split

csv_path = "../../data/processed/ProSeq_with_5component_analysis.csv"
sequences, targets = load_and_prepare_data(csv_path)

labels = np.argmax(targets, axis=1)
train_seq, test_seq, train_targets, test_targets = train_test_split(
    sequences, targets, test_size=0.2, random_state=42, stratify=labels
)
train_labels = np.argmax(train_targets, axis=1)
train_seq, val_seq, train_targets, val_targets = train_test_split(
    train_seq, train_targets, test_size=0.2, random_state=42, stratify=train_labels
)

train_ds = PromoterDataset(train_seq, train_targets)
val_ds = PromoterDataset(val_seq, val_targets)
test_ds = PromoterDataset(test_seq, test_targets)

batch_size = 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)

print(len(train_ds), len(val_ds), len(test_ds))


5590 1398 1747


In [None]:
# Train DNABERT-like model
if getattr(torch.backends, 'mps', None) is not None and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DNABertClassifier(k=6, num_labels=5, hidden_size=256, num_layers=6, num_heads=8, max_position_embeddings=1024)
model.to(device)

criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=8, factor=0.5)

num_epochs = 30
train_losses, val_losses = [], []
best_val = float('inf')
bad_epochs = 0
max_bad_epochs = 10
print(f"Training on {device}...")
for epoch in range(num_epochs):
    tr = train_epoch(model, train_loader, criterion, optimizer, device)
    va = validate_epoch(model, val_loader, criterion, device)
    scheduler.step(va)
    train_losses.append(tr)
    val_losses.append(va)
    if va < best_val - 1e-6:
        best_val = va
        bad_epochs = 0
        torch.save(model.state_dict(), 'best_dnabert_like.pth')
    else:
        bad_epochs += 1
        if bad_epochs >= max_bad_epochs:
            print(f"Early stop at epoch {epoch+1}")
            break
    print(f"Epoch {epoch+1:03d}/{num_epochs} - train {tr:.6f} - val {va:.6f} - lr {optimizer.param_groups[0]['lr']:.2e}")


Training on mps...


KeyboardInterrupt: 

In [None]:
# Evaluation and quick plots
import matplotlib.pyplot as plt

# reload best
model.load_state_dict(torch.load('best_dnabert_like.pth', map_location=device))

predictions, true_targets = evaluate_model(model, test_loader, device)

component_names = ['Component_1', 'Component_2', 'Component_3', 'Component_4', 'Component_5']
metrics = {}
for i, name in enumerate(component_names):
    mse = mean_squared_error(true_targets[:, i], predictions[:, i])
    r2 = r2_score(true_targets[:, i], predictions[:, i])
    metrics[name] = {'MSE': float(mse), 'R2': float(r2)}

overall_mse = mean_squared_error(true_targets, predictions)
overall_r2 = r2_score(true_targets.flatten(), predictions.flatten())
metrics['Overall'] = {'MSE': float(overall_mse), 'R2': float(overall_r2)}

print(metrics)

# basic training curve
plt.figure(figsize=(6,4))
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.legend(); plt.xlabel('epoch'); plt.ylabel('loss'); plt.title('DNABERT-like training'); plt.grid(True, alpha=0.3)
plt.show()


: 

: 

In [None]:
# Single-sequence prediction helper
def predict_component_probabilities_dnabert(model: DNABertClassifier, sequence: str, device):
    model.eval()
    seqs = [sequence]
    max_len = (len(sequence)-model.k+1)+2 if len(sequence) >= model.k else model.k+2
    input_ids, attn = encode_batch(seqs, model.k, model.stoi, max_len=max_len)
    input_ids = input_ids.to(device)
    attn = attn.to(device)
    with torch.no_grad():
        logits = model(input_ids, attn)
        probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
    pred_comp = int(np.argmax(probs) + 1)
    conf = float(np.max(probs))
    return {
        'component_1_prob': float(probs[0]),
        'component_2_prob': float(probs[1]),
        'component_3_prob': float(probs[2]),
        'component_4_prob': float(probs[3]),
        'predicted_component': pred_comp,
        'confidence': conf,
    }

# Example
sample = sequences[0]
print(predict_component_probabilities_dnabert(model, sample, device))


: 

: 