In [None]:
import os, collections, sys, re, gzip
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib
import seaborn as sns
import anndata
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import gzip

print("Python:", sys.version.split()[0])
print ("")
print("anndata:", anndata.__version__)
print("scanpy:", sc.__version__)
print("seaborn:", sns.__version__)
print("numpy:", np.__version__)
print("matplotlib:", matplotlib.__version__)
print("pandas:", pd.__version__)

In [None]:
# Load the mdata object from the file
output_dir = 'Intermediate_Files/Clustering'

from scanpy import read_h5ad
# Load the mdata object from the file
adata_g_filtered = read_h5ad(os.path.join(output_dir, "PBMC_gene_AutoZI_clustered_celltypes_reannotated_AutoZILatent_08132025.h5mu"))
adata_i_filtered = read_h5ad(os.path.join(output_dir, "PBMC_iso_AutoZI_clustered_celltypes_reannotated_AutoZILatent_08132025.h5mu"))

In [None]:
import pandas as pd
import re

# Collect Bambu isoforms present in AnnData object
bambu_isos = [gene for gene in adata_i_filtered.var_names if "Bambu" in gene]

# Load GTF file from preprocessing pipeline; includes HG38 annotations and novel transcript info
gtf_file = "extended_annotations.gtf"

gtf_cols = ["chrom", "source", "feature", "start", "end", "score", "strand", "frame", "attribute"]
gtf = pd.read_csv(gtf_file, sep="\t", comment="#", names=gtf_cols, dtype={"chrom": str})

# Keep only exon rows
exons = gtf[gtf["feature"] == "exon"].copy()

# Function to extract fields from GTF attributes
def get_attr(attr_string, key):
    pattern = f'{key} ["=]?([^";]+)'
    match = re.search(pattern, attr_string)
    return match.group(1) if match else None

# Use function to extract gene_id, transcript_id, and exon_number for later joining
exons["gene_id"] = exons["attribute"].apply(lambda x: get_attr(x, "gene_id"))
exons["transcript_id"] = exons["attribute"].apply(lambda x: get_attr(x, "transcript_id"))
exons["exon_number"] = exons["attribute"].apply(lambda x: get_attr(x, "exon_number"))

# Create combined ID to match bambu_isos
exons["combined_id"] = exons["gene_id"] + ":" + exons["transcript_id"]

# Filter exons to only Bambu isoforms
bambu_exons = exons[exons["combined_id"].isin(bambu_isos)].copy()

# Process exon_number as integer and sort
bambu_exons["exon_number"] = bambu_exons["exon_number"].astype(int)
bambu_exons = bambu_exons.sort_values(["combined_id", "exon_number"])

# Add exon length and per-transcript length
bambu_exons["exon_length"] = bambu_exons["end"] - bambu_exons["start"] + 1
transcript_lengths = bambu_exons.groupby("combined_id")["exon_length"].sum().reset_index()
transcript_lengths.rename(columns={"exon_length": "transcript_length"}, inplace=True)

# Join total transcript length back to exon rows 
bambu_exons = bambu_exons.merge(transcript_lengths, on="combined_id", how="left")

# Rename chrom to chromosome
bambu_exons = bambu_exons.rename(columns={"chrom": "chromosome"})

# Final per-exon table for Bambu transcripts found in dataset
print(bambu_exons[[
    "chromosome", "start", "end", "strand", "gene_id",
    "transcript_id", "combined_id", "exon_number", "exon_length", "transcript_length"
]])

In [None]:
# OPTIONAL: export to TSV
bambu_exons.to_csv("bambu_exons.tsv", sep="\t", index=False)

In [None]:
### Histogram: Exon Counts per Transcript (Overall)

import numpy as np
import matplotlib.pyplot as plt

# Group exon count per transcript for all transcripts in GTF
exon_counts_all = exons.groupby("transcript_id")["exon_number"].nunique().reset_index()
exon_counts_all.rename(columns={"exon_number": "n_exons"}, inplace=True)

# Bin counts ≥10
exon_counts_all["n_exons_binned"] = exon_counts_all["n_exons"].apply(lambda x: x if x < 10 else 10)

# Median (unbinned, if you want to keep for later)
median_exons = exon_counts_all["n_exons"].median()

