### Collecting stats from isoquant outputs

Collecting the SQANTI-style classification calls from the IsoQuant outputs.

### imports

In [None]:
import gzip
import pickle
from collections import defaultdict, Counter
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path

import numpy as np
import pysam

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

import mdl.sc_isoform_paper.plots  # noqa
from mdl.sc_isoform_paper.constants import MASSEQ_KEYS, SAMPLE_COLORS, SHORTREAD_KEYS
from mdl.sc_isoform_paper.pipseq_barcodes import barcode_to_sequence

from mdl.isoscelles.leiden import cluster_leaf_nodes, cluster_labels

### Setup

In [None]:
root_dir = Path.home()
data_path = root_dir / "data"

masseq_data_path = data_path / "masseq"
annotated_path = masseq_data_path / "20250124_annotated"

sqanti_stats_file = data_path / "isoquant_sqanti3_classes.pickle"

figure_path = root_dir / "202501_figures"

In [None]:
sample_order = [MASSEQ_KEYS[i] for i in (1, 3, 4)]
sample_order

In [None]:
if sqanti_stats_file.exists():
    with sqanti_stats_file.open("rb") as fh:
        c_counters = pickle.load(fh)

## Read SQANTI classifications

If you have the pickle file, you can skip to **Plotting and tables**.

Otherwise we need to read the clustering from notebook 01 and then count the SQANTI3 classifications (from the IsoQuant output) broken down by cluster assignment.

In [None]:
annotated_bams = sorted(annotated_path.glob("*.[134].*annotated.bam"))
len(annotated_bams)

In [None]:
with open(data_path / "shortread_clustering_100k.pickle", "rb") as fh:
    sr_clustering = pickle.load(fh)

with open(data_path / "shortread_stats_100k.pickle", "rb") as fh:
    ix_dict = pickle.load(fh)["ix_dict"]

with open(data_path / "shortread_stats_100k.pickle", "rb") as fh:
    sr_numis = pickle.load(fh)["numis"]

In [None]:
c_arrays = dict()
for key in SHORTREAD_KEYS:
    _leaf_nodes = cluster_leaf_nodes(sr_clustering[key][0])
    _label_array = cluster_labels(sr_clustering[key][0], _leaf_nodes)
    _k2i = {k: i for i, k in enumerate(sorted(_leaf_nodes))}    
    c_arrays[key] = np.array([_k2i.get(k, -1) for k in _label_array])

In [None]:
bc_dict = dict()

for fp in data_path.glob("10x*/outs/raw_feature_bc_matrix/barcodes.tsv.gz"):
    fp = fp.parent
    print(fp)
    with gzip.open(fp / "barcodes.tsv.gz", "rt") as fh:
        # strip off the -1 suffix from barcodes
        bc_dict[fp.parent.parent.name] = [line.strip()[:-2] for line in fh]

fp = data_path / "pipseq_pbmc"
print(fp)
with gzip.open(fp / "raw_matrix" / "barcodes.tsv.gz", "rt") as fh:
    # convert from encoded barcode to the real sequence
    bc_dict[fp.name] = [barcode_to_sequence(line.strip()) for line in fh]


In [None]:
def read_class_counts_w_bc(key, bam_file, bc_dict):
    sq_class_counts = defaultdict(Counter)

    with pysam.AlignmentFile(bam_file, "rb", threads=2) as fh:
        for a in fh:
            bc = a.get_tag("CB")
            yc = a.get_tag("XS")
            if bc in bc_dict:
                sq_class_counts[bc_dict[bc]][yc] += 1

    return key, dict(sq_class_counts)

In [None]:
bc_to_cluster = {
    k: dict(
        zip(
            (bc for bc,i in zip(bc_dict[k2], ix_dict[k2]) if i),
            c_arrays[k2]
        )
    ) for k, k2 in zip(sample_order, ["pipseq_pbmc", "10x_3p_pbmc", "10x_5p_pbmc"])
}

