In [7]:
%pip install --upgrade --force-reinstall --no-cache-dir numpy==1.26.4 pandas==2.2.3

Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl.metadata (61 kB)
Collecting pandas==2.2.3
  Downloading pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (89 kB)
Collecting python-dateutil>=2.8.2 (from pandas==2.2.3)
  Downloading python_dateutil-2.9.0.post0-py2.py3-none-any.whl.metadata (8.4 kB)
Collecting pytz>=2020.1 (from pandas==2.2.3)
  Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas==2.2.3)
  Downloading tzdata-2025.3-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting six>=1.5 (from python-dateutil>=2.8.2->pandas==2.2.3)
  Downloading six-1.17.0-py2.py3-none-any.whl.metadata (1.7 kB)
Downloading numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl (13.7 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.7/13.7 MB[0m [31m1.4 MB/s[0m  [33m0:00:09[0ma [36m0:00:01[0m[36m0:00:01[0m:01[0mm
[?25hDownloading pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl (11

In [None]:
from Bio import SeqIO
import pandas as pd
import pyBigWig
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch


In [3]:
genome_file = "GRCh38.primary_assembly.genome.fa"
genome = SeqIO.to_dict(SeqIO.parse(genome_file, "fasta"))
print("Chromosomes loaded:", list(genome.keys())[:5])

Chromosomes loaded: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5']


In [4]:

gtf_file = "gencode.v49.primary_assembly.basic.annotation.gtf"
genes = []

with open(gtf_file) as f:
    for line in f:
        if line.startswith("#"): continue
        fields = line.strip().split("\t")
        if fields[2] != "gene": continue
        chrom = fields[0]
        start = int(fields[3])
        end = int(fields[4])
        strand = fields[6]
        # Parse gene_id from attributes
        attr = fields[8]
        gene_id = attr.split('gene_id "')[1].split('"')[0]
        genes.append([gene_id, chrom, start, end, strand])

genes_df = pd.DataFrame(genes, columns=["gene_id","chrom","start","end","strand"])
print("Number of genes:", len(genes_df))
genes_df.head()

Number of genes: 78899


Unnamed: 0,gene_id,chrom,start,end,strand
0,ENSG00000290825.2,chr1,11121,24894,+
1,ENSG00000223972.6,chr1,12010,13670,+
2,ENSG00000310526.1,chr1,14356,30744,-
3,ENSG00000227232.6,chr1,14696,24886,-
4,ENSG00000278267.1,chr1,17369,17436,-


In [None]:
plus_bw = pyBigWig.open("K562_plus_unique.bigWig")
minus_bw = pyBigWig.open("K562_minus_unique.bigWig")
# showing the first few lines of the bigwig files
print("Plus bigwig header:", plus_bw.header())
print("Minus bigwig header:", minus_bw.header())

In [6]:
print(list(genome.keys())[:5])
print(list(plus_bw.chroms().keys())[:5])

['chr1', 'chr2', 'chr3', 'chr4', 'chr5']
['chr1', 'chr2', 'chr3', 'chr4', 'chr5']


In [7]:

def one_hot(seq):
    mapping = np.zeros(128, dtype=np.int8)  # ASCII table
    mapping[ord('A')] = 0
    mapping[ord('C')] = 1
    mapping[ord('G')] = 2
    mapping[ord('T')] = 3
    
    seq = seq.upper()
    arr = np.zeros((len(seq),4), dtype=np.float32)
    # only convert A/C/G/T
    valid_idx = np.array([ord(b) for b in seq])
    arr[np.arange(len(seq)), mapping[valid_idx]] = 1
    return arr

In [10]:
window = 5000
seq_len = 2 * window
batch_size = 10000  # adjust if needed

# Keep only genes whose chromosomes exist in the bigWigs
primary_genes_df = genes_df[genes_df['chrom'].isin(plus_bw.chroms().keys())].reset_index(drop=True)
genes = primary_genes_df.to_dict('records')

# Integer encoding function
def int_encode(seq):
    """Encode DNA sequence as integers: A=0, C=1, G=2, T=3, N/other=4"""
    seq = seq.upper()
    mapping = {'A':0, 'C':1, 'G':2, 'T':3}
    return np.array([mapping.get(b, 4) for b in seq], dtype=np.int64)

for i in range(0, len(genes), batch_size):
    batch_genes = genes[i:i+batch_size]
    X_batch = []
    y_batch = []
    
    for row in batch_genes:
        chrom = row['chrom']
        strand = row['strand']
        tss = row['start'] if strand == '+' else row['end']
        
        chrom_len = plus_bw.chroms()[chrom]
        seq_start = max(0, tss - window)
        seq_end = min(chrom_len, tss + window)
        
        seq = genome[chrom].seq[seq_start:seq_end]
        if strand == '-':
            seq = seq.reverse_complement()
        
        # Integer encoding with padding
        seq_array = int_encode(str(seq))
        if len(seq_array) < seq_len:
            seq_array = np.pad(seq_array, (0, seq_len - len(seq_array)), mode='constant', constant_values=4)
        X_batch.append(seq_array)
        
        # CAGE signal
        plus_signal = np.nan_to_num(plus_bw.values(chrom, seq_start, seq_end, numpy=True))
        minus_signal = np.nan_to_num(minus_bw.values(chrom, seq_start, seq_end, numpy=True))
        total_signal = plus_signal + minus_signal
        if len(total_signal) < seq_len:
            total_signal = np.pad(total_signal, (0, seq_len - len(total_signal)), mode='constant')
        y_batch.append(total_signal.sum())
    
    # Convert to numpy arrays
    X_batch = np.array(X_batch, dtype=np.int64)  # integers
    y_batch = np.array(y_batch, dtype=np.float32)  # labels

    # Save batch compressed to save space
    np.savez_compressed(f"data_batches/data_batch_{i//batch_size}.npz", X=X_batch, y=y_batch)
    print(f"Saved batch {i//batch_size} with {len(batch_genes)} genes")

Saved batch 0 with 10000 genes
Saved batch 1 with 10000 genes
Saved batch 2 with 10000 genes
Saved batch 3 with 10000 genes
Saved batch 4 with 10000 genes
Saved batch 5 with 10000 genes
Saved batch 6 with 10000 genes
Saved batch 7 with 8691 genes


In [11]:
num_batches = len(primary_genes_df) // batch_size
class GeneDataset(Dataset):
    def __init__(self, batch_files):
        self.files = batch_files
        self.index_map = []
        # build index mapping to know which batch/file contains which sample
        for b, f in enumerate(batch_files):
            data = np.load(f)
            for i in range(len(data['y'])):
                self.index_map.append((b, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        batch_idx, i = self.index_map[idx]
        data = np.load(self.files[batch_idx])
        X = data['X'][i]
        y = np.log1p(data['y'][i])  # log1p here
        return torch.tensor(X), torch.tensor(y)

dataset = GeneDataset([f"data_batches/data_batch_{i}.npz" for i in range(num_batches)])
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


In [13]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

X_batch, y_batch = next(iter(train_loader))
print("X batch shape:", X_batch.shape)  # [batch_size, seq_len]
print("y batch shape:", y_batch.shape)  # [batch_size]

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionPool(nn.Module):
    def __init__(self, pool_size, channels):
        super().__init__()
        self.pool_size = pool_size
        self.weight_proj = nn.Conv1d(channels, channels, kernel_size=1)

    def forward(self, x):
        # x: (B, C, L)
        B, C, L = x.shape
        weights = self.weight_proj(x) # (B, C, L)
        
        x = x.view(B, C, L // self.pool_size, self.pool_size)
        weights = weights.view(B, C, L // self.pool_size, self.pool_size)
        
        weights = torch.softmax(weights, dim=-1)
        return (x * weights).sum(dim=-1)

class ResidualConvBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm1d(channels),
            nn.GELU(),
            nn.Conv1d(channels, channels, kernel_size=1)
        )

    def forward(self, x):
        return x + self.conv(x)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5):
        super().__init__()
        self.conv = nn.Sequential(
            nn.BatchNorm1d(in_channels),
            nn.GELU(),
            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size//2)
        )

    def forward(self, x):
        return self.conv(x)

class ConvTowerBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = ConvBlock(in_channels, out_channels, kernel_size=5)
        self.res_block = ResidualConvBlock(out_channels)
        self.pool = AttentionPool(pool_size=2, channels=out_channels)

    def forward(self, x):
        x = self.conv_block(x)
        x = self.res_block(x)
        x = self.pool(x)
        return x

class RelativePosition(nn.Module):
    def __init__(self, num_units, max_relative_position):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        nn.init.xavier_uniform_(self.embeddings_table)

    def forward(self, length, device):
        range_vec = torch.arange(length, device=device)
        distance_mat = range_vec[None, :] - range_vec[:, None]
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = distance_mat_clipped + self.max_relative_position
        final_mat = final_mat.long()
        embeddings = self.embeddings_table[final_mat]
        return embeddings

class MultiHeadAttentionRel(nn.Module):
    def __init__(self, d_model, num_heads, key_size=64):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.key_size = key_size
        self.inner_dim = key_size * num_heads
        
        self.q_proj = nn.Linear(d_model, self.inner_dim)
        self.k_proj = nn.Linear(d_model, self.inner_dim)
        self.v_proj = nn.Linear(d_model, self.inner_dim)
        self.out_proj = nn.Linear(self.inner_dim, d_model)
        
        # Relative positional encoding
        self.rel_pos = RelativePosition(key_size, max_relative_position=256)

    def forward(self, x):
        B, L, C = x.shape
        
        q = self.q_proj(x).view(B, L, self.num_heads, self.key_size).transpose(1, 2) # (B, H, L, K)
        k = self.k_proj(x).view(B, L, self.num_heads, self.key_size).transpose(1, 2)
        v = self.v_proj(x).view(B, L, self.num_heads, self.key_size).transpose(1, 2)
        
        # q @ k.T -> (B, H, L, L)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.key_size ** 0.5)
        
        # Relative positional encoding
        rel_pos_embed = self.rel_pos(L, x.device) # (L, L, K)
        
        rel_scores = torch.einsum('bhlk,lmk->bhlm', q, rel_pos_embed) / (self.key_size ** 0.5)
        
        scores = scores + rel_scores
        
        attn = torch.softmax(scores, dim=-1)
        
        out = torch.matmul(attn, v) # (B, H, L, K)
        out = out.transpose(1, 2).contiguous().view(B, L, self.inner_dim)
        
        return self.out_proj(out)

class TransformerBlock(nn.Module):
    def __init__(self, d_model=1536, num_heads=8, key_size=64, ff_hidden=6144):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttentionRel(d_model, num_heads, key_size)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_hidden),
            nn.GELU(),
            nn.Linear(ff_hidden, d_model)
        )

    def forward(self, x):
        # x: (B, L, C)
        x = x + self.mha(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class Enformer(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 1. Initial feature extraction and downsampling
        self.stem = nn.Sequential(
            nn.Conv1d(4, 768, kernel_size=15, padding=7),
            ResidualConvBlock(768),
            AttentionPool(pool_size=2, channels=768)
        )
        
        # 2. Hierarchical feature extraction + downsampling (Conv Tower)
        # 6 blocks
        tower_channels = [768, 896, 1024, 1152, 1280, 1408, 1536]
        self.conv_tower = nn.Sequential(*[
            ConvTowerBlock(tower_channels[i], tower_channels[i+1])
            for i in range(len(tower_channels)-1)
        ])
        
        # 3. Transformer Block (x11 layers)
        self.transformers = nn.Sequential(*[
            TransformerBlock(d_model=1536, num_heads=8, key_size=64, ff_hidden=6144)
            for _ in range(11)
        ])
        
        # 4. Pointwise Block
        self.pointwise = nn.Sequential(
            nn.Conv1d(1536, 3072, kernel_size=1),
            nn.Dropout(0.05),
            nn.GELU()
        )
        
        # 5. Output Heads (Human head)
        self.human_head = nn.Sequential(
            nn.Conv1d(3072, 5313, kernel_size=1),
            nn.Softplus()
        )

    def forward(self, x):
        # Input shape: (B, 4, 196608)
        
        # Stem
        x = self.stem(x) # Output: (B, 768, 98304)
        
        # Conv Tower
        x = self.conv_tower(x) # Output: (B, 1536, 1536)
        
        # Transformer expects (B, L, C)
        x = x.transpose(1, 2) # Output: (B, 1536, 1536)
        x = self.transformers(x) # Output: (B, 1536, 1536)
        x = x.transpose(1, 2) # Output: (B, 1536, 1536)
        
        # Crop (Remove edges: 1536 -> 896)
        # (1536 - 896) / 2 = 320
        x = x[:, :, 320:-320] # Output: (B, 1536, 896)
        
        # Pointwise Block
        x = self.pointwise(x) # Output: (B, 3072, 896)
        
        # Human Head
        x = self.human_head(x) # Output: (B, 5313, 896)
        
        # Final Output shape: (B, 896, 5313)
        return x.transpose(1, 2)

# Example usage:
# model = Enformer()
# x = torch.randn(1, 4, 196608)
# y = model(x)
# print("Output shape:", y.shape) # Expected: torch.Size([1, 896, 5313])

In [None]:
# To initialize weights from the pre-trained Enformer model, we can use the `enformer-pytorch` library.
# First, make sure it's installed: %pip install enformer-pytorch

from enformer_pytorch import Enformer as PretrainedEnformer

def load_pretrained_weights(custom_model):
    print("Loading pre-trained Enformer model...")
    # Load the official pre-trained weights
    pretrained_model = PretrainedEnformer.from_pretrained('EleutherAI/enformer-official-rough')
    
    print("Transferring weights to custom model...")
    
    # 1. Transfer Stem weights
    custom_model.stem[0].weight.data = pretrained_model.stem[0].weight.data.clone()
    custom_model.stem[0].bias.data = pretrained_model.stem[0].bias.data.clone()
    
    # Residual Conv Block in Stem
    custom_model.stem[1].conv[0].weight.data = pretrained_model.stem[1].fn[0].weight.data.clone()
    custom_model.stem[1].conv[0].bias.data = pretrained_model.stem[1].fn[0].bias.data.clone()
    custom_model.stem[1].conv[2].weight.data = pretrained_model.stem[1].fn[2].weight.data.clone()
    custom_model.stem[1].conv[2].bias.data = pretrained_model.stem[1].fn[2].bias.data.clone()
    
    # Attention Pool in Stem
    custom_model.stem[2].weight_proj.weight.data = pretrained_model.stem[2].pool_fn.weight.data.clone()
    custom_model.stem[2].weight_proj.bias.data = pretrained_model.stem[2].pool_fn.bias.data.clone()

    # 2. Transfer Conv Tower weights
    for i in range(6):
        # ConvBlock
        custom_model.conv_tower[i].conv_block.conv[0].weight.data = pretrained_model.conv_tower[i][0].fn[0].weight.data.clone()
        custom_model.conv_tower[i].conv_block.conv[0].bias.data = pretrained_model.conv_tower[i][0].fn[0].bias.data.clone()
        custom_model.conv_tower[i].conv_block.conv[2].weight.data = pretrained_model.conv_tower[i][0].fn[2].weight.data.clone()
        custom_model.conv_tower[i].conv_block.conv[2].bias.data = pretrained_model.conv_tower[i][0].fn[2].bias.data.clone()
        
        # ResidualConvBlock
        custom_model.conv_tower[i].res_block.conv[0].weight.data = pretrained_model.conv_tower[i][1].fn[0].weight.data.clone()
        custom_model.conv_tower[i].res_block.conv[0].bias.data = pretrained_model.conv_tower[i][1].fn[0].bias.data.clone()
        custom_model.conv_tower[i].res_block.conv[2].weight.data = pretrained_model.conv_tower[i][1].fn[2].weight.data.clone()
        custom_model.conv_tower[i].res_block.conv[2].bias.data = pretrained_model.conv_tower[i][1].fn[2].bias.data.clone()
        
        # AttentionPool
        custom_model.conv_tower[i].pool.weight_proj.weight.data = pretrained_model.conv_tower[i][2].pool_fn.weight.data.clone()
        custom_model.conv_tower[i].pool.weight_proj.bias.data = pretrained_model.conv_tower[i][2].pool_fn.bias.data.clone()

    # 3. Transfer Transformer weights
    for i in range(11):
        # LayerNorms
        custom_model.transformers[i].norm1.weight.data = pretrained_model.transformer[i][0].norm.weight.data.clone()
        custom_model.transformers[i].norm1.bias.data = pretrained_model.transformer[i][0].norm.bias.data.clone()
        custom_model.transformers[i].norm2.weight.data = pretrained_model.transformer[i][1].norm.weight.data.clone()
        custom_model.transformers[i].norm2.bias.data = pretrained_model.transformer[i][1].norm.bias.data.clone()
        
        # MHA
        custom_model.transformers[i].mha.q_proj.weight.data = pretrained_model.transformer[i][0].fn.to_q.weight.data.clone()
        custom_model.transformers[i].mha.k_proj.weight.data = pretrained_model.transformer[i][0].fn.to_k.weight.data.clone()
        custom_model.transformers[i].mha.v_proj.weight.data = pretrained_model.transformer[i][0].fn.to_v.weight.data.clone()
        custom_model.transformers[i].mha.out_proj.weight.data = pretrained_model.transformer[i][0].fn.to_out.weight.data.clone()
        
        # Relative Positional Encoding
        custom_model.transformers[i].mha.rel_pos.embeddings_table.data = pretrained_model.transformer[i][0].fn.rel_pos.rel_pos.data.clone()
        
        # FFN
        custom_model.transformers[i].ffn[0].weight.data = pretrained_model.transformer[i][1].fn.net[0].weight.data.clone()
        custom_model.transformers[i].ffn[0].bias.data = pretrained_model.transformer[i][1].fn.net[0].bias.data.clone()
        custom_model.transformers[i].ffn[2].weight.data = pretrained_model.transformer[i][1].fn.net[3].weight.data.clone()
        custom_model.transformers[i].ffn[2].bias.data = pretrained_model.transformer[i][1].fn.net[3].bias.data.clone()

    # 4. Transfer Pointwise Block weights
    custom_model.pointwise[0].weight.data = pretrained_model.crop_and_pointwise[1].weight.data.clone()
    custom_model.pointwise[0].bias.data = pretrained_model.crop_and_pointwise[1].bias.data.clone()

    # 5. Transfer Human Head weights
    # Note: The official model has a slightly different head structure, but we can transfer the final linear layer
    custom_model.human_head[0].weight.data = pretrained_model._heads['human'][0].weight.data.clone()
    custom_model.human_head[0].bias.data = pretrained_model._heads['human'][0].bias.data.clone()

    print("Weights transferred successfully!")
    return custom_model

# Example usage:
# model = Enformer()
# model = load_pretrained_weights(model)