In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
# ================================================================
# Colab GPU-ready, fast nBins version (paste-and-run)
# - Uses GPU for the model
# - Uses bigWig nBins stats for FAST per-bin ATAC/CTCF sampling
# - Has MAX_SAMPLES guard for quick smoke tests
# - Saves per-bin DNA & Epi embeddings to Parquet shards
# ================================================================

# ---------- Colab installs ----------
!pip -q install pyBigWig pyarrow scikit-image

# ---------- CONFIG ----------
from pathlib import Path

# Folder layout expected:
#   {CELLTYPE_ROOT}/genomic_features/atac.bw
#   {CELLTYPE_ROOT}/genomic_features/ctcf_log2fc.bw
#   {CELLTYPE_ROOT}/hic_matrix/chr1.npz, chr2.npz, ...
#   {CELLTYPE_ROOT}/../dna_sequence/chr1.fa.gz, chr2.fa.gz, ...
#   {CELLTYPE_ROOT}/../centrotelo.bed   (optional)
CELLTYPE_ROOT = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90"      # <-- EDIT
ASSEMBLY      = "hg38"
CELLTYPE      = "IMR90"

ATAC_BW       = f"{CELLTYPE_ROOT}/genomic_features/atac.bw"
CTCF_BW       = f"{CELLTYPE_ROOT}/genomic_features/ctcf_log2fc.bw"
DNA_DIR       = f"{CELLTYPE_ROOT}/../dna_sequence"
HIC_DIR       = f"{CELLTYPE_ROOT}/hic_matrix"
CENTROTELO_BED= f"{CELLTYPE_ROOT}/../centrotelo.bed"   # set to None if you don't have it

OUT_DIR       = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings_out"              # Parquet outputs

# Window/bin params (aligned to your dataset)
RES_BP        = 10000               # 10 kb
WINDOW_BP     = 2_097_152            # ~2.097 Mb (2^21)
L_BINS        = int(WINDOW_BP // RES_BP)  # 209
SAMPLE_BINS   = 500                  # 5 Mb proposed frame
STRIDE_BINS   = 61                   # 500 kb step
IMAGE_SCALE   = 256                  # Hi-C resized for fixed-size models (kept for completeness)

# Limit how many samples to process (set None to run all)
MAX_SAMPLES   = 300                  # <-- change to None for full run

# Choose chromosomes (None -> default split; for quick test use ["chr10"])
CHROMS        = ["chr10"]            # <-- edit or set to None

# Output sharding
ROWS_PER_SHARD= 10_000

# Model dimensions
DNA_EMB_DIM   = 128
EPI_EMB_DIM   = 128

# Optional checkpoint (for this simple model). Leave None for random weights.
CKPT_PATH     = None

# ================================================================
# Imports & GPU setup
# ================================================================
import os, io, gzip, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyBigWig as pbw
from skimage.transform import resize

from typing import List, Tuple

# Ensure GPU
assert torch.cuda.is_available(), "No GPU detected. In Colab: Runtime > Change runtime type > Hardware accelerator: GPU"
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
print("Using device:", device)

# ================================================================
# Feature I/O helpers (self-contained)
# ================================================================
class Feature():
    def __init__(self, **kwargs):
        self.load(**kwargs)
    def load(self, **kwargs): raise NotImplementedError
    def get(self, *args, **kwargs): raise NotImplementedError
    def __len__(self): raise NotImplementedError

class HiCFeature(Feature):
    def load(self, path=None):
        self.hic = self._load_npz(path)
    def get(self, start, window=WINDOW_BP, res=RES_BP):
        start_bin = int(start / res)
        range_bin = int(window / res)
        end_bin   = start_bin + range_bin
        return self._diag_to_mat(self.hic, start_bin, end_bin)
    def _load_npz(self, path):
        print(f"Reading Hi-C: {path}")
        return dict(np.load(path))
    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for diag_i in range(square_len):
            diag_load[str(diag_i)]  = ori_load[str(diag_i)][start : start + square_len - diag_i]
            diag_load[str(-diag_i)] = ori_load[str(-diag_i)][start : start + square_len - diag_i]
        diag_region = []
        for diag_i in range(square_len):
            row = []
            for line_i in range(-diag_i, -diag_i + square_len):
                if line_i < 0:
                    row.append(diag_load[str(line_i)][line_i + diag_i])
                else:
                    row.append(diag_load[str(line_i)][diag_i])
            diag_region.append(row)
        return np.array(diag_region, dtype=np.float32).reshape(square_len, square_len)
    def __len__(self):
        return len(self.hic['0'])

class GenomicFeature(Feature):
    # We will not use per-base values here; nBins is used downstream for speed.
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f"Feature path: {path}\nNormalization status: {norm}")
    def load(self, **kwargs): pass
    def get(self, *args, **kwargs): raise NotImplementedError("Use bigWig nBins path in main loop")
    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)

class SequenceFeature(Feature):
    def load(self, path=None):
        self.seq = self._read_seq(path)
    def get(self, start, end):
        seq = self._slice(self.seq, start, end)
        return self._onehot(seq)
    def __len__(self):
        return len(self.seq)
    def _read_seq(self, dna_path):
        print(f"Reading sequence: {dna_path}")
        with gzip.open(dna_path, "r") as f:
            raw = f.read().decode("utf-8")
        raw = raw[raw.find('\n')+1:].replace('\n','').lower()
        return raw
    def _slice(self, seq, start, end):
        return seq[start:end]
    def _onehot(self, seq):
        # [a,t,c,g,n] ordering
        enc = {'a':0,'t':1,'c':2,'g':3,'n':4}
        idx = np.fromiter((enc.get(ch,4) for ch in seq), dtype=np.int32, count=len(seq))
        out = np.zeros((len(seq), 5), dtype=np.float32)
        out[np.arange(len(seq)), idx] = 1.0
        return out

