In [2]:
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pytorch_lightning
!pip install pyBigWig

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.6.0-py3-none-any.whl (849 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m73.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.15.2 pytorch_lightning-2.6.0 torchmetrics-1.8.2
Collecting pyBigWig
  D

In [None]:
# Run this in a separate cell BEFORE your main code
import torch
import gc

# Clear GPU cache
torch.cuda.empty_cache()
gc.collect()

# Check GPU memory
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [None]:
"""
End-to-End Hi-C Prediction Training with C.Origami Architecture
Trains DNA/Epigenomic encoders jointly with Hi-C prediction decoder
MEMORY OPTIMIZED VERSION
"""

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import pytorch_lightning.callbacks as callbacks

# ================================================================
# Model Components (C.Origami-style architecture)
# ================================================================

class ConvBlock(nn.Module):
    def __init__(self, size, stride=2, hidden_in=64, hidden=64):
        super().__init__()
        pad_len = int(size / 2)
        self.scale = nn.Sequential(
            nn.Conv1d(hidden_in, hidden, size, stride, pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        scaled = self.scale(x)
        identity = scaled
        res_out = self.res(scaled)
        return self.relu(res_out + identity)


class EncoderSplit(nn.Module):
    """Separate encoders for DNA sequence and epigenomic features"""
    def __init__(self, num_epi, output_size=256, filter_size=5, num_blocks=12):
        super().__init__()
        self.filter_size = filter_size

        # DNA sequence encoder (5 channels: A, T, C, G, N)
        self.conv_start_seq = nn.Sequential(
            nn.Conv1d(5, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        # Epigenomic features encoder
        self.conv_start_epi = nn.Sequential(
            nn.Conv1d(num_epi, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        hiddens = [32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256]
        hidden_ins = [32, 32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256]
        hiddens_half = (np.array(hiddens) / 2).astype(int)
        hidden_ins_half = (np.array(hidden_ins) / 2).astype(int)

        self.res_blocks_seq = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.res_blocks_epi = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.conv_end = nn.Conv1d(256, output_size, 1)

    def forward(self, x):
        # x: [B, 5+num_epi, L]
        seq = x[:, :5, :]
        epi = x[:, 5:, :]

        seq = self.res_blocks_seq(self.conv_start_seq(seq))
        epi = self.res_blocks_epi(self.conv_start_epi(epi))

        x = torch.cat([seq, epi], dim=1)
        return self.conv_end(x)

    def _get_res_blocks(self, n, his, hs):
        blocks = []
        for hi, h in zip(his, hs):
            blocks.append(ConvBlock(self.filter_size, hidden_in=hi, hidden=h))
        return nn.Sequential(*blocks)


class PositionalEncoding(nn.Module):
    def __init__(self, hidden, dropout=0.1, max_len=2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden, 2) * (-np.log(10000.0) / hidden))
        pe = torch.zeros(max_len, 1, hidden)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        if seq_len > self.pe.size(0):
            position = torch.arange(seq_len, device=x.device).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, x.size(2), 2, device=x.device) * (-np.log(10000.0) / x.size(2)))
            pe = torch.zeros(seq_len, 1, x.size(2), device=x.device)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            x = x + pe.squeeze(1)
        else:
            x = x + self.pe[:seq_len].squeeze(1)
        return self.dropout(x)


class TransformerModule(nn.Module):
    def __init__(self, hidden=256, layers=8):
        super().__init__()
        self.pos_encoder = PositionalEncoding(hidden, dropout=0.1, max_len=2048)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden,
            nhead=8,
            dropout=0.1,
            dim_feedforward=512,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, layers)

    def forward(self, x):
        x = self.pos_encoder(x)
        return self.transformer(x)


class ResBlockDilated(nn.Module):
    def __init__(self, size, hidden=64, dil=2):
        super().__init__()
        pad_len = dil
        self.res = nn.Sequential(
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        res_out = self.res(x)
        return self.relu(res_out + identity)


class Decoder(nn.Module):
    def __init__(self, in_channel, hidden=256, filter_size=3, num_blocks=5, output_size=256):
        super().__init__()
        self.output_size = output_size
        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channel, hidden, 3, 1, 1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
        )
        self.res_blocks = self._get_res_blocks(num_blocks, hidden)
        self.conv_end = nn.Conv2d(hidden, 1, 1)

    def forward(self, x):
        x = self.conv_start(x)
        x = self.res_blocks(x)
        x = self.conv_end(x).squeeze(1)

        if x.size(-1) != self.output_size or x.size(-2) != self.output_size:
            x = nn.functional.interpolate(
                x.unsqueeze(1),
                size=(self.output_size, self.output_size),
                mode='bilinear',
                align_corners=False
            ).squeeze(1)

        return x

    def _get_res_blocks(self, n, hidden):
        blocks = []
        for i in range(n):
            dilation = 2 ** (i + 1)
            blocks.append(ResBlockDilated(3, hidden=hidden, dil=dilation))
        return nn.Sequential(*blocks)


class ConvTransHiCModel(nn.Module):
    """Full C.Origami-style model: Encoder + Transformer + Decoder"""
    def __init__(self, num_genomic_features=2, mid_hidden=256, output_size=256):
        super().__init__()
        self.encoder = EncoderSplit(num_genomic_features, output_size=mid_hidden, num_blocks=12)
        self.transformer = TransformerModule(hidden=mid_hidden, layers=8)

        # MEMORY FIX: First downsample the sequence, THEN diagonalize
        self.pre_decoder_conv = nn.Sequential(
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
        )

        self.decoder = Decoder(mid_hidden * 2, hidden=256, output_size=output_size)

    def forward(self, x):
        # x: [B, features, length]
        x = x.float()
        x = self.encoder(x)  # [B, D, L] - L ~ 611
        x = x.transpose(1, 2)  # [B, L, D]
        x = self.transformer(x)  # [B, L, D]
        x = x.transpose(1, 2)  # [B, D, L]

        # MEMORY FIX: Downsample before diagonalizing to reduce memory
        x = self.pre_decoder_conv(x)  # [B, D, L//4] - L//4 ~ 153

        x = self._diagonalize(x)  # [B, 2D, L//4, L//4] - Much smaller!
        return self.decoder(x)  # [B, output_size, output_size]

    def _diagonalize(self, x):
        # x: [B, D, L]
        x_i = x.unsqueeze(3).repeat(1, 1, 1, x.size(2))  # [B, D, L, L]
        x_j = x.unsqueeze(2).repeat(1, 1, x.size(2), 1)  # [B, D, L, L]
        return torch.cat([x_i, x_j], dim=1)  # [B, 2D, L, L]


# ================================================================
# Dataset
# ================================================================

import gzip
import pyBigWig as pbw
from skimage.transform import resize

class SequenceFeature:
    def __init__(self, path):
        print(f'Reading sequence: {path}')
        with gzip.open(path, 'r') as f:
            seq = f.read().decode("utf-8")
        seq = seq[seq.find('\n')+1:].replace('\n', '').lower()
        self.seq = seq

    def get(self, start, end):
        seq = self.seq[start:end]
        en_dict = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
        idx = np.array([en_dict.get(ch, 4) for ch in seq], dtype=int)
        onehot = np.zeros((len(seq), 5), dtype=np.float32)
        if len(seq) > 0:
            onehot[np.arange(len(seq)), idx] = 1.0
        return onehot

    def __len__(self):
        return len(self.seq)


class HiCFeature:
    def __init__(self, path):
        print(f'Reading Hi-C: {path}')
        self.hic = dict(np.load(path))

    def get(self, start, window=2097152, res=10000):
        start_bin = int(start / res)
        end_bin = start_bin + int(window / res)
        return self._diag_to_mat(self.hic, start_bin, end_bin)

    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for d in range(square_len):
            diag_load[str(d)] = ori_load[str(d)][start:start + (square_len - d)]
            diag_load[str(-d)] = ori_load[str(-d)][start:start + (square_len - d)]

        rows = []
        for d in range(square_len):
            row = []
            for line in range(-d, -d + square_len):
                if line < 0:
                    row.append(diag_load[str(line)][line + d])
                else:
                    row.append(diag_load[str(line)][d])
            rows.append(row)
        return np.array(rows, dtype=np.float32)

    def __len__(self):
        return len(self.hic['0'])


class GenomicFeature:
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f'Feature: {path}, Norm: {norm}')

    def get(self, chr_name, start, end):
        with pbw.open(self.path) as bw:
            signals = np.array(bw.values(chr_name, int(start), int(end)))
        signals = np.nan_to_num(signals, 0.0)
        if self.norm == 'log':
            signals = np.log(signals + 1)
        return signals.astype(np.float32)

    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)


class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.res = 10000
        self.sample_bins = 500
        self.stride = 50
        self.image_scale = 256
        self.chr_name = chr_name
        self.use_aug = use_aug

        print(f'Loading {chr_name}...')
        self.seq = SequenceFeature(f'{celltype_root}/../dna_sequence/{chr_name}.fa.gz')
        self.genomic_features = feature_list
        self.mat = HiCFeature(f'{celltype_root}/hic_matrix/{chr_name}.npz')
        self.omit_regions = omit_regions

        self._check_length()
        all_intervals = self._get_intervals()
        self.intervals = self._filter(all_intervals, omit_regions)

    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.sample_bins * self.res)

        if self.use_aug:
            offset = np.random.randint(0, end - start - target_size + 1)
        else:
            offset = 0
        start, end = start + offset, start + offset + target_size

        seq = self.seq.get(start, end)
        features = [f.get(self.chr_name, start, end) for f in self.genomic_features]
        mat = self.mat.get(start)
        mat = resize(mat, (self.image_scale, self.image_scale), anti_aliasing=True)
        mat = np.log(mat + 1).astype(np.float32)

        if self.use_aug:
            seq = self._gaussian_noise(seq, 0.1)
            features = [self._gaussian_noise(f, 0.1) for f in features]
            seq, features, mat = self._reverse_complement(seq, features, mat)

        return seq, features, mat

    def _gaussian_noise(self, x, std=0.1):
        return x + np.random.randn(*x.shape).astype(np.float32) * std

    def _reverse_complement(self, seq, features, mat, chance=0.5):
        if np.random.rand() < chance:
            seq = np.flip(seq, 0).copy()
            seq = np.concatenate([seq[:, 1:2], seq[:, 0:1], seq[:, 3:4], seq[:, 2:3], seq[:, 4:5]], axis=1)
            features = [np.flip(f, 0).copy() for f in features]
            mat = np.flip(mat, [0, 1]).copy()
        return seq, features, mat

    def _get_intervals(self):
        chr_bins = len(self.seq) / self.res
        n = max(0, int((chr_bins - self.sample_bins) / self.stride))
        starts = np.arange(n).reshape(-1, 1) * self.stride
        bins = np.concatenate([starts, starts + self.sample_bins], axis=1) if n > 0 else np.zeros((0, 2))
        return (bins * self.res).astype(int)

    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions) == 0:
            return intervals.tolist()
        valid = []
        for s, e in intervals:
            if np.sum((s <= omit_regions[:, 1]) & (omit_regions[:, 0] <= e)) == 0:
                valid.append([int(s), int(e)])
        return valid

    def _check_length(self):
        if self.genomic_features:
            assert len(self.seq) == self.genomic_features[0].length(self.chr_name)
        assert abs(len(self.seq) / self.res - len(self.mat)) < 2


