1. Download dependencies and hg38 genome

In [None]:
!pip install pyfaidx transformers datasets tqdm ipywidgets
!pip install transformers accelerate --quiet
!wget -c http://hgdownload.cse.ucsc.edu/goldenpath/hg38/bigZips/hg38.fa.gz
!gunzip -k hg38.fa.gz

Defaulting to user installation because normal site-packages is not writeable


2025-04-27 19:24:04.129534: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745781844.151231    8251 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745781844.158411    8251 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


2. Generate input sequences

In [7]:
import pandas as pd
import requests
from tqdm import tqdm
from pyfaidx import Fasta
from transformers import BertTokenizer, BertForSequenceClassification
import torch

df = pd.read_csv("all.csv")
df = df.head(100)

# Load reference genome
genome = Fasta("hg38.fa")

# Set window size
window = 3000

# Extract 3000bp flanking sequences (total 6000+1bp)

def get_sequence(row, flank_size=256):
    try:
        chrom = str(row["#CHROM"])
        if not chrom.startswith("chr"):
            chrom = "chr" + chrom

        pos = int(row["POS"])
        start = max(0, pos - flank_size - 1)
        end = pos + flank_size

        # Check if chromosome is present in the genome
        if chrom not in genome:
            return None

        seq = genome[chrom][start:end].seq.upper()

        # Validate length
        if len(seq) != (2 * flank_size + 1):
            return None

        return seq
    except Exception as e:
        print(f"[⚠️ get_sequence] Error: {e}")
        return None


def generate_mutant_sequence(row, flank_size=256):
    try:
        seq = list(row["Context_Sequence"])
        mut_pos = flank_size  # The mutation position is in the center

        # Ensure REF/ALT are single nucleotides
        if len(row["REF"]) != 1 or len(row["ALT"]) != 1:
            return None

        # Ensure REF matches the reference sequence
        if seq[mut_pos] != row["REF"]:
            return None

        seq[mut_pos] = row["ALT"]
        return "".join(seq)
    except Exception as e:
        print(f"[⚠️ generate_mutant_sequence] Error: {e}")
        return None

# Extract context sequences (reference)
tqdm.pandas()
df["Context_Sequence"] = df.progress_apply(lambda row: get_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Context_Sequence"], inplace=True)

# Generate mutant sequences
df["Mutant_Sequence"] = df.progress_apply(lambda row: generate_mutant_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Mutant_Sequence"], inplace=True)

print(f"Successfully generated context and mutant sequences. Total valid records: {len(df)}")

100%|██████████| 100/100 [00:00<00:00, 1225.59it/s]
100%|██████████| 100/100 [00:00<00:00, 6500.68it/s]

Successfully generated context and mutant sequences. Total valid records: 100





3. Compute the NT scores

In [4]:
import torch
import numpy as np
import random
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Nucleotide Transformer 2.5B (multi-species)
model = AutoModelForMaskedLM.from_pretrained(
    "InstaDeepAI/nucleotide-transformer-2.5b-multi-species",
    trust_remote_code=True
)

# Use DataParallel for multi-GPU inference
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for inference!")
    model = torch.nn.DataParallel(model)

tokenizer = AutoTokenizer.from_pretrained(
    "InstaDeepAI/nucleotide-transformer-2.5b-multi-species",
    trust_remote_code=True
)

model.to(device)
model.eval()

# Batch inference function (returns CLS embeddings)
@torch.no_grad()
def predict_batch(sequences, batch_size=2):
    embeddings = []
    dataloader = DataLoader(sequences, batch_size=batch_size)
    for batch in tqdm(dataloader, desc="Predicting Batch", leave=True):
        inputs = tokenizer(list(batch), return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]
        cls_embeddings = last_hidden_state[:, 0, :].cpu()
        embeddings.append(cls_embeddings)
    return torch.cat(embeddings, dim=0).numpy()

# Extract wild-type and mutant sequences
context_sequences = df["Context_Sequence"].tolist()
mutant_sequences = df["Mutant_Sequence"].tolist()

# Run inference
print("Predicting Wild-Type Sequences...")
wt_embeddings = predict_batch(context_sequences, batch_size=1)

print("Predicting Mutant Sequences...")
mut_embeddings = predict_batch(mutant_sequences, batch_size=1)

# Compute L2 distance between WT and mutant embeddings
l2_scores = np.linalg.norm(wt_embeddings - mut_embeddings, axis=1)

# Save scores to original dataframe
df["NT"] = l2_scores
df.to_csv("nt.csv")
print("Nucleotide Transformer scoring complete. L2 distances saved to column 'NT'.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Predicting Wild-Type Sequences...


Predicting Batch: 100%|██████████| 100/100 [00:50<00:00,  2.00it/s]


Predicting Mutant Sequences...


Predicting Batch: 100%|██████████| 100/100 [00:51<00:00,  1.95it/s]

Nucleotide Transformer scoring complete. L2 distances saved to column 'NT'.



