In [None]:
import pandas as pd 
import numpy as np
from io import StringIO
from matplotlib import pyplot as plt
import random
import pod5
import pysam

In [None]:
SAMTOOLS_CMD = 'samtools'

reads = pd.read_csv('../demux/rep1/identifiers-reporter.txt', sep='\t', index_col=0)
ret = !$SAMTOOLS_CMD view ../polya/rep1/reporter.bam |cut -f1,27| sed -e 's/pt:i://g' | sort | uniq
polya_len = pd.read_csv(StringIO('\n'.join(ret)), sep='\t', names=['read_id', 'polya_len'])
polya_len = polya_len[polya_len['polya_len'] >= 1].set_index('read_id') # Minimum poly(A) length
reads_wpa = pd.merge(reads, polya_len, left_index=True, right_index=True)
reads_wpa.head()

In [None]:
polya_regions = []

for aln in pysam.AlignmentFile('../polya/rep1/reporter.bam', 'rb'):
    if aln.is_secondary or aln.is_supplementary:
        continue
    
    patag = aln.get_tag('pa')
    if patag[1] < 0:
        continue

    polya_regions.append([aln.query_name, patag[1], patag[2], patag[2] - patag[1]])

polya_regions = pd.DataFrame(polya_regions, columns=['read_id', 'polya_start', 'polya_end', 'polya_siglen']).set_index('read_id')
reads_wpa = pd.merge(reads_wpa, polya_regions, left_index=True, right_index=True)
reads_wpa

In [None]:
LOAD_SIGNAL_LENGTH = 9600
UTRSIDE_LENGTH = 1600
POLYA_MAD_THRESHOLD = 10
ADAPTER_SCALE_POINT = 60
POLYA_SCALE_POINT = 90

with pod5.Reader('../subset/rep1/reporter.pod5') as reader:
    signals = {}
    for read in reader.reads(reads_wpa.index):
        try:
            readinfo = reads_wpa.loc[str(read.read_id)]
        except KeyError:
            continue


        signal = read.signal_pa[:readinfo.polya_end + UTRSIDE_LENGTH]
        adapter_med = np.median(signal[:readinfo.polya_start])
        polya_med = np.median(signal[readinfo.polya_start:readinfo.polya_end])

        norm_scale = (ADAPTER_SCALE_POINT - POLYA_SCALE_POINT) / (adapter_med - polya_med)
        norm_shift = POLYA_SCALE_POINT - polya_med * norm_scale

        signal = signal[::-1] * norm_scale + norm_shift
        signal = signal[:LOAD_SIGNAL_LENGTH]
        if len(signal) < LOAD_SIGNAL_LENGTH:
            signal = np.pad(signal, (0, LOAD_SIGNAL_LENGTH - len(signal)), mode='constant', constant_values=0)

        signals[str(read.read_id)] = signal

print(len(signals))

In [None]:
from matplotlib import colormaps
from matplotlib.colors import LinearSegmentedColormap

mycmap = colormaps.get_cmap('viridis')

colors = ['#3030FF', '#FF00FF']
outlier_cmap= LinearSegmentedColormap.from_list('blue_to_magenta', colors, N=256)

In [None]:
SAMPLE_COUNT = 30
MASK_LOWER_THRESHOLD = 80
MASK_UPPER_THRESHOLD = 100
OUTLIER_SURPLUS_COUNT = 10
OUTLIER_THRESHOLD = 0.4

def prepare_signals(reads, signals):
    # Oversample reads for outlier filtering
    if len(reads) >= SAMPLE_COUNT + OUTLIER_SURPLUS_COUNT:
        reads = reads.sample(SAMPLE_COUNT + OUTLIER_SURPLUS_COUNT)
    reads = reads.sort_values(by='polya_siglen', ascending=False)

    alnsignals = []
    alnsignal_outlier_colors = []
    alnsignal_outlier_alphas = []
    for read in reads.index:
        readinfo = reads.loc[read]

        sig = signals[str(read)]

        polyamask = np.array(
            [0] * UTRSIDE_LENGTH + [1] * readinfo.polya_siglen +
            [0] * (LOAD_SIGNAL_LENGTH - readinfo.polya_siglen - UTRSIDE_LENGTH))

        is_outlier = (sig < MASK_LOWER_THRESHOLD) | (sig > MASK_UPPER_THRESHOLD)
        outlier_alpha = (polyamask[:LOAD_SIGNAL_LENGTH] & is_outlier).astype(float)
        outlier_color = (sig > MASK_UPPER_THRESHOLD).astype(float)

        #mask = polyamask[:LOAD_SIGNAL_LENGTH] & ((sig < MASK_LOWER_THRESHOLD) | (sig > MASK_UPPER_THRESHOLD))
        #if np.sum(mask) < readinfo.polya_siglen * OUTLIER_THRESHOLD:
        #if is_outlier.sum() < readinfo.polya_siglen * OUTLIER_THRESHOLD:
        alnsignals.append(sig)
        alnsignal_outlier_colors.append(outlier_color)
        alnsignal_outlier_alphas.append(outlier_alpha)

    #if len(alnsignals) < SAMPLE_COUNT:
    #    raise ValueError(f'Insufficient reads for outlier filtering: {len(alnsignals)} < {SAMPLE_COUNT}')
    if len(alnsignals) > SAMPLE_COUNT:
        selected = list(range(len(alnsignals)))
        random.shuffle(selected)
        selected = sorted(selected[:SAMPLE_COUNT])
        alnsignals = [alnsignals[i] for i in selected]
        alnsignal_outlier_colors = [alnsignal_outlier_colors[i] for i in selected]
        alnsignal_outlier_alphas = [alnsignal_outlier_alphas[i] for i in selected]

    return np.array(alnsignals), np.array(alnsignal_outlier_colors), np.array(alnsignal_outlier_alphas)