# ================================================================
# Evaluation Metrics
# ================================================================

from scipy.stats import pearsonr, spearmanr
from scipy.ndimage import uniform_filter1d

def upper_triangle_flatten(mat):
    iu = np.triu_indices_from(mat, k=1)
    return mat[iu]

def compute_global_corr(pred, true):
    p, t = upper_triangle_flatten(pred), upper_triangle_flatten(true)
    mask = np.isfinite(p) & np.isfinite(t)
    if mask.sum() < 3:
        return np.nan, np.nan
    return pearsonr(p[mask], t[mask])[0], spearmanr(p[mask], t[mask])[0]

def compute_distance_stratified_corr(pred, true, bin_size=10_000, step=50_000, max_dist=None):
    W = pred.shape[0]
    if max_dist is None:
        max_dist = (W - 1) * bin_size

    dists = np.abs(np.subtract.outer(np.arange(W), np.arange(W))) * bin_size
    out = {}

    for start in range(0, max_dist, step):
        end = start + step
        mask = (dists >= start) & (dists < end)
        mask &= ~np.eye(W, dtype=bool)

        if mask.sum() < 100:
            continue

        try:
            r, _ = pearsonr(pred[mask], true[mask])
            out[(start, end)] = r
        except:
            continue

    return out

def compute_insulation_score(mat, window=40):
    W = mat.shape[0]
    insul = np.zeros(W, dtype=np.float32)

    for i in range(window, W - window):
        sub = mat[i-window:i, i:i+window]
        insul[i] = np.mean(sub)

    insul = -np.log(insul + 1e-6)
    insul = uniform_filter1d(insul, size=5)
    return insul

def compute_insulation_corr(pred, true, window=40):
    ip = compute_insulation_score(pred, window=window)
    it = compute_insulation_score(true, window=window)
    mask = np.isfinite(ip) & np.isfinite(it)
    if mask.sum() < 3:
        return np.nan
    return pearsonr(ip[mask], it[mask])[0]


# ================================================================
# Custom Collate Function
# ================================================================

def collate_fn(batch):
    seqs, features_list, mats = zip(*batch)

    seqs = torch.from_numpy(np.stack(seqs))

    num_features = len(features_list[0])
    features = [np.stack([f[i] for f in features_list]) for i in range(num_features)]
    features = [torch.from_numpy(f) for f in features]

    mats = torch.from_numpy(np.stack(mats))

    return seqs, features, mats


# ================================================================
# PyTorch Lightning Module
# ================================================================

class HiCTrainingModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(ignore=['args'])
        self.args = args
        self.model = ConvTransHiCModel(num_genomic_features=2, mid_hidden=256, output_size=256)

        self.val_preds = []
        self.val_trues = []

    def forward(self, x):
        return self.model(x)

    def _proc_batch(self, batch):
        seq, features, mat = batch
        features = torch.stack(features, dim=2)
        inputs = torch.cat([seq, features], dim=2)
        inputs = inputs.transpose(1, 2)
        return inputs, mat

    def training_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        self.log('train_loss', loss, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        self.log('val_loss', loss, prog_bar=True, sync_dist=True)

        self.val_preds.append(outputs.detach().cpu().numpy())
        self.val_trues.append(mat.detach().cpu().numpy())

        return loss

    def on_validation_epoch_end(self):
        if len(self.val_preds) == 0:
            return

        preds = np.concatenate(self.val_preds, axis=0)
        trues = np.concatenate(self.val_trues, axis=0)

        global_pears, global_spears, ins_corrs = [], [], []
        dist_bin_stats = {}

        for i in range(len(preds)):
            pred, true = preds[i], trues[i]

            gp, gs = compute_global_corr(pred, true)
            global_pears.append(gp)
            global_spears.append(gs)

            ic = compute_insulation_corr(pred, true, window=40)
            ins_corrs.append(ic)

            d = compute_distance_stratified_corr(pred, true, bin_size=10_000,
                                                 step=50_000, max_dist=2_560_000)
            for k, r in d.items():
                dist_bin_stats.setdefault(k, []).append(r)

        def mean_finite(x):
            x = [v for v in x if np.isfinite(v)]
            return float(np.mean(x)) if len(x) else float('nan')

        metrics = {
            'val_pearson': mean_finite(global_pears),
            'val_spearman': mean_finite(global_spears),
            'val_insulation': mean_finite(ins_corrs),
        }

        bins_sorted = sorted(dist_bin_stats.keys(), key=lambda x: x[0])
        for idx, (s, e) in enumerate(bins_sorted[:5]):
            vals = dist_bin_stats[(s, e)]
            metrics[f'val_dist_{s//1000}kb'] = mean_finite(vals)

        self.log_dict(metrics, prog_bar=False, sync_dist=True)

        self.val_preds.clear()
        self.val_trues.clear()

    def test_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)

        if not hasattr(self, 'test_preds'):
            self.test_preds = []
            self.test_trues = []

        self.test_preds.append(outputs.detach().cpu().numpy())
        self.test_trues.append(mat.detach().cpu().numpy())

        return loss

    def on_test_epoch_end(self):
        if not hasattr(self, 'test_preds') or len(self.test_preds) == 0:
            return

        preds = np.concatenate(self.test_preds, axis=0)
        trues = np.concatenate(self.test_trues, axis=0)

        global_pears, global_spears, ins_corrs = [], [], []
        dist_bin_stats = {}

        for i in range(len(preds)):
            pred, true = preds[i], trues[i]

            gp, gs = compute_global_corr(pred, true)
            global_pears.append(gp)
            global_spears.append(gs)

            ic = compute_insulation_corr(pred, true, window=40)
            ins_corrs.append(ic)

            d = compute_distance_stratified_corr(pred, true, bin_size=10_000,
                                                 step=50_000, max_dist=2_560_000)
            for k, r in d.items():
                dist_bin_stats.setdefault(k, []).append(r)

        def mean_finite(x):
            x = [v for v in x if np.isfinite(v)]
            return float(np.mean(x)) if len(x) else float('nan')

        if self.trainer.is_global_zero:
            results_path = f"{self.args.save_path}/test_results.txt"
            bins_sorted = sorted(dist_bin_stats.keys(), key=lambda x: x[0])

            with open(results_path, 'w') as f:
                f.write(f"FINAL TEST EVALUATION\n")
                f.write(f"{'='*70}\n")
                f.write(f"Test samples: {len(preds)}\n\n")
                f.write(f"Global Pearson:  {mean_finite(global_pears):.4f}\n")
                f.write(f"Global Spearman: {mean_finite(global_spears):.4f}\n")
                f.write(f"Insulation corr: {mean_finite(ins_corrs):.4f}\n\n")
                f.write(f"Distance-stratified Pearson:\n")
                for (s, e) in bins_sorted:
                    vals = dist_bin_stats[(s, e)]
                    f.write(f"  {s//1000:4d}-{e//1000:4d} kb: {mean_finite(vals):.4f}\n")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-4, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.args.max_epochs, eta_min=1e-6
        )
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'epoch'}}


# ================================================================
# Main Training Script
# ================================================================

def load_centrotelo(bed_path):
    if not os.path.exists(bed_path):
        return {}
    import pandas as pd
    df = pd.read_csv(bed_path, sep='\t', names=['chr', 'start', 'end'])
    result = {}
    for chr_name, group in df.groupby('chr'):
        result[chr_name] = group[['start', 'end']].to_numpy(dtype=int)
    return result


def get_available_chromosomes(celltype_root):
    hic_dir = f'{celltype_root}/hic_matrix'
    available = []

    if os.path.exists(hic_dir):
        for f in os.listdir(hic_dir):
            if f.endswith('.npz'):
                chr_name = f.replace('.npz', '')
                available.append(chr_name)

    return sorted(available)


def get_dataset(args, mode):
    celltype_root = f'{args.data_root}/{args.assembly}/{args.celltype}'

    features = [
        GenomicFeature(f'{celltype_root}/genomic_features/ctcf_log2fc.bw', norm=None),
        GenomicFeature(f'{celltype_root}/genomic_features/atac.bw', norm='log'),
    ]

    centrotelo = load_centrotelo(f'{celltype_root}/../centrotelo.bed')

    available_chrs = get_available_chromosomes(celltype_root)
    print(f"\nAvailable chromosomes: {available_chrs}")

    if mode == 'train':
        desired = [f'chr{i}' for i in range(1, 23) if i not in [10, 15]]
    elif mode == 'val':
        desired = ['chr10']
    elif mode == 'test':
        desired = ['chr15']
    else:
        raise ValueError(f'Unknown mode: {mode}')

    chr_names = [c for c in desired if c in available_chrs]

    if len(chr_names) == 0:
        raise ValueError(f"No valid chromosomes found for mode '{mode}'")

    print(f"Using chromosomes for {mode}: {chr_names}\n")

    datasets = []
    for chr_name in chr_names:
        omit = centrotelo.get(chr_name, np.zeros((0, 2), dtype=int))
        use_aug = (mode == 'train')

        try:
            ds = ChromosomeDataset(celltype_root, chr_name, omit, features, use_aug=use_aug)
            datasets.append(ds)
            print(f"✓ Loaded {chr_name}: {len(ds)} samples")
        except FileNotFoundError as e:
            print(f"⚠️  Skipping {chr_name}: File not found")
            continue
        except Exception as e:
            print(f"⚠️  Skipping {chr_name}: {str(e)}")
            continue

    if len(datasets) == 0:
        raise ValueError(f"No datasets loaded for mode '{mode}'")

    print(f"Total {mode} samples: {sum(len(d) for d in datasets)}\n")

    return torch.utils.data.ConcatDataset(datasets)


def main():
    class Args:
        data_root = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90"
        assembly = "hg38"
        celltype = "imr90"
        batch_size = 4  # REDUCED from 8 to 4 for memory
        num_workers = 2  # REDUCED from 4 to 2 for memory
        max_epochs = 80
        patience = 20
        num_gpus = 1
        save_path = "checkpoints"
        save_top_k = 5

    args = Args()

    pl.seed_everything(2077, workers=True)

    os.makedirs(f'{args.save_path}/models', exist_ok=True)
    os.makedirs(f'{args.save_path}/logs', exist_ok=True)

    print("="*70)
    print("LOADING DATASETS")
    print("="*70)

    train_ds = get_dataset(args, 'train')
    val_ds = get_dataset(args, 'val')
    test_ds = get_dataset(args, 'test')

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn
    )

    model = HiCTrainingModule(args)

    checkpoint_cb = callbacks.ModelCheckpoint(
        dirpath=f'{args.save_path}/models',
        filename='hic-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        save_top_k=args.save_top_k,
        mode='min'
    )

    early_stop_cb = callbacks.EarlyStopping(
        monitor='val_loss',
        patience=args.patience,
        mode='min'
    )

    lr_monitor = callbacks.LearningRateMonitor(logging_interval='epoch')

    csv_logger = pl.loggers.CSVLogger(save_dir=f'{args.save_path}/logs')

    trainer = pl.Trainer(
        accelerator='gpu',
        devices=args.num_gpus,
        strategy='ddp' if args.num_gpus > 1 else 'auto',
        max_epochs=args.max_epochs,
        callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
        logger=csv_logger,
        gradient_clip_val=1.0,
        precision='16-mixed',
        enable_progress_bar=True,
        enable_model_summary=True,
        accumulate_grad_batches=2,  # ADDED: Gradient accumulation to simulate batch_size=8
    )

    print("\n" + "="*70)
    print("STARTING TRAINING")
    print("="*70)
    trainer.fit(model, train_loader, val_loader)

    print("\n" + "="*70)
    print("TESTING BEST MODEL")
    print("="*70)
    trainer.test(model, test_loader, ckpt_path='best')

    print(f"\n✅ Training complete! Best model saved to {args.save_path}/models")
    print(f"✅ Test results saved to {args.save_path}/test_results.txt")


if __name__ == '__main__':
    main()

INFO:lightning_fabric.utilities.seed:Seed set to 2077


LOADING DATASETS
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw, Norm: None
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw, Norm: log

Available chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']
Using chromosomes for train: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22']

Loading chr1...
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr1.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr1.npz


KeyboardInterrupt: 

In [None]:
"""
End-to-End Hi-C Prediction Training with C.Origami Architecture
SPEED OPTIMIZED VERSION FOR A100
"""

import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import pytorch_lightning.callbacks as callbacks
import gc