# ================================================================
# Datasets
# ================================================================
class ChromosomeDataset(torch.utils.data.Dataset):
    """
    Provides (sequence, features[], Hi-C) windows across a chromosome.
    Features[] are placeholders here; we'll fetch ATAC/CTCF via nBins in main.
    """
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.use_aug      = use_aug
        self.res          = RES_BP
        self.bins_float   = WINDOW_BP / self.res
        self.image_scale  = IMAGE_SCALE
        self.sample_bins  = SAMPLE_BINS
        self.stride       = STRIDE_BINS
        self.chr_name     = chr_name

        print(f"Loading chromosome {chr_name}…")
        dna_path = f"{celltype_root}/../dna_sequence/{chr_name}.fa.gz"
        hic_path = f"{celltype_root}/hic_matrix/{chr_name}.npz"
        self.seq = SequenceFeature(path=dna_path)
        self.mat = HiCFeature(path=hic_path)
        self.genomic_features = feature_list  # not used for values here; just for length checks

        self.omit_regions = omit_regions if omit_regions is not None else np.zeros((0,2), dtype=int)
        self._check_lengths()
        self.all_intervals = self._get_active_intervals()
        self.intervals     = self._filter(self.all_intervals, self.omit_regions)

    def __len__(self): return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.bins_float * self.res)

        if self.use_aug:
            start, end = self._shift_aug(target_size, start, end)
        else:
            start, end = self._shift_fix(target_size, start, end)

        # DNA one-hot [L_bp,5]
        seq = self.seq.get(start, end)

        # Hi-C region (optional for training workflows)
        mat = self.mat.get(start, window=WINDOW_BP, res=self.res)
        mat = resize(mat, (self.image_scale, self.image_scale), anti_aliasing=True).astype(np.float32)
        mat = np.log1p(mat)

        # Optional augmentations on seq & mat (not enabling for val/test speed)
        if self.use_aug:
            seq = self._gaussian(seq, 0.1)
            # reverse/complement + flip mat w.p. 0.5
            if np.random.rand() < 0.5:
                seq = np.flip(seq, 0).copy()
                if np.random.rand() < 0.5:
                    a,t,c,g,n = seq[:,0:1], seq[:,1:2], seq[:,2:3], seq[:,3:4], seq[:,4:5]
                    seq = np.concatenate([t,a,g,c,n], axis=1)
                mat = np.flip(mat, (0,1)).copy()

        # features[] placeholder kept for API compatibility
        features_placeholder = []
        return seq, features_placeholder, mat, start, end

    # helpers
    def _gaussian(self, arr, std=0.1):
        return arr + np.random.randn(*arr.shape).astype(np.float32) * std
    def _get_active_intervals(self):
        chr_bins = len(self.seq) / self.res
        data_size = int((chr_bins - self.sample_bins) / self.stride)
        starts = (np.arange(data_size).reshape(-1,1) * self.stride)
        intervals_bin = np.concatenate([starts, starts + self.sample_bins], axis=1)
        return (intervals_bin * self.res).astype(int)
    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions)==0:
            return intervals.tolist()
        valid = []
        for start, end in intervals:
            start_cond = start <= omit_regions[:,1]
            end_cond   = omit_regions[:,0] <= end
            if int(np.sum(start_cond * end_cond)) == 0:
                valid.append([int(start), int(end)])
        return valid
    def _shift_aug(self, target_size, start, end):
        max_off = max(1, (end - start - target_size))
        offset = random.randrange(max_off)
        return start + offset, start + offset + target_size
    def _shift_fix(self, target_size, start, end):
        return start, start + target_size
    def _check_lengths(self):
        if len(self.genomic_features) > 0:
            f0_len = self.genomic_features[0].length(self.chr_name)
            assert len(self.seq) == f0_len, f"Sequence {len(self.seq)} vs first feature {f0_len} mismatch."
        dna_bins = len(self.seq) / self.res
        hic_bins = len(self.mat)
        assert abs(dna_bins - hic_bins) < 2, f"DNA bins {dna_bins} vs Hi-C bins {hic_bins} mismatch."

class GenomeDataset(torch.utils.data.Dataset):
    """
    Default split:
      train -> autosomes except chr10/chr15 and excluding chrX
      val   -> chr10
      test  -> chr15
    """
    def __init__(self, celltype_root, assembly, feat_dicts, mode="val", use_aug=False):
        self.data_root = celltype_root
        self.use_aug   = use_aug if mode=="train" else False

        self.chr_names = self._enumerate_chrs(assembly)
        if mode == "train":
            for drop in ["chr10","chr15","chrX"]:
                if drop in self.chr_names:
                    self.chr_names.remove(drop)
        elif mode == "val":
            self.chr_names = ["chr10"]
        elif mode == "test":
            self.chr_names = ["chr15"]
        else:
            raise ValueError(f"Unknown mode: {mode}")

        # Override with manual CHROMS if provided
        global CHROMS
        if CHROMS is not None:
            self.chr_names = CHROMS

        # Feature objects (for length checks; values fetched via nBins later)
        self.genomic_features = []
        for d in feat_dicts.values():
            self.genomic_features.append(GenomicFeature(f"{celltype_root}/genomic_features/{Path(d['file_name']).name}", d['norm']))

        # Omit regions
        if CENTROTELO_BED and os.path.exists(CENTROTELO_BED):
            omit_dict = self._proc_bed(CENTROTELO_BED)
        else:
            print("No centrotelo bed provided; proceeding without region masking.")
            omit_dict = {name: np.zeros((0,2), dtype=int) for name in self.chr_names}

        print("Loading chromosome datasets…")
        self.chr_data, self.lengths = {}, []
        for chr_name in self.chr_names:
            ds = ChromosomeDataset(self.data_root, chr_name, omit_dict.get(chr_name, None),
                                   self.genomic_features, use_aug=self.use_aug)
            self.chr_data[chr_name] = ds
            self.lengths.append(len(ds))
        print("Chromosome datasets loaded.")
        self.ranges = self._ranges(self.lengths)

    def __len__(self): return sum(self.lengths)

    def __getitem__(self, idx):
        chr_name, local_idx = self._locate(idx)
        seq, features_placeholder, mat, start, end = self.chr_data[chr_name][local_idx]
        return seq, features_placeholder, mat, start, end, chr_name, local_idx

    # helpers
    def _enumerate_chrs(self, assembly):
        print(f"Using assembly: {assembly}")
        if assembly in ["hg38","hg19"]:
            nums = list(range(1,23))
        elif assembly in ["mm10","mm9"]:
            nums = list(range(1,20))
        else:
            raise ValueError(f"Assembly {assembly} unknown.")
        return [f"chr{n}" for n in nums] + ["chrX"]
    def _ranges(self, lengths):
        cur, out = 0, []
        for L in lengths:
            out.append([cur, cur + L - 1])
            cur += L
        return out
    def _locate(self, idx):
        for i, (s,e) in enumerate(self.ranges):
            if s <= idx <= e:
                return self.chr_names[i], idx - s
        raise IndexError(idx)
    def _proc_bed(self, bed_path):
        df = pd.read_csv(bed_path, sep="\t", names=["chr","start","end"])
        return {k: v[["start","end"]].to_numpy(dtype=int) for k,v in df.groupby("chr")}

# ================================================================
# Model (dual encoders) + hooks
# ================================================================
class SeqEncoder(nn.Module):
    def __init__(self, in_channels=4, emb_dim=DNA_EMB_DIM):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=15, padding=7)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.proj  = nn.Conv1d(128, emb_dim, kernel_size=1)
    def forward(self, x_base, res_bp=RES_BP, l_bins=L_BINS):
        # x_base: [B,4,L_bp] -> trim to multiple of bins
        L_bp = x_base.shape[-1]
        eff  = l_bins * res_bp
        if L_bp > eff:
            x_base = x_base[..., :eff]
        x = F.relu(self.conv1(x_base))
        x = F.relu(self.conv2(x))
        x = self.proj(x)                    # [B,D,eff]
        B,D,L = x.shape
        x = x.view(B, D, l_bins, res_bp).mean(dim=-1)  # [B,D,L_bins]
        return x.transpose(1,2).contiguous()           # [B,L_bins,D]

class EpiEncoder(nn.Module):
    def __init__(self, in_dim=2, emb_dim=EPI_EMB_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, emb_dim)
        )
    def forward(self, epi_bins):
        B,L,F = epi_bins.shape
        return self.net(epi_bins.view(B*L, F)).view(B, L, -1)

class FusionHead(nn.Module):
    def __init__(self, emb_dim=min(DNA_EMB_DIM, EPI_EMB_DIM)):
        super().__init__()
        self.fuse = nn.Linear(emb_dim*2, emb_dim)
        self.out  = nn.Linear(emb_dim, 1)
    def forward(self, dna_emb, epi_emb):
        D = min(dna_emb.shape[-1], epi_emb.shape[-1])
        if dna_emb.shape[-1] != D: dna_emb = dna_emb[..., :D]
        if epi_emb.shape[-1] != D: epi_emb = epi_emb[..., :D]
        fused = torch.cat([dna_emb, epi_emb], dim=-1)
        fused = F.relu(self.fuse(fused))
        return fused, self.out(fused)

class ConvTransModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq_encoder = SeqEncoder()
        self.epi_encoder = EpiEncoder()
        self.fusion_head = FusionHead()
    def forward(self, x_base4, x_epi, res_bp=RES_BP, l_bins=L_BINS):
        dna_emb = self.seq_encoder(x_base4, res_bp, l_bins)  # [B,L,D1]
        epi_emb = self.epi_encoder(x_epi)                    # [B,L,D2]
        fused, _ = self.fusion_head(dna_emb, epi_emb)
        return {"dna_emb": dna_emb, "epi_emb": epi_emb, "fused": fused}

def register_hooks(model: nn.Module):
    caches = {}
    def mk(name):
        def _hook(_, __, out):
            caches[name] = out.detach()
        return _hook
    model.seq_encoder.register_forward_hook(mk("dna_emb"))
    model.epi_encoder.register_forward_hook(mk("epi_emb"))
    return caches

# Utility: convert [L,5] (a,t,c,g,n) to [L,4] (A,C,G,T)
def onehot5_to_base4(seq_onehot5: np.ndarray) -> np.ndarray:
    a = seq_onehot5[:,0:1]
    t = seq_onehot5[:,1:2]
    c = seq_onehot5[:,2:3]
    g = seq_onehot5[:,3:4]
    return np.concatenate([a,c,g,t], axis=1).astype(np.float32)  # [L,4]

# ================================================================
# Main: build dataset, bigWig handles, run model on GPU, save shards
# ================================================================
def main():
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
    assert os.path.isdir(DNA_DIR), f"DNA_DIR not found: {DNA_DIR}"
    assert os.path.isdir(HIC_DIR), f"HIC_DIR not found: {HIC_DIR}"

    # Ensure bigWigs are under the expected folder (copy if needed)
    Path(f"{CELLTYPE_ROOT}/genomic_features").mkdir(parents=True, exist_ok=True)
    for src in [ATAC_BW, CTCF_BW]:
        expected = f"{CELLTYPE_ROOT}/genomic_features/{Path(src).name}"
        if os.path.abspath(src) != os.path.abspath(expected):
            if not os.path.exists(expected):
                import shutil; shutil.copy2(src, expected)

    # Feature dicts just to construct GenomicFeature for length checks
    feat_dicts = {
        "ctcf_log2fc": {"file_name": Path(CTCF_BW).name, "norm": None},
        "atac":        {"file_name": Path(ATAC_BW).name,  "norm": "log"},
    }

    # Use 'val' (chr10) by default; override by CHROMS above
    mode = "val" if CHROMS is None else "train"
    ds = GenomeDataset(CELLTYPE_ROOT, ASSEMBLY, feat_dicts, mode=mode, use_aug=False)

    # Open bigWigs ONCE, and use nBins for per-bin means
    bw_atac = pbw.open(ATAC_BW)
    bw_ctcf = pbw.open(CTCF_BW)

    # Build model on GPU
    model = ConvTransModel().to(device).eval()
    if CKPT_PATH and os.path.exists(CKPT_PATH):
        sd = torch.load(CKPT_PATH, map_location="cpu")
        state = sd.get("state_dict", sd)
        model.load_state_dict(state, strict=False)
        print("Loaded checkpoint:", CKPT_PATH)
    else:
        print("No checkpoint provided; embeddings will be untrained/random.")

    caches = register_hooks(model)

    shard_rows = []
    shard_id = 0
    def flush_shard():
        nonlocal shard_rows, shard_id
        if not shard_rows: return
        df = pd.DataFrame(shard_rows)
        out_path = os.path.join(OUT_DIR, f"embeddings_{shard_id:05d}.parquet")
        df.to_parquet(out_path, index=False)
        print(f"Wrote {out_path} ({len(df)} rows)")
        shard_id += 1
        shard_rows = []

    # Iterate samples (windows)
    total = len(ds)
    print(f"Total windows in dataset: {total}")
    max_n = total if (MAX_SAMPLES is None) else min(MAX_SAMPLES, total)

    for idx in range(max_n):
        seq5, _features_placeholder, _mat256, win_start, win_end, chr_name, _local = ds[idx]

        # ----- Build inputs -----
        # DNA to base4 [1,4,L_bp] on GPU
        base4 = onehot5_to_base4(seq5)                              # [L_bp,4]
        x_base = torch.from_numpy(base4.T).unsqueeze(0).to(device)  # [1,4,L_bp]

        # Epi via nBins: directly get per-bin means from bigWig (FAST)
        atac_bins = bw_atac.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
        ctcf_bins = bw_ctcf.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
        atac_bins = np.log1p(np.nan_to_num(np.array(atac_bins, dtype=np.float32), 0.0))
        ctcf_bins = np.nan_to_num(np.array(ctcf_bins, dtype=np.float32), 0.0)
        epi_bins  = np.stack([atac_bins, ctcf_bins], axis=-1)       # [L_BINS,2]
        x_epi = torch.from_numpy(epi_bins).unsqueeze(0).to(device)  # [1,L_BINS,2]

        # ----- Forward on GPU (hooks will capture) -----
        with torch.no_grad():
            _ = model(x_base, x_epi, RES_BP, L_BINS)

        # Pull from caches (CPU numpy)
        dna_emb = caches["dna_emb"].detach().cpu().numpy()[0]  # [L_BINS, D1]
        epi_emb = caches["epi_emb"].detach().cpu().numpy()[0]  # [L_BINS, D2]

        # Save rows
        for l in range(L_BINS):
            bin_start = int(win_start + l*RES_BP)
            bin_end   = bin_start + RES_BP
            shard_rows.append({
                "assembly": ASSEMBLY,
                "celltype": CELLTYPE,
                "chr": chr_name,
                "win_start": int(win_start),
                "win_end": int(win_end),
                "bin_idx": int(l),
                "bin_start": bin_start,
                "bin_end": bin_end,
                "dna_emb": dna_emb[l].astype(np.float32).tolist(),
                "epi_emb": epi_emb[l].astype(np.float32).tolist()
            })

        if len(shard_rows) >= ROWS_PER_SHARD:
            flush_shard()

        if (idx+1) % 50 == 0 or (idx+1) == max_n:
            print(f"Processed {idx+1}/{max_n} windows")

    flush_shard()

    # Close bigWigs
    bw_atac.close()
    bw_ctcf.close()
    torch.cuda.empty_cache()
    print("Done.")

# Run
main()


Using device: cuda
Using assembly: hg38
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw
Normalization status: None
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw
Normalization status: log
Loading chromosome datasets…
Loading chromosome chr10…
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr10.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr10.npz
Chromosome datasets loaded.
No checkpoint provided; embeddings will be untrained/random.
Total windows in dataset: 198


KeyboardInterrupt: 

In [None]:
# ================================================================
# Colab GPU-ready, fast nBins version (C.Origami-aligned)
# - 8,192 bp bins (L_BINS = 256) over 2,097,152 bp windows
# - Uses GPU for the model
# - Robust bigWig nBins stats (handles None, chr mismatches)
# - Has MAX_SAMPLES guard
# - Saves per-bin DNA & Epi embeddings to Parquet shards
# ================================================================

# ---------- Colab installs ----------
!pip -q install pyBigWig pyarrow

# ---------- CONFIG ----------
from pathlib import Path

# Folder layout expected:
#   {CELLTYPE_ROOT}/genomic_features/atac.bw
#   {CELLTYPE_ROOT}/genomic_features/ctcf_log2fc.bw
#   {CELLTYPE_ROOT}/hic_matrix/chr1.npz, chr2.npz, ...
#   {CELLTYPE_ROOT}/../dna_sequence/chr1.fa.gz, chr2.fa.gz, ...
#   {CELLTYPE_ROOT}/../centrotelo.bed   (optional)
CELLTYPE_ROOT = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90"      # <-- EDIT
ASSEMBLY      = "hg38"
CELLTYPE      = "IMR90"

