1. Download dependencies

In [2]:
!pip install pyfaidx 
!pip install ipywidgets
!pip uninstall -y numpy
!pip install "numpy<2.0.0" --upgrade
!pip install git+https://github.com/songlab-cal/gpn.git

Defaulting to user installation because normal site-packages is not writeable
Collecting pyfaidx
  Downloading pyfaidx-0.8.1.4-py3-none-any.whl (28 kB)
Installing collected packages: pyfaidx
Successfully installed pyfaidx-0.8.1.4
Defaulting to user installation because normal site-packages is not writeable
Collecting ipywidgets
  Downloading ipywidgets-8.1.7-py3-none-any.whl (139 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 KB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting jupyterlab_widgets~=3.0.15
  Downloading jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m216.6/216.6 KB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting comm>=0.1.3
  Downloading comm-0.2.3-py3-none-any.whl (7.3 kB)
Collecting widgetsnbextension~=4.0.14
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31

2. Download hg38 genome

In [3]:
!wget -c http://hgdownload.cse.ucsc.edu/goldenpath/hg38/bigZips/hg38.fa.gz
!gunzip -k hg38.fa.gz

--2025-08-02 14:11:17--  http://hgdownload.cse.ucsc.edu/goldenpath/hg38/bigZips/hg38.fa.gz
Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.119.163
Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.119.163|:80... connected.
HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable

    The file is already fully retrieved; nothing to do.

gzip: hg38.fa already exists; do you wish to overwrite (y or n)? ^C


3. Generate input sequences

In [None]:
from pyfaidx import Fasta
import pandas as pd
from tqdm import tqdm

# Load dataset and genome
df = pd.read_csv("all.csv")
genome = Fasta("hg38.fa")

# Define sequence window size (this is the flanking size on each side, so total length is twice this plus 1)
window = 240

# Function to extract reference sequence centered at variant position
def get_sequence(row, flank_size=256):
    try:
        chrom = str(row["#CHROM"])
        if not chrom.startswith("chr"):
            chrom = "chr" + chrom

        pos = int(row["POS"])
        start = max(0, pos - flank_size - 1)
        end = pos + flank_size

        if chrom not in genome:
            return None

        seq = genome[chrom][start:end].seq.upper()

        # Check sequence length
        if len(seq) != (2 * flank_size + 1):
            return None

        return seq
    except Exception as e:
        print(f"[⚠️ get_sequence] Error: {e}")
        return None

# Function to generate mutant sequence by replacing the reference base with ALT at the center
def generate_mutant_sequence(row, flank_size=256):
    try:
        seq = list(row["Context_Sequence"])
        mut_pos = flank_size

        if len(row["REF"]) != 1 or len(row["ALT"]) != 1:
            return None

        if seq[mut_pos] != row["REF"]:
            return None

        seq[mut_pos] = row["ALT"]
        return "".join(seq)
    except Exception as e:
        print(f"[⚠️ generate_mutant_sequence] Error: {e}")
        return None

# Generate sequences
tqdm.pandas()
df["Context_Sequence"] = df.progress_apply(lambda row: get_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Context_Sequence"], inplace=True)

df["Mutant_Sequence"] = df.progress_apply(lambda row: generate_mutant_sequence(row, flank_size=window), axis=1)
df.dropna(subset=["Mutant_Sequence"], inplace=True)

print(f"Successfully generated context and mutant sequences. Total valid records: {len(df)}")

100%|██████████| 259601/259601 [00:34<00:00, 7499.57it/s] 
100%|██████████| 259601/259601 [00:09<00:00, 28772.56it/s]

✅ Successfully generated context and mutant sequences. Total valid records: 259600





4. Compute the PhyloGPN scores

In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from gpn.phylogpn import model, tokenizer

# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Using device: {device}")

# Parameter settings
BATCH_SIZE = 256
pad_token = tokenizer.pad_token
pad_size = 240  # Centered on REF base

# Load data
df = df.head(100)
seqs = df["Context_Sequence"].tolist()
refs = df["REF"].tolist()
alts = df["ALT"].tolist()

# Center-aligned padding function
pad_sequence = lambda seq: pad_token * pad_size + seq + pad_token * pad_size
padded_seqs = [pad_sequence(seq) for seq in seqs]

# Batch generator
def batch_list(lst, batch_size):
    for i in range(0, len(lst), batch_size):
        yield list(range(i, min(i + batch_size, len(lst)))), lst[i:i + batch_size]

# Inference and scoring
scores = []
match_flags = []
success_count = 0

for idxs, batch_seqs in tqdm(batch_list(padded_seqs, BATCH_SIZE),
                             total=(len(padded_seqs) + BATCH_SIZE - 1) // BATCH_SIZE,
                             desc="Running inference"):

    input_tensor = tokenizer(batch_seqs, return_tensors="pt", padding=True)["input_ids"].to(device)

    with torch.no_grad():
        output_logits = model(input_tensor)  # dictionary: A, C, G, T

    for i, orig_idx in enumerate(idxs):
        ref_base = refs[orig_idx]
        alt_base = alts[orig_idx]
        seq = seqs[orig_idx]

        try:
            # Ensure the sequence is center-aligned
            actual_base = seq[240]
            is_match = (actual_base == ref_base)

            # Compute logit difference
            score = output_logits[alt_base][i, 240].item() - output_logits[ref_base][i, 240].item()
            success_count += 1
        except (KeyError, IndexError):
            score = float("nan")
            is_match = float("nan")

        scores.append(score)
        match_flags.append(is_match)

# Add result columns & save
df["PhyloGPN"] = scores
df["Match_REF_at_240"] = match_flags
df.to_csv("phylogpn.csv", index=False)

# Sample output
print(df[["Context_Sequence", "REF", "ALT", "PhyloGPN", "Match_REF_at_240"]].head())
print(f"Successfully scored sequences: {success_count} / {len(seqs)}")

Using device: cuda


Running inference: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it]

                                    Context_Sequence REF ALT  PhyloGPN  \
0  ACTGCCCTGGGCCGTGGTAGTTCTCTGTCCTTCATCAGGCTTTGTC...   C   T  2.370520   
1  AACATCACCATGAAGTGGCTGAAGGATAAGCAGCCAATGGATGCCA...   G   A  0.072245   
2  AGAGCCAGGAGCTGAGAAAATCTATTGGGGGTTGAGAGGAGTGCCT...   G   T  4.983103   
3  TTTGGTAGTTTTCCCTTTAAAATAATCAGAACTGCATATTGACAGA...   T   C  6.111948   
4  AAATGCCTAGGAAGTACCTTTCAGAGAAAGTAGAGAATATTTAATA...   T   C  4.164602   

   Match_REF_at_240  
0              True  
1              True  
2              True  
3              True  
4              True  
✅ Successfully scored sequences: 100 / 100