# ================================================================
# Model Components (C.Origami-style architecture)
# ================================================================

class ConvBlock(nn.Module):
    def __init__(self, size, stride=2, hidden_in=64, hidden=64):
        super().__init__()
        pad_len = int(size / 2)
        self.scale = nn.Sequential(
            nn.Conv1d(hidden_in, hidden, size, stride, pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        scaled = self.scale(x)
        identity = scaled
        res_out = self.res(scaled)
        return self.relu(res_out + identity)


class EncoderSplit(nn.Module):
    """Separate encoders for DNA sequence and epigenomic features"""
    def __init__(self, num_epi, output_size=256, filter_size=5, num_blocks=12):
        super().__init__()
        self.filter_size = filter_size

        self.conv_start_seq = nn.Sequential(
            nn.Conv1d(5, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        self.conv_start_epi = nn.Sequential(
            nn.Conv1d(num_epi, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        hiddens = [32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256]
        hidden_ins = [32, 32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256]
        hiddens_half = (np.array(hiddens) / 2).astype(int)
        hidden_ins_half = (np.array(hidden_ins) / 2).astype(int)

        self.res_blocks_seq = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.res_blocks_epi = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.conv_end = nn.Conv1d(256, output_size, 1)

    def forward(self, x):
        seq = x[:, :5, :]
        epi = x[:, 5:, :]

        seq = self.res_blocks_seq(self.conv_start_seq(seq))
        epi = self.res_blocks_epi(self.conv_start_epi(epi))

        x = torch.cat([seq, epi], dim=1)
        return self.conv_end(x)

    def _get_res_blocks(self, n, his, hs):
        blocks = []
        for hi, h in zip(his, hs):
            blocks.append(ConvBlock(self.filter_size, hidden_in=hi, hidden=h))
        return nn.Sequential(*blocks)


class PositionalEncoding(nn.Module):
    def __init__(self, hidden, dropout=0.1, max_len=2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden, 2) * (-np.log(10000.0) / hidden))
        pe = torch.zeros(max_len, 1, hidden)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        if seq_len > self.pe.size(0):
            position = torch.arange(seq_len, device=x.device).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, x.size(2), 2, device=x.device) * (-np.log(10000.0) / x.size(2)))
            pe = torch.zeros(seq_len, 1, x.size(2), device=x.device)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            x = x + pe.squeeze(1)
        else:
            x = x + self.pe[:seq_len].squeeze(1)
        return self.dropout(x)


class TransformerModule(nn.Module):
    def __init__(self, hidden=256, layers=8):
        super().__init__()
        self.pos_encoder = PositionalEncoding(hidden, dropout=0.1, max_len=2048)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden,
            nhead=8,
            dropout=0.1,
            dim_feedforward=512,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, layers)

    def forward(self, x):
        x = self.pos_encoder(x)
        return self.transformer(x)


class ResBlockDilated(nn.Module):
    def __init__(self, size, hidden=64, dil=2):
        super().__init__()
        pad_len = dil
        self.res = nn.Sequential(
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        res_out = self.res(x)
        return self.relu(res_out + identity)


class Decoder(nn.Module):
    def __init__(self, in_channel, hidden=256, filter_size=3, num_blocks=5, output_size=256):
        super().__init__()
        self.output_size = output_size
        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channel, hidden, 3, 1, 1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
        )
        self.res_blocks = self._get_res_blocks(num_blocks, hidden)
        self.conv_end = nn.Conv2d(hidden, 1, 1)

    def forward(self, x):
        x = self.conv_start(x)
        x = self.res_blocks(x)
        x = self.conv_end(x).squeeze(1)

        if x.size(-1) != self.output_size or x.size(-2) != self.output_size:
            x = nn.functional.interpolate(
                x.unsqueeze(1),
                size=(self.output_size, self.output_size),
                mode='bilinear',
                align_corners=False
            ).squeeze(1)

        return x

    def _get_res_blocks(self, n, hidden):
        blocks = []
        for i in range(n):
            dilation = 2 ** (i + 1)
            blocks.append(ResBlockDilated(3, hidden=hidden, dil=dilation))
        return nn.Sequential(*blocks)


class ConvTransHiCModel(nn.Module):
    """Full C.Origami-style model: Encoder + Transformer + Decoder"""
    def __init__(self, num_genomic_features=2, mid_hidden=256, output_size=256):
        super().__init__()
        self.encoder = EncoderSplit(num_genomic_features, output_size=mid_hidden, num_blocks=12)
        self.transformer = TransformerModule(hidden=mid_hidden, layers=8)

        self.pre_decoder_conv = nn.Sequential(
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
        )

        self.decoder = Decoder(mid_hidden * 2, hidden=256, output_size=output_size)

    def forward(self, x):
        x = x.float()
        x = self.encoder(x)
        x = x.transpose(1, 2)
        x = self.transformer(x)
        x = x.transpose(1, 2)
        x = self.pre_decoder_conv(x)
        x = self._diagonalize(x)
        return self.decoder(x)

    def _diagonalize(self, x):
        x_i = x.unsqueeze(3).repeat(1, 1, 1, x.size(2))
        x_j = x.unsqueeze(2).repeat(1, 1, x.size(2), 1)
        return torch.cat([x_i, x_j], dim=1)


# ================================================================
# Dataset
# ================================================================

import gzip
import pyBigWig as pbw
from skimage.transform import resize

class SequenceFeature:
    def __init__(self, path):
        print(f'Reading sequence: {path}')
        with gzip.open(path, 'r') as f:
            seq = f.read().decode("utf-8")
        seq = seq[seq.find('\n')+1:].replace('\n', '').lower()
        self.seq = seq

    def get(self, start, end):
        seq = self.seq[start:end]
        en_dict = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
        idx = np.array([en_dict.get(ch, 4) for ch in seq], dtype=int)
        onehot = np.zeros((len(seq), 5), dtype=np.float32)
        if len(seq) > 0:
            onehot[np.arange(len(seq)), idx] = 1.0
        return onehot

    def __len__(self):
        return len(self.seq)


class HiCFeature:
    def __init__(self, path):
        print(f'Reading Hi-C: {path}')
        self.hic = dict(np.load(path))

    def get(self, start, window=2097152, res=10000):
        start_bin = int(start / res)
        end_bin = start_bin + int(window / res)
        return self._diag_to_mat(self.hic, start_bin, end_bin)

    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for d in range(square_len):
            diag_load[str(d)] = ori_load[str(d)][start:start + (square_len - d)]
            diag_load[str(-d)] = ori_load[str(-d)][start:start + (square_len - d)]

        rows = []
        for d in range(square_len):
            row = []
            for line in range(-d, -d + square_len):
                if line < 0:
                    row.append(diag_load[str(line)][line + d])
                else:
                    row.append(diag_load[str(line)][d])
            rows.append(row)
        return np.array(rows, dtype=np.float32)

    def __len__(self):
        return len(self.hic['0'])


class GenomicFeature:
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f'Feature: {path}, Norm: {norm}')

    def get(self, chr_name, start, end):
        with pbw.open(self.path) as bw:
            signals = np.array(bw.values(chr_name, int(start), int(end)))
        signals = np.nan_to_num(signals, 0.0)
        if self.norm == 'log':
            signals = np.log(signals + 1)
        return signals.astype(np.float32)

    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)


