In [1]:
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.environ['PYTORCH_JIT'] = '0'

In [3]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.6.0-py3-none-any.whl (849 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m37.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.15.2 pytorch_lightning-2.6.0 torchmetrics-1.8.2


In [4]:
import gc
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import pytorch_lightning.callbacks as callbacks


In [5]:
!pip install pyBigWig

Collecting pyBigWig
  Downloading pyBigWig-0.3.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (19 kB)
Downloading pyBigWig-0.3.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (187 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/187.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m187.1/187.1 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyBigWig
Successfully installed pyBigWig-0.3.24


In [None]:
"""
End-to-End Hi-C Prediction Training with C.Origami Architecture
SPEED OPTIMIZED VERSION FOR A100
WITH DNA METHYLATION DATA
"""
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import pytorch_lightning.callbacks as callbacks
import gc

# ================================================================
# Model Components (C.Origami-style architecture)
# ================================================================

class ConvBlock(nn.Module):
    def __init__(self, size, stride=2, hidden_in=64, hidden=64):
        super().__init__()
        pad_len = int(size / 2)
        self.scale = nn.Sequential(
            nn.Conv1d(hidden_in, hidden, size, stride, pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
        )
        self.res = nn.Sequential(
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Conv1d(hidden, hidden, size, padding=pad_len),
            nn.BatchNorm1d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        scaled = self.scale(x)
        identity = scaled
        res_out = self.res(scaled)
        return self.relu(res_out + identity)


class EncoderSplit(nn.Module):
    """Separate encoders for DNA sequence and epigenomic features"""
    def __init__(self, num_epi, output_size=256, filter_size=5, num_blocks=12):
        super().__init__()
        self.filter_size = filter_size

        self.conv_start_seq = nn.Sequential(
            nn.Conv1d(5, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        self.conv_start_epi = nn.Sequential(
            nn.Conv1d(num_epi, 16, 3, 2, 1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
        )

        hiddens = [32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256, 256]
        hidden_ins = [32, 32, 32, 32, 32, 64, 64, 128, 128, 128, 128, 256]
        hiddens_half = (np.array(hiddens) / 2).astype(int)
        hidden_ins_half = (np.array(hidden_ins) / 2).astype(int)

        self.res_blocks_seq = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)
        self.res_blocks_epi = self._get_res_blocks(num_blocks, hidden_ins_half, hiddens_half)

        self.conv_end = nn.Conv1d(256, output_size, 1)

    def forward(self, x):
        seq = x[:, :5, :]
        epi = x[:, 5:, :]

        seq = self.res_blocks_seq(self.conv_start_seq(seq))
        epi = self.res_blocks_epi(self.conv_start_epi(epi))

        x = torch.cat([seq, epi], dim=1)
        return self.conv_end(x)

    def _get_res_blocks(self, n, his, hs):
        blocks = []
        for hi, h in zip(his, hs):
            blocks.append(ConvBlock(self.filter_size, hidden_in=hi, hidden=h))
        return nn.Sequential(*blocks)


class PositionalEncoding(nn.Module):
    def __init__(self, hidden, dropout=0.1, max_len=2048):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden, 2) * (-np.log(10000.0) / hidden))
        pe = torch.zeros(max_len, 1, hidden)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        if seq_len > self.pe.size(0):
            position = torch.arange(seq_len, device=x.device).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, x.size(2), 2, device=x.device) * (-np.log(10000.0) / x.size(2)))
            pe = torch.zeros(seq_len, 1, x.size(2), device=x.device)
            pe[:, 0, 0::2] = torch.sin(position * div_term)
            pe[:, 0, 1::2] = torch.cos(position * div_term)
            x = x + pe.squeeze(1)
        else:
            x = x + self.pe[:seq_len].squeeze(1)
        return self.dropout(x)


class TransformerModule(nn.Module):
    def __init__(self, hidden=256, layers=8):
        super().__init__()
        self.pos_encoder = PositionalEncoding(hidden, dropout=0.1, max_len=2048)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden,
            nhead=8,
            dropout=0.1,
            dim_feedforward=512,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, layers)

    def forward(self, x):
        x = self.pos_encoder(x)
        return self.transformer(x)


class ResBlockDilated(nn.Module):
    def __init__(self, size, hidden=64, dil=2):
        super().__init__()
        pad_len = dil
        self.res = nn.Sequential(
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
            nn.Conv2d(hidden, hidden, size, padding=pad_len, dilation=dil),
            nn.BatchNorm2d(hidden),
        )
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x
        res_out = self.res(x)
        return self.relu(res_out + identity)


class Decoder(nn.Module):
    def __init__(self, in_channel, hidden=256, filter_size=3, num_blocks=5, output_size=256):
        super().__init__()
        self.output_size = output_size

        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channel, hidden, 3, 1, 1),
            nn.BatchNorm2d(hidden),
            nn.ReLU(),
        )

        self.res_blocks = self._get_res_blocks(num_blocks, hidden)
        self.conv_end = nn.Conv2d(hidden, 1, 1)

    def forward(self, x):
        x = self.conv_start(x)
        x = self.res_blocks(x)
        x = self.conv_end(x).squeeze(1)

        if x.size(-1) != self.output_size or x.size(-2) != self.output_size:
            x = nn.functional.interpolate(
                x.unsqueeze(1),
                size=(self.output_size, self.output_size),
                mode='bilinear',
                align_corners=False
            ).squeeze(1)

        return x

    def _get_res_blocks(self, n, hidden):
        blocks = []
        for i in range(n):
            dilation = 2 ** (i + 1)
            blocks.append(ResBlockDilated(3, hidden=hidden, dil=dilation))
        return nn.Sequential(*blocks)