In [None]:
key_list = [MASSEQ_KEYS[int(ab.name.split(".")[2])] for ab in annotated_bams]

with ProcessPoolExecutor(8) as exc:
    c_counters = defaultdict(lambda: defaultdict(Counter))
    for key, cc in exc.map(
        read_class_counts_w_bc, 
        key_list,
        annotated_bams,
        (bc_to_cluster[key] for key in key_list)
    ):
        for i in cc:
            c_counters[key][i] += cc[i]

In [None]:
c_counters = {k: {j: v for j,v in c_counters[k].items()} for k in c_counters}

if not sqanti_stats_file.exists():
    with sqanti_stats_file.open("wb") as out:
        pickle.dump(c_counters, out)

## Plotting and tables

In [None]:
# order categories by overall counts
c_order = [
    c for c,_ in sum((sum(c_counters[k].values(), start=Counter()) for k in c_counters), start=Counter()).most_common()
    if c != "-"
] + ["-"]

In [None]:
# same as the labels from notebook 01
cluster_names = {
    ("PIPseq", "0.8x"): {
        0: 'CD4 T cells 1',
        1: 'CD4 T cells 2',
        2: 'Naïve CD4',
        3: 'Cytotoxic T cells',
        4: 'Innate Lymphoid',
        5: 'CD16 Monocytes',
        6: 'CD14 Monocytes',
        7: 'B cells',
    },
    ("10x 3'",): {
        0: 'CD4 T cells 1',
        1: 'CD4 T cells 2',
        2: 'Naïve CD4',
        3: 'Cytotoxic T cells',
        4: 'B cells',
        5: 'CD14 Monocytes',
        6: 'CD16 Monocytes',
        7: 'DC',
    },
    ("10x 5'",): {
        0: 'CD4 T cells 1',
        1: 'CD4 T cells 2',
        2: 'Naïve CD4',
        3: 'Cytotoxic T cells',
        4: 'Innate Lymphoid',
        5: 'B cells',
        6: 'CD14 Monocytes',
        7: 'CD16 Monocytes',
    }
}

cluster_reverse_labels = {
    k: {v: i for i,v in cluster_names[k].items()}
    for k in cluster_names
}


In [None]:
all_cell_labels = sorted(set.union(*(set(v.values()) for v in cluster_names.values())))

x = np.arange(len(c_order))
fig, axs = plt.subplots(3, 3, figsize=(18, 9), sharex=True)

for lbl, ax in zip(all_cell_labels, axs.flatten()):
    for i, k in enumerate(c_counters):
        if lbl not in cluster_reverse_labels[k]:
            continue
            
        j = cluster_reverse_labels[k][lbl]
        y = [c_counters[k][j][c] / c_counters[k][j].total() for c in c_order]
        ax.bar(x + 0.05 + i * 0.3, width=0.3, height=y, color=SAMPLE_COLORS[k[0]], align="edge")
    ax.set_title(lbl)
    ax.set_xticks(x + 0.5, c_order, rotation=75, ha="right")
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))

plt.savefig(figure_path / "supp_fig7.svg")
plt.show()


In [None]:
combined_c_counters = {k: sum(c_counters[k].values(), start=Counter()) for k in c_counters}

# supplemental table 4
print(f"{'class':23}", *(f"{' '.join(k):>16}" for k in c_counters), sep="\t")
for c in c_order:
    print(f"{c:23}", *(f"{combined_c_counters[k][c] / combined_c_counters[k].total():16.2%}" for k in combined_c_counters), sep="\t")


In [None]:
# note: the cluster orders are not the same. We reordered the columns for supplementary table 5
for k in c_counters:
    print(f"{' '.join(k):23}", *(f"{cluster_names[k][i].replace(' ', '_'):>16}" for i in range(8)), sep="\t")
    for c in c_order:
        print(f"{c:23}", *(f"{c_counters[k][i][c] / c_counters[k][i].total():16.2%}" for i in range(8)), sep="\t")
    print()