class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.res = 10000
        self.sample_bins = 500
        self.stride = 50
        self.image_scale = 256
        self.chr_name = chr_name
        self.use_aug = use_aug

        print(f'Loading {chr_name}...')
        self.seq = SequenceFeature(f'{celltype_root}/../dna_sequence/{chr_name}.fa.gz')
        self.genomic_features = feature_list
        self.mat = HiCFeature(f'{celltype_root}/hic_matrix/{chr_name}.npz')
        self.omit_regions = omit_regions

        self._check_length()
        all_intervals = self._get_intervals()
        self.intervals = self._filter(all_intervals, omit_regions)

    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.sample_bins * self.res)

        if self.use_aug:
            offset = np.random.randint(0, end - start - target_size + 1)
        else:
            offset = 0
        start, end = start + offset, start + offset + target_size

        seq = self.seq.get(start, end)
        features = [f.get(self.chr_name, start, end) for f in self.genomic_features]
        mat = self.mat.get(start)
        mat = resize(mat, (self.image_scale, self.image_scale), anti_aliasing=True)
        mat = np.log(mat + 1).astype(np.float32)

        if self.use_aug:
            seq = self._gaussian_noise(seq, 0.1)
            features = [self._gaussian_noise(f, 0.1) for f in features]
            seq, features, mat = self._reverse_complement(seq, features, mat)

        return seq, features, mat

    def _gaussian_noise(self, x, std=0.1):
        return x + np.random.randn(*x.shape).astype(np.float32) * std

    def _reverse_complement(self, seq, features, mat, chance=0.5):
        if np.random.rand() < chance:
            seq = np.flip(seq, 0).copy()
            seq = np.concatenate([seq[:, 1:2], seq[:, 0:1], seq[:, 3:4], seq[:, 2:3], seq[:, 4:5]], axis=1)
            features = [np.flip(f, 0).copy() for f in features]
            mat = np.flip(mat, [0, 1]).copy()
        return seq, features, mat

    def _get_intervals(self):
        chr_bins = len(self.seq) / self.res
        n = max(0, int((chr_bins - self.sample_bins) / self.stride))
        starts = np.arange(n).reshape(-1, 1) * self.stride
        bins = np.concatenate([starts, starts + self.sample_bins], axis=1) if n > 0 else np.zeros((0, 2))
        return (bins * self.res).astype(int)

    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions) == 0:
            return intervals.tolist()
        valid = []
        for s, e in intervals:
            if np.sum((s <= omit_regions[:, 1]) & (omit_regions[:, 0] <= e)) == 0:
                valid.append([int(s), int(e)])
        return valid

    def _check_length(self):
        if self.genomic_features:
            assert len(self.seq) == self.genomic_features[0].length(self.chr_name)
        assert abs(len(self.seq) / self.res - len(self.mat)) < 2


# ================================================================
# Evaluation Metrics (SIMPLIFIED FOR SPEED)
# ================================================================

from scipy.stats import pearsonr, spearmanr
from scipy.ndimage import uniform_filter1d

def upper_triangle_flatten(mat):
    iu = np.triu_indices_from(mat, k=1)
    return mat[iu]

def compute_global_corr(pred, true):
    p, t = upper_triangle_flatten(pred), upper_triangle_flatten(true)
    mask = np.isfinite(p) & np.isfinite(t)
    if mask.sum() < 3:
        return np.nan, np.nan
    return pearsonr(p[mask], t[mask])[0], spearmanr(p[mask], t[mask])[0]


# ================================================================
# Custom Collate Function
# ================================================================

def collate_fn(batch):
    seqs, features_list, mats = zip(*batch)

    seqs = torch.from_numpy(np.stack(seqs))

    num_features = len(features_list[0])
    features = [np.stack([f[i] for f in features_list]) for i in range(num_features)]
    features = [torch.from_numpy(f) for f in features]

    mats = torch.from_numpy(np.stack(mats))

    return seqs, features, mats


# ================================================================
# PyTorch Lightning Module (SIMPLIFIED FOR SPEED)
# ================================================================

class HiCTrainingModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(ignore=['args'])
        self.args = args
        self.model = ConvTransHiCModel(num_genomic_features=2, mid_hidden=256, output_size=256)

    def forward(self, x):
        return self.model(x)

    def _proc_batch(self, batch):
        seq, features, mat = batch
        features = torch.stack(features, dim=2)
        inputs = torch.cat([seq, features], dim=2)
        inputs = inputs.transpose(1, 2)
        return inputs, mat

    def training_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)

        # SPEED: Only compute correlation on subset of validation
        if batch_idx == 0:  # Only first batch
            pred = outputs[0].detach().cpu().numpy()
            true = mat[0].detach().cpu().numpy()
            gp, gs = compute_global_corr(pred, true)
            self.log('val_pearson', gp, prog_bar=True)

        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=3e-4,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,
            anneal_strategy='cos'
        )
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}


# ================================================================
# Main Training Script
# ================================================================

def load_centrotelo(bed_path):
    if not os.path.exists(bed_path):
        return {}
    import pandas as pd
    df = pd.read_csv(bed_path, sep='\t', names=['chr', 'start', 'end'])
    result = {}
    for chr_name, group in df.groupby('chr'):
        result[chr_name] = group[['start', 'end']].to_numpy(dtype=int)
    return result


def get_available_chromosomes(celltype_root):
    hic_dir = f'{celltype_root}/hic_matrix'
    available = []

    if os.path.exists(hic_dir):
        for f in os.listdir(hic_dir):
            if f.endswith('.npz'):
                chr_name = f.replace('.npz', '')
                available.append(chr_name)

    return sorted(available)


def get_dataset(args, mode):
    celltype_root = f'{args.data_root}/{args.assembly}/{args.celltype}'

    features = [
        GenomicFeature(f'{celltype_root}/genomic_features/ctcf_log2fc.bw', norm=None),
        GenomicFeature(f'{celltype_root}/genomic_features/atac.bw', norm='log'),
    ]

    centrotelo = load_centrotelo(f'{celltype_root}/../centrotelo.bed')

    available_chrs = get_available_chromosomes(celltype_root)
    print(f"\nAvailable chromosomes: {available_chrs}")

    if mode == 'train':
        desired = [f'chr{i}' for i in range(1, 23) if i not in [10, 15]]
    elif mode == 'val':
        desired = ['chr10']
    elif mode == 'test':
        desired = ['chr15']
    else:
        raise ValueError(f'Unknown mode: {mode}')

    chr_names = [c for c in desired if c in available_chrs]

    if len(chr_names) == 0:
        raise ValueError(f"No valid chromosomes found for mode '{mode}'")

    print(f"Using chromosomes for {mode}: {chr_names}\n")

    datasets = []
    for chr_name in chr_names:
        omit = centrotelo.get(chr_name, np.zeros((0, 2), dtype=int))
        use_aug = (mode == 'train')

        try:
            ds = ChromosomeDataset(celltype_root, chr_name, omit, features, use_aug=use_aug)
            datasets.append(ds)
            print(f"✓ Loaded {chr_name}: {len(ds)} samples")
        except FileNotFoundError as e:
            print(f"⚠️  Skipping {chr_name}: File not found")
            continue
        except Exception as e:
            print(f"⚠️  Skipping {chr_name}: {str(e)}")
            continue

    if len(datasets) == 0:
        raise ValueError(f"No datasets loaded for mode '{mode}'")

    print(f"Total {mode} samples: {sum(len(d) for d in datasets)}\n")

    return torch.utils.data.ConcatDataset(datasets)