ATAC_BW       = f"{CELLTYPE_ROOT}/genomic_features/atac.bw"
CTCF_BW       = f"{CELLTYPE_ROOT}/genomic_features/ctcf_log2fc.bw"
DNA_DIR       = f"{CELLTYPE_ROOT}/../dna_sequence"
HIC_DIR       = f"{CELLTYPE_ROOT}/hic_matrix"
CENTROTELO_BED= f"{CELLTYPE_ROOT}/../centrotelo.bed"   # set to None if you don't have it

OUT_DIR       = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out"  # Parquet outputs

# Window/bin params (C.Origami-aligned)
RES_BP        = 10000                 # 8,192 bp per bin
WINDOW_BP     = 2_097_152            # 2^21 bp window
L_BINS        = int(WINDOW_BP // RES_BP)  # 256 bins
SAMPLE_BINS   = 500                  # 5 Mb proposed frame
STRIDE_BINS   = 61                   # ~500 kb step (≈ 61 * 8,192)

# Limit how many samples to process (set None to run all)
MAX_SAMPLES   = 300                  # <-- change to None for full run

# Choose chromosomes (None -> default split; for quick test use ["chr10"])
CHROMS        = ["chr10"]            # <-- edit or set to None

# Output sharding
ROWS_PER_SHARD= 10_000

# Model dimensions
DNA_EMB_DIM   = 128
EPI_EMB_DIM   = 128

# Optional checkpoint (for this simple model). Leave None for random weights.
CKPT_PATH     = None

# ================================================================
# Imports & GPU setup
# ================================================================
import os, io, gzip, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyBigWig as pbw

from typing import List, Tuple

# Ensure GPU
assert torch.cuda.is_available(), "No GPU detected. In Colab: Runtime > Change runtime type > Hardware accelerator: GPU"
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
print("Using device:", device)

# ================================================================
# Feature I/O helpers (self-contained)
# ================================================================
class Feature():
    def __init__(self, **kwargs):
        self.load(**kwargs)
    def load(self, **kwargs): raise NotImplementedError
    def get(self, *args, **kwargs): raise NotImplementedError
    def __len__(self): raise NotImplementedError

class HiCFeature(Feature):
    def load(self, path=None):
        self.hic = self._load_npz(path)
    def get(self, start, window=WINDOW_BP, res=RES_BP):
        start_bin = int(start / res)
        range_bin = int(window / res)
        end_bin   = start_bin + range_bin
        return self._diag_to_mat(self.hic, start_bin, end_bin)
    def _load_npz(self, path):
        print(f"Reading Hi-C: {path}")
        return dict(np.load(path))
    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for diag_i in range(square_len):
            # Be tolerant if some diagonals are missing: pad with zeros
            pos_key = str(diag_i); neg_key = str(-diag_i)
            pos_arr = ori_load.get(pos_key, None)
            neg_arr = ori_load.get(neg_key, None)
            max_len = square_len - diag_i
            if pos_arr is None:
                diag_load[pos_key] = np.zeros(max_len, dtype=np.float32)
            else:
                diag_load[pos_key] = pos_arr[start : start + max_len]
            if neg_arr is None:
                diag_load[neg_key] = np.zeros(max_len, dtype=np.float32)
            else:
                diag_load[neg_key] = neg_arr[start : start + max_len]

        diag_region = []
        for diag_i in range(square_len):
            row = []
            for line_i in range(-diag_i, -diag_i + square_len):
                if line_i < 0:
                    row.append(diag_load[str(line_i)][line_i + diag_i])
                else:
                    row.append(diag_load[str(line_i)][diag_i])
            diag_region.append(row)
        return np.array(diag_region, dtype=np.float32).reshape(square_len, square_len)
    def __len__(self):
        return len(self.hic['0'])

class GenomicFeature(Feature):
    # We won't fetch per-base values here; nBins is used downstream for speed.
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f"Feature path: {path}\nNormalization status: {norm}")
    def load(self, **kwargs): pass
    def get(self, *args, **kwargs): raise NotImplementedError("Use bigWig nBins path in main loop")
    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)

class SequenceFeature(Feature):
    def load(self, path=None):
        self.seq = self._read_seq(path)
    def get(self, start, end):
        seq = self._slice(self.seq, start, end)
        return self._onehot(seq)
    def __len__(self):
        return len(self.seq)
    def _read_seq(self, dna_path):
        print(f"Reading sequence: {dna_path}")
        with gzip.open(dna_path, "r") as f:
            raw = f.read().decode("utf-8")
        raw = raw[raw.find('\n')+1:].replace('\n','').lower()
        return raw
    def _slice(self, seq, start, end):
        return seq[start:end]
    def _onehot(self, seq):
        # [a,t,c,g,n] ordering
        enc = {'a':0,'t':1,'c':2,'g':3,'n':4}
        idx = np.fromiter((enc.get(ch,4) for ch in seq), dtype=np.int32, count=len(seq))
        out = np.zeros((len(seq), 5), dtype=np.float32)
        out[np.arange(len(seq)), idx] = 1.0
        return out

# ================================================================
# Datasets
# ================================================================
class ChromosomeDataset(torch.utils.data.Dataset):
    """
    Provides (sequence, features[], Hi-C) windows across a chromosome.
    Features[] are placeholders here; we'll fetch ATAC/CTCF via nBins in main.
    """
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.use_aug      = use_aug
        self.res          = RES_BP
        self.bins_float   = WINDOW_BP / self.res   # 256.0
        self.sample_bins  = SAMPLE_BINS
        self.stride       = STRIDE_BINS
        self.chr_name     = chr_name

        print(f"Loading chromosome {chr_name}…")
        dna_path = f"{celltype_root}/../dna_sequence/{chr_name}.fa.gz"
        hic_path = f"{celltype_root}/hic_matrix/{chr_name}.npz"
        self.seq = SequenceFeature(path=dna_path)
        self.mat = HiCFeature(path=hic_path)
        self.genomic_features = feature_list  # used for length checks

        self.omit_regions = omit_regions if omit_regions is not None else np.zeros((0,2), dtype=int)
        self._check_lengths()
        self.all_intervals = self._get_active_intervals()
        self.intervals     = self._filter(self.all_intervals, self.omit_regions)

    def __len__(self): return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.bins_float * self.res)  # 256 * 8192 = 2,097,152 bp

        if self.use_aug:
            start, end = self._shift_aug(target_size, start, end)
        else:
            start, end = self._shift_fix(target_size, start, end)

        # DNA one-hot [L_bp,5]
        seq = self.seq.get(start, end)

        # Hi-C region (already 256x256 with RES_BP=8192 and L_BINS=256)
        mat = self.mat.get(start, window=WINDOW_BP, res=self.res).astype(np.float32)
        mat = np.log1p(mat)  # no resize needed

        # Optional light aug on sequence + mat
        if self.use_aug:
            seq = seq + np.random.randn(*seq.shape).astype(np.float32) * 0.1
            if np.random.rand() < 0.5:
                seq = np.flip(seq, 0).copy()
                if np.random.rand() < 0.5:
                    a,t,c,g,n = seq[:,0:1], seq[:,1:2], seq[:,2:3], seq[:,3:4], seq[:,4:5]
                    seq = np.concatenate([t,a,g,c,n], axis=1)
                mat = np.flip(mat, (0,1)).copy()

        # features[] placeholder kept for API compatibility
        features_placeholder = []
        return seq, features_placeholder, mat, start, end

    # helpers
    def _get_active_intervals(self):
        chr_bins = len(self.seq) / self.res
        data_size = int((chr_bins - self.sample_bins) / self.stride)
        data_size = max(0, data_size)
        starts = (np.arange(data_size).reshape(-1,1) * self.stride)
        intervals_bin = np.concatenate([starts, starts + self.sample_bins], axis=1) if data_size>0 else np.zeros((0,2), dtype=int)
        return (intervals_bin * self.res).astype(int)
    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions)==0:
            return intervals.tolist()
        valid = []
        for start, end in intervals:
            start_cond = start <= omit_regions[:,1]
            end_cond   = omit_regions[:,0] <= end
            if int(np.sum(start_cond * end_cond)) == 0:
                valid.append([int(start), int(end)])
        return valid
    def _shift_aug(self, target_size, start, end):
        max_off = max(1, (end - start - target_size))
        offset = random.randrange(max_off)
        return start + offset, start + offset + target_size
    def _shift_fix(self, target_size, start, end):
        return start, start + target_size
    def _check_lengths(self):
        if len(self.genomic_features) > 0:
            f0_len = self.genomic_features[0].length(self.chr_name)
            assert len(self.seq) == f0_len, f"Sequence {len(self.seq)} vs first feature {f0_len} mismatch."
        dna_bins = len(self.seq) / self.res
        hic_bins = len(self.mat)
        assert abs(dna_bins - hic_bins) < 2, f"DNA bins {dna_bins} vs Hi-C bins {hic_bins} mismatch."

