In [1]:
import cooler
import cooltools
from cooltools.api.saddle import _make_cis_obsexp_fetcher
from cooltools.lib import numutils

import bioframe as bf
import numpy as np

import multiprocessing as mp

from sklearn.preprocessing import QuantileTransformer

import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

  @numba.jit  # (nopython=True)
  @numba.jit  # (nopython=True)


In [2]:
# define parameters
nProc = 8
binSize = 10_000
balance_column = "weight"

#samples = ["4DNFIBM9QCFG", "t0_q30", "t12_q30"] # include 4DNFIBM9QCFG if the aim is to generate graph dataset for training
samples = ["t0_q30", "t12_q30"]

mcool_paths = [
    f"matrices/{sample}.mcool::resolutions/{binSize}"
    for sample in samples
]
npz_path = f"matrix/HeLa_{binSize}.obs_exp_qt.npz"

min_diag = 2
max_diag = -1

In [7]:
# connections to mcool

clr_ = [cooler.Cooler(path) for path in mcool_paths]

In [8]:
# fetch chr arms to calculate expected
hg38_chromsizes = bf.fetch_chromsizes("hg38")
hg38_cens = bf.fetch_centromeres("hg38")
hg38_arms = bf.make_chromarms(hg38_chromsizes, hg38_cens)
hg38_arms = hg38_arms[hg38_arms.chrom.isin(clr_[0].chromnames)].reset_index(drop=True)
view_df = hg38_arms[hg38_arms["chrom"] != "chrY"]

In [9]:
# calculate expected values
expected_ = [
    cooltools.expected_cis(
        clr=clr,
        view_df=view_df,
        smooth=False,
        aggregate_smoothed=False,
        nproc=nProc,
        clr_weight_name=balance_column,
        ignore_diags=min_diag,
    )
    for clr in clr_
]

In [10]:
# generate region fetcher functions
getmatrix_ = [
    _make_cis_obsexp_fetcher(
        clr_[i],
        expected_[i],
        view_df,
    )
    for i, _ in enumerate(samples)
]

In [11]:
def _matrixPairs(reg, minD, maxD, scaler=None, returnNan=False):

    matrix_ = [getmatrix_[i](reg, reg) for i, _ in enumerate(samples)]

    for i, matrix in enumerate(matrix_):
        matrix[matrix == 0] = np.nan

        for d in np.arange(-minD + 1, minD):
            numutils.set_diag(matrix, np.nan, d)
        if maxD >= 0:
            for d in np.append(
                np.arange(-matrix.shape[0], -maxD),
                np.arange(maxD + 1, matrix.shape[0]),
            ):
                numutils.set_diag(matrix, np.nan, d)
        if scaler is not None:
            eW = matrix[~np.isnan(matrix)]
            if eW.shape[0] > 0:
                eW = eW[eW > 0]

                if scaler == 'QT':
                    scaler_ = QuantileTransformer(n_quantiles=1000, output_distribution='uniform', random_state=42)
                    scaledData = scaler_.fit_transform(eW.reshape(-1, 1))
                elif scaler is None:
                    scaledData = eW.reshape(-1, 1)
                else:
                    scaledData = scaler.fit_transform(eW.reshape(-1, 1))
                
                matrix[~np.isnan(matrix)] = scaledData.flatten()
                
                if returnNan:
                    matrix_[i] = matrix
                else:
                    matrix_[i] = np.nan_to_num(matrix, nan=0)
            else:
                matrix_[i] = np.zeros_like(matrix)
        else:
            matrix_[i] = np.nan_to_num(matrix, nan=0) if not returnNan else matrix

    return {reg: matrix_}

In [12]:
args_list = [(region, min_diag, max_diag, 'QT', False) for region in view_df["name"]]

with mp.Pool(nProc) as mp_pool:
    mp_ = mp_pool.starmap(
        _matrixPairs,
        args_list,
    )

In [47]:
matrix_dict = dict()
for mp, region in zip(mp_, view_df["name"]):
    for i, sample in enumerate(samples):
        assert list(mp.keys())[0] == region
        matrix_dict.update({f"{sample}-{region}": mp[region][i]})

In [48]:
np.savez(npz_path, **matrix_dict)

In [49]:
# test loading saved npz data

npz = np.load(npz_path)
npz.files

['t0_q30-chr1_p',
 't12_q30-chr1_p',
 't30_q30-chr1_p',
 't60_q30-chr1_p',
 't0_q30-chr1_q',
 't12_q30-chr1_q',
 't30_q30-chr1_q',
 't60_q30-chr1_q',
 't0_q30-chr2_p',
 't12_q30-chr2_p',
 't30_q30-chr2_p',
 't60_q30-chr2_p',
 't0_q30-chr2_q',
 't12_q30-chr2_q',
 't30_q30-chr2_q',
 't60_q30-chr2_q',
 't0_q30-chr3_p',
 't12_q30-chr3_p',
 't30_q30-chr3_p',
 't60_q30-chr3_p',
 't0_q30-chr3_q',
 't12_q30-chr3_q',
 't30_q30-chr3_q',
 't60_q30-chr3_q',
 't0_q30-chr4_p',
 't12_q30-chr4_p',
 't30_q30-chr4_p',
 't60_q30-chr4_p',
 't0_q30-chr4_q',
 't12_q30-chr4_q',
 't30_q30-chr4_q',
 't60_q30-chr4_q',
 't0_q30-chr5_p',
 't12_q30-chr5_p',
 't30_q30-chr5_p',
 't60_q30-chr5_p',
 't0_q30-chr5_q',
 't12_q30-chr5_q',
 't30_q30-chr5_q',
 't60_q30-chr5_q',
 't0_q30-chr6_p',
 't12_q30-chr6_p',
 't30_q30-chr6_p',
 't60_q30-chr6_p',
 't0_q30-chr6_q',
 't12_q30-chr6_q',
 't30_q30-chr6_q',
 't60_q30-chr6_q',
 't0_q30-chr7_p',
 't12_q30-chr7_p',
 't30_q30-chr7_p',
 't60_q30-chr7_p',
 't0_q30-chr7_q',
 't12_q