def main():
    # Clear memory first
    torch.cuda.empty_cache()
    gc.collect()

    # Enable optimizations for A100
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')

    class Args:
        data_root = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90"
        assembly = "hg38"
        celltype = "imr90"
        batch_size = 12
        num_workers = 4
        max_epochs = 10  # CHANGED TO 10 EPOCHS
        patience = 5  # Adjusted patience for 10 epochs
        num_gpus = 1
        save_path = "checkpoints"
        save_top_k = 3

    args = Args()

    pl.seed_everything(2077, workers=True)

    os.makedirs(f'{args.save_path}/models', exist_ok=True)
    os.makedirs(f'{args.save_path}/logs', exist_ok=True)

    print("="*70)
    print("LOADING DATASETS")
    print("="*70)

    train_ds = get_dataset(args, 'train')
    val_ds = get_dataset(args, 'val')
    test_ds = get_dataset(args, 'test')

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    model = HiCTrainingModule(args)

    checkpoint_cb = callbacks.ModelCheckpoint(
        dirpath=f'{args.save_path}/models',
        filename='hic-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        save_top_k=args.save_top_k,
        mode='min',
        save_last=True,
    )

    early_stop_cb = callbacks.EarlyStopping(
        monitor='val_loss',
        patience=args.patience,
        mode='min'
    )

    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')

    csv_logger = pl.loggers.CSVLogger(save_dir=f'{args.save_path}/logs')

    trainer = pl.Trainer(
        accelerator='gpu',
        devices=args.num_gpus,
        max_epochs=args.max_epochs,
        callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
        logger=csv_logger,
        gradient_clip_val=1.0,
        precision='16-mixed',
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=10,
        val_check_interval=0.25,
    )

    print("\n" + "="*70)
    print("STARTING TRAINING")
    print(f"Batch size: {args.batch_size}")
    print(f"Max epochs: {args.max_epochs}")
    print(f"Training samples: {len(train_ds)}")
    print("="*70)

    trainer.fit(model, train_loader, val_loader)

    print("\n" + "="*70)
    print("TESTING BEST MODEL")
    print("="*70)
    trainer.test(model, test_loader, ckpt_path='best')

    print(f"\n✅ Training complete! Best model saved to {args.save_path}/models")


if __name__ == '__main__':
    main()

  _C._set_float32_matmul_precision(precision)
INFO:lightning_fabric.utilities.seed:Seed set to 2077


LOADING DATASETS
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw, Norm: None
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw, Norm: log

Available chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']
Using chromosomes for train: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22']

Loading chr1...
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr1.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr1.npz
✓ Loaded chr1: 471 samples
Loading chr2...
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_seq

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)


✓ Loaded chr15: 177 samples
Total test samples: 177



INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores



STARTING TRAINING
Batch size: 12
Max epochs: 10
Training samples: 4737


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.12/dist-packages/pytorch_lightning/utilities/model_summary/model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.


Output()

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/checkpoints/models/hic-epoch=02-val_loss=0.3567.ckpt



TESTING BEST MODEL


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/checkpoints/models/hic-epoch=02-val_loss=0.3567.ckpt


Output()


✅ Training complete! Best model saved to checkpoints/models


In [3]:
# =========================
# ONE-CELL READY-TO-PASTE COLAB SCRIPT
# End-to-End Hi-C Prediction Training with C.Origami Architecture
# SAVES BEST MODEL + LOGS DIRECTLY TO YOUR DRIVE IMR90 FOLDER
# =========================

# ---- MUST BE FIRST (helps fragmentation) ----
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---- Mount Drive (so nothing is lost) ----
from google.colab import drive
drive.mount("/content/drive")

# ---- (Optional) installs; comment out if already installed ----
!pip -q install pytorch-lightning==2.3.3 pyBigWig scikit-image scipy pandas

import gc
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import pytorch_lightning.callbacks as callbacks

# ================================================================
# Model Components (C.Origami-style architecture)
# ================================================================