class GenomeDataset(torch.utils.data.Dataset):
    """
    Default split:
      train -> autosomes except chr10/chr15 and excluding chrX
      val   -> chr10
      test  -> chr15
    """
    def __init__(self, celltype_root, assembly, feat_dicts, mode="val", use_aug=False):
        self.data_root = celltype_root
        self.use_aug   = use_aug if mode=="train" else False

        self.chr_names = self._enumerate_chrs(assembly)
        if mode == "train":
            for drop in ["chr10","chr15","chrX"]:
                if drop in self.chr_names:
                    self.chr_names.remove(drop)
        elif mode == "val":
            self.chr_names = ["chr10"]
        elif mode == "test":
            self.chr_names = ["chr15"]
        else:
            raise ValueError(f"Unknown mode: {mode}")

        # Override with manual CHROMS if provided
        global CHROMS
        if CHROMS is not None:
            self.chr_names = CHROMS

        # Feature objects (for length checks; values fetched via nBins later)
        self.genomic_features = []
        for d in feat_dicts.values():
            self.genomic_features.append(GenomicFeature(f"{celltype_root}/genomic_features/{Path(d['file_name']).name}", d['norm']))

        # Omit regions
        if CENTROTELO_BED and os.path.exists(CENTROTELO_BED):
            omit_dict = self._proc_bed(CENTROTELO_BED)
        else:
            print("No centrotelo bed provided; proceeding without region masking.")
            omit_dict = {name: np.zeros((0,2), dtype=int) for name in self.chr_names}

        print("Loading chromosome datasets…")
        self.chr_data, self.lengths = {}, []
        for chr_name in self.chr_names:
            ds = ChromosomeDataset(self.data_root, chr_name, omit_dict.get(chr_name, None),
                                   self.genomic_features, use_aug=self.use_aug)
            self.chr_data[chr_name] = ds
            self.lengths.append(len(ds))
        print("Chromosome datasets loaded.")
        self.ranges = self._ranges(self.lengths)

    def __len__(self): return sum(self.lengths)
    def __getitem__(self, idx):
        chr_name, local_idx = self._locate(idx)
        seq, features_placeholder, mat, start, end = self.chr_data[chr_name][local_idx]
        return seq, features_placeholder, mat, start, end, chr_name, local_idx

    # helpers
    def _enumerate_chrs(self, assembly):
        print(f"Using assembly: {assembly}")
        if assembly in ["hg38","hg19"]:
            nums = list(range(1,23))
        elif assembly in ["mm10","mm9"]:
            nums = list(range(1,20))
        else:
            raise ValueError(f"Assembly {assembly} unknown.")
        return [f"chr{n}" for n in nums] + ["chrX"]
    def _ranges(self, lengths):
        cur, out = 0, []
        for L in lengths:
            out.append([cur, cur + L - 1])
            cur += L
        return out
    def _locate(self, idx):
        for i, (s,e) in enumerate(self.ranges):
            if s <= idx <= e:
                return self.chr_names[i], idx - s
        raise IndexError(idx)
    def _proc_bed(self, bed_path):
        df = pd.read_csv(bed_path, sep="\t", names=["chr","start","end"])
        return {k: v[["start","end"]].to_numpy(dtype=int) for k,v in df.groupby("chr")}

# ================================================================
# Model (dual encoders) + hooks
# ================================================================
class SeqEncoder(nn.Module):
    def __init__(self, in_channels=4, emb_dim=DNA_EMB_DIM):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=15, padding=7)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.proj  = nn.Conv1d(128, emb_dim, kernel_size=1)
    def forward(self, x_base, res_bp=RES_BP, l_bins=L_BINS):
        # x_base: [B,4,L_bp] -> trim to integer number of bins
        L_bp = x_base.shape[-1]
        eff  = l_bins * res_bp
        if L_bp > eff:
            x_base = x_base[..., :eff]
        x = F.relu(self.conv1(x_base))
        x = F.relu(self.conv2(x))
        x = self.proj(x)                    # [B,D,eff]
        B,D,L = x.shape
        x = x.view(B, D, l_bins, res_bp).mean(dim=-1)  # [B,D,L_bins]
        return x.transpose(1,2).contiguous()           # [B,L_bins,D]

class EpiEncoder(nn.Module):
    def __init__(self, in_dim=2, emb_dim=EPI_EMB_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, emb_dim)
        )
    def forward(self, epi_bins):
        B,L,F = epi_bins.shape
        return self.net(epi_bins.view(B*L, F)).view(B, L, -1)

class FusionHead(nn.Module):
    def __init__(self, emb_dim=min(DNA_EMB_DIM, EPI_EMB_DIM)):
        super().__init__()
        self.fuse = nn.Linear(emb_dim*2, emb_dim)
        self.out  = nn.Linear(emb_dim, 1)
    def forward(self, dna_emb, epi_emb):
        D = min(dna_emb.shape[-1], epi_emb.shape[-1])
        if dna_emb.shape[-1] != D: dna_emb = dna_emb[..., :D]
        if epi_emb.shape[-1] != D: epi_emb = epi_emb[..., :D]
        fused = torch.cat([dna_emb, epi_emb], dim=-1)
        fused = F.relu(self.fuse(fused))
        return fused, self.out(fused)

class ConvTransModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq_encoder = SeqEncoder()
        self.epi_encoder = EpiEncoder()
        self.fusion_head = FusionHead()
    def forward(self, x_base4, x_epi, res_bp=RES_BP, l_bins=L_BINS):
        dna_emb = self.seq_encoder(x_base4, res_bp, l_bins)  # [B,L,D1]
        epi_emb = self.epi_encoder(x_epi)                    # [B,L,D2]
        fused, _ = self.fusion_head(dna_emb, epi_emb)
        return {"dna_emb": dna_emb, "epi_emb": epi_emb, "fused": fused}

def register_hooks(model: nn.Module):
    caches = {}
    def mk(name):
        def _hook(_, __, out):
            caches[name] = out.detach()
        return _hook
    model.seq_encoder.register_forward_hook(mk("dna_emb"))
    model.epi_encoder.register_forward_hook(mk("epi_emb"))
    return caches