# Plot histogram of exon counts (4x4)
plt.figure(figsize=(4, 4))
counts, bins, bars = plt.hist(
    exon_counts_all["n_exons_binned"],
    bins=np.arange(0.5, 11.5, 1),   # bin edges so bars are centered at 1–10
    color="steelblue", edgecolor="black", align="mid"
)

# Add counts above bars
for bar, count in zip(bars, counts):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, height + 0.5, str(int(count)),
             ha='center', va='bottom', fontsize=8, rotation=90)

# X-axis ticks at integer positions and "10+" bucket
xtick_positions = list(range(1, 10)) + [10]
xtick_labels = [str(i) for i in range(1, 10)] + ['10+']
plt.xticks(xtick_positions, labels=xtick_labels)

plt.xlabel("Number of Exons")
plt.ylabel("Number of Transcripts")
plt.title("Distribution of Exon Counts per Transcript (Grouped ≥10)")

# Adjust y-limit dynamically
plt.ylim(0, max(counts) * 1.15)

plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/all_isoform_exon_count_distribution_grouped_square.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### Transcript length histogram for ALL dataset transcripts (Figure S2)
# Histogram for lengths < 8000 bp, plus a single overflow bar for ≥8000.

import re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Build combined_id set from adata_i_filtered
combined_ids_from_adata = set()
rx_ensg = re.compile(r"(ENSG\d+)")
rx_enst = re.compile(r"(ENST\d+)")
for name in adata_i_filtered.var_names:
    s = str(name)
    if "Bambu" in s:
        combined_ids_from_adata.add(s)
    else:
        mg, mt = rx_ensg.search(s), rx_enst.search(s)
        if mg and mt:
            combined_ids_from_adata.add(f"{mg.group(1)}:{mt.group(1)}")

# Make sure exons have combined_id and exon_length
exons = exons.copy()
if "combined_id" not in exons.columns:
    exons["combined_id"] = exons["gene_id"].astype(str) + ":" + exons["transcript_id"].astype(str)
if "exon_length" not in exons.columns:
    exons["start"] = pd.to_numeric(exons["start"], errors="coerce")
    exons["end"]   = pd.to_numeric(exons["end"],   errors="coerce")
    exons["exon_length"] = exons["end"] - exons["start"] + 1
exons = exons.dropna(subset=["combined_id", "exon_length"])

# Restrict to transcripts present in your dataset; then sum exon lengths per transcript
exons_in_data = exons[exons["combined_id"].isin(combined_ids_from_adata)].copy()
transcript_lengths = (
    exons_in_data
    .groupby("combined_id", as_index=False)["exon_length"]
    .sum()
    .rename(columns={"exon_length": "transcript_length"})
)
if transcript_lengths.empty:
    raise ValueError("No transcript lengths matched between GTF and adata_i_filtered.")

# Quick stats for sanity
min_len = int(transcript_lengths["transcript_length"].min())
max_len = int(transcript_lengths["transcript_length"].max())
median_length = float(transcript_lengths["transcript_length"].median())
print(f"Transcript length range: {min_len} – {max_len} bp")
print(f"Median transcript length: {int(median_length)} bp")

# Histogram with overflow bin (≥8000)
THRESH = 8000
mask_main   = transcript_lengths["transcript_length"] < THRESH   # strictly less than 8000
mask_over   = transcript_lengths["transcript_length"] >= THRESH  # 8000 and above
vals_main   = transcript_lengths.loc[mask_main, "transcript_length"]
overflow_ct = int(mask_over.sum())

n_bins     = 40
bin_edges  = np.linspace(0, THRESH, n_bins + 1)

# Plot histogram for transcript length
plt.figure(figsize=(5, 4))
ax = sns.histplot(
    x=vals_main,
    bins=bin_edges,
    color="#377eb8",
    alpha=1,
    edgecolor="black"
)

ax.set_xlabel("Transcript Length (nt)")
ax.set_ylabel("Number of Transcripts")
ax.set_xlim(0, THRESH)

ymax = ax.get_ylim()[1]
ax.set_ylim(0, ymax * 1.15)

# Median line (dashed) and label
ax.plot([median_length, median_length], [0, ymax * 1.03],
        color="#D62728", linewidth=1.5, linestyle="dashed")