class ConvBlock(nn.Module):
    def __init__(self, size, stride=2, hidden_in=64, hidden=64):
        super().__init__()
        pad_len = int(size / 2)
        self.scale = nn.Sequential(
            nn.Conv1d(hidden_in, hidden, size, stride, pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        scaled = self.scale(x)
        identity = scaled
        res_out = self.res(scaled)
        return self.relu(res_out + identity)


class EncoderSplit(nn.Module):
    """Separate encoders for DNA sequence and epigenomic features"""
    def __init__(self, num_epi, output_size=256, filter_size=5, num_blocks=12):
        super().__init__()
        self.filter_size = filter_size

        self.conv_start_seq = nn.Sequential(
            nn.Conv1d(5, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        self.conv_start_epi = nn.Sequential(
            nn.Conv1d(num_epi, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        hiddens = [32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256]
        hidden_ins = [32, 32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256]
        hiddens_half = (np.array(hiddens) / 2).astype(int)
        hidden_ins_half = (np.array(hidden_ins) / 2).astype(int)

        self.res_blocks_seq = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.res_blocks_epi = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.conv_end = nn.Conv1d(256, output_size, 1)

    def forward(self, x):
        seq = x[:, :5, :]
        epi = x[:, 5:, :]

        seq = self.res_blocks_seq(self.conv_start_seq(seq))
        epi = self.res_blocks_epi(self.conv_start_epi(epi))

        x = torch.cat([seq, epi], dim=1)
        return self.conv_end(x)

    def _get_res_blocks(self, n, his, hs):
        blocks = []
        for hi, h in zip(his, hs):
            blocks.append(ConvBlock(self.filter_size, hidden_in=hi, hidden=h))
        return nn.Sequential(*blocks)


class PositionalEncoding(nn.Module):
    def __init__(self, hidden, dropout=0.1, max_len=2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden, 2) * (-np.log(10000.0) / hidden))
        pe = torch.zeros(max_len, 1, hidden)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        if seq_len > self.pe.size(0):
            position = torch.arange(seq_len, device=x.device).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, x.size(2), 2, device=x.device) * (-np.log(10000.0) / x.size(2)))
            pe = torch.zeros(seq_len, 1, x.size(2), device=x.device)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            x = x + pe.squeeze(1)
        else:
            x = x + self.pe[:seq_len].squeeze(1)
        return self.dropout(x)


class TransformerModule(nn.Module):
    def __init__(self, hidden=256, layers=8):
        super().__init__()
        self.pos_encoder = PositionalEncoding(hidden, dropout=0.1, max_len=2048)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden,
            nhead=8,
            dropout=0.1,
            dim_feedforward=512,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, layers)

    def forward(self, x):
        x = self.pos_encoder(x)
        return self.transformer(x)


class ResBlockDilated(nn.Module):
    def __init__(self, size, hidden=64, dil=2):
        super().__init__()
        pad_len = dil
        self.res = nn.Sequential(
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        res_out = self.res(x)
        return self.relu(res_out + identity)


class Decoder(nn.Module):
    def __init__(self, in_channel, hidden=256, filter_size=3, num_blocks=5, output_size=256):
        super().__init__()
        self.output_size = output_size
        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channel, hidden, 3, 1, 1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
        )
        self.res_blocks = self._get_res_blocks(num_blocks, hidden)
        self.conv_end = nn.Conv2d(hidden, 1, 1)

    def forward(self, x):
        x = self.conv_start(x)
        x = self.res_blocks(x)
        x = self.conv_end(x).squeeze(1)

        if x.size(-1) != self.output_size or x.size(-2) != self.output_size:
            x = nn.functional.interpolate(
                x.unsqueeze(1),
                size=(self.output_size, self.output_size),
                mode='bilinear',
                align_corners=False
            ).squeeze(1)

        return x

    def _get_res_blocks(self, n, hidden):
        blocks = []
        for i in range(n):
            dilation = 2 ** (i + 1)
            blocks.append(ResBlockDilated(3, hidden=hidden, dil=dilation))
        return nn.Sequential(*blocks)


class ConvTransHiCModel(nn.Module):
    """Full C.Origami-style model: Encoder + Transformer + Decoder"""
    def __init__(self, num_genomic_features=2, mid_hidden=256, output_size=256):
        super().__init__()
        self.encoder = EncoderSplit(num_genomic_features, output_size=mid_hidden, num_blocks=12)
        self.transformer = TransformerModule(hidden=mid_hidden, layers=8)

        self.pre_decoder_conv = nn.Sequential(
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
        )

        self.decoder = Decoder(mid_hidden * 2, hidden=256, output_size=output_size)

    def forward(self, x):
        x = x.float()
        x = self.encoder(x)
        x = x.transpose(1, 2)
        x = self.transformer(x)
        x = x.transpose(1, 2)
        x = self.pre_decoder_conv(x)
        x = self._diagonalize(x)
        return self.decoder(x)

    def _diagonalize(self, x):
        # IMPORTANT: expand avoids extra VRAM vs repeat, but gives same view behavior
        B, C, L = x.shape
        x_i = x.unsqueeze(3).expand(B, C, L, L)
        x_j = x.unsqueeze(2).expand(B, C, L, L)
        return torch.cat([x_i, x_j], dim=1)


# ================================================================
# Dataset
# ================================================================

import gzip
import pyBigWig as pbw
from skimage.transform import resize

class SequenceFeature:
    def __init__(self, path):
        print(f'Reading sequence: {path}')
        with gzip.open(path, 'r') as f:
            seq = f.read().decode("utf-8")
        seq = seq[seq.find('\n')+1:].replace('\n', '').lower()
        self.seq = seq

    def get(self, start, end):
        seq = self.seq[start:end]
        en_dict = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
        idx = np.array([en_dict.get(ch, 4) for ch in seq], dtype=int)
        onehot = np.zeros((len(seq), 5), dtype=np.float32)
        if len(seq) > 0:
            onehot[np.arange(len(seq)), idx] = 1.0
        return onehot

    def __len__(self):
        return len(self.seq)


class HiCFeature:
    def __init__(self, path):
        print(f'Reading Hi-C: {path}')
        self.hic = dict(np.load(path))

        # robust max diag key (many datasets only have 0..255)
        pos = [int(k) for k in self.hic.keys() if k.lstrip('-').isdigit() and int(k) >= 0]
        self.max_pos_diag = max(pos) if pos else 0
        self.has_neg = any((k.lstrip('-').isdigit() and int(k) < 0) for k in self.hic.keys())

    def get(self, start, window=2097152, res=10000):
        start_bin = int(start / res)
        end_bin = start_bin + int(window / res)
        return self._diag_to_mat(self.hic, start_bin, end_bin)

    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        M = np.zeros((square_len, square_len), dtype=np.float32)

        # fill only diagonals that exist; missing diagonals stay 0
        max_d = min(square_len - 1, self.max_pos_diag)
        for d in range(max_d + 1):
            key = str(d)
            if key not in ori_load:
                continue
            v = ori_load[key][start:start + (square_len - d)]
            n = v.shape[0]
            if n <= 0:
                continue
            i = np.arange(n, dtype=np.int32)
            M[i, i + d] = v
            M[i + d, i] = v

        return M

    def __len__(self):
        # length in bins
        return len(self.hic['0'])


class GenomicFeature:
    """
    Fixes NaNs from log1p on signed tracks:
    - norm=None: signed track (ctcf_log2fc), clip outliers only
    - norm='log': counts-like (ATAC), clamp negatives to 0, then log1p
    """
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f'Feature: {path}, Norm: {norm}')

    def get(self, chr_name, start, end):
        with pbw.open(self.path) as bw:
            signals = np.array(bw.values(chr_name, int(start), int(end)), dtype=np.float32)

        signals = np.nan_to_num(signals, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

        if self.norm == 'log':
            signals = np.clip(signals, 0.0, None)
            signals = np.log1p(signals)
        else:
            # signed track like log2fc: keep sign, clip extremes
            signals = np.clip(signals, -10.0, 10.0)

        return signals.astype(np.float32)

    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)


class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.res = 10000
        self.sample_bins = 500
        self.stride = 50
        self.image_scale = 256
        self.chr_name = chr_name
        self.use_aug = use_aug

        print(f'Loading {chr_name}...')
        self.seq = SequenceFeature(f'{celltype_root}/../dna_sequence/{chr_name}.fa.gz')
        self.genomic_features = feature_list
        self.mat = HiCFeature(f'{celltype_root}/hic_matrix/{chr_name}.npz')
        self.omit_regions = omit_regions

        self._check_length()
        all_intervals = self._get_intervals()
        self.intervals = self._filter(all_intervals, omit_regions)

    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.sample_bins * self.res)

        if self.use_aug:
            offset = np.random.randint(0, end - start - target_size + 1)
        else:
            offset = 0
        start, end = start + offset, start + offset + target_size

        seq = self.seq.get(start, end)
        features = [f.get(self.chr_name, start, end) for f in self.genomic_features]
        mat = self.mat.get(start)

        mat = resize(mat, (self.image_scale, self.image_scale), anti_aliasing=True)
        mat = np.log1p(np.clip(mat, 0.0, None)).astype(np.float32)

        if self.use_aug:
            seq = self._gaussian_noise(seq, 0.1)
            features = [self._gaussian_noise(f, 0.1) for f in features]
            seq, features, mat = self._reverse_complement(seq, features, mat)

        return seq, features, mat

    def _gaussian_noise(self, x, std=0.1):
        return x + np.random.randn(*x.shape).astype(np.float32) * std

    def _reverse_complement(self, seq, features, mat, chance=0.5):
        if np.random.rand() < chance:
            seq = np.flip(seq, 0).copy()
            seq = np.concatenate([seq[:, 1:2], seq[:, 0:1], seq[:, 3:4], seq[:, 2:3], seq[:, 4:5]], axis=1)
            features = [np.flip(f, 0).copy() for f in features]
            mat = np.flip(mat, [0, 1]).copy()
        return seq, features, mat

    def _get_intervals(self):
        chr_bins = len(self.seq) / self.res
        n = max(0, int((chr_bins - self.sample_bins) / self.stride))
        starts = np.arange(n).reshape(-1, 1) * self.stride
        bins = np.concatenate([starts, starts + self.sample_bins], axis=1) if n > 0 else np.zeros((0, 2))
        return (bins * self.res).astype(int)

    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions) == 0:
            return intervals.tolist()
        valid = []
        for s, e in intervals:
            if np.sum((s <= omit_regions[:, 1]) & (omit_regions[:, 0] <= e)) == 0:
                valid.append([int(s), int(e)])
        return valid

    def _check_length(self):
        if self.genomic_features:
            assert len(self.seq) == self.genomic_features[0].length(self.chr_name)
        assert abs(len(self.seq) / self.res - len(self.mat)) < 2


# ================================================================
# Evaluation Metrics (SIMPLIFIED FOR SPEED)
# ================================================================

from scipy.stats import pearsonr, spearmanr

def upper_triangle_flatten(mat):
    iu = np.triu_indices_from(mat, k=1)
    return mat[iu]

def compute_global_corr(pred, true):
    p, t = upper_triangle_flatten(pred), upper_triangle_flatten(true)
    mask = np.isfinite(p) & np.isfinite(t)
    if mask.sum() < 3:
        return np.nan, np.nan
    return pearsonr(p[mask], t[mask])[0], spearmanr(p[mask], t[mask])[0]


# ================================================================
# Custom Collate Function
# ================================================================

def collate_fn(batch):
    seqs, features_list, mats = zip(*batch)

    seqs = torch.from_numpy(np.stack(seqs))
    num_features = len(features_list[0])
    features = [np.stack([f[i] for f in features_list]) for i in range(num_features)]
    features = [torch.from_numpy(f) for f in features]
    mats = torch.from_numpy(np.stack(mats))

    return seqs, features, mats


# ================================================================
# PyTorch Lightning Module (SIMPLIFIED FOR SPEED)
# ================================================================

class HiCTrainingModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(ignore=['args'])
        self.args = args
        self.model = ConvTransHiCModel(num_genomic_features=2, mid_hidden=256, output_size=256)

    def forward(self, x):
        return self.model(x)

    def _proc_batch(self, batch):
        seq, features, mat = batch
        features = torch.stack(features, dim=2)
        inputs = torch.cat([seq, features], dim=2)
        inputs = inputs.transpose(1, 2)

        # NaN/Inf guard
        if not torch.isfinite(inputs).all() or not torch.isfinite(mat).all():
            raise RuntimeError("Non-finite values in inputs or targets (NaN/Inf). Check feature transforms.")

        return inputs, mat

    def training_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)

        if batch_idx == 0:
            pred = outputs[0].detach().cpu().numpy()
            true = mat[0].detach().cpu().numpy()
            gp, _ = compute_global_corr(pred, true)
            self.log('val_pearson', gp, prog_bar=True)

        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=3e-4,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,
            anneal_strategy='cos'
        )
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}


# ================================================================
# Main Training Script (SAVES DIRECTLY TO DRIVE)
# ================================================================

def load_centrotelo(bed_path):
    if not os.path.exists(bed_path):
        return {}
    import pandas as pd
    df = pd.read_csv(bed_path, sep='\t', names=['chr', 'start', 'end'])
    result = {}
    for chr_name, group in df.groupby('chr'):
        result[chr_name] = group[['start', 'end']].to_numpy(dtype=int)
    return result


def get_available_chromosomes(celltype_root):
    hic_dir = f'{celltype_root}/hic_matrix'
    available = []
    if os.path.exists(hic_dir):
        for f in os.listdir(hic_dir):
            if f.endswith('.npz'):
                available.append(f.replace('.npz', ''))
    return sorted(available)


def get_dataset(args, mode):
    celltype_root = f'{args.data_root}/{args.assembly}/{args.celltype}'

    features = [
        GenomicFeature(f'{celltype_root}/genomic_features/ctcf_log2fc.bw', norm=None),  # signed
        GenomicFeature(f'{celltype_root}/genomic_features/atac.bw', norm='log'),        # nonneg log1p
    ]

    centrotelo = load_centrotelo(f'{celltype_root}/../centrotelo.bed')

    available_chrs = get_available_chromosomes(celltype_root)
    print(f"\nAvailable chromosomes: {available_chrs}")

    if mode == 'train':
        desired = [f'chr{i}' for i in range(1, 23) if i not in [10, 15]]
    elif mode == 'val':
        desired = ['chr10']
    elif mode == 'test':
        desired = ['chr15']
    else:
        raise ValueError(f'Unknown mode: {mode}')

    chr_names = [c for c in desired if c in available_chrs]
    if len(chr_names) == 0:
        raise ValueError(f"No valid chromosomes found for mode '{mode}'")

    print(f"Using chromosomes for {mode}: {chr_names}\n")

    datasets = []
    for chr_name in chr_names:
        omit = centrotelo.get(chr_name, np.zeros((0, 2), dtype=int))
        use_aug = (mode == 'train')
        try:
            ds = ChromosomeDataset(celltype_root, chr_name, omit, features, use_aug=use_aug)
            datasets.append(ds)
            print(f"✓ Loaded {chr_name}: {len(ds)} samples")
        except FileNotFoundError:
            print(f"⚠️  Skipping {chr_name}: File not found")
        except Exception as e:
            print(f"⚠️  Skipping {chr_name}: {str(e)}")

    if len(datasets) == 0:
        raise ValueError(f"No datasets loaded for mode '{mode}'")

    print(f"Total {mode} samples: {sum(len(d) for d in datasets)}\n")
    return torch.utils.data.ConcatDataset(datasets)


def main():
    torch.cuda.empty_cache()
    gc.collect()

    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')

    class Args:
        data_root = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90"
        assembly = "hg38"
        celltype = "imr90"

        batch_size = 12
        num_workers = 4
        max_epochs = 10
        patience = 5
        num_gpus = 1
        save_top_k = 3

        # SAVE TO DRIVE INSIDE YOUR IMR90 FOLDER
        save_root = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/checkpoints"

    args = Args()
    pl.seed_everything(2077, workers=True)

    model_dir = f"{args.save_root}/models"
    log_dir = f"{args.save_root}/logs"
    os.makedirs(model_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)

    print("="*70)
    print("LOADING DATASETS")
    print("="*70)

    train_ds = get_dataset(args, 'train')
    val_ds = get_dataset(args, 'val')
    test_ds = get_dataset(args, 'test')

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn,
    )

    model = HiCTrainingModule(args)

    checkpoint_cb = callbacks.ModelCheckpoint(
        dirpath=model_dir,
        filename='hic-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        save_top_k=args.save_top_k,
        mode='min',
        save_last=True,
    )

    early_stop_cb = callbacks.EarlyStopping(
        monitor='val_loss',
        patience=args.patience,
        mode='min'
    )

    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')

    csv_logger = pl.loggers.CSVLogger(save_dir=log_dir, name="hic_training")

    trainer = pl.Trainer(
        accelerator='gpu',
        devices=args.num_gpus,
        max_epochs=args.max_epochs,
        callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
        logger=csv_logger,
        gradient_clip_val=1.0,
        precision='16-mixed',
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=10,
        val_check_interval=0.25,
    )

    print("\n" + "="*70)
    print("STARTING TRAINING")
    print(f"Batch size: {args.batch_size}")
    print(f"Max epochs: {args.max_epochs}")
    print(f"Training samples: {len(train_ds)}")
    print("Saving best checkpoints to:")
    print(model_dir)
    print("="*70)

    trainer.fit(model, train_loader, val_loader)

    print("\n" + "="*70)
    print("TESTING BEST MODEL")
    print("="*70)
    trainer.test(model, test_loader, ckpt_path='best')

    print(f"\n✅ Training complete! Best model saved to:\n{model_dir}")
    print(f"✅ Logs saved to:\n{log_dir}")


if __name__ == '__main__':
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m812.3/812.3 kB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m187.1/187.1 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m69.7 MB/s[0m eta [36m0:00:00[0m
[?25h

  _C._set_float32_matmul_precision(precision)
INFO:lightning_fabric.utilities.seed:Seed set to 2077


LOADING DATASETS
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw, Norm: None
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw, Norm: log

Available chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']
Using chromosomes for train: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22']

Loading chr1...
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_sequence/chr1.fa.gz
Reading Hi-C: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/hic_matrix/chr1.npz
✓ Loaded chr1: 471 samples
Loading chr2...
Reading sequence: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/../dna_seq

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
/usr/local/lib/python3.12/dist-packages/pytorch_lightning/plugins/precision/amp.py:52: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs



STARTING TRAINING
Batch size: 12
Max epochs: 10
Training samples: 4737
Saving best checkpoints to:
/content/drive/MyDrive/ML4GEN DATA/data - IMR90/checkpoints/models


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type              | Params | Mode 
----------------------------------------------------
0 | model | ConvTransHiCModel | 13.2 M | train
----------------------------------------------------
13.2 M    Trainable params
0         Non-trainable params
13.2 M    Total params
52.940    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/.shortcut-targets-by-id/126vLgWy4wFKfcpk6pUYC6UiOP4G1DHi2/ML4GEN DATA/data - IMR90/checkpoints/models/hic-epoch=04-val_loss=0.3308.ckpt



TESTING BEST MODEL


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/drive/.shortcut-targets-by-id/126vLgWy4wFKfcpk6pUYC6UiOP4G1DHi2/ML4GEN DATA/data - IMR90/checkpoints/models/hic-epoch=04-val_loss=0.3308.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]


✅ Training complete! Best model saved to:
/content/drive/MyDrive/ML4GEN DATA/data - IMR90/checkpoints/models
✅ Logs saved to:
/content/drive/MyDrive/ML4GEN DATA/data - IMR90/checkpoints/logs


In [None]:
import os, sys, io, contextlib, datetime

# ==== CONFIG: where to save the log in your Drive ====
LOG_DIR = "/content/drive/MyDrive/ML4GEN DATA/hic_training_logs"
os.makedirs(LOG_DIR, exist_ok=True)

run_name = datetime.datetime.now().strftime("imr90_run_%Y%m%d_%H%M%S")
log_path = os.path.join(LOG_DIR, f"{run_name}.log")

# ==== Tee class: writes to both notebook + file ====
class Tee(io.TextIOBase):
    def __init__(self, *streams):
        self.streams = streams

    def write(self, s):
        for st in self.streams:
            st.write(s)
            st.flush()
        return len(s)

    def flush(self):
        for st in self.streams:
            st.flush()

print(f"👉 Logging this run to:\n{log_path}\n")

# ==== Run training and capture output ====
with open(log_path, "w") as f:
    tee = Tee(sys.stdout, f)
    with contextlib.redirect_stdout(tee), contextlib.redirect_stderr(tee):
        main()  # calls your big training script

print(f"\n✅ Done! Full console output saved to:\n{log_path}")
