## Transcriptome coverage analysis

In this notebook we're plotting the transcript coverage split four ways: good vs bad priming, and FSM vs ISM isoform matches. We start with the annotated bam files that we create from the internal priming analysis.

In [None]:
import itertools
import pickle

from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

import numpy as np

import matplotlib.pyplot as plt

import pysam


import mdl.sc_isoform_paper.coverage as cov
from mdl.sc_isoform_paper.constants import MASSEQ_KEYS, SAMPLE_COLORS
from mdl.sc_isoform_paper.plots import plot_all_covs
from mdl.sc_isoform_paper.priming import Priming, PrimingClassifier

In [None]:
pysam.set_verbosity(0)

root_dir = Path.home()
data_path = root_dir / "data" / "masseq"
annotated_path = data_path / "20250124_annotated"
figure_path = root_dir / "202501_figures"

reference_path = root_dir / "reference"

# generated with paftools.js
gencode_basic_bed = reference_path /  "GRCh38.gencode.v39.annotation.basic.bed"
tx_data = cov.TranscriptData(gencode_basic_bed)

In [None]:
# computing coverage for PIPseq 0.8x SPRI, 10x 3', 10x 5'
keys = [MASSEQ_KEYS[i] for i in (1, 3, 4)]
annotated_bams = sorted(annotated_path.glob("*.[134].*annotated.bam"))
len(annotated_bams)


In [None]:
priming_tags = tuple(p.name for p in Priming)
splice_matches = ("full_splice_match", "incomplete_splice_match")

good_tag_set = {p.name for p in PrimingClassifier.GOOD_PRIMING_TAGS}
bad_tag_set = set(priming_tags) - good_tag_set

### Computing transcript coverage based on priming and splicing tags

We need to go through the BAM files and partition the read coverage based on the priming and SQANTI classification tags. We parallelize the job over transcripts, which adds up to about 1.5M combinations: 24 BAMs x 61,314 transcripts. This takes a very long time (several hours) using `pysam`'s `count_coverage` method. The aggregated results (binned by transcript length) are available as `coverage_stats.pickle`. If this file is available, we can skip to **Results and Plotting**

In [None]:
coverage_stats_file = data_path / "coverage_stats.pickle"

if coverage_stats_file.exists():
    with coverage_stats_file.open("rb") as fh:
        tx_depth_bins, binned_tx = pickle.load(fh)

In [None]:
per_tx_args = (
    (MASSEQ_KEYS[int(anno_bam.name.split(".")[2])], anno_bam, tx, priming_tags, splice_matches)
    for anno_bam in annotated_bams
    for tx in tx_data
)
len(annotated_bams) * len(tx_data)

In [None]:
%%time

per_tag_tx_depth = defaultdict(lambda: defaultdict(int))

with ProcessPoolExecutor(16, initializer=cov.share_tx_data, initargs=(tx_data,)) as exc:
    for tx, txd in exc.map(
        cov.calc_cov_from_bam,
        *zip(*per_tx_args),
        chunksize=len(tx_data.loc)
    ):
        for ksc, arr in txd.items():
            per_tag_tx_depth[ksc][tx] += arr

per_tag_tx_depth = dict(per_tag_tx_depth)

In [None]:
# aggregate into good (expected) and bad (likely internally primed) categories

per_tx_depth = defaultdict(lambda: defaultdict(int))

for k in keys:
    for t in good_tag_set:
        for s in splice_matches:
            if (k,t,s) in per_tag_tx_depth:
                for tx in per_tag_tx_depth[k,t,s]:
                    per_tx_depth[k, "good", s][tx] += per_tag_tx_depth[k, t, s][tx]

    for t in bad_tag_set:
        for s in splice_matches:
            if (k,t,s) in per_tag_tx_depth:
                for tx in per_tag_tx_depth[k,t,s]:
                    per_tx_depth[k, "bad", s][tx] += per_tag_tx_depth[k, t, s][tx]

per_tx_depth = {k: dict(v) for k, v in per_tx_depth.items()}


In [None]:
# aggregate over three size bins: 0 - 2kb, 2kb - 4kb, and 4kb+
tx_depth_bins, binned_tx = cov.overall_depth(tx_data, keys, per_tx_depth, [2000, 4000, 500000])

In [None]:
if not coverage_stats_file.exists():
    with coverage_stats_file.open("wb") as out:
        pickle.dump((tx_depth_bins, binned_tx), out)

## Results and Plotting

First we'll print out some summary statistics for transcript coverage of incomplete splice matches. In "good priming" cases we see a strong enrichment for coverage at the 3' end of the transcript, while internal priming tends to lead to 5' enrichment for ISMs.

Then, we plot the overall results.

In [None]:

print("length      \ttechnology")
for k in sorted(keys):
    for b1, b2 in itertools.pairwise([0, 2000, 4000, 500000]):
        print(f"{b1 // 1000}kb - {b2 // 1000}kb", f"{' '.join(k):10}", sep="\t", end="\t")
        tot = tx_depth_bins[b2][k, 'good', 'incomplete_splice_match'].sum()
        print(
            *(f"{tx_depth_bins[b2][k, 'good', 'incomplete_splice_match'][i:j].sum() / tot:.1%}"
              for i, j in itertools.pairwise(np.linspace(0, 1000, 5, dtype=int))
             ),
            sep="\t"
        )
    print()

In [None]:

print("length      \ttechnology")
for k in sorted(keys):
    for b1, b2 in itertools.pairwise([0, 2000, 4000, 500000]):
        print(f"{b1 // 1000}kb - {b2 // 1000}kb", f"{' '.join(k):10}", sep="\t", end="\t")
        tot = tx_depth_bins[b2][k, 'bad', 'incomplete_splice_match'].sum()
        print(
            *(f"{tx_depth_bins[b2][k, 'bad', 'incomplete_splice_match'][i:j].sum() / tot:.1%}"
              for i, j in itertools.pairwise(np.linspace(0, 1000, 5, dtype=int))
             ),
            sep="\t"
        )
    print()

In [None]:
x = np.arange(4)
bins = list(itertools.pairwise(np.linspace(0, 1000, x.shape[0] + 1, dtype=int)))
p_cl_list = list(itertools.product(("good", "bad"), cov.SPLICE_MATCHES))
fig, axs = plt.subplots(4, 3, figsize=(10, 8), sharey="row", sharex=True)

for ik, k in enumerate(sorted(keys)):
    for j, (b1, b2) in enumerate(itertools.pairwise([0, 2000, 4000, 500000])):
        axs[0, j].set_title(f"{b1 // 1000}kb - {b2 // 1000}kb")
        for i, (p, cl) in enumerate(p_cl_list):
            d = tx_depth_bins[b2][k, p, cl]
            axs[i, j].bar(
                x + 0.05 + ik * 0.3, [d[i:j].sum() / d.sum() for i,j in bins],
                width=0.3, color=SAMPLE_COLORS[k[0]], align="edge", label=" ".join(k)
            )
            axs[i, j].axhline(1 / x.shape[0], color="k", linestyle=":")

            axs[i, j].set_xticks(x + 0.5)
    
    for i, (p, cl) in enumerate(p_cl_list):
        axs[i, 0].set_ylabel(f"{p}\n{cl}")

handles, labels = axs[0,0].get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc="outside lower center",
    ncol=3,
)

plt.savefig(figure_path / "supp_fig11_coverage_bins.svg")
plt.show()

In [None]:
plot_all_covs(keys, tx_depth_bins, binned_tx, tx_data.last_exon_r, output_file=figure_path / "fig2d_coverage.svg")