class ConvTransHiCModel(nn.Module):
    """Full C.Origami-style model: Encoder + Transformer + Decoder"""
    def __init__(self, num_genomic_features=4, mid_hidden=256, output_size=256):
        super().__init__()
        self.encoder = EncoderSplit(num_genomic_features, output_size=mid_hidden, num_blocks=12)
        self.transformer = TransformerModule(hidden=mid_hidden, layers=8)

        self.pre_decoder_conv = nn.Sequential(
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
            nn.Conv1d(mid_hidden, mid_hidden, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm1d(mid_hidden),
            nn.ReLU(),
        )

        self.decoder = Decoder(mid_hidden * 2, hidden=256, output_size=output_size)

    def forward(self, x):
        x = x.float()
        x = self.encoder(x)
        x = x.transpose(1, 2)
        x = self.transformer(x)
        x = x.transpose(1, 2)
        x = self.pre_decoder_conv(x)
        x = self._diagonalize(x)
        return self.decoder(x)

    def _diagonalize(self, x):
        x_i = x.unsqueeze(3).repeat(1, 1, 1, x.size(2))
        x_j = x.unsqueeze(2).repeat(1, 1, x.size(2), 1)
        return torch.cat([x_i, x_j], dim=1)


# ================================================================
# Dataset
# ================================================================

import gzip
import pyBigWig as pbw
from skimage.transform import resize


class SequenceFeature:
    def __init__(self, path):
        print(f'Reading sequence: {path}')
        with gzip.open(path, 'r') as f:
            seq = f.read().decode("utf-8")
            seq = seq[seq.find('\n')+1:].replace('\n', '').lower()
        self.seq = seq

    def get(self, start, end):
        seq = self.seq[start:end]
        en_dict = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
        idx = np.array([en_dict.get(ch, 4) for ch in seq], dtype=int)
        onehot = np.zeros((len(seq), 5), dtype=np.float32)
        if len(seq) > 0:
            onehot[np.arange(len(seq)), idx] = 1.0
        return onehot

    def __len__(self):
        return len(self.seq)


class HiCFeature:
    def __init__(self, path):
        print(f'Reading Hi-C: {path}')
        self.hic = dict(np.load(path))

    def get(self, start, window=2097152, res=10000):
        start_bin = int(start / res)
        end_bin = start_bin + int(window / res)
        return self._diag_to_mat(self.hic, start_bin, end_bin)

    def _diag_to_mat(self, ori_load, start, end):
        square_len = end - start
        diag_load = {}
        for d in range(square_len):
            diag_load[str(d)] = ori_load[str(d)][start:start + (square_len - d)]
            diag_load[str(-d)] = ori_load[str(-d)][start:start + (square_len - d)]

        rows = []
        for d in range(square_len):
            row = []
            for line in range(-d, -d + square_len):
                if line < 0:
                    row.append(diag_load[str(line)][line + d])
                else:
                    row.append(diag_load[str(line)][d])
            rows.append(row)
        return np.array(rows, dtype=np.float32)

    def __len__(self):
        return len(self.hic['0'])


class GenomicFeature:
    def __init__(self, path, norm):
        self.path = path
        self.norm = norm
        print(f'Feature: {path}, Norm: {norm}')

    def get(self, chr_name, start, end):
        with pbw.open(self.path) as bw:
            signals = np.array(bw.values(chr_name, int(start), int(end)))
            signals = np.nan_to_num(signals, 0.0)
            if self.norm == 'log':
                signals = np.log(signals + 1)
            return signals.astype(np.float32)

    def length(self, chr_name):
        with pbw.open(self.path) as bw:
            return bw.chroms(chr_name)


class ChromosomeDataset(torch.utils.data.Dataset):
    def __init__(self, celltype_root, chr_name, omit_regions, feature_list, use_aug=True):
        self.res = 10000
        self.sample_bins = 500
        self.stride = 50
        self.image_scale = 256
        self.chr_name = chr_name
        self.use_aug = use_aug

        print(f'Loading {chr_name}...')
        self.seq = SequenceFeature(f'{celltype_root}/../dna_sequence/{chr_name}.fa.gz')
        self.genomic_features = feature_list
        self.mat = HiCFeature(f'{celltype_root}/hic_matrix/{chr_name}.npz')
        self.omit_regions = omit_regions

        self._check_length()
        all_intervals = self._get_intervals()
        self.intervals = self._filter(all_intervals, omit_regions)

    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, idx):
        start, end = self.intervals[idx]
        target_size = int(self.sample_bins * self.res)

        if self.use_aug:
            offset = np.random.randint(0, end - start - target_size + 1)
        else:
            offset = 0

        start, end = start + offset, start + offset + target_size

        seq = self.seq.get(start, end)
        features = [f.get(self.chr_name, start, end) for f in self.genomic_features]
        mat = self.mat.get(start)
        mat = resize(mat, (self.image_scale, self.image_scale), anti_aliasing=True)
        mat = np.log(mat + 1).astype(np.float32)

        if self.use_aug:
            seq = self._gaussian_noise(seq, 0.1)
            features = [self._gaussian_noise(f, 0.1) for f in features]
            seq, features, mat = self._reverse_complement(seq, features, mat)

        return seq, features, mat

    def _gaussian_noise(self, x, std=0.1):
        return x + np.random.randn(*x.shape).astype(np.float32) * std

    def _reverse_complement(self, seq, features, mat, chance=0.5):
        if np.random.rand() < chance:
            seq = np.flip(seq, 0).copy()
            seq = np.concatenate([seq[:, 1:2], seq[:, 0:1], seq[:, 3:4], seq[:, 2:3], seq[:, 4:5]], axis=1)
            features = [np.flip(f, 0).copy() for f in features]
            mat = np.flip(mat, [0, 1]).copy()
        return seq, features, mat

    def _get_intervals(self):
        chr_bins = len(self.seq) / self.res
        n = max(0, int((chr_bins - self.sample_bins) / self.stride))
        starts = np.arange(n).reshape(-1, 1) * self.stride
        bins = np.concatenate([starts, starts + self.sample_bins], axis=1) if n > 0 else np.zeros((0, 2))
        return (bins * self.res).astype(int)

    def _filter(self, intervals, omit_regions):
        if omit_regions is None or len(omit_regions) == 0:
            return intervals.tolist()

        valid = []
        for s, e in intervals:
            if np.sum((s <= omit_regions[:, 1]) & (omit_regions[:, 0] <= e)) == 0:
                valid.append([int(s), int(e)])
        return valid

    def _check_length(self):
        if self.genomic_features:
            assert len(self.seq) == self.genomic_features[0].length(self.chr_name)
            assert abs(len(self.seq) / self.res - len(self.mat)) < 2


