In [10]:
import numpy as np
import pod5 as p5
import pysam
import logging
from sklearn.preprocessing import normalize
import pandas as pd
from statsmodels import robust

logger = logging.getLogger("test_logger")
logger.setLevel(logging.DEBUG)
test_log = logging.FileHandler(
    "/home/xiaoyf/methylation/deepsignal/log/normalize.log", "a", encoding="utf-8"
)
formatter = logging.Formatter(
    "%(asctime)s - %(filename)s - line:%(lineno)d - %(levelname)s - %(message)s"
)
test_log.setFormatter(formatter)
# 加载文件到logger对象中
logger.addHandler(test_log)

In [4]:
def extract_signal_from_pod5(pod5_path) -> np.array:
    signals = []
    with p5.Reader(pod5_path) as reader:
        for read_record in reader.reads():
            if read_record.signal is None:
                logger.critical(
                    "Signal is None for read id {}".format(read_record.read_id)
                )
            # signals[str(read_record.read_id)] = {'signal':read_record.signal,'shift':read_record.calibration.offset,'scale':read_record.calibration.scale}#不加str会变成UUID，很奇怪
            signals.append(
                [
                    str(read_record.read_id),
                    read_record.signal.astype(np.int16),
                    np.int16(read_record.calibration.offset),
                    np.float16(read_record.calibration.scale),
                ]
            )
            # 0:read_id,1:signal,2:shift,3:scale
    return np.array(signals, dtype=object)  # np.array is small than list


def extract_move_from_bam(bam_path) -> np.array:
    seq_move = []
    bamfile = pysam.AlignmentFile(bam_path, "rb", check_sq=False)
    try:
        for read in bamfile.fetch(until_eof=True):  # 暂时不使用索引，使用返回是空值
            # print(read.query_name)
            tags = dict(read.tags)
            mv_tag = tags["mv"]
            ts_tag = tags["ts"]
            sm_tag = tags["sm"]
            sd_tag = tags["sd"]
            # read.update({read.query_name:{"sequence":read.query_sequence,"stride":mv_tag[0],"mv_table":np.array(mv_tag[1:]),"num_trimmed":ts_tag,"shift":sm_tag,"scale":sd_tag}})
            seq_move.append(
                [
                    read.query_name,
                    read.query_sequence,
                    np.int16(mv_tag[0]),
                    np.array(mv_tag[1:], dtype=np.int16),
                    np.int16(ts_tag),
                    np.float16(sm_tag),
                    np.float16(sd_tag),
                ]
            )
    except ValueError:
        print("bam don't has index")
        for read in bamfile.fetch(until_eof=True, multiple_iterators=False):
            tags = dict(read.tags)
            mv_tag = tags["mv"]
            ts_tag = tags["ts"]
            sm_tag = tags["sm"]
            sd_tag = tags["sd"]
            seq_move.append(
                [
                    read.query_name,
                    read.query_sequence,
                    np.int16(mv_tag[0]),
                    np.array(mv_tag[1:], dtype=np.int16),
                    np.int16(ts_tag),
                    np.float16(sm_tag),
                    np.float16(sd_tag),
                ]
            )
            # 0:read_id,1:sequence,2:stride,3:mv_table,4:num_trimmed,5:to_norm_shift,6:to_norm_scale
            # read[read.query_name] = {"sequence":read.query_sequence,"stride":mv_tag[0],"mv_table":np.array(mv_tag[1:]),"num_trimmed":ts_tag,"shift":sm_tag,"scale":sd_tag}
    return np.array(seq_move, dtype=object)


def read_from_pod5_bam(pod5_path, bam_path, read_id=None) -> np.array:
    read = []
    signal = extract_signal_from_pod5(pod5_path)
    seq_move = extract_move_from_bam(bam_path)
    if read_id is not None:
        for i in range(len(seq_move)):
            if seq_move[i][0] == read_id:
                if seq_move[i][1] is not None:
                    for j in range(len(signal)):
                        if signal[j][0] == seq_move[i][0]:
                            read.append(
                                [
                                    signal[j][0],
                                    signal[j][1],
                                    signal[j][2],
                                    signal[j][3],
                                    seq_move[i][1],
                                    seq_move[i][2],
                                    seq_move[i][3],
                                    seq_move[i][4],
                                    seq_move[i][5],
                                    seq_move[i][6],
                                ]
                            )

    else:
        for i in range(len(seq_move)):
            if seq_move[i][1] is not None:
                for j in range(len(signal)):
                    if signal[j][0] == seq_move[i][0]:
                        read.append(
                            [
                                signal[j][0],
                                signal[j][1],
                                signal[j][2],
                                signal[j][3],
                                seq_move[i][1],
                                seq_move[i][2],
                                seq_move[i][3],
                                seq_move[i][4],
                                seq_move[i][5],
                                seq_move[i][6],
                            ]
                        )
                # 0:read_id,1:signal,2:to_pA_shift,3:to_pA_scale,4:sequence,5:stride,6:mv_table,7:num_trimmed,8:to_norm_shift,9:to_norm_scale

    return np.array(read, dtype=object)

