1. Download dependencies and hg38 genome

In [None]:
!pip install pyfaidx transformers datasets tqdm ipywidgets
!wget -c http://hgdownload.cse.ucsc.edu/goldenpath/hg38/bigZips/hg38.fa.gz
!gunzip -k hg38.fa.gz

Defaulting to user installation because normal site-packages is not writeable


2025-04-28 03:32:38.565804: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745811158.588499  150437 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745811158.595557  150437 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


2. Generate input sequences

In [None]:
import pandas as pd
import requests
from tqdm import tqdm
from pyfaidx import Fasta
from transformers import BertTokenizer, BertForSequenceClassification
import torch

df = pd.read_csv("all.csv")

# Load reference genome
genome = Fasta("hg38.fa")

# Set window size
window = 512

# Extract 512bp flanking sequences (total 1024+1bp)

def get_sequence(row, flank_size=256):
    try:
        chrom = str(row["#CHROM"])
        if not chrom.startswith("chr"):
            chrom = "chr" + chrom

        pos = int(row["POS"])
        start = max(0, pos - flank_size - 1)
        end = pos + flank_size

        # Check if chromosome is present in the genome
        if chrom not in genome:
            return None

        seq = genome[chrom][start:end].seq.upper()

        # Validate length
        if len(seq) != (2 * flank_size + 1):
            return None

        return seq
    except Exception as e:
        print(f"[⚠️ get_sequence] Error: {e}")
        return None


def generate_mutant_sequence(row, flank_size=256):
    try:
        seq = list(row["Context_Sequence"])
        mut_pos = flank_size  # The mutation position is in the center

        # Ensure REF/ALT are single nucleotides
        if len(row["REF"]) != 1 or len(row["ALT"]) != 1:
            return None

        # Ensure REF matches the reference sequence
        if seq[mut_pos] != row["REF"]:
            return None

        seq[mut_pos] = row["ALT"]
        return "".join(seq)
    except Exception as e:
        print(f"[⚠️ generate_mutant_sequence] Error: {e}")
        return None

# Extract context sequences (reference)
tqdm.pandas()
df["Context_Sequence"] = df.progress_apply(lambda row: get_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Context_Sequence"], inplace=True)

# Generate mutant sequences
df["Mutant_Sequence"] = df.progress_apply(lambda row: generate_mutant_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Mutant_Sequence"], inplace=True)

print(f"Successfully generated context and mutant sequences. Total valid records: {len(df)}")

3. Compute the DNABERT2 score

In [18]:
import torch
import numpy as np
import random
from transformers import BertModel, AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

df = df.head(100)

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load DNABERT2 model and tokenizer
model = BertModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

# Enable multi-GPU if available
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for inference (DataParallel)")
    model = torch.nn.DataParallel(model)

model = model.to(device)
model.eval()

# Prediction function with batching & multi-GPU support
@torch.no_grad()
def predict_batch(sequences, batch_size=32):
    embeddings = []
    dataloader = DataLoader(sequences, batch_size=batch_size)
    for batch in tqdm(dataloader, desc="Predicting Batch", leave=True):
        inputs = tokenizer(list(batch), return_tensors="pt", padding=True, truncation=True, max_length=4096)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
        embeddings.append(cls_embeddings)
    return torch.cat(embeddings, dim=0).numpy()

# Extract wild-type and mutant sequences
context_sequences = df["Context_Sequence"].tolist()
mutant_sequences = df["Mutant_Sequence"].tolist()

# Run inference
print("Predicting wild-type sequences...")
wt_embeddings = predict_batch(context_sequences, batch_size=512)

print("Predicting mutant sequences...")
mut_embeddings = predict_batch(mutant_sequences, batch_size=512)

# Compute L2 distance between wild-type and mutant embeddings
l2_scores = np.linalg.norm(wt_embeddings - mut_embeddings, axis=1)

# Save L2 scores into the original dataframe
df["DB2"] = l2_scores

# (Optional) Save result if needed
df.to_csv("db2.csv", index=False)

print("DNABERT2 inference complete. L2 distance scores have been added to the dataset as 'DB2'.")

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['embeddings.position_embeddings.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.intermediate.dense.b

Predicting wild-type sequences...


Predicting Batch: 100%|██████████| 1/1 [00:00<00:00,  2.50it/s]


Predicting mutant sequences...


Predicting Batch: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s]

DNABERT2 inference complete. L2 distance scores have been added to the dataset as 'DB2'.



