In [1]:
import wandb
wandb.login()
# grelu enformer really wants this wandb bs so did that

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

  2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

  ········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /net/dali/home/mscbio/ahk112/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mahk112[0m ([33mahk112-university-of-pittsburgh[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader

# grelu is installed like: pip3 install grelu
from grelu.model.models import EnformerPretrainedModel
from grelu.sequence.format import convert_input_type


import pandas as pd
import warnings
warnings.filterwarnings('ignore')

In [3]:
base_dir = "data_processed"
FASTA_PATH = os.path.join(base_dir, "promoters_flank1000.fasta")   # our promoter fasta
EMB_OUT_PATH = os.path.join(base_dir, "enformer_promoter_embeddings_flank1000.npy")
GENE_OUT_PATH = os.path.join(base_dir, "enformer_promoter_genes_flank1000.npy")

pairs_path = os.path.join(base_dir, "tf_gene_pairs_with_promoters_flank1000.csv")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [4]:
TARGET_LENGTH = 2000   # length we want per sequence (promoter is ±1000bp)
BATCH_SIZE = 32        # can bump up/down depending on GPU mem

In [5]:
# read fasta into gene names and sequences. headers are gene names followed by seqs
def read_fasta(path):
    gene_names = []
    sequences = []
    curr_name = None
    curr_seq = []
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                # add prev record
                if curr_name is not None:
                    gene_names.append(curr_name)
                    sequences.append("".join(curr_seq))
                curr_name = line[1:].strip()
                curr_seq = []
            else:
                curr_seq.append(line)
    # last record
    if curr_name is not None:
        gene_names.append(curr_name)
        sequences.append("".join(curr_seq))
    return gene_names, sequences

    

In [6]:
gene_names, sequences = read_fasta(FASTA_PATH)
print(f"#genes: {len(gene_names)}")
print("First 5 gene names:", gene_names[:5])
print("First seq length:", len(sequences[0]))

#genes: 4828
First 5 gene names: ['A1BG', 'AAK1', 'AAMDC', 'AAMP', 'AARSD1']
First seq length: 2000


above should have ~5k genes and seq length of 2k

In [7]:
# padding to target length just in case (enformer wants fixed lengths)

def pad_or_trim(seq, target_len):
    if len(seq) > target_len:
        return seq[:target_len]
    elif len(seq) < target_len:
        return seq + "N" * (target_len - len(seq)) # n universally means any base
    else:
        return seq

padded_seqs = [pad_or_trim(s, TARGET_LENGTH) for s in sequences]

lengths = [len(s) for s in padded_seqs]
print("Min len:", min(lengths), "Max len:", max(lengths)) # should be 2k

Min len: 2000 Max len: 2000


In [8]:
# OHC sequences for enformer. something like 4 or 5 (cuz N) by 2k
ohes = convert_input_type(
    inputs=padded_seqs,
    output_type="one_hot",   # ask for one-hot encoding
    genome="hg38",           # we are passing seqs so maybe don't need
    add_batch_axis=False,    # we want [L, 4] per item; DataLoader will add batch dim
)
print("Number of one-hot items:", len(ohes))
print("Shape of first one-hot:", np.array(ohes[0]).shape)
# should be something like (TARGET_LENGTH, 4). 4 or 5 for the latter cuz N

Number of one-hot items: 4828
Shape of first one-hot: (4, 2000)


In [9]:
# OH seqs in a dataloader to batch into enformer. just the one hot arrays

# list into a numpy array so it is treated as a dataset of arrays
ohes_np = np.stack(ohes, axis=0)   # shape: (N, L, 4)
print("ohes_np shape:", ohes_np.shape)

test_loader = DataLoader(
    dataset=ohes_np,
    batch_size=BATCH_SIZE,
    shuffle=False
)

ohes_np shape: (4828, 4, 2000)


In [10]:
# load enformer from grelu (pretrained) 
feature_extractor = EnformerPretrainedModel(
    n_tasks=32,    
    device=device
).to(device)

feature_extractor.eval()

total_params = sum(p.numel() for p in feature_extractor.parameters())
print("Enformer params:", f"{total_params:,}")

[34m[1mwandb[0m: Downloading large artifact 'human_state_dict:latest', 939.29MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:15.0 (62.5MB/s)


Enformer params: 229,943,840


In [20]:
# run the enformer. one embedding per seq

all_embeddings = []

with torch.no_grad():
    for batch in test_loader:
        batch = batch.float().to(device) # batch is B, L, 4/5
        out = feature_extractor(batch)

        out = out.squeeze(-1)

        all_embeddings.append(out.cpu().numpy())

embeddings = np.concatenate(all_embeddings, axis=0) #shape is N,D

print("Embeddings shape:", embeddings.shape)
print("Should match #genes:", embeddings.shape[0], "vs", len(gene_names))

Embeddings shape: (4828, 32)
Should match #genes: 4828 vs 4828


In [21]:
# save embeddings
np.save(EMB_OUT_PATH, embeddings)
np.save(GENE_OUT_PATH, np.array(gene_names))

print("Saved embeddings to:", EMB_OUT_PATH)
print("Saved gene names to:", GENE_OUT_PATH)

Saved embeddings to: data_processed/enformer_promoter_embeddings_flank1000.npy
Saved gene names to: data_processed/enformer_promoter_genes_flank1000.npy


In [22]:
# how to put these into the tf-gene pair df. don't save into that file cuz it would be massive
# here we just build a dictionary mapping gene to embedding and fetch on the go
pairs = pd.read_csv(pairs_path)

emb = np.load(EMB_OUT_PATH)
genes_emb = np.load(GENE_OUT_PATH)

gene_to_idx = {g: i for i, g in enumerate(genes_emb)} # gene -> row index map

# helper to get embedding for a gene
def get_emb(gene):
    idx = gene_to_idx.get(gene)
    if idx is None:
        return None
    return emb[idx]

# test on a subset of 5
pairs_sample = pairs.sample(5, random_state=0).copy()
pairs_sample["emb"] = pairs_sample["gene"].apply(get_emb)

print(pairs_sample[["TF", "gene", "expr", "emb"]])

             TF     gene      expr  \
606847      BID    RPL7A -0.050565   
695173    SALL3     SGCB  0.216774   
87447    NFATC1  SMARCE1  0.050626   
58264   NEUROD1   MRPS24 -0.227060   
441685     PAX5   ZNF428  0.017421   

                                                      emb  
606847  [-0.08855927, -0.019950302, 0.15034984, -0.002...  
695173  [-0.005865831, 0.099577, -0.008437029, 0.01797...  
87447   [0.059702277, 0.06322225, 0.0258956, -0.084927...  
58264   [-0.2606159, 0.023753757, 0.14891303, -0.12770...  
441685  [0.060610272, -0.22239193, -0.25339863, 0.0353...  