In [5]:
# 0:read_id,1:signal,2:to_pA_shift,3:to_pA_scale,4:sequence,5:stride,6:mv_table,7:num_trimmed,8:to_norm_shift,9:to_norm_scale
def norm_signal_read_id(signal) -> np.array:
    shift_scale_norm = []
    # signal_norm=[]
    if signal[3] == 0:
        logger.critical("to_pA_scale of read {} is 0").format(signal[0])
    shift_scale_norm = [
        (signal[8] / signal[3]) - np.float16(signal[2]),
        (signal[9] / signal[3]),
    ]
    # 0:shift,1:scale
    num_trimmed = signal[7]
    # print('num_trimmed:{} and signal:{}'.format(num_trimmed,signal[1]))
    # print('shift:{} and scale:{}'.format(shift_scale_norm[0],shift_scale_norm[1]))
    if shift_scale_norm[1] == 0:
        logger.critical("scale of read {} is 0").format(signal[0])
    if num_trimmed >= 0:
        signal_norm = (
            signal[1][num_trimmed:].astype(np.float16) - shift_scale_norm[0]
        ) / shift_scale_norm[1]
    else:
        signal_norm = (
            signal[1][:num_trimmed].astype(np.float16) - shift_scale_norm[0]
        ) / shift_scale_norm[1]

    return signal_norm

In [6]:
pod5_path = "/homeb/xiaoyf/data/HG002/example/pod5/output.pod5"
bam_path = "/homeb/xiaoyf/data/HG002/example/bam/has_moves.bam"
reads = read_from_pod5_bam(pod5_path, bam_path)

[E::idx_find_and_load] Could not retrieve index file for '/homeb/xiaoyf/data/HG002/example/bam/has_moves.bam'


In [18]:
def _normalize_signals(signals, normalize_method="mad"):
    sig = signals.astype(np.float64)
    if normalize_method == "zscore":
        sshift, sscale = np.mean(sig), np.std(sig)
    elif normalize_method == "mad":
        sshift, sscale = np.median(sig), robust.mad(sig)
    else:
        raise ValueError("")
    norm_signals = (sig - sshift) / sscale
    return np.around(norm_signals, decimals=6)

In [20]:
# 0:read_id,1:signal,2:to_pA_shift,3:to_pA_scale,4:sequence,5:stride,6:mv_table,7:num_trimmed,8:to_norm_shift,9:to_norm_scale
df = pd.DataFrame(columns=['read id', 'raw max', 'raw mean', 'raw min', 'raw median', 
                           'linalg.norm max', 'linalg.norm mean', 'linalg.norm min', 'linalg.norm median',
                           'remora max', 'remora mean', 'remora min', 'remora median', 'remora std',
                           'mad max', 'mad mean', 'mad min', 'mad median', 'mad std',
                           'raw mad max', 'raw mad mean', 'raw mad min', 'raw mad median', 'raw mad std'])