# Utility: convert [L,5] (a,t,c,g,n) to [L,4] (A,C,G,T)
def onehot5_to_base4(seq_onehot5: np.ndarray) -> np.ndarray:
    a = seq_onehot5[:,0:1]
    t = seq_onehot5[:,1:2]
    c = seq_onehot5[:,2:3]
    g = seq_onehot5[:,3:4]
    return np.concatenate([a,c,g,t], axis=1).astype(np.float32)  # [L,4]

# ================================================================
# Main: build dataset, bigWig handles, run model on GPU, save shards
# ================================================================
def main():
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
    assert os.path.isdir(DNA_DIR), f"DNA_DIR not found: {DNA_DIR}"
    assert os.path.isdir(HIC_DIR), f"HIC_DIR not found: {HIC_DIR}"

    # Ensure bigWigs are under the expected folder (copy if needed)
    Path(f"{CELLTYPE_ROOT}/genomic_features").mkdir(parents=True, exist_ok=True)
    for src in [ATAC_BW, CTCF_BW]:
        expected = f"{CELLTYPE_ROOT}/genomic_features/{Path(src).name}"
        if os.path.abspath(src) != os.path.abspath(expected):
            if not os.path.exists(expected):
                import shutil; shutil.copy2(src, expected)

    # Feature dicts just to construct GenomicFeature for length checks
    feat_dicts = {
        "ctcf_log2fc": {"file_name": Path(CTCF_BW).name, "norm": None},
        "atac":        {"file_name": Path(ATAC_BW).name,  "norm": "log"},
    }

    # Use 'val' (chr10) by default; override by CHROMS above
    mode = "val" if CHROMS is None else "train"
    ds = GenomeDataset(CELLTYPE_ROOT, ASSEMBLY, feat_dicts, mode=mode, use_aug=False)

    # Open bigWigs ONCE, and use nBins for per-bin means
    bw_atac = pbw.open(ATAC_BW)
    bw_ctcf = pbw.open(CTCF_BW)

    # Build model on GPU
    model = ConvTransModel().to(device).eval()
    if CKPT_PATH and os.path.exists(CKPT_PATH):
        sd = torch.load(CKPT_PATH, map_location="cpu")
        state = sd.get("state_dict", sd)
        model.load_state_dict(state, strict=False)
        print("Loaded checkpoint:", CKPT_PATH)
    else:
        print("No checkpoint provided; embeddings will be untrained/random.")

    caches = register_hooks(model)

    shard_rows = []
    shard_id = 0
    def flush_shard():
        nonlocal shard_rows, shard_id
        if not shard_rows: return
        df = pd.DataFrame(shard_rows)
        out_path = os.path.join(OUT_DIR, f"embeddings2_{shard_id:05d}.parquet")
        df.to_parquet(out_path, index=False)
        print(f"Wrote {out_path} ({len(df)} rows)")
        shard_id += 1
        shard_rows = []

    # Iterate samples (windows)
    total = len(ds)
    print(f"Total windows in dataset: {total}")
    max_n = total if (MAX_SAMPLES is None) else min(MAX_SAMPLES, total)

    for idx in range(max_n):
        seq5, _features_placeholder, _mat256, win_start, win_end, chr_name, _local = ds[idx]

        # ----- Build inputs -----
        # DNA to base4 [1,4,L_bp] on GPU
        base4 = onehot5_to_base4(seq5)                              # [L_bp,4]
        x_base = torch.from_numpy(base4.T).unsqueeze(0).to(device)  # [1,4,L_bp]

        # ----- Epi via nBins: robust handling for None and chr mismatches -----
        try:
            atac_list = bw_atac.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
            ctcf_list = bw_ctcf.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
        except RuntimeError as e:
            avail_atac = list(bw_atac.chroms().keys())
            avail_ctcf = list(bw_ctcf.chroms().keys())
            raise RuntimeError(
                f"bigWig stats failed for chromosome {chr_name}. "
                f"ATAC first chroms: {avail_atac[:10]} ... "
                f"CTCF first chroms: {avail_ctcf[:10]} ... "
                f"Original error: {e}"
            )

        atac_bins = np.array([0.0 if v is None else float(v) for v in atac_list], dtype=np.float32)
        ctcf_bins = np.array([0.0 if v is None else float(v) for v in ctcf_list], dtype=np.float32)

        if atac_bins.size != L_BINS or ctcf_bins.size != L_BINS:
            raise ValueError(
                f"Expected {L_BINS} bins but got ATAC={atac_bins.size}, CTCF={ctcf_bins.size}. "
                f"win: {chr_name}:{win_start}-{win_end}"
            )

        atac_bins = np.log1p(atac_bins)  # paper-consistent
        epi_bins  = np.stack([atac_bins, ctcf_bins], axis=-1).astype(np.float32)  # [L_BINS, 2]
        x_epi     = torch.from_numpy(epi_bins).unsqueeze(0).to(device)            # [1,L_BINS,2]

        # ----- Forward on GPU (hooks will capture) -----
        with torch.no_grad():
            _ = model(x_base, x_epi, RES_BP, L_BINS)

        # Pull from caches (CPU numpy)
        dna_emb = caches["dna_emb"].detach().cpu().numpy()[0]  # [L_BINS, D1]
        epi_emb = caches["epi_emb"].detach().cpu().numpy()[0]  # [L_BINS, D2]

        # Save rows
        for l in range(L_BINS):
            bin_start = int(win_start + l*RES_BP)
            bin_end   = bin_start + RES_BP
            shard_rows.append({
                "assembly": ASSEMBLY,
                "celltype": CELLTYPE,
                "chr": chr_name,
                "win_start": int(win_start),
                "win_end": int(win_end),
                "bin_idx": int(l),
                "bin_start": bin_start,
                "bin_end": bin_end,
                "dna_emb": dna_emb[l].astype(np.float32).tolist(),
                "epi_emb": epi_emb[l].astype(np.float32).tolist()
            })

        if len(shard_rows) >= ROWS_PER_SHARD:
            flush_shard()

        if (idx+1) % 50 == 0 or (idx+1) == max_n:
            print(f"Processed {idx+1}/{max_n} windows")

    flush_shard()

    # Close bigWigs
    bw_atac.close()
    bw_ctcf.close()
    torch.cuda.empty_cache()
    print("Done.")

# Run
main()


Using device: cuda
Using assembly: hg38
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw
Normalization status: None
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw
Normalization status: log
Loading chromosome datasets…
Loading chromosome chr10…
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr10.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr10.npz
Chromosome datasets loaded.
No checkpoint provided; embeddings will be untrained/random.
Total windows in dataset: 198
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out/embeddings2_00000.parquet (10032 rows)
Processed 50/198 windows
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out/embeddings2_00001.parquet (10032 rows)
Processed 100/198 windows
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddin

In [None]:
# ================================================================
# Colab GPU-ready, nBins version (per-chrom processing, chr in filenames)
# - 10 kb bins, 2,097,152 bp windows (L_BINS=209)
# - Iterates chr1..chr22, chrX, chrY
# - Writes shards like embeddings_chr10_00003.parquet
# ================================================================

# ---------- Colab installs ----------
!pip -q install pyBigWig pyarrow scikit-image

# ---------- CONFIG ----------
from pathlib import Path

CELLTYPE_ROOT = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90"
ASSEMBLY      = "hg38"
CELLTYPE      = "IMR90"

ATAC_BW       = f"{CELLTYPE_ROOT}/genomic_features/atac.bw"
CTCF_BW       = f"{CELLTYPE_ROOT}/genomic_features/ctcf_log2fc.bw"
DNA_DIR       = f"{CELLTYPE_ROOT}/../dna_sequence"
HIC_DIR       = f"{CELLTYPE_ROOT}/hic_matrix"
CENTROTELO_BED= f"{CELLTYPE_ROOT}/../centrotelo.bed"   # or None

