In [None]:
# Setup and Imports
import sys
import os
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image

# Add project root to path
# This assumes the notebook is in 'notebooks/' and 'src' is in '../src'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src import config
from src import plotting as pl
from src import transcendental as tran

# Configure plotting
sc.settings.verbosity = 0
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.size'] = 8

## 1. Load Data
First, we load the final processed AnnData objects generated by the `02_run_mutrans.py` script. We need both the original single-cell object (with metrics mapped back) and the SEACell (metacell) summary object.

In [None]:
# Load Data
print("Loading data...")
adata_org = sc.read(config.MUTRANS_AD_ORG)
adata_seacell = sc.read(config.MUTRANS_AD_SEACELL)

print(f"Loaded adata_org: {adata_org.shape}")
print(f"Loaded adata_seacell: {adata_seacell.shape}")

## 2. Generate Transition Plots
Next, we'll pre-generate the MuTrans transition path plots (MPFT and MPPT). These are saved as temporary PNGs, which we will later embed into our main figure.

In [None]:
# Generate Transition Images
print("Generating transition plots...")
transition_plot_dir = config.FIGURES_DIR / '04_comprehensive_figures' / 'transition_plots'
transition_files = pl.generate_mutrans_transition_plots(adata_seacell, transition_plot_dir)
print(f"Plots saved in: {transition_plot_dir}")

## 3. Helper Function for Figure Assembly
This is the helper function from our `03_generate_all_figures.py` script. We define it here to assemble the figure panels.

In [None]:
# Define Figure Assembly Functions
def plot_image_panel(ax, file_path, label_char):
    """Helper to plot a pre-generated image onto an axis."""
    if file_path and file_path.exists():
        img = Image.open(file_path)
        ax.imshow(img)
    else:
        ax.text(0.5, 0.5, f'Image not found\n{label_char}', ha='center', va='center', transform=ax.transAxes)
    ax.axis('off')
    ax.set_title(label_char, fontsize=14, fontweight='bold', loc='left', pad=5)

def create_comprehensive_figure(adata_org, adata_seacell, transition_files, save_path, cell_level_heatmaps=False):
    """
    Main function to assemble the multi-panel figure.
    (This is identical to the function in scripts/03_generate_all_figures.py)
    """
    print(f"Creating comprehensive figure: {save_path.name}")
    fig = plt.figure(figsize=(30, 12))
    gs = gridspec.GridSpec(3, 12, figure=fig, hspace=0.35, wspace=0.4,
                          left=0.04, right=0.98, top=0.96, bottom=0.04,
                          height_ratios=[1, 1.2, 1])

    # --- ROW 1: Cell-level UMAPs and Violins (always from adata_org) ---
    ax_a = fig.add_subplot(gs[0, 0:2]);  pl.plot_umap(adata_org, ax_a, 'entropy', 'A')
    ax_b = fig.add_subplot(gs[0, 2:4]);  pl.plot_umap(adata_org, ax_b, 'land', 'B')
    ax_c = fig.add_subplot(gs[0, 4:6]);  pl.plot_violin(adata_org, ax_c, 'entropy', 'C')
    ax_d = fig.add_subplot(gs[0, 6:8]);  pl.plot_violin(adata_org, ax_d, 'land', 'D')
    ax_e = fig.add_subplot(gs[0, 8:12]); pl.plot_flow_matrix(adata_org, ax_e, 'E')

    # --- ROW 2: Transition Plots (from adata_seacell) ---
    ax_f = fig.add_subplot(gs[1, 0:4]);  plot_image_panel(ax_f, transition_files.get('mpft'), 'F')
    ax_g = fig.add_subplot(gs[1, 4:8]);  plot_image_panel(ax_g, transition_files.get('3to4'), 'G')
    ax_h = fig.add_subplot(gs[1, 8:12]); plot_image_panel(ax_h, transition_files.get('10to4'), 'H')

    # --- ROW 3: Transcendental Heatmaps ---
    ax_i = fig.add_subplot(gs[2, 0:6])
    ax_j = fig.add_subplot(gs[2, 6:12])
    
    if cell_level_heatmaps:
        if 'rho_class' not in adata_org.obsm:
            adata_org = tran.map_seacell_memberships_to_cells(adata_org, adata_seacell)
        
        pl.plot_transcendental_heatmap(adata_org, '3', '4', ax_i, max_cells=5000, region_method='adaptive', title='I: Transition A3 → A4 (Cells)')
        pl.plot_transcendental_heatmap(adata_org, '10', '4', ax_j, max_cells=5000, region_method='adaptive', title='J: Transition A10 → A4 (Cells)')
    else:
        pl.plot_transcendental_heatmap(adata_seacell, '3', '4', ax_i, region_method='logistic', title='I: Transition A3 → A4 (Metacells)')
        pl.plot_transcendental_heatmap(adata_seacell, '10', '4', ax_j, region_method='logistic', title='J: Transition A10 → A4 (Metacells)')
    
    # --- Save Figure ---
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.savefig(save_path.with_suffix('.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"  ...Saved {save_path.name}")
    return save_path

## 4. Generate Figure (Metacell Version)
This figure uses the transcendental heatmaps generated from the SEACell (metacell) data, which is computationally faster and shows broader trends.

In [None]:
# Generate Metacell Figure
save_path_seacells = config.FIGURES_DIR / '04_comprehensive_figures' / 'Fig_Comprehensive_Metacells.pdf'
created_path = create_comprehensive_figure(
    adata_org, adata_seacell, transition_files, 
    save_path_seacells, cell_level_heatmaps=False
)

print(f"Metacell figure saved to: {created_path}")

## 5. Generate Figure (Single-Cell Version)
This figure uses the transcendental heatmaps generated from the mapped single-cell data, which provides a high-resolution view of the transitions.

In [None]:
# Generate Single-Cell Figure
save_path_cells = config.FIGURES_DIR / '04_comprehensive_figures' / 'Fig_Comprehensive_Cells.pdf'
created_path_cells = create_comprehensive_figure(
    adata_org, adata_seacell, transition_files, 
    save_path_cells, cell_level_heatmaps=True
)

print(f"Single-cell figure saved to: {created_path_cells}")