ax.text(x=425, y=ymax * 1.07, s=f"Median = {int(median_length)}",
       color="#D62728", fontsize=10, fontweight="bold")

# Overflow bar (≥8000) placed just right of THRESH 
bin_width   = bin_edges[1] - bin_edges[0] 
overflow_x  = THRESH + bin_width * 0.5  
if overflow_ct > 0:
    bar = ax.bar(overflow_x, overflow_ct, width=bin_width,
                 color="#377eb8", edgecolor="black")
    h = bar.patches[0].get_height()
    ax.text(overflow_x, h + max(0.5, 0.02*h), str(int(overflow_ct)),
            ha="center", va="bottom", fontsize=10)

# Clean, explicit xticks: 0..6000 every 1000, plus labeled overflow bin
xticks      = list(range(0, THRESH, 2000)) + [overflow_x]
xticklabels = [f"{t:,}" for t in xticks[:-1]] + [f"≥{THRESH}"]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)

plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/all_transcript_length_distribution_with_overflow.pdf",
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
## SUBSET: New transcripts from known genes (ENSG + BambuTx)

import matplotlib.pyplot as plt

# Filter to only new transcripts from known genes
filtered = bambu_exons[
    bambu_exons["gene_id"].str.startswith("ENSG") &
    bambu_exons["transcript_id"].str.startswith("BambuTx")
].copy()

# Count exons per novel isoform, then count unique genes per exon count
exon_counts_per_isoform = filtered.groupby("combined_id")["exon_number"].nunique().reset_index()
exon_counts_per_isoform.rename(columns={"exon_number": "n_exons"}, inplace=True)
exon_counts_per_isoform["gene_id"] = exon_counts_per_isoform["combined_id"].apply(lambda x: x.split(":")[0])
genes_per_exon_count = exon_counts_per_isoform.groupby("n_exons")["gene_id"].nunique()

# Count transcripts per exon count (no collapse by gene)
tx_per_exon_count = exon_counts_per_isoform.groupby("n_exons")["combined_id"].nunique()

In [None]:
### Bar chart: Number of novel transcripts (from known genes) per # of exons

plt.figure(figsize=(4,4))
bars = plt.bar(tx_per_exon_count.index, tx_per_exon_count.values,
               width=0.8, color="steelblue", edgecolor="black")

for bar in bars:
    h = bar.get_height()
    plt.text(bar.get_x()+bar.get_width()/2, h+0.5, str(int(h)),
             ha='center', va='bottom', fontsize=10)

plt.xlabel("Number of Exons")
plt.ylabel("Number of New Transcripts from Known Genes")  # now accurate
plt.xticks(tx_per_exon_count.index)
plt.ylim(0, max(tx_per_exon_count.values)*1.15)
plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/bambu_new_isoforms_known_genes_exon_distribution.pdf")
plt.show()

In [None]:
### Bar chart: Number of novel transcripts (from known genes) per # of exons (Figure 2c)

import pandas as pd
import matplotlib.pyplot as plt

# Same data as above, but exon counts grouped into {1,2,3,4,5+} for a compact figure 
# Count transcripts per exon count
tx_per_exon = (
    exon_counts_per_isoform
    .groupby("n_exons")["combined_id"]
    .nunique()
    .reset_index(name="tx_count")
)

# Bin n_exons (≥5 => "5+")
tx_per_exon["n_exons"] = tx_per_exon["n_exons"].astype(int)
tx_per_exon["n_exons"] = tx_per_exon["n_exons"].apply(lambda x: str(x) if x < 5 else "5+")
tx_per_exon = (
    tx_per_exon
    .groupby("n_exons", as_index=False)["tx_count"]
    .sum()
)

# Order categories and sort
tx_per_exon["n_exons"] = pd.Categorical(tx_per_exon["n_exons"],
                                        categories=["1", "2", "3", "4", "5+"],
                                        ordered=True)
tx_per_exon = tx_per_exon.sort_values("n_exons")

# (optional) sanity check: counts should match pre/post binning
#total_before = int(exon_counts_per_isoform["combined_id"].nunique())
#total_after  = int(tx_per_exon["tx_count"].sum())
#print(f"Sanity — transcripts: before binning={total_before}, after binning={total_after}")