# ================================================================
# Evaluation Metrics (SIMPLIFIED FOR SPEED)
# ================================================================

from scipy.stats import pearsonr, spearmanr
from scipy.ndimage import uniform_filter1d


def upper_triangle_flatten(mat):
    iu = np.triu_indices_from(mat, k=1)
    return mat[iu]


def compute_global_corr(pred, true):
    p, t = upper_triangle_flatten(pred), upper_triangle_flatten(true)
    mask = np.isfinite(p) & np.isfinite(t)
    if mask.sum() < 3:
        return np.nan, np.nan
    return pearsonr(p[mask], t[mask])[0], spearmanr(p[mask], t[mask])[0]


# ================================================================
# Custom Collate Function
# ================================================================

def collate_fn(batch):
    seqs, features_list, mats = zip(*batch)
    seqs = torch.from_numpy(np.stack(seqs))

    num_features = len(features_list[0])
    features = [np.stack([f[i] for f in features_list]) for i in range(num_features)]
    features = [torch.from_numpy(f) for f in features]

    mats = torch.from_numpy(np.stack(mats))
    return seqs, features, mats


# ================================================================
# PyTorch Lightning Module (SIMPLIFIED FOR SPEED)
# ================================================================

class HiCTrainingModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(ignore=['args'])
        self.args = args
        # UPDATED: num_genomic_features=4 (CTCF, ATAC, DNAmeth minus, DNAmeth plus)
        self.model = ConvTransHiCModel(num_genomic_features=4, mid_hidden=256, output_size=256)

    def forward(self, x):
        return self.model(x)

    def _proc_batch(self, batch):
        seq, features, mat = batch
        features = torch.stack(features, dim=2)
        inputs = torch.cat([seq, features], dim=2)
        inputs = inputs.transpose(1, 2)
        return inputs, mat

    def training_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)

        # SPEED: Only compute correlation on subset of validation
        if batch_idx == 0:  # Only first batch
            pred = outputs[0].detach().cpu().numpy()
            true = mat[0].detach().cpu().numpy()
            gp, gs = compute_global_corr(pred, true)
            self.log('val_pearson', gp, prog_bar=True)

        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, mat = self._proc_batch(batch)
        outputs = self(inputs)
        loss = nn.functional.mse_loss(outputs, mat)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=3e-4, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=3e-4,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1,
            anneal_strategy='cos'
        )
        return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}