reads = reads_wpa[(reads_wpa['rnaname'] == 'A2') & (reads_wpa['sample'] == 'untreated')]
sig, outlier_colors, outlier_alphas = prepare_signals(reads, signals)

In [None]:
reads = reads_wpa[(reads_wpa['rnaname'] == 'A2') & (reads_wpa['sample'] == 'untreated')]
median_polya_speed = np.median(reads['polya_siglen'] / reads['polya_len'])
print('Median poly(A) speed:', median_polya_speed)

In [None]:
from matplotlib import colors
heatmap_args_general = {
    #'vmin': 50, 'vmax': 110,
    'cmap': mycmap,
    'interpolation': 'none',
    'norm': colors.PowerNorm(gamma=1.1, vmin=50, vmax=110),
    'zorder': -1,
}
heatmap_args_faded = {
    #'vmin': 50, 'vmax': 110,
    'cmap': 'Greys',
    'interpolation': 'none',
    'norm': colors.PowerNorm(gamma=1.1, vmin=50, vmax=110),
    'alpha': 0.3,
    'zorder': -1,
}

SAMPLING_RATE = 4000
TICK_INTERVAL = 0.4
TICKS = []
TICKS += [UTRSIDE_LENGTH - SAMPLING_RATE * s for s in np.arange(0, UTRSIDE_LENGTH / SAMPLING_RATE + 0.01, TICK_INTERVAL)]
TICKS += [UTRSIDE_LENGTH + SAMPLING_RATE * s for s in np.arange(0, (LOAD_SIGNAL_LENGTH - UTRSIDE_LENGTH) / SAMPLING_RATE + 0.01, TICK_INTERVAL)]
TICKS = pd.Series(TICKS).sort_values().unique()
TICKLABELS = [f'{(UTRSIDE_LENGTH-t) / SAMPLING_RATE:.1f}' for t in TICKS]
MINORTICKS = np.arange(0, LOAD_SIGNAL_LENGTH, 400)
POSITION_MARKER_POLYA_POSITION = 60

SAMPLE_COUNT = 50
common_seed = 2

In [None]:
def plot_tail_signals(ax, sig, outlier_colors, outlier_alphas, heatmap_args):
    cb = ax.imshow(sig, aspect='auto', **heatmap_args)
    if outlier_colors is not None:
        ax.imshow(outlier_colors, aspect='auto', cmap=outlier_cmap, alpha=outlier_alphas, vmin=0, vmax=1, interpolation='none', zorder=-1)
        ax.axvline(UTRSIDE_LENGTH, color='black', linestyle='--')
        ax.axvline(UTRSIDE_LENGTH + POSITION_MARKER_POLYA_POSITION * median_polya_speed, color='black',
                linestyle=':', alpha=.5)
    else:
        ax.axvline(UTRSIDE_LENGTH, color='white', linestyle='--')
        ax.axvline(UTRSIDE_LENGTH + POSITION_MARKER_POLYA_POSITION * median_polya_speed, color='white',
                linestyle=':', alpha=.5)

    ax.set_xticks(TICKS)
    ax.set_xticklabels(TICKLABELS)
    ax.set_xticks(MINORTICKS, minor=True)
    plt.setp(ax.get_yticklines(), visible=False)
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_rasterization_zorder(0)
    #ax.set_ylabel('Single RNA molecules')
    #ax.set_xlabel('Time (s)')

    
    return cb

def plot_pane(ax, rna, sample):
    np.random.seed(common_seed)
    reads = reads_wpa[(reads_wpa['rnaname'] == rna) & (reads_wpa['sample'] == sample)]
    sig, outlier_colors, outlier_alphas = prepare_signals(reads, signals)
    cbax1 = plot_tail_signals(ax[0], sig, None, None, heatmap_args_general)
    ax[0].set_ylabel(f'{rna} / {sample}', fontsize=12)

    cbax2 = plot_tail_signals(ax[1], sig, outlier_colors, outlier_alphas, heatmap_args_faded)
    ax[1].set_ylabel('')

    return cbax1, cbax2

fig, axes = plt.subplots(7, 2, figsize=(8, 12))

cbax1, cbax2 = plot_pane(axes[0], 'control', 'untreated')
plot_pane(axes[1], 'control', 'rg7834')
plot_pane(axes[2], 'A2', 'untreated')
plot_pane(axes[3], 'A2', 'rg7834')
plot_pane(axes[4], 'A7', 'untreated')
plot_pane(axes[5], 'A7', 'rg7834')

axes[5, 0].set_xlabel('Time from 3\' UTR end (s)')
axes[5, 1].set_xlabel('Time from 3\' UTR end (s)')

axes[6, 0].axis('off')
axes[6, 1].axis('off')
cb1 = plt.colorbar(cbax1, ax=axes[6, 0], orientation='horizontal')
cb2 = plt.colorbar(cbax2, ax=axes[6, 1], orientation='horizontal')
cb1.set_ticks(np.arange(50, 111, 10))
cb2.set_ticks(np.arange(50, 111, 10))

plt.tight_layout()
plt.savefig('rep1-mixed-tailing-signal-dual.pdf', dpi=300)