# Plot
plt.figure(figsize=(4, 4))
bars = plt.bar(tx_per_exon["n_exons"], tx_per_exon["tx_count"],
               width=0.8, color="steelblue", edgecolor="black")

for bar in bars:
    h = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, h + max(0.5, h*0.03), f"{int(h)}",
             ha='center', va='bottom', fontsize=10)

plt.xlabel("Number of Exons")
plt.ylabel("Number of New Transcripts from Known Genes")
plt.ylim(0, max(tx_per_exon["tx_count"]) * 1.15)
plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/bambu_new_isoforms_known_genes_exon_distribution_grouped.pdf")
plt.show()

In [None]:
### SUBSET: New transcripts from new genes (BambuGene + BambuTx)

import pandas as pd
import matplotlib.pyplot as plt

# Filter: novel genes (BambuGene*) and novel transcripts (BambuTx*)
filtered = bambu_exons[
    bambu_exons["gene_id"].str.startswith("Bambu") &
    bambu_exons["transcript_id"].str.startswith("BambuTx")
].copy()

filtered["exon_number"] = pd.to_numeric(filtered["exon_number"], errors="coerce")
filtered = filtered.dropna(subset=["exon_number"])
filtered["exon_number"] = filtered["exon_number"].astype(int)

# Count exons per transcript
exon_counts_per_isoform = (
    filtered.groupby("combined_id")["exon_number"]
    .nunique()
    .reset_index(name="n_exons")
)

# Count transcripts per exon count
tx_per_exon_count = (
    exon_counts_per_isoform.groupby("n_exons")["combined_id"]
    .nunique()
    .reset_index(name="tx_count")
    .sort_values("n_exons")
)

In [None]:
### Bar chart: Number of novel transcripts (from new genes) per # of exons (Figure 2f)
plt.figure(figsize=(4, 4))
bars = plt.bar(tx_per_exon_count["n_exons"], tx_per_exon_count["tx_count"],
               width=0.8, color="steelblue", edgecolor="black")

for b in bars:
    h = b.get_height()
    plt.text(b.get_x() + b.get_width()/2, h + max(0.5, 0.03*h), f"{int(h)}",
             ha="center", va="bottom", fontsize=10)

plt.xlabel("Number of Exons")
plt.ylabel("Number of Transcripts from New Genes")  # ← now correct
plt.xticks(tx_per_exon_count["n_exons"])
plt.ylim(0, max(tx_per_exon_count["tx_count"]) * 1.15)
plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/bambu_new_isoforms_new_genes_exon_distribution.pdf")
plt.show()

In [None]:
### Histogram: transcript lengths for new isoforms from known genes (Figure 2b)

import seaborn as sns
import matplotlib.pyplot as plt

# Filter to only new transcripts from known genes
filtered = bambu_exons[
    bambu_exons["gene_id"].str.startswith("ENSG") &
    bambu_exons["transcript_id"].str.startswith("BambuTx")
].copy()

# Compute transcript length per isoform (sum of exon lengths)
transcript_lengths = filtered.groupby("combined_id")["exon_length"].sum().reset_index()
transcript_lengths.rename(columns={"exon_length": "transcript_length"}, inplace=True)

# Print range for transcript length
min_len = transcript_lengths["transcript_length"].min()
max_len = transcript_lengths["transcript_length"].max()
print(f"Transcript length range: {min_len} - {max_len} bp")

# Median
median_length = transcript_lengths["transcript_length"].median()

# Plot
plt.figure(figsize=(4, 4))

sns.histplot(x=transcript_lengths["transcript_length"], color="#377eb8", alpha=1, bins=30)

plt.xlabel("Transcript Length (nt)")
plt.ylabel("# of New Transcripts from Known Genes")
plt.ylim(0, 7)  # Adjust manually to leave room above

# Median line
plt.plot([median_length, median_length], [0, 6.35], color="#D62728", linewidth=1.5, alpha=1, linestyle='dashed')

# Median label
plt.text(x=200, y=6.53, s=f"Median = {int(median_length)}", color="#D62728", fontsize=10, fontweight="bold")

plt.tight_layout()