# ================================================================
# Main Training Script
# ================================================================

def load_centrotelo(bed_path):
    if not os.path.exists(bed_path):
        return {}

    import pandas as pd
    df = pd.read_csv(bed_path, sep='\t', names=['chr', 'start', 'end'])
    result = {}
    for chr_name, group in df.groupby('chr'):
        result[chr_name] = group[['start', 'end']].to_numpy(dtype=int)
    return result


def get_available_chromosomes(celltype_root):
    hic_dir = f'{celltype_root}/hic_matrix'
    available = []
    if os.path.exists(hic_dir):
        for f in os.listdir(hic_dir):
            if f.endswith('.npz'):
                chr_name = f.replace('.npz', '')
                available.append(chr_name)
    return sorted(available)


def get_dataset(args, mode):
    celltype_root = f'{args.data_root}/{args.assembly}/{args.celltype}'

    # UPDATED: Added DNA methylation features (minus and plus strand)
    features = [
        GenomicFeature(f'{celltype_root}/genomic_features/ctcf_log2fc.bw', norm=None),
        GenomicFeature(f'{celltype_root}/genomic_features/atac.bw', norm='log'),
        GenomicFeature(f'{celltype_root}/genomic_features/dnameth-minusstrand.bigWig', norm=None),
        GenomicFeature(f'{celltype_root}/genomic_features/dnameth-plusstrand.bigWig', norm=None),
    ]

    centrotelo = load_centrotelo(f'{celltype_root}/../centrotelo.bed')
    available_chrs = get_available_chromosomes(celltype_root)

    print(f"\nAvailable chromosomes: {available_chrs}")

    if mode == 'train':
        desired = [f'chr{i}' for i in range(1, 23) if i not in [10, 15]]
    elif mode == 'val':
        desired = ['chr10']
    elif mode == 'test':
        desired = ['chr15']
    else:
        raise ValueError(f'Unknown mode: {mode}')

    chr_names = [c for c in desired if c in available_chrs]

    if len(chr_names) == 0:
        raise ValueError(f"No valid chromosomes found for mode '{mode}'")

    print(f"Using chromosomes for {mode}: {chr_names}\n")

    datasets = []
    for chr_name in chr_names:
        omit = centrotelo.get(chr_name, np.zeros((0, 2), dtype=int))
        use_aug = (mode == 'train')

        try:
            ds = ChromosomeDataset(celltype_root, chr_name, omit, features, use_aug=use_aug)
            datasets.append(ds)
            print(f"✓ Loaded {chr_name}: {len(ds)} samples")
        except FileNotFoundError as e:
            print(f"⚠️ Skipping {chr_name}: File not found")
            continue
        except Exception as e:
            print(f"⚠️ Skipping {chr_name}: {str(e)}")
            continue

    if len(datasets) == 0:
        raise ValueError(f"No datasets loaded for mode '{mode}'")

    print(f"Total {mode} samples: {sum(len(d) for d in datasets)}\n")
    return torch.utils.data.ConcatDataset(datasets)


