In [None]:
import numpy as np
import pandas as pd
import os, re
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import multiply
from tensorflow.keras import backend as K
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import Sequence

tf.random.set_seed(123)

In [None]:
def _parse_chr_from_base(base: str) -> str:
    # ex) GM12878_chr10_25kb -> chr10
    m = re.search(r'_(chr[^_]+)_', base)
    if not m:
        raise ValueError(f"cannot parse chromosome from: {base}")
    return m.group(1)

In [None]:
# Generate submatrices

def get_ontad_all_bins(tad_file: str) -> np.ndarray:
    df = pd.read_csv(tad_file, header=None, sep=r'\s+', comment='#')
    if len(df) == 0:
        return np.array([], dtype=int)
    df = df.iloc[1:, :]  # skip level 0
    starts = df.iloc[:, 0].astype(int).to_numpy() - 1
    ends   = df.iloc[:, 1].astype(int).to_numpy() - 1
    return np.unique(np.concatenate([starts, ends]))

def load_hic_bins_from_bed(boundary_bed: str, chrom: str, res: int = 25000, use_center: bool = True) -> np.ndarray:
    if not os.path.exists(boundary_bed):
        return np.array([], dtype=int)

    df = pd.read_csv(boundary_bed, sep='\t', header=None, comment='#')
    if df.shape[1] < 3:
        raise ValueError(f"unexpected columns in {boundary_bed}")
    df = df[df.iloc[:, 0] == chrom].copy()
    if df.empty:
        return np.array([], dtype=int)

    start = df.iloc[:, 1].astype(np.int64).to_numpy()
    end   = df.iloc[:, 2].astype(np.int64).to_numpy()

    if use_center:
        pos = (start + end) // 2
        bins = np.rint(pos / res).astype(int)
    else:
        bins = (start // res).astype(int)
    return np.unique(bins)

def _intersect_with_tol(a: np.ndarray, b: np.ndarray, tol: int = 0) -> np.ndarray:
    a = np.unique(a); b = np.unique(b)
    if tol <= 0:
        return np.intersect1d(a, b)
    bset = set(b)
    hit = []
    for x in a:
        ok = False
        for t in range(-tol, tol+1):
            if (x + t) in bset:
                ok = True; break
        if ok:
            hit.append(x)
    return np.array(sorted(set(hit)), dtype=int)


def generate_samples(matrix, boundary_bins, output_path):
    padded_matrix = np.pad(matrix, ((7,7), (7,7)), mode='constant')
    samples = []
    for bin_idx in range(7, 7 + len(matrix)):
        matrix_patch = padded_matrix[bin_idx-7:bin_idx+8, bin_idx-7:bin_idx+8]
        label = 1 if (bin_idx - 7) in boundary_bins else 0
        samples.append(np.append(matrix_patch.flatten(), label))
    samples = np.array(samples).astype('float32')
    np.save(output_path, samples)

def create_samples_intersection(
    matrix_dir: str,
    tad_dir: str,
    hic_out_dir: str,          # hicFindTADs --outPrefix directory (where *_boundaries.bed is)
    output_dir: str,
    tol: int = 0
):

    os.makedirs(output_dir, exist_ok=True)

    for file_name in sorted(os.listdir(matrix_dir)):

        base = os.path.splitext(file_name)[0]  # ex) GM12878_chr10_25kb
        chrom = _parse_chr_from_base(base)

        matrix_path = os.path.join(matrix_dir, file_name)
        tad_path    = os.path.join(tad_dir, base + '.tad')
        hic_bed     = os.path.join(hic_out_dir, base + '_boundaries.bed')

        if not os.path.exists(matrix_path):
            print(f"[WARN] missing matrix: {matrix_path}"); continue
        if not os.path.exists(tad_path):
            print(f"[WARN] missing tad: {tad_path}"); continue
        if not os.path.exists(hic_bed):
            print(f"[WARN] missing hic boundaries: {hic_bed}"); continue

        matrix = np.loadtxt(matrix_path)

        # OnTAD (excluding level 0)
        ontad_bins_all = get_ontad_all_bins(tad_path)

        # HiCExplorer: boundaries.bed → 25kb bin
        hic_bins = load_hic_bins_from_bed(hic_bed, chrom=chrom, res=25000, use_center=True)

        # Intersection
        inter_bins = _intersect_with_tol(ontad_bins_all, hic_bins, tol=tol)
        pos_set = set(inter_bins.tolist())

        output_path = os.path.join(output_dir, file_name.replace('.txt', '.npy'))

        generate_samples(matrix, pos_set, output_path)
        print(f"[OK] {base}: |OnTAD(all)|={len(ontad_bins_all)} |HiC|={len(hic_bins)} |∩|={len(pos_set)} -> {output_path}")

In [None]:
# Training dataset
create_samples_intersection(
    matrix_dir = '/hic/matrix/txt/file/path',
    tad_dir    = '/ontad/output//tad/file/path',
    hic_out_dir= '/HiCexplorer/output/bed/file/path',
    output_dir = '/training/dataset/output/directory',
    tol        = 0
)

In [None]:
# Test dataset
create_samples_intersection(
    matrix_dir = '/hic/matrix/txt/file/path',
    tad_dir    = '/ontad/output//tad/file/path',
    hic_out_dir= '/HiCexplorer/output/bed/file/path',
    output_dir = '/test/dataset/output/directory',
    tol        = 0
)

## Postprocessing

In [None]:
# only training data
def training_postprocess(submatrix_data, output_path):
    pos_samples = submatrix_data[submatrix_data[:, -1] == 1]
    neg_samples = submatrix_data[submatrix_data[:, -1] == 0]

    rotated_samples = []
    for sample in pos_samples:
        matrix = sample[:-1].reshape(15, 15)
        rotated = np.rot90(matrix, k=-1).flatten()
        rotated_sample = np.append(rotated, 1)
        rotated_samples.append(rotated_sample)
    rotated_samples = np.array(rotated_samples)

    num_pos = len(pos_samples) + len(rotated_samples)
    neg_limit = min(len(neg_samples), num_pos * 4)
    sampled_neg = neg_samples[np.random.choice(len(neg_samples), size=neg_limit, replace=False)]

    final_data = np.vstack([pos_samples, rotated_samples, sampled_neg])
    np.random.shuffle(final_data)
    np.save(output_path, final_data)

    print(f"Processed and saved to {output_path}")
    print(f"Positive samples (including augmented): {len(pos_samples) + len(rotated_samples)}")
    print(f"Negative samples (limited): {len(sampled_neg)}")

def process_entire_dataset(training_root_folder):
    print(f"\nProcessing (flat): {training_root_folder} ...")

    for file_name in sorted(os.listdir(training_root_folder)):
        if not file_name.endswith('.npy'):
            continue

        file_path = os.path.join(training_root_folder, file_name)
        output_file = file_name.replace('.npy', '_processed.npy')
        output_path = os.path.join(training_root_folder, output_file)

        submatrix_data = np.load(file_path)
        training_postprocess(submatrix_data, output_path)

In [None]:
training_root = 'postprocessed/training/dataset/output/directory'
process_entire_dataset(training_root)

## Generate model weight file(h5)

In [None]:
def channel_attention(input_feature, ratio=8):
    channel = input_feature.shape[-1]
    filters = max(1, int(channel//ratio))
    shared_layer_one = tf.keras.layers.Dense(filters,
                             activation='relu',
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')
    shared_layer_two = tf.keras.layers.Dense(channel,
                             kernel_initializer='he_normal',
                             use_bias=True,
                             bias_initializer='zeros')

    avg_pool = tf.keras.layers.GlobalAveragePooling2D()(input_feature)    
    avg_pool = tf.keras.layers.Reshape((1,1,channel))(avg_pool)
    avg_pool = shared_layer_one(avg_pool)
    avg_pool = shared_layer_two(avg_pool)

    max_pool = tf.keras.layers.GlobalMaxPooling2D()(input_feature)
    max_pool = tf.keras.layers.Reshape((1,1,channel))(max_pool)
    max_pool = shared_layer_one(max_pool)
    max_pool = shared_layer_two(max_pool)
   

    cbam_feature = tf.keras.layers.Add()([avg_pool,max_pool])
    cbam_feature = tf.keras.layers.Activation('sigmoid')(cbam_feature)


    return multiply([input_feature, cbam_feature])

In [None]:
class ACF(layers.Layer):
    def __init__(self, D=1, mid_ch=32, k_row=5, k_col=5, return_center=True, **kwargs):
        super().__init__(**kwargs)
        self.D = D
        self.mid_ch = mid_ch
        self.k_row = k_row
        self.k_col = k_col
        self.return_center = return_center

        self.row_conv = keras.Sequential([
            layers.SeparableConv2D(self.mid_ch, kernel_size=(3, self.k_row),
                                   padding='same', activation='relu'),
            layers.Conv2D(self.mid_ch, kernel_size=1, activation='relu')
        ])

        self.col_conv = keras.Sequential([
            layers.SeparableConv2D(self.mid_ch, kernel_size=(self.k_col, 3),
                                   padding='same', activation='relu'),
            layers.Conv2D(self.mid_ch, kernel_size=1, activation='relu')
        ])

        self.gate_row = keras.Sequential([layers.Conv2D(1, 1, activation='sigmoid')])
        self.gate_col = keras.Sequential([layers.Conv2D(1, 1, activation='sigmoid')])

        self.head = keras.Sequential([
            layers.Conv2D(self.mid_ch, 1, activation='relu'),
            layers.Conv2D(1, 1, activation=None)
        ])

    def _shift(self, x, dy=0, dx=0):
        B, H, W, C = tf.unstack(tf.shape(x))
        pad_top  = tf.maximum( dy, 0)
        pad_bot  = tf.maximum(-dy, 0)
        pad_left = tf.maximum( dx, 0)
        pad_right= tf.maximum(-dx, 0)
        xpad = tf.pad(x, [[0,0],[pad_top,pad_bot],[pad_left,pad_right],[0,0]])
        y0 = tf.maximum(-dy, 0)
        y1 = y0 + H
        x0 = tf.maximum(-dx, 0)
        x1 = x0 + W
        return xpad[:, y0:y1, x0:x1, :]

    def call(self, x, training=None):
        diffs_row = []
        diffs_col = []
        for d in range(1, self.D+1):
            down = self._shift(x, dy=+d)
            up   = self._shift(x, dy=-d)
            right= self._shift(x, dx=+d)
            left = self._shift(x, dx=-d)

            dr = down - x
            ur = up   - x
            rc = right- x
            lc = left - x

            diffs_row += [dr, ur, tf.abs(dr), tf.abs(ur)]
            diffs_col += [rc, lc, tf.abs(rc), tf.abs(lc)]

        Fr = tf.concat(diffs_row, axis=-1)
        Fc = tf.concat(diffs_col, axis=-1)

        R = self.row_conv(Fr)
        C = self.col_conv(Fc)

        gr = self.gate_row(R)
        gc = self.gate_col(C)

        R = R * gr
        C = C * gc

        fused = tf.concat([R, C], axis=-1)
        P = self.head(fused)

        if self.return_center:
            H = tf.shape(P)[1]
            W = tf.shape(P)[2]
            i = H // 2
            j = W // 2
            center = P[:, i:i+1, j:j+1, :]
            center = tf.reshape(center, [tf.shape(P)[0], 1])
            return center
        else:
            return P

In [None]:
def cbam_channel_only(x, ratio=8):
    return channel_attention(x, ratio)

def init_model():
    inputs = layers.Input(shape=(15,15,1))

    x = layers.Conv2D(128, 3, padding='same', activation='relu', kernel_initializer='he_normal')(inputs)
    x = layers.MaxPooling2D((3,3), strides=(3,3))(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu', kernel_initializer='he_normal')(x)
    x = cbam_channel_only(x, ratio=8)

    acf_logit = ACF(D=2, mid_ch=32, k_row=5, k_col=5, return_center=True)(x)

    x = layers.Dense(128, activation=None, kernel_initializer='he_normal')(acf_logit)
    x = layers.PReLU()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(64, activation=None, kernel_initializer='he_normal')(x)
    x = layers.PReLU()(x)
    x = layers.Dropout(0.4)(x)


    outputs = layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.01),
        optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
        metrics=[keras.metrics.TruePositives(name='tp'),
                 keras.metrics.FalsePositives(name='fp'),
                 keras.metrics.TrueNegatives(name='tn'),
                 keras.metrics.FalseNegatives(name='fn'),
                 keras.metrics.BinaryAccuracy(name='accuracy'),
                 keras.metrics.Precision(name='precision'),
                 keras.metrics.Recall(name='recall')]
    )
    return model



model = init_model()
model.summary()

In [None]:
def fun(a):
    oshape = a.shape
    a = a.reshape(-1,15).astype('float32')
    mean = np.mean(a,axis = 1).reshape(-1,1)
    a = a - mean
    sqrt = (np.sqrt(a.var(axis =1))+1e-10).reshape(-1,1)
    a = a/sqrt
    return a.reshape(oshape)
    
class load_testdata(Sequence):
    def __init__(self, x_y_set, batch_size):
        self.x_y_set = x_y_set
        self.x = self.x_y_set[:,:]
        self.batch_size = batch_size

    def __len__(self):
        
        return math.floor(len(self.x) / self.batch_size)
        
    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        batch_x = batch_x.reshape(-1,15,15,1)
           
        return np.array(batch_x)

In [None]:
def load_all_processed_npy(folder_path):
    all_data = []
    for file in os.listdir(folder_path):
        if file.endswith('_processed.npy'):
            path = os.path.join(folder_path, file)
            data = np.load(path)
            all_data.append(data)
    if all_data:
        return np.vstack(all_data)
    else:
        return None

In [None]:
def train_and_save_model(train_data, output_h5):
    X_flat = train_data[:, :-1].astype('float32')
    X_flat = fun(X_flat) 
    X = X_flat.reshape(-1, 15, 15, 1) 
    y = train_data[:, -1].astype('float32')
    model = init_model()
    model.fit(X, y, epochs=30, batch_size=128, verbose=2)
    model.save_weights(output_h5)

In [None]:
base_dir = "/training/dataset/directory"
output_dir = "/weight/output/directory"

os.makedirs(output_dir, exist_ok=True)

In [None]:
# Train
level_path = base_dir
data = load_all_processed_npy(level_path)
if data is not None:
    h5_path = os.path.join(output_dir, "model.weights.h5")
    train_and_save_model(data, h5_path)