for read in reads:
    logger.info('read id: {}'.format(read[0]))
    rmax=np.max(read[1])
    rmean=np.mean(read[1])
    rmin=np.min(read[1])
    rmedian=np.median(read[1])
    logger.info('raw max: {}'.format(rmax))
    logger.info('raw mean: {}'.format(rmean))
    logger.info('raw min: {}'.format(rmin))
    logger.info('raw median: {}'.format(rmedian))
    if np.any(read[1] + abs(np.min(read[1]))<0):
        logger.critical(abs(np.min(read[1])))
        logger.critical(read[1])
        logger.critical(read[1] + abs(np.min(read[1])))
    else:
        logger.info('raw bincount: {}'.format(np.argmax(np.bincount(read[1] + abs(np.min(read[1]))))))
    norm_sigal=read[1]/np.linalg.norm(read[1])
    lmax=np.max(norm_sigal)
    lmean=np.mean(norm_sigal)
    lmin=np.min(norm_sigal)
    lmedian=np.median(norm_sigal)
    logger.debug('linalg.norm max: {}'.format(lmax))
    logger.debug('linalg.norm mean: {}'.format(lmean))
    logger.debug('linalg.norm min: {}'.format(lmin))
    logger.debug('linalg.norm median: {}'.format(lmedian))
    norm_sigal=norm_signal_read_id(read)
    re_max = np.max(norm_sigal)
    re_mean = np.mean(norm_sigal)
    re_min = np.min(norm_sigal)
    re_median = np.median(norm_sigal)
    re_std = np.std(norm_sigal.astype(np.float64))#
    logger.info('remora max: {}'.format(re_max))
    logger.info('remora mean: {}'.format(re_mean))
    logger.info('remora min: {}'.format(re_min))
    logger.info('remora median: {}'.format(re_median))
    logger.info('remora std: {}'.format(re_std))
    norm_sigal=_normalize_signals(norm_sigal)
    mad_max = np.max(norm_sigal)
    mad_mean = np.mean(norm_sigal)
    mad_min = np.min(norm_sigal)
    mad_median = np.median(norm_sigal)
    mad_std = np.std(norm_sigal.astype(np.float64))#
    logger.debug('mad max: {}'.format(mad_max))
    logger.debug('mad mean: {}'.format(mad_mean))
    logger.debug('mad min: {}'.format(mad_min))
    logger.debug('mad median: {}'.format(mad_median))
    logger.debug('mad std: {}'.format(mad_std))
    norm_sigal=_normalize_signals(read[1])
    rmad_max = np.max(norm_sigal)
    rmad_mean = np.mean(norm_sigal)
    rmad_min = np.min(norm_sigal)
    rmad_median = np.median(norm_sigal)
    rmad_std = np.std(norm_sigal.astype(np.float64))#
    logger.debug('raw mad max: {}'.format(rmad_max))
    logger.debug('raw mad mean: {}'.format(rmad_mean))
    logger.debug('raw mad min: {}'.format(rmad_min))
    logger.debug('raw mad median: {}'.format(rmad_median))
    logger.debug('raw mad std: {}'.format(rmad_std))
    td={'read id': read[0], 'raw max':rmax, 'raw mean':rmean, 'raw min':rmin, 'raw median':rmedian, 
                           'linalg.norm max':lmax, 'linalg.norm mean':lmean, 'linalg.norm min':lmin, 'linalg.norm median':lmedian,
                           'remora max':re_max, 'remora mean':re_mean, 'remora min':re_min, 'remora median':re_median, 'remora std':re_std,
                           'mad max':mad_max, 'mad mean':mad_mean, 'mad min':mad_min, 'mad median':mad_median, 'mad std':mad_std,
                           'raw mad max':rmad_max, 'raw mad mean':rmad_mean, 'raw mad min':rmad_min, 'raw mad median':rmad_median, 'raw mad std':rmad_std}#
    temp=pd.DataFrame(td, index=[0])
    df = pd.concat([df,temp], ignore_index=True)
    
    logger.info('-------------------------------------------------------')
print(df)

                                   read id raw max     raw mean raw min  \
0     3239b4d9-0a7e-471c-86fe-e156b9f279a0    1318   899.075692     599   
1     732722e8-65c8-4ad1-b4d5-f9ec85e830d3    1466  1007.847152     654   
2     77a9f606-58bd-42a9-b7f8-a3bb26d97e7f    1465   936.392630     626   
3     7cb07bb1-5ea7-4eb0-a974-fce4e4381a9b    1284   894.430262     583   
4     7d5dd8ce-045d-49ac-964d-a13c9d119380    1373   924.243811     593   
...                                    ...     ...          ...     ...   
3995  fed81f0c-0944-4edf-b39f-d735ef2e68bd    1280   841.526710     480   
3996  fedb5e31-1b54-44cc-a0d0-2fa37b810582    1442   952.077405     504   
3997  fef04f59-0dad-4460-b3ec-f6dad1e8941f    1583   908.010156     449   
3998  fe987708-476a-422c-b4ea-86766b8a2ccf    1256   843.874054     470   
3999  fe48f4ae-94c4-4b9a-b0af-2d21b26901f5    1387   914.702922     303   

      raw median  linalg.norm max  linalg.norm mean  linalg.norm min  \
0          899.0         0.

In [14]:
print(abs(np.int8(-128))-128)
print(abs(np.int16(-128))-128)

-256
0