OUT_DIR       = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out"

# Window/bin params (10kb data)
RES_BP        = 10_000
WINDOW_BP     = 2_097_152
L_BINS        = int(WINDOW_BP // RES_BP)   # 209
SAMPLE_BINS   = 500
STRIDE_BINS   = 61
IMAGE_SCALE   = 256

MAX_WINDOWS_PER_CHR = None      # set to small int for quick smoke tests
ROWS_PER_SHARD      = 10_000

DNA_EMB_DIM   = 128
EPI_EMB_DIM   = 128
CKPT_PATH     = None

# ================================================================
# Imports & GPU setup
# ================================================================
import os, io, gzip, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyBigWig as pbw
from skimage.transform import resize

assert torch.cuda.is_available(), "Enable GPU: Runtime > Change runtime type > GPU"
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
print("Using device:", device)

# ================================================================
# Feature I/O helpers
# ================================================================
class Feature():
    def __init__(self, **kwargs): self.load(**kwargs)
    def load(self, **kwargs): raise NotImplementedError
    def get(self, *args, **kwargs): raise NotImplementedError
    def __len__(self): raise NotImplementedError

class HiCFeature(Feature):
    def load(self, path=None):
        print(f"Reading Hi-C: {path}")
        self.hic = dict(np.load(path))
    def get(self, start, window=WINDOW_BP, res=RES_BP):
        start_bin = int(start / res)
        end_bin   = start_bin + int(window / res)
        return self._diag_to_mat(self.hic, start_bin, end_bin)
    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for d in range(square_len):
            diag_load[str(d)]  = ori_load[str(d)][start : start + (square_len - d)]
            diag_load[str(-d)] = ori_load[str(-d)][start : start + (square_len - d)]
        rows = []
        for d in range(square_len):
            row = []
            for line in range(-d, -d + square_len):
                if line < 0:
                    row.append(diag_load[str(line)][line + d])
                else:
                    row.append(diag_load[str(line)][d])
            rows.append(row)
        return np.array(rows, dtype=np.float32)
    def __len__(self): return len(self.hic['0'])

class GenomicFeature(Feature):
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f"Feature path: {path}\nNormalization status: {norm}")
    def load(self, **kwargs): pass
    def get(self, *args, **kwargs): raise NotImplementedError
    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)

class SequenceFeature(Feature):
    def load(self, path=None):
        print(f"Reading sequence: {path}")
        with gzip.open(path, "r") as f:
            raw = f.read().decode("utf-8")
        raw = raw[raw.find('\n')+1:].replace('\n','').lower()
        self.seq = raw
    def get(self, start, end):
        seq = self.seq[start:end]
        # onehot [a,t,c,g,n] -> (n,5)
        enc = {'a':0,'t':1,'c':2,'g':3,'n':4}
        idx = np.fromiter((enc.get(ch,4) for ch in seq), dtype=np.int32, count=len(seq))
        out = np.zeros((len(seq), 5), dtype=np.float32)
        if len(seq) > 0:
            out[np.arange(len(seq)), idx] = 1.0
        return out
    def __len__(self): return len(self.seq)

# ================================================================
# Datasets (now we'll build per chromosome explicitly)
# ================================================================
class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, res_bp=RES_BP):
        self.res        = res_bp
        self.sample_bins= SAMPLE_BINS
        self.stride     = STRIDE_BINS
        self.chr_name   = chr_name

        dna_path = f"{celltype_root}/../dna_sequence/{chr_name}.fa.gz"
        hic_path = f"{celltype_root}/hic_matrix/{chr_name}.npz"
        self.seq = SequenceFeature(path=dna_path)
        self.mat = HiCFeature(path=hic_path)
        self.genomic_features = feature_list

        self.omit_regions = omit_regions if omit_regions is not None else np.zeros((0,2), dtype=int)
        self._check_lengths()
        self.all_intervals = self._windows()
        self.intervals     = self._mask(self.all_intervals, self.omit_regions)

    def __len__(self): return len(self.intervals)
    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        seq = self.seq.get(start, end)
        mat = self.mat.get(start, window=WINDOW_BP, res=self.res)
        mat = np.log1p(resize(mat, (IMAGE_SCALE, IMAGE_SCALE), anti_aliasing=True).astype(np.float32))
        return seq, [], mat, start, end

    def _windows(self):
        chr_bins = len(self.seq) / self.res
        n = max(0, int((chr_bins - self.sample_bins) / self.stride))
        starts = (np.arange(n).reshape(-1,1) * self.stride)
        bins = np.concatenate([starts, starts + self.sample_bins], axis=1) if n>0 else np.zeros((0,2), dtype=int)
        return (bins * self.res).astype(int)

    def _mask(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions)==0: return intervals.tolist()
        valid = []
        for s,e in intervals:
            if int(np.sum((s <= omit_regions[:,1]) & (omit_regions[:,0] <= e))) == 0:
                valid.append([int(s), int(e)])
        return valid

    def _check_lengths(self):
        if len(self.genomic_features)>0:
            f0 = self.genomic_features[0].length(self.chr_name)
            assert len(self.seq) == f0, f"Sequence {len(self.seq)} vs first feature {f0} mismatch."
        dna_bins = len(self.seq) / self.res
        hic_bins = len(self.mat)
        assert abs(dna_bins - hic_bins) < 2, f"DNA bins {dna_bins} vs Hi-C bins {hic_bins} mismatch."

# ================================================================
# Model + hooks
# ================================================================
class SeqEncoder(nn.Module):
    def __init__(self, in_channels=4, emb_dim=DNA_EMB_DIM):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=15, padding=7)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.proj  = nn.Conv1d(128, emb_dim, kernel_size=1)
    def forward(self, x, res_bp=RES_BP, l_bins=L_BINS):
        L_bp = x.shape[-1]
        eff  = l_bins * res_bp
        if L_bp > eff:
            x = x[..., :eff]
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.proj(x)                    # [B,D,eff]
        B,D,L = x.shape
        x = x.view(B, D, l_bins, res_bp).mean(dim=-1)   # [B,D,L_bins]
        return x.transpose(1,2).contiguous()            # [B,L_bins,D]

class EpiEncoder(nn.Module):
    def __init__(self, in_dim=2, emb_dim=EPI_EMB_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, emb_dim)
        )
    def forward(self, epi_bins):
        B,L,F = epi_bins.shape
        return self.net(epi_bins.view(B*L, F)).view(B, L, -1)

class FusionHead(nn.Module):
    def __init__(self, emb_dim=min(DNA_EMB_DIM, EPI_EMB_DIM)):
        super().__init__()
        self.fuse = nn.Linear(emb_dim*2, emb_dim)
        self.out  = nn.Linear(emb_dim, 1)
    def forward(self, dna_emb, epi_emb):
        D = min(dna_emb.shape[-1], epi_emb.shape[-1])
        if dna_emb.shape[-1] != D: dna_emb = dna_emb[..., :D]
        if epi_emb.shape[-1] != D: epi_emb = epi_emb[..., :D]
        fused = torch.cat([dna_emb, epi_emb], dim=-1)
        fused = F.relu(self.fuse(fused))
        return fused, self.out(fused)

class ConvTransModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq_encoder = SeqEncoder()
        self.epi_encoder = EpiEncoder()
        self.fusion_head = FusionHead()
    def forward(self, x_base4, x_epi, res_bp=RES_BP, l_bins=L_BINS):
        dna_emb = self.seq_encoder(x_base4, res_bp, l_bins)
        epi_emb = self.epi_encoder(x_epi)
        fused, _ = self.fusion_head(dna_emb, epi_emb)
        return {"dna_emb": dna_emb, "epi_emb": epi_emb, "fused": fused}

