To run this notebook you can use the flash2_pedro conda environment. Please do not install any package here. The best is if you can clone this environment ot a new one, or install the required packages in a new environment.

## Dataloader example to fetch a batch of sequences together with their alelle frequencies per position and nucleotide and a mask which masks the positions where the allele frequencies are not high enough quality (based on the number of individuals with data at those positions)

```python

In [None]:
#This codes takes a data frame with varinat information and fetches the sequence around the variant from a fasta file. It fetches a window with the same length as the context length of the model.
#In some rare cases that window will surpass the boundaries of a chromosome (when the variant is close to the end of a chromosome). In that case, the variant is discarded (for now).
import pandas as pd
import pyranges as pr
import math 
from pathlib import Path
import numpy as np


<torch._C.Generator at 0x7f9504816c50>

In [14]:
#this is the starting dataset:  contains the sequences for the upstream regions anchored at the start codon for each human gene. The sequence length is equal to the SpeciesLM input context length=2003.
#This dataset discards the sex chromosmomes and the mitochondrial chromosome.
seqs_df = pd.read_csv('/s/project/ml4rg_students/2025/project07/data/gtf_start_extended_ints_df_2003_seq.csv')
print (seqs_df)

      Chromosome     Start       End Strand  \
0           chr1     63564     65567      +   
1           chr1    922431    924434      +   
2           chr1    923941    925944      +   
3           chr1    958693    960696      +   
4           chr1    964531    966534      +   
...          ...       ...       ...    ...   
30055      chr22  50577912  50579915      -   
30056      chr22  50582778  50584781      -   
30057      chr22  50627776  50629779      -   
30058      chr22  50627369  50629372      -   
30059      chr22  50782291  50784294      -   

                                                     seq  seq_len  
0      TATCGATGGGCACCTTCTTTTTCTTAATTGTATCATACATTTTTAT...     2003  
1      AGAAGACACAGACTTCAGGAGAGGAAGGCACAGGAACTCACTGGCA...     2003  
2      TCCCCGCCGGGCGGGCGCGCGCCAGTGGACGCGGGTGCACGACTGA...     2003  
3      TCGGGAAGAGATTTTTGCACAACTCACCAACATACGCTCCCTGCCT...     2003  
4      TCCGCAGTGGGGCTGCGGGGAGGGGGGCGCGGGTCCGCAGTGGGGC...     2003  
...                        

In [None]:
import pandas as pd
from gnomad_db.database import gnomAD_DB

database_location = '/s/project/benchmark-lm/ssd-cache'
db = gnomAD_DB(database_location, gnomad_version="v4")

In [4]:
nuc_to_int_dict = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
revcomp_dict = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A'}

def one_hot_seq(seq):
    
    seq = seq.upper()
    one_hot = np.zeros((len(seq), 4), dtype=int)
    for i, nucleotide in enumerate(seq):
        if nucleotide == 'A':
            one_hot[i, 0] = 1
        elif nucleotide == 'C':
            one_hot[i, 1] = 1
        elif nucleotide == 'G':
            one_hot[i, 2] = 1
        elif nucleotide == 'T':
            one_hot[i, 3] = 1
        else:
            raise ValueError(f"Unknown nucleotide {nucleotide} at position {i} in sequence {seq}")
    return one_hot



In [5]:
def get_af_from_interval(interval):

    df_region = db.get_info_for_interval(chrom=interval['Chromosome'].values[0].strip('chr'), 
        interval_start=interval['Start'].values[0] + 1, #IMPORTANT 1 BASED COORDINATES IN GnomAD
        interval_end=interval['End'].values[0], query="*") #END IS INCLUDED

    #only consider SNPs that Pass the gnomad filtering criteria
    df_region = df_region[df_region['filter'] == 'PASS'].copy()
    df_region['len_ref'] = df_region['ref'].apply(len)
    df_region['len_alt'] = df_region['alt'].apply(len)
    df_region = df_region[(df_region['len_ref'] == 1) & (df_region['len_alt'] == 1)].copy().reset_index(drop=True)

    if interval['Strand'].values[0] == '+':
        df_region['relative_pos'] = (df_region['pos'] - (interval['Start'].values[0] + 1)).astype(int)  # this is zero based

        df_region['ref_int'] = df_region['ref'].map(nuc_to_int_dict).astype(int)
        df_region['alt_int'] = df_region['alt'].map(nuc_to_int_dict).astype(int)

        interval_length = interval['End'].values[0] - interval['Start'].values[0]
        afs_arr = np.zeros((interval_length, 4))
        afs_arr[df_region['relative_pos'].values, df_region['alt_int'].values] = df_region['AF'].values

        one_hot_arr = one_hot_seq(interval['seq'].values[0])

        afs_arr[one_hot_arr==1] = 1 - afs_arr.sum(axis=-1)

    elif interval['Strand'].values[0] == '-':
        df_region['relative_pos'] = (interval['End'].values[0] - df_region['pos']).astype(int)

        df_region['ref_int'] = df_region['ref'].map(revcomp_dict).map(nuc_to_int_dict).astype(int)
        df_region['alt_int'] = df_region['alt'].map(revcomp_dict).map(nuc_to_int_dict).astype(int)

        interval_length = interval['End'].values[0] - interval['Start'].values[0]
        afs_arr = np.zeros((interval_length, 4))
        afs_arr[df_region['relative_pos'].values, df_region['alt_int'].values] = df_region['AF'].values

        one_hot_arr = one_hot_seq(interval['seq'].values[0])
        afs_arr[one_hot_arr==1] = 1 - afs_arr.sum(axis=-1)

    return afs_arr, df_region 

### Fetch total allele number per position

In [6]:
import polars as pl

scores_lazy = pl.scan_parquet(
    "/s/project/benchmark-lm/data/gnomad4_1_allele_number/"
)
scores_lazy

PermissionError: Permission denied (os error 13): /s/project/benchmark-lm/data/gnomad4_1_allele_number/

<LazyFrame at 0x7F9514833DF0>

## Tokenizer - SpeciesLM

In [7]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
import torch.nn as nn
from torch.amp import autocast
import tqdm

# Load the model
model_name = "johahi/specieslm-metazoa-upstream-k6"
model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)

# Load the corresponding tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

You are using a model of type rotarybert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


In [8]:
proxy_species = 'homo_sapiens'
assert proxy_species in tokenizer.get_vocab()

def kmers(seq, k=6): #for codons, k = 6
    # splits a sequence into non-overlappnig k-mers
    return [seq[i:i + k] for i in range(0, len(seq), k) if i + k <= len(seq)]

def kmers_stride1(seq, k=6):
    # splits a sequence into overlapping k-mers
    return [seq[i:i + k] for i in range(0, len(seq)-k+1)]  

def tok_func_species(seq, proxy_species):
    res = tokenizer(proxy_species + " " +  " ".join(kmers_stride1(seq)))
    return res


In [9]:
# create a pytorch dataset which inputs the seqs_df and ilocs it to get the interval and the afs_arr and an_arr

import torch
from torch.utils.data import Dataset

class GnomADIntervalSpeciesLMDataset(Dataset):
    def __init__(self, seqs_df, minimum_total_allele_number=150788):
        self.seqs_df = seqs_df
        self.minimum_total_allele_number = minimum_total_allele_number # 99% of positions in the human genome are above this. Meaning for 99% of positions we have data on minimum_total_allele_number/2 or more individuals (/2 because each human has 2 sets of chromosome from the father and mother)
    
    def __len__(self):
        return len(self.seqs_df)
    
    def __getitem__(self, idx):
        interval = self.seqs_df.iloc[idx].to_frame().T
        afs_arr, _ = get_af_from_interval(interval)
        an_df = scores_lazy.filter((pl.col("chrom") == interval['Chromosome'].values[0]) & (pl.col("pos") >= interval['Start'].values[0]+1) & (pl.col("pos") <= interval['End'].values[0])).collect().to_pandas()
        an_arr = an_df['AN'].values if interval['Strand'].values[0] == '+' else an_df['AN'].values[::-1]
        an_mask = an_arr < self.minimum_total_allele_number
        input_ids = tok_func_species(interval['seq'].values[0], proxy_species=proxy_species)['input_ids']
        
        
        return {
            'labels': torch.tensor(afs_arr.copy(), dtype=torch.float32),
            'an_arr': torch.tensor(an_arr.copy(), dtype=torch.long),
            'an_mask': torch.tensor(an_mask.copy(), dtype=torch.bool), #True if the position does not pass the quality criteria of the minimum_total_allele_number
            'input_ids': torch.tensor(input_ids, dtype=torch.long)
        }

In [10]:
from torch.utils.data import DataLoader

gnomad_dataset = GnomADIntervalSpeciesLMDataset(seqs_df)

loader  = DataLoader(
    gnomad_dataset,
    batch_size = 8,
    shuffle    = True,
    num_workers= 0)

In [11]:
for batch in loader:
    
    break

NameError: name 'db' is not defined

In [12]:
batch

NameError: name 'batch' is not defined

## Future steps:
- You have now seen how to run each DNA LM: SpeciesLM and GPN-MSA, and how to fetch the allele frequencies and counts for a specific region. 
- Your next goal will be to fine-tune each model on the allele frequency per position in the genome.
- Most positions do not have variants, there the allele frequency for variant nucleotides is 0 and the reference nucleotide is 1
- Start by creating a train/validation/test split for the dataset. 
- Allele frequencies are shaped by mutation biases, unrelated to fitness effects, you can decide whether to correct for this when training the model or afterwards
  - using the mutation probabilities estimated by the neutral mutation rate model as another feature or in the loss function. This way you provide what to expect under no selection pressure and the model can focus on the selection/fitness effects
  - correcting the predicted allele frequencies by the estimated mutation rate