In [20]:
from Bio import SeqIO
import pandas as pd
import pyBigWig
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch


In [11]:
genome_file = "GRCh38.primary_assembly.genome.fa"
genome = SeqIO.to_dict(SeqIO.parse(genome_file, "fasta"))
print("Chromosomes loaded:", list(genome.keys())[:5])

Chromosomes loaded: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5']


In [12]:

gtf_file = "gencode.v49.primary_assembly.basic.annotation.gtf"
genes = []

with open(gtf_file) as f:
    for line in f:
        if line.startswith("#"): continue
        fields = line.strip().split("\t")
        if fields[2] != "gene": continue
        chrom = fields[0]
        start = int(fields[3])
        end = int(fields[4])
        strand = fields[6]
        # Parse gene_id from attributes
        attr = fields[8]
        gene_id = attr.split('gene_id "')[1].split('"')[0]
        genes.append([gene_id, chrom, start, end, strand])

genes_df = pd.DataFrame(genes, columns=["gene_id","chrom","start","end","strand"])
print("Number of genes:", len(genes_df))
genes_df.head()

Number of genes: 78899


Unnamed: 0,gene_id,chrom,start,end,strand
0,ENSG00000290825.2,chr1,11121,24894,+
1,ENSG00000223972.6,chr1,12010,13670,+
2,ENSG00000310526.1,chr1,14356,30744,-
3,ENSG00000227232.6,chr1,14696,24886,-
4,ENSG00000278267.1,chr1,17369,17436,-


In [13]:
plus_bw = pyBigWig.open("K562_plus_unique.bigWig")
minus_bw = pyBigWig.open("K562_minus_unique.bigWig")

In [14]:
print(list(genome.keys())[:5])
print(list(plus_bw.chroms().keys())[:5])

['chr1', 'chr2', 'chr3', 'chr4', 'chr5']
['chr1', 'chr2', 'chr3', 'chr4', 'chr5']


In [15]:

def one_hot(seq):
    mapping = np.zeros(128, dtype=np.int8)  # ASCII table
    mapping[ord('A')] = 0
    mapping[ord('C')] = 1
    mapping[ord('G')] = 2
    mapping[ord('T')] = 3
    
    seq = seq.upper()
    arr = np.zeros((len(seq),4), dtype=np.float32)
    # only convert A/C/G/T
    valid_idx = np.array([ord(b) for b in seq])
    arr[np.arange(len(seq)), mapping[valid_idx]] = 1
    return arr

In [None]:
window = 5000
seq_len = 2 * window
batch_size = 10000  # adjust if needed

# Keep only genes whose chromosomes exist in the bigWigs
primary_genes_df = genes_df[genes_df['chrom'].isin(plus_bw.chroms().keys())].reset_index(drop=True)
genes = primary_genes_df.to_dict('records')

# Integer encoding function
def int_encode(seq):
    """Encode DNA sequence as integers: A=0, C=1, G=2, T=3, N/other=4"""
    seq = seq.upper()
    mapping = {'A':0, 'C':1, 'G':2, 'T':3}
    return np.array([mapping.get(b, 4) for b in seq], dtype=np.int64)

for i in range(0, len(genes), batch_size):
    batch_genes = genes[i:i+batch_size]
    X_batch = []
    y_batch = []
    
    for row in batch_genes:
        chrom = row['chrom']
        strand = row['strand']
        tss = row['start'] if strand == '+' else row['end']
        
        chrom_len = plus_bw.chroms()[chrom]
        seq_start = max(0, tss - window)
        seq_end = min(chrom_len, tss + window)
        
        seq = genome[chrom].seq[seq_start:seq_end]
        if strand == '-':
            seq = seq.reverse_complement()
        
        # Integer encoding with padding
        seq_array = int_encode(str(seq))
        if len(seq_array) < seq_len:
            seq_array = np.pad(seq_array, (0, seq_len - len(seq_array)), mode='constant', constant_values=4)
        X_batch.append(seq_array)
        
        # CAGE signal
        plus_signal = np.nan_to_num(plus_bw.values(chrom, seq_start, seq_end, numpy=True))
        minus_signal = np.nan_to_num(minus_bw.values(chrom, seq_start, seq_end, numpy=True))
        total_signal = plus_signal + minus_signal
        if len(total_signal) < seq_len:
            total_signal = np.pad(total_signal, (0, seq_len - len(total_signal)), mode='constant')
        y_batch.append(total_signal.sum())
    
    # Convert to numpy arrays
    X_batch = np.array(X_batch, dtype=np.int64)  # integers
    y_batch = np.array(y_batch, dtype=np.float32)  # labels

    # Save batch compressed to save space
    np.savez_compressed(f"data_batches/data_batch_{i//batch_size}.npz", X=X_batch, y=y_batch)
    print(f"Saved batch {i//batch_size} with {len(batch_genes)} genes")

Saved batch 0 with 10000 genes
Saved batch 1 with 10000 genes
Saved batch 2 with 10000 genes
Saved batch 3 with 10000 genes
Saved batch 4 with 10000 genes


In [None]:
num_batches = len(primary_genes_df) // batch_size
class GeneDataset(Dataset):
    def __init__(self, batch_files):
        self.files = batch_files
        self.index_map = []
        # build index mapping to know which batch/file contains which sample
        for b, f in enumerate(batch_files):
            data = np.load(f)
            for i in range(len(data['y'])):
                self.index_map.append((b, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        batch_idx, i = self.index_map[idx]
        data = np.load(self.files[batch_idx])
        X = data['X'][i]
        y = np.log1p(data['y'][i])  # log1p here
        return torch.tensor(X), torch.tensor(y)

dataset = GeneDataset([f"data_batches/data_batch_{i}.npz" for i in range(num_batches)])
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


In [24]:
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

X_batch, y_batch = next(iter(train_loader))
print("X batch shape:", X_batch.shape)  # [batch_size, seq_len, 4]
print("y batch shape:", y_batch.shape)  # [batch_size]

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/homebrew/Caskroom/miniconda/base/envs/py10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/homebrew/Caskroom/miniconda/base/envs/py10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GeneDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 