# Enformer Feature Extraction for Genomic Sequences

Enformer pre-trained model to extract features from genomic sequences

## 1. Setup and Configuration

### Environment Setup

In [None]:
import os

# Configure environment
os.environ['WANDB_API_KEY'] = '...' # replace with your key
os.environ['CUDA_VISIBLE_DEVICES'] = '5'  # Set available GPU device

# Login to Weights & Biases for experiment tracking
!wandb login

print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")

### Import Required Libraries

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import average_precision_score, roc_auc_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

# Grelu library for genomic deep learning
from grelu.model.models import EnformerPretrainedModel
from grelu.io.fasta import read_fasta
from grelu.sequence.format import convert_input_type

import warnings
warnings.filterwarnings('ignore')

### Configuration Parameters

In [None]:
# Configuration
SAVE = True
SEED = 1182024
TARGET_LENGTH = 8192  # Fixed sequence length for padding
BATCH_SIZE = 64

# Set random seeds for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## 2. Data Loading and Preprocessing

### Load and Process Genomic Sequences

In [None]:
def pad_sequence(seq, target_len):
    """
    Pad or truncate sequence to target length.
    
    Args:
        seq: DNA sequence string
        target_len: Desired length
    
    Returns:
        Padded/truncated sequence
    """
    if len(seq) > target_len:
        return seq[:target_len]
    else:
        return seq + "N" * (target_len - len(seq))

# Load sequences from FASTA file
input_fasta = "../../genomic_sequences/gencode.v49.pc_transcripts.gene_names.fa"

sequences = read_fasta(input_fasta)
padded_sequences = [pad_sequence(seq, TARGET_LENGTH) for seq in sequences]

print(f"Processed {len(padded_sequences)} sequences padded to length {TARGET_LENGTH} bp.")

## 3. Model Initialization

### Load Pre-trained Enformer Model

In [None]:
# Convert sequences to one-hot encoding format
ohes = convert_input_type(
    inputs=padded_sequences,
    output_type="one_hot",
    genome="hg38",
    add_batch_axis=True,
)

print(f"One-hot encoded shape: {ohes[0].shape}")

### Convert to One-Hot Encoding

In [None]:
# Initialize Enformer pre-trained model
feature_extractor = EnformerPretrainedModel(
    n_tasks=32,
    device=device
)
feature_extractor = feature_extractor.to(device)

# Calculate total parameters
total_params = sum(p.numel() for p in feature_extractor.parameters())
print(f"Model loaded successfully")
print(f"Total parameters: {total_params:,}")

## 4. Feature Extraction

### Generate Embeddings with Enformer

In [None]:
# Create DataLoader for batch processing
test_loader = DataLoader(
    dataset=ohes,
    batch_size=BATCH_SIZE,
    pin_memory=True
)

# Extract features from all sequences
embeddings = []
feature_extractor.eval()

with torch.no_grad():
    for batch_idx, data in tqdm(enumerate(test_loader), total=len(test_loader), desc="Extracting features"):
        # Move data to device
        curr_data = torch.tensor(data, dtype=torch.float32).to(device)
        
        # Get embeddings from model
        batch_embeddings = feature_extractor(curr_data).squeeze().detach().cpu().numpy()
        embeddings.extend(batch_embeddings)

# Convert to numpy array
embedding = np.array(embeddings, dtype=np.float32)
embedding = np.stack(embedding, axis=0)

print(f"\nEmbedding shape: {embedding.shape}")
print(f"Total sequences processed: {len(embedding)}")

### Save Embeddings

In [None]:
if SAVE:
    output_file = "../../embeddings/embeddings_enformer_gencode.v49.pc_transcripts.npy"
    np.save(output_file, embedding)
    print(f"Embeddings saved to: {output_file}")
else:
    print("SAVE is set to False - embeddings not saved")