def register_hooks(model):
    caches = {}
    def mk(name):
        def _hook(_, __, out): caches[name] = out.detach()
        return _hook
    model.seq_encoder.register_forward_hook(mk("dna_emb"))
    model.epi_encoder.register_forward_hook(mk("epi_emb"))
    return caches

def onehot5_to_base4(seq_onehot5: np.ndarray) -> np.ndarray:
    a = seq_onehot5[:,0:1]; t = seq_onehot5[:,1:2]
    c = seq_onehot5[:,2:3]; g = seq_onehot5[:,3:4]
    return np.concatenate([a,c,g,t], axis=1).astype(np.float32)  # [L,4]

# ================================================================
# Main (per-chrom loop + chr in filenames)
# ================================================================
def main():
    Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
    assert os.path.isdir(DNA_DIR)
    assert os.path.isdir(HIC_DIR)

    # Ensure genomic_features live under CELLTYPE_ROOT
    Path(f"{CELLTYPE_ROOT}/genomic_features").mkdir(parents=True, exist_ok=True)
    for src in [ATAC_BW, CTCF_BW]:
        expected = f"{CELLTYPE_ROOT}/genomic_features/{Path(src).name}"
        if os.path.abspath(src) != os.path.abspath(expected) and not os.path.exists(expected):
            import shutil; shutil.copy2(src, expected)

    feat_dicts = {
        "ctcf_log2fc": {"file_name": Path(CTCF_BW).name, "norm": None},
        "atac":        {"file_name": Path(ATAC_BW).name,  "norm": "log"},
    }
    feature_list = [
        GenomicFeature(f"{CELLTYPE_ROOT}/genomic_features/{Path(d['file_name']).name}", d['norm'])
        for d in feat_dicts.values()
    ]

    # Omit regions (if present)
    omit_dict = {}
    if CENTROTELO_BED and os.path.exists(CENTROTELO_BED):
        df = pd.read_csv(CENTROTELO_BED, sep="\t", names=["chr","start","end"])
        for k,v in df.groupby("chr"):
            omit_dict[k] = v[["start","end"]].to_numpy(dtype=int)

    # Open bigWigs once
    bw_atac = pbw.open(ATAC_BW)
    bw_ctcf = pbw.open(CTCF_BW)

    # Model
    model = ConvTransModel().to(device).eval()
    if CKPT_PATH and os.path.exists(CKPT_PATH):
        state = torch.load(CKPT_PATH, map_location="cpu")
        model.load_state_dict(state.get("state_dict", state), strict=False)
        print("Loaded checkpoint:", CKPT_PATH)
    else:
        print("No checkpoint provided; embeddings will be untrained/random.")
    caches = register_hooks(model)

    # Chromosome order: 1..22, X, Y
    chr_list = [f"chr{i}" for i in range(1,23)] + ["chrX","chrY"]

    for chr_name in chr_list:
        print(f"\n===== Processing {chr_name} =====")
        if not os.path.exists(f"{DNA_DIR}/{chr_name}.fa.gz"):
            print(f"Skipping {chr_name}: missing {DNA_DIR}/{chr_name}.fa.gz")
            continue
        if not os.path.exists(f"{HIC_DIR}/{chr_name}.npz"):
            print(f"Skipping {chr_name}: missing {HIC_DIR}/{chr_name}.npz")
            continue

        ds = ChromosomeDataset(
            CELLTYPE_ROOT,
            chr_name,
            omit_dict.get(chr_name, None),
            feature_list,
            res_bp=RES_BP
        )
        total = len(ds)
        if total == 0:
            print(f"{chr_name}: 0 windows (likely short chromosome or masked regions).")
            continue
        max_n = total if (MAX_WINDOWS_PER_CHR is None) else min(MAX_WINDOWS_PER_CHR, total)
        print(f"{chr_name}: {max_n}/{total} windows")

        shard_rows, shard_id = [], 0
        def flush():
            nonlocal shard_rows, shard_id
            if not shard_rows: return
            df = pd.DataFrame(shard_rows)
            out_path = os.path.join(OUT_DIR, f"embeddings_{chr_name}_{shard_id:05d}.parquet")
            df.to_parquet(out_path, index=False)
            print(f"Wrote {out_path} ({len(df)} rows)")
            shard_id += 1
            shard_rows = []

        for idx in range(max_n):
            seq5, _feat_placeholder, _mat256, win_start, win_end = ds[idx]
            # DNA -> base4 -> GPU
            base4 = onehot5_to_base4(seq5)                       # [L_bp,4]
            x_base = torch.from_numpy(base4.T).unsqueeze(0).to(device)

            # Epi (nBins= L_BINS)
            atac_vals = bw_atac.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
            ctcf_vals = bw_ctcf.stats(chr_name, int(win_start), int(win_end), nBins=L_BINS, type="mean")
            atac = np.log1p(np.nan_to_num(np.array(atac_vals, dtype=np.float32), 0.0))
            ctcf = np.nan_to_num(np.array(ctcf_vals, dtype=np.float32), 0.0)
            if atac.size != L_BINS or ctcf.size != L_BINS:
                raise ValueError(f"{chr_name}:{win_start}-{win_end} expected {L_BINS} bins, got ATAC={atac.size}, CTCF={ctcf.size}")
            epi = np.stack([atac, ctcf], axis=-1).astype(np.float32)  # [L_BINS,2]
            x_epi = torch.from_numpy(epi).unsqueeze(0).to(device)

            # Forward
            with torch.no_grad():
                _ = model(x_base, x_epi, RES_BP, L_BINS)
            dna_emb = caches["dna_emb"].detach().cpu().numpy()[0]   # [L_BINS, D1]
            epi_emb = caches["epi_emb"].detach().cpu().numpy()[0]   # [L_BINS, D2]

            # Rows
            for l in range(L_BINS):
                bstart = int(win_start + l*RES_BP)
                bend   = bstart + RES_BP
                shard_rows.append({
                    "assembly": ASSEMBLY,
                    "celltype": CELLTYPE,
                    "chr": chr_name,
                    "win_start": int(win_start),
                    "win_end": int(win_end),
                    "bin_idx": int(l),
                    "bin_start": bstart,
                    "bin_end": bend,
                    "dna_emb": dna_emb[l].astype(np.float32).tolist(),
                    "epi_emb": epi_emb[l].astype(np.float32).tolist()
                })
            if len(shard_rows) >= ROWS_PER_SHARD:
                flush()
            if (idx+1) % 50 == 0 or (idx+1) == max_n:
                print(f"{chr_name}: processed {idx+1}/{max_n} windows")

        flush()  # last shard for this chromosome

    # cleanup
    print("\nAll chromosomes done.")
    torch.cuda.empty_cache()
    try:
        bw_atac.close(); bw_ctcf.close()
    except: pass

# Run
main()


Using device: cuda
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw
Normalization status: None
Feature path: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw
Normalization status: log
No checkpoint provided; embeddings will be untrained/random.

===== Processing chr1 =====
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr1.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr1.npz
chr1: 385/385 windows
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out/embeddings_chr1_00000.parquet (10032 rows)
chr1: processed 50/385 windows
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out/embeddings_chr1_00001.parquet (10032 rows)
chr1: processed 100/385 windows
Wrote /content/drive/MyDrive/ML4GEN DATA/data - IMR90/embeddings2_out/embeddings_chr1_00002.parquet (10032 rows)
chr1: processe