def main():
    # Clear memory first
    torch.cuda.empty_cache()
    gc.collect()

    # Check GPU availability
    print("="*70)
    print("SYSTEM CHECK")
    print("="*70)
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
        print(f"CUDA version: {torch.version.cuda}")
    else:
        print("⚠️ WARNING: No GPU detected! Training will be VERY slow on CPU.")
        print("   To enable GPU in Colab: Runtime > Change runtime type > Hardware accelerator > GPU")
    print("="*70)

    # Enable optimizations for A100
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')

    class Args:
        data_root = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90"
        assembly = "hg38"
        celltype = "imr90"
        batch_size = 12
        num_workers = 2  # Reduced from 4 to 2 as per warning
        max_epochs = 10
        patience = 5
        num_gpus = 1 if torch.cuda.is_available() else 0  # Auto-detect GPU
        save_path = "/content/drive/MyDrive/ML4GEN DATA/data - IMR90/checkpoints_dna_meth"
        save_top_k = 3

    args = Args()

    pl.seed_everything(2077, workers=True)

    os.makedirs(f'{args.save_path}/models', exist_ok=True)
    os.makedirs(f'{args.save_path}/logs', exist_ok=True)

    print("="*70)
    print("LOADING DATASETS")
    print("="*70)

    train_ds = get_dataset(args, 'train')
    val_ds = get_dataset(args, 'val')
    test_ds = get_dataset(args, 'test')

    train_loader = DataLoader(
        train_ds,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=collate_fn,
        prefetch_factor=2,
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )

    model = HiCTrainingModule(args)

    checkpoint_cb = callbacks.ModelCheckpoint(
        dirpath=f'{args.save_path}/models',
        filename='hic-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        save_top_k=args.save_top_k,
        mode='min',
        save_last=True,
    )

    early_stop_cb = callbacks.EarlyStopping(
        monitor='val_loss',
        patience=args.patience,
        mode='min'
    )

    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')
    csv_logger = pl.loggers.CSVLogger(save_dir=f'{args.save_path}/logs')

    trainer = pl.Trainer(
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=args.num_gpus if args.num_gpus > 0 else 'auto',
        max_epochs=args.max_epochs,
        callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
        logger=csv_logger,
        gradient_clip_val=1.0,
        precision='16-mixed' if torch.cuda.is_available() else '32',
        enable_progress_bar=True,
        enable_model_summary=True,
        log_every_n_steps=10,
        val_check_interval=0.25,
    )

    print("\n" + "="*70)
    print("STARTING TRAINING")
    print(f"Batch size: {args.batch_size}")
    print(f"Max epochs: {args.max_epochs}")
    print(f"Training samples: {len(train_ds)}")
    print("="*70)

    trainer.fit(model, train_loader, val_loader)

    print("\n" + "="*70)
    print("TESTING BEST MODEL")
    print("="*70)

    trainer.test(model, test_loader, ckpt_path='best')

    print(f"\n✅ Training complete! Best model saved to {args.save_path}/models")


if __name__ == '__main__':
    main()

  _C._set_float32_matmul_precision(precision)
INFO:lightning_fabric.utilities.seed:Seed set to 2077


SYSTEM CHECK
PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB
CUDA version: 12.6
LOADING DATASETS
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/ctcf_log2fc.bw, Norm: None
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/atac.bw, Norm: log
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/dnameth-minusstrand.bigWig, Norm: None
Feature: /content/drive/MyDrive/ML4GEN DATA/data - IMR90/hg38/imr90/genomic_features/dnameth-plusstrand.bigWig, Norm: None

Available chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']
Using chromosomes for train: ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20'

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores



STARTING TRAINING
Batch size: 12
Max epochs: 10
Training samples: 4737


/usr/local/lib/python3.12/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:881: Checkpoint directory /content/drive/.shortcut-targets-by-id/126vLgWy4wFKfcpk6pUYC6UiOP4G1DHi2/ML4GEN DATA/data - IMR90/checkpoints_dna_meth/models exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.12/dist-packages/pytorch_lightning/utilities/model_summary/model_summary.py:242: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.


Output()