# Save transparent PDF
plt.savefig("Intermediate_Files/Paper_Figs/bambu_new_isoforms_known_genes_transcript_length_histogram.pdf", dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### Histogram: transcript lengths for new isoforms from new genes (Figure 2e)
import seaborn as sns
import matplotlib.pyplot as plt

# Filter to only new transcripts from new genes
filtered = bambu_exons[
    bambu_exons["gene_id"].str.startswith("Bambu") &
    bambu_exons["transcript_id"].str.startswith("BambuTx")
].copy()

# Compute transcript length per isoform (sum of exon lengths)
transcript_lengths = filtered.groupby("combined_id")["exon_length"].sum().reset_index()
transcript_lengths.rename(columns={"exon_length": "transcript_length"}, inplace=True)

# Print range
min_len = transcript_lengths["transcript_length"].min()
max_len = transcript_lengths["transcript_length"].max()
print(f"Transcript length range: {min_len} - {max_len} bp")

# Median
median_length = transcript_lengths["transcript_length"].median()

# Plot
plt.figure(figsize=(4, 4))

sns.histplot(x=transcript_lengths["transcript_length"], color="#377eb8", alpha=1, bins=30)

plt.xlabel("Transcript Length (nt)")
plt.ylabel("# of Transcripts from Novel Genes")
plt.ylim(0, 8.5)  # Adjust manually to leave room above

# Median line
plt.plot([median_length, median_length], [0, 7.7], color="#D62728", linewidth=1.5, alpha=1, linestyle='dashed')

# Median label
plt.text(x=425, y=7.8, s=f"Median = {int(median_length)}", color="#D62728", fontsize=10, fontweight="bold")

plt.tight_layout()

# Save transparent PDF
plt.savefig("Intermediate_Files/Paper_Figs/bambu_new_isoforms_new_genes_transcript_length_histogram.pdf", 
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### Bar Summary: BambuTx by locus/context (Figure 2d)

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Filter to only BambuTx transcripts
filtered = bambu_exons[
    bambu_exons["transcript_id"].str.startswith("BambuTx")
].copy()

# Drop duplicates to ensure unique transcripts
filtered_unique = filtered[["combined_id", "gene_id", "strand"]].drop_duplicates()

# Categorize transcripts
def categorize(row):
    if row["gene_id"].startswith("Bambu"):
        return "New locus"
    elif row["strand"] == "-":
        return "Opposite strand"
    else:
        return "Same strand, no exon overlap"

filtered_unique["Category"] = filtered_unique.apply(categorize, axis=1)

# Count unique transcripts (combined_id) per category, not genes
counts = filtered_unique.groupby("Category")["combined_id"].nunique().reset_index()
counts.rename(columns={"combined_id": "Count"}, inplace=True)

# Assign colors and labels
custom_palette = ["#377eb8", "#4daf4a", "#984ea3"]
new_labels = ["New locus", "Opposite strand", "Same strand, no exon overlap"]
color_map = dict(zip(new_labels, custom_palette))

# Generate Bar Plot
plt.figure(figsize=(5, 6))
ax = sns.barplot(
    data=counts, y="Category", x="Count",
    dodge=False, palette=color_map, saturation=1, order=new_labels
)

# Add bar labels
for container in ax.containers:
    ax.bar_label(container, fontsize=10)

plt.xlim(0, counts["Count"].max() + 5)
sns.despine(ax=ax, top=False, right=True, left=False, bottom=False)
ax.set_yticks([])
ax.set_ylabel("")
plt.xlabel("Number of Transcripts")  # 🔹 updated axis label

# Custom legend
patches = [mpatches.Patch(color=color, label=label) for color, label in zip(custom_palette, new_labels)]
plt.legend(handles=patches, fontsize=10, loc='lower center', bbox_to_anchor=(0.5, 1.05), frameon=True)

plt.tight_layout()
plt.savefig("Intermediate_Files/Paper_Figs/bambu_combined_category_distribution.pdf", 
            dpi=600, transparent=True, bbox_inches="tight")
plt.show()

In [None]:
### Histogram: Number of transcripts expressed per gene across dataset (multi-isoform genes only) (Figure 2a)

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

output_dir = "Intermediate_Files/Paper_Figs"
os.makedirs(output_dir, exist_ok=True)

# Parse combined ids
combined_ids = pd.Series(adata_i_filtered.var_names)
split_ids = combined_ids.str.split(":", expand=True)
split_ids.columns = ["gene_name", "gene_id", "transcript_id"]
split_ids["gene_id_final"] = split_ids["gene_id"].fillna(split_ids["gene_name"])

# Filter to only known genes
known_genes = split_ids[split_ids["gene_id_final"].str.startswith("ENSG")].copy()

# Count isoforms per gene, then count genes per isoform count
isoform_counts_per_gene = known_genes.groupby("gene_id_final").size().reset_index(name="n_isoforms")

# Keep only genes with >=2 isoforms
isoform_counts_per_gene = isoform_counts_per_gene[isoform_counts_per_gene["n_isoforms"] >= 2]

# Count number of genes for each number of isoforms
summary = isoform_counts_per_gene["n_isoforms"].value_counts().reset_index()
summary.columns = ["n_isoforms", "n_genes"]
summary = summary.sort_values("n_isoforms")

# Group ≥10 isoforms together to compress tail
summary["n_isoforms"] = summary["n_isoforms"].apply(lambda x: "10+" if x >= 10 else str(x))
summary = summary.groupby("n_isoforms")["n_genes"].sum().reset_index()

ordered_categories = [str(i) for i in range(2, 10)] + ["10+"]
summary["n_isoforms"] = pd.Categorical(summary["n_isoforms"], categories=ordered_categories, ordered=True)
summary = summary.sort_values("n_isoforms")

# Plot histogram
plt.figure(figsize=(4, 4))

ax = sns.barplot(data=summary, y="n_genes", x="n_isoforms",
                 color="#377eb8", saturation=1, edgecolor="black")

# Bar labels
ax.bar_label(ax.containers[0], fontsize=9, padding=1)
plt.ylim(0,2600)
plt.xlabel("# of transcripts expressed")
plt.ylabel("# of gene bodies")
#plt.title("Number of Known Genes with Multiple Isoforms")

plt.tight_layout()

# Save
plt.savefig(os.path.join(output_dir, "number_of_genes_with_multiple_transcripts_known_only.pdf"),
            dpi=600, transparent=True, bbox_inches="tight")

plt.show()

In [None]:
### Histogram: Number of transcripts expressed per gene across dataset

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Output directory
output_dir = "Intermediate_Files/Paper_Figs"
os.makedirs(output_dir, exist_ok=True)

# Parse combined ids
combined_ids = pd.Series(adata_i_filtered.var_names)

split_ids = combined_ids.str.split(":", expand=True)
split_ids.columns = ["gene_name", "gene_id", "transcript_id"]
split_ids["gene_id_final"] = split_ids["gene_id"].fillna(split_ids["gene_name"])

# ✅ Filter to only known genes
known_genes = split_ids[split_ids["gene_id_final"].str.startswith("ENSG")].copy()

# Count isoforms per gene
isoform_counts_per_gene = known_genes.groupby("gene_id_final").size().reset_index(name="n_isoforms")

# Count number of genes for each number of isoforms
summary = isoform_counts_per_gene["n_isoforms"].value_counts().reset_index()
summary.columns = ["n_isoforms", "n_genes"]
summary = summary.sort_values("n_isoforms")

# Group 10+ isoforms together
summary["n_isoforms"] = summary["n_isoforms"].apply(lambda x: "10+" if x >= 10 else str(x))
summary = summary.groupby("n_isoforms")["n_genes"].sum().reset_index()

# Set the order for plotting
ordered_categories = [str(i) for i in range(1, 10)] + ["10+"]
summary["n_isoforms"] = pd.Categorical(summary["n_isoforms"], categories=ordered_categories, ordered=True)
summary = summary.sort_values("n_isoforms")

# Plot
plt.figure(figsize=(4, 4))

ax = sns.barplot(data=summary, y="n_genes", x="n_isoforms",
                 color="#377eb8", saturation=1, edgecolor="black")

# Bar labels
ax.bar_label(ax.containers[0], fontsize=9, padding=1)
plt.ylim(0,4100)
plt.xlabel("# of transcripts expressed")
plt.ylabel("# of gene bodies")
#plt.title("Number of Known Genes with Multiple Isoforms")

plt.tight_layout()

# Save
plt.savefig(os.path.join(output_dir, "number_of_genes_by_transcript_number_known_only.pdf"),
            dpi=600, transparent=True, bbox_inches="tight")

plt.show()

In [None]:
import gzip

def list_biotypes_in_gtf(gtf_path, level="gene"):
    """
    Print all distinct biotypes found in a GTF file.

    Parameters
    ----------
    gtf_path : str
        Path to GTF (.gtf or .gtf.gz).
    level : str
        Which attributes to search for: 'gene' or 'transcript'.
    """
    biotypes = set()
    opener = gzip.open if gtf_path.endswith(".gz") else open
    with opener(gtf_path, "rt") as fh:
        for line in fh:
            if line.startswith("#"):
                continue
            fields = line.strip().split("\t")
            if len(fields) < 9:
                continue
            # Only look at gene rows if level="gene"
            if level == "gene" and fields[2] != "gene":
                continue
            attr_field = fields[8]

            # Look for common biotype keys
            for key in [f"{level}_biotype", f"{level}_type", "biotype"]:
                for part in attr_field.split(";"):
                    part = part.strip()
                    if part.startswith(key):
                        # Take the second field (after key)
                        pieces = part.split(" ", 1)
                        if len(pieces) > 1:
                            val = pieces[1].replace('"', '').strip()
                            biotypes.add(val)

    print(f"Found {len(biotypes)} distinct biotypes at {level} level:")
    for b in sorted(biotypes):
        print("  ", b)


In [None]:
list_biotypes_in_gtf("Homo_sapiens.GRCh38.113.gtf", level="transcript")

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

## Parse combined IDs from AnnData var_names
combined_ids = pd.Series(adata_i_filtered.var_names)

split_ids = combined_ids.str.split(":", expand=True)
split_ids.columns = ["gene_name", "gene_id", "transcript_id"]

transcript_df = pd.DataFrame({
    "combined_id": combined_ids,
    "gene_id": split_ids["gene_id"],
    "transcript_id": split_ids["transcript_id"]
})

## Parse GTF to build transcript_id to transcript_biotype map
gtf_file = "Homo_sapiens.GRCh38.113.gtf" ## Use this .gtf since biotype is unlisted in extended.gtf

transcript_biotype_dict = {}

with open(gtf_file, 'r') as f:
    for line in f:
        if line.startswith("#"):
            continue
        fields = line.strip().split('\t')
        if fields[2] != "transcript": # only transcript feature rows carry biotype
            continue
        attr_field = fields[8]
        attr_dict = {}
        for attr in attr_field.strip().split(';'):
            if attr.strip():
                key, value = attr.strip().split(' ', 1)
                attr_dict[key] = value.strip('"')
        transcript_id = attr_dict.get("transcript_id")
        biotype = attr_dict.get("transcript_biotype")
        if transcript_id and biotype:
            transcript_biotype_dict[transcript_id] = biotype

# Align transcript list to reference biotype dictionary
transcript_biotype_df = pd.DataFrame({
    'transcript_id': transcript_df["transcript_id"],
    'biotype': transcript_df["transcript_id"].map(transcript_biotype_dict)
})

## Count transcripts per gene and keep genes with ≥2 isoforms
gene_transcript_counts = transcript_df.groupby("gene_id").size().reset_index(name="n_transcripts")
genes_multi = gene_transcript_counts[gene_transcript_counts["n_transcripts"] >= 2]
transcript_df = transcript_df.merge(genes_multi, on="gene_id", how="inner") #filter
transcript_df = transcript_df.merge(transcript_biotype_df, on="transcript_id", how="left") #add biotype

# Each row is (gene_id, one transcript's biotype, n_transcripts for that gene)
gene_biotype_summary = transcript_df[["gene_id", "biotype", "n_transcripts"]].drop_duplicates()

# Aggregate: Number of isoforms per biotype per gene
summary = gene_biotype_summary.groupby(["n_transcripts", "biotype"]).size().reset_index(name="n_gene_bodies")

In [None]:
## Group biotypes into largest categories represented

def simplify_biotype(biotype):
    if biotype == "protein_coding":
        return "Protein coding"
    elif biotype == "nonsense_mediated_decay":
        return "Nonsense-mediated decay"
    elif biotype == "retained_intron":
        return "Retained intron"
    elif biotype == "protein_coding_CDS_not_defined":
        return "Coding sequence not defined"
    else:
        return "Other"

# Map to categories on your per-gene summary
gene_biotype_summary["biotype_category"] = gene_biotype_summary["biotype"].apply(simplify_biotype)

# Counts per mapped category (unique gene bodies per category)           
counts_by_category = (
    gene_biotype_summary
    .groupby("biotype_category")["gene_id"]
    .nunique()
    .sort_values(ascending=False)
)
print("## Gene bodies per mapped biotype category")
for cat, n in counts_by_category.items():
    print(f"{cat}\t{n}")

# Breakdown of raw biotypes that fell into 'Other' (unique gene bodies per raw biotype)
other_breakdown = (
    gene_biotype_summary
    .loc[gene_biotype_summary["biotype_category"] == "Other"]
    .groupby("biotype")["gene_id"]
    .nunique()
    .sort_values(ascending=False)
)
print("\n## Raw biotypes grouped into 'Other' (unique gene bodies)")
if other_breakdown.empty:
    print("None")
else:
    for bt, n in other_breakdown.items():
        print(f"{bt}\t{n}")

# Show all raw biotypes present (unique gene bodies per raw biotype)
all_raw = (
    gene_biotype_summary
    .groupby("biotype")["gene_id"]
    .nunique()
    .sort_values(ascending=False)
)
print("\n## All raw biotypes in dataset (unique gene bodies)")
for bt, n in all_raw.items():
    print(f"{bt}\t{n}")

In [None]:
# Summary = gene_id, biotype, n_transcripts, n_gene_bodies
summary = gene_biotype_summary.groupby(["n_transcripts", "biotype"]).size().reset_index(name="n_gene_bodies").copy()

# Simplify biotype labels
summary["biotype_category"] = summary["biotype"].apply(simplify_biotype)

# Group 10+ transcripts together
summary["n_transcripts"] = summary["n_transcripts"].apply(lambda x: "10+" if x >= 10 else str(x))

# Combine counts by grouped biotype and transcript count
summary = summary.groupby(["n_transcripts", "biotype_category"])["n_gene_bodies"].sum().reset_index()

# Rename for clarity
summary.rename(columns={"n_transcripts": "n_isoforms"}, inplace=True)

# Set transcript count as ordered categorical
ordered_categories = [str(i) for i in range(2, 10)] + ["10+"]
summary["n_isoforms"] = pd.Categorical(summary["n_isoforms"], categories=ordered_categories, ordered=True)
summary = summary.sort_values("n_isoforms")

In [None]:
### Bar plot: Number of genes by isoform count and transcript biotype category

biotype_colors = {
    "Protein coding": "#1f77b4",
    "Nonsense-mediated decay": "#ff7f0e",
    "Retained intron": "#2ca02c",
    "Coding sequence not defined": "#9467bd",
    "Other": "#7f7f7f"
}

# Define order
biotype_order = [
    "Protein coding",
    "Retained intron",
    "Coding sequence not defined",
    "Nonsense-mediated decay",
    "Other"
]

plt.figure(figsize=(12, 4))

ax = sns.barplot(data=summary, x="n_isoforms", y="n_gene_bodies", hue="biotype_category",
                 palette=biotype_colors, hue_order=biotype_order, saturation=1, width = 0.9)

# Print counts on bars for quick reading
for container in ax.containers:
    ax.bar_label(container, fontsize=8, padding=1)

plt.xlabel("# of Isoforms Expressed")
plt.ylabel("# of Gene Bodies")
plt.ylim(0,2400) #adjust to dataset scale
#plt.title("Gene Bodies with Multiple Transcripts Split by Transcript Biotype")
plt.legend(title="Transcript Biotype", fontsize=10, ncol = 2)
plt.tight_layout()

output_dir = "Intermediate_Files/Paper_Figs"
os.makedirs(output_dir, exist_ok=True)

plt.savefig(os.path.join(output_dir, "gene_bodies_by_biotype_grouped_FINAL.pdf"),
            dpi=600, transparent=True, bbox_inches="tight")

plt.show()