#### Generating plots where cluster of interest is only Central Memory CD8 T Cells

In [3]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi_central_memory_cd8"

subpopulations = [
    "Effector_CD8",
    "Memory_Precursor_Effector_CD8",
    "Exhausted_T",
    "Central_Memory_CD8",
    "Stem_Like_CD8",
    "Effector_Memory_CD8",
    "Proliferating_Effector",
    "All_CD8"
]

cohorts = ["control", "short_term", "long_term"]

all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# --------------------------------------------------
# Clusters we are interested in as the "source" (COI)
# --------------------------------------------------
clusters_of_interest = [12]

# Provide the path to your CSV that maps each cluster to a color and celltype
color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

def optimal_transport_visualization(subpop_name, cohort_name):
    """
    For a single subpopulation and cohort, run the full OT pipeline:
    1) Load data at each timepoint.
    2) Perform OT to map from source to next timepoint.
    3) Collect single-cell and cluster-level arrows, plus aggregated arrows.
    4) Generate PDFs of the movement plots (single-cell, cluster-level, aggregated).
    5) Return distributions_coi for stacked-bar plotting.
    """
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return None
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return None

    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}
    
    # Gather data across all valid timepoints
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return None
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return None
    
    # Prepare data structures for results
    timepoint_results = {}
    distributions_coi = {tp: {cohort_name: {c: {} for c in clusters_of_interest}} for tp in cohort_timepoints}
    
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample so each timepoint has the same # of source & gray cells for plotting
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        # If not the last timepoint, compute OT from source_tp -> target_tp
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # If there's no data in the next timepoint, skip
                timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                # Identify columns for PCs
                pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
                full_source_coords_pc = source_cells[pc_cols].values
                full_target_coords_pc = target_cells[pc_cols].values
                
                # UMAP coords
                full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
                full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
                
                # Uniform distribution over source / target
                a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
                b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
                
                cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
                
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                    timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                    timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    continue
                
                full_target_indices = np.argmax(transport_plan, axis=1)
                displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
                norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
                
                # Displacements for the downsampled subset
                sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
                sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
                sampled_norms = norms_full[sampled_source_indices]
                
                timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = sampled_norms
                
                if 'seurat_clusters' in source_cells.columns and 'seurat_clusters' in target_cells.columns:
                    # Lump some clusters if needed
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16:14,17:14})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({16:14,17:14})
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({9:6,18:6})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({9:6,18:6})
                    
                    # For cluster-level arrows, compute average displacement per cluster
                    source_clusters_full = source_cells['seurat_clusters'].values
                    df_cluster_full = pd.DataFrame({
                        'cluster': source_clusters_full,
                        'sx': full_source_coords_umap[:,0],
                        'sy': full_source_coords_umap[:,1],
                        'dx': displacement_vectors_full[:,0],
                        'dy': displacement_vectors_full[:,1],
                        'norm': norms_full
                    })
                    
                    gray_cells_tp = full_data[source_tp]['gray']
                    all_cells_combined = pd.concat([source_cells, gray_cells_tp])
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({9:6,18:6})
                    
                    all_clusters = all_cells_combined['seurat_clusters'].values
                    all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                    
                    df_all = pd.DataFrame({
                        'cluster': all_clusters,
                        'sx': all_coords_umap[:,0],
                        'sy': all_coords_umap[:,1]
                    })
                    
                    centroids = df_all.groupby('cluster')[['sx','sy']].median()
                    mean_disp = df_cluster_full.groupby('cluster')[['dx','dy']].mean()
                    
                    cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                    
                    # Build arrow info for each cluster of interest
                    cluster_arrows = []
                    common_clusters = centroids.index.intersection(mean_disp.index)
                    for clust in common_clusters:
                        if clust in clusters_of_interest:
                            cx, cy = centroids.loc[clust, ['sx','sy']]
                            cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                            cnorm = np.sqrt(cdx**2 + cdy**2)
                            if cnorm > 0:
                                cdx /= cnorm
                                cdy /= cnorm
                            length = cluster_norms.loc[clust]
                            cdx *= length
                            cdy *= length
                            cluster_arrows.append((clust, cx, cy, cdx, cdy))
                    
                    timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
    
                    if len(cluster_arrows) > 0:
                        # "Aggregated" arrow by summation
                        source_median_x = source_cells['UMAP_2'].median()
                        source_median_y = source_cells['UMAP_1'].median()
                        total_dx = sum([arrow[3] for arrow in cluster_arrows])
                        total_dy = sum([arrow[4] for arrow in cluster_arrows])
                        timepoint_results[source_tp]['aggregated_arrow'] = (
                            source_median_x, 
                            source_median_y, 
                            total_dx, 
                            total_dy
                        )
                    else:
                        timepoint_results[source_tp]['aggregated_arrow'] = None
    
                    # Compute target cluster distributions for each COI (source cluster)
                    for coi in clusters_of_interest:
                        coi_mask = (source_cells['seurat_clusters'] == coi)
                        if np.any(coi_mask):
                            # Which target clusters do these COI cells map to?
                            selected_target_indices = full_target_indices[coi_mask]
                            selected_target_clusters = target_cells['seurat_clusters'].iloc[selected_target_indices].values
                            unique_tclusters, counts = np.unique(selected_target_clusters, return_counts=True)
                            total_count = counts.sum()
                            if total_count > 0:
                                fraction_dict = {int(tc): (ct / total_count) for tc, ct in zip(unique_tclusters, counts)}
                            else:
                                fraction_dict = {}
                            distributions_coi[source_tp][cohort_name][coi] = fraction_dict
                        else:
                            distributions_coi[source_tp][cohort_name][coi] = {}
                else:
                    # If cluster info is missing
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    for coi in clusters_of_interest:
                        distributions_coi[source_tp][cohort_name][coi] = {}
        else:
            # Last timepoint has no "next" timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            for coi in clusters_of_interest:
                distributions_coi[source_tp][cohort_name][coi] = {}

    # ---------------------------
    # Generate PDFs of movement
    # ---------------------------

    # 1) Single-cell arrows
    output_file_original = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_original) as pdf_original:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_orig, axes_orig = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            # ---------------------------
            # Only draw arrows for COI
            # ---------------------------
            if len(single_cell_displacements) > 0:
                # Coordinates for all sampled source cells
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                
                # Compute the unit vectors for each displacement
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                # If you have cluster info, filter by clusters_of_interest
                if 'seurat_clusters' in sampled_source_cells.columns:
                    for j in range(len(source_coords_sampled)):
                        # Check if this cell is in one of your clusters of interest
                        current_cluster = sampled_source_cells.iloc[j]['seurat_clusters']
                        if current_cluster not in clusters_of_interest:
                            # Skip drawing arrows if not in COI
                            continue
                        
                        sx, sy = source_coords_sampled[j]
                        dx, dy = scaled_vectors[j]
                        ax.arrow(
                            sx, sy, 
                            dx, dy,
                            color='black', 
                            alpha=0.7, 
                            head_width=0.5, 
                            head_length=0.5,
                            length_includes_head=True, 
                            linewidth=0.5
                        )
                else:
                    # If no cluster column, you can either skip or draw all
                    for j in range(len(source_coords_sampled)):
                        sx, sy = source_coords_sampled[j]
                        dx, dy = scaled_vectors[j]
                        ax.arrow(
                            sx, sy, 
                            dx, dy,
                            color='black', 
                            alpha=0.7, 
                            head_width=0.5, 
                            head_length=0.5,
                            length_includes_head=True, 
                            linewidth=0.5
                        )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # 2) Cluster-level arrows
    output_file_cluster = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_cluster) as pdf_cluster:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_clust, axes_clust = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            source_cells = full_data[source_tp]['source']
            if 'seurat_clusters' in source_cells.columns:
                source_cluster_counts = source_cells['seurat_clusters'].value_counts()
                total_source_cells_current = len(source_cells)
            else:
                source_cluster_counts = pd.Series()
                total_source_cells_current = 1
            
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion

                # Outline in black for clarity
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color='black',
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width + 2
                )
                
                # Main arrow in cluster color
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color=arrow_color,
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width
                )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # 3) Single aggregated arrow
    output_file_single_agg = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_single_agg) as pdf_agg:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_agg, axes_agg = plt.subplots(
            1,
            num_timepoints,
            figsize=(4 * num_timepoints, 4),
            sharex=True,
            sharey=True
        )
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(plot_tps):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                ax.arrow(
                    global_cx, global_cy,
                    total_dx, total_dy,
                    color='black', 
                    alpha=0.9,
                    head_width=0.5, 
                    head_length=0.5,
                    length_includes_head=True, 
                    linewidth=2.0
                )

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)
    
    # Return the distributions for stacked-bar plotting
    return distributions_coi


# ---------------------------
# Main Execution Loop
# ---------------------------
for subpop_name in subpopulations:
    # Prepare a nested dict so we can collect distributions
    distributions_coi_all = {
        tp: {
            c: {coi: {} for coi in clusters_of_interest} 
            for c in cohorts
        }
        for tp in all_timepoints
    }
    
    # Run for each cohort
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        
        dist_coi = optimal_transport_visualization(subpop_name, cohort_name)
        if dist_coi is not None:
            for tp in dist_coi:
                for c in dist_coi[tp]:
                    for coi in dist_coi[tp][c]:
                        distributions_coi_all[tp][c][coi] = dist_coi[tp][c][coi]

    # --------------------------------------------------
    # 1) Group non-COI target clusters into "Others"
    # --------------------------------------------------
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                fraction_dict = distributions_coi_all[tp][c][coi]
                if fraction_dict:
                    new_dict = {}
                    others_sum = 0.0
                    for tclust, frac in fraction_dict.items():
                        # If the target cluster is not in our "source" list, 
                        # put it in the "Others" bin
                        if tclust not in clusters_of_interest:
                            others_sum += frac
                        else:
                            new_dict[tclust] = frac
                    if others_sum > 0:
                        new_dict['Others'] = others_sum
                    distributions_coi_all[tp][c][coi] = new_dict

    # Gather all target clusters (including "Others")
    all_target_clusters = []
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                all_target_clusters.extend(distributions_coi_all[tp][c][coi].keys())
    all_target_clusters = set(all_target_clusters)

    # Handle "Others" separately for sorting
    others_present = ("Others" in all_target_clusters)
    if others_present:
        all_target_clusters.remove("Others")

    # Now only numeric cluster IDs remain
    # Sort numeric clusters in ascending order
    all_target_clusters = sorted(all_target_clusters)

    # Put "Others" at the end if present
    if others_present:
        all_target_clusters.append("Others")

    def get_celltype_name(clust_id):
        if clust_id == 'Others':
            return "Others"
        return cluster_celltype_map.get(clust_id, f"Cluster {clust_id}")

    def get_cluster_color(clust_id):
        if clust_id == 'Others':
            return 'gray'
        return cluster_colors_map.get(clust_id, 'gray')

    # ---------------------------
    # Generate the stacked bar chart PDF
    # ---------------------------
    output_file_stacked = os.path.join(base_output_dir, f"{subpop_name}_target_cluster_distribution_coi_with_stack_labels.pdf")
    with PdfPages(output_file_stacked) as pdf_dist:
        fig, axes = plt.subplots(
            len(clusters_of_interest), 
            len(all_timepoints), 
            figsize=(4 * len(all_timepoints), 3 * len(clusters_of_interest)), 
            sharex=False, 
            sharey=True
        )
        
        # Handle shape for single row/col
        if len(clusters_of_interest) == 1 and len(all_timepoints) == 1:
            axes = np.array([[axes]])
        elif len(clusters_of_interest) == 1:
            axes = axes[np.newaxis, :]
        elif len(all_timepoints) == 1:
            axes = axes[:, np.newaxis]

        # Prepare legend patches
        legend_patches = []
        for tc in all_target_clusters:
            legend_patches.append(
                plt.Rectangle((0,0),1,1,
                              facecolor=get_cluster_color(tc),
                              edgecolor='black',
                              label=get_celltype_name(tc))
            )
        
        # Plot bars
        for row_i, coi in enumerate(clusters_of_interest):
            for col_i, tp in enumerate(all_timepoints):
                ax = axes[row_i, col_i]
                bar_positions = np.arange(len(cohorts))
                bottoms = np.zeros(len(cohorts))

                # Retrieve distribution data for each cohort at (tp, coi)
                cohort_distributions = [distributions_coi_all[tp][c][coi] for c in cohorts]

                # Build up data for stacked bars: (target_cluster, [fractions for each cohort])
                stack_data = []
                for tc in all_target_clusters:
                    h = [dist.get(tc, 0.0) for dist in cohort_distributions]
                    stack_data.append((tc, h))

                # Sort so biggest fraction (summed across cohorts) is at the bottom
                stack_data.sort(key=lambda x: sum(x[1]), reverse=True)

                # Plot
                for (tc, heights) in stack_data:
                    color = get_cluster_color(tc)
                    ax.bar(
                        bar_positions, 
                        heights, 
                        bottom=bottoms, 
                        color=color, 
                        edgecolor='black'
                    )
                    # Label each segment if fraction is big enough
                    for idx, val in enumerate(heights):
                        if val > 0.05:
                            mid_y = bottoms[idx] + val / 2
                            ax.text(
                                bar_positions[idx], 
                                mid_y, 
                                get_celltype_name(tc),
                                ha='center', 
                                va='center',
                                fontsize=6, 
                                color='white'
                            )
                    bottoms += heights

                if row_i == 0:
                    ax.set_title(f"{tp}", fontsize=10)

                ax.set_xticks(bar_positions)
                ax.set_xticklabels(cohorts, rotation=45, ha='right', fontsize=8)
                ax.set_ylim(0, 1)

        fig.suptitle(f"Distribution of Target Clusters per COI - {subpop_name}", fontsize=16)
        plt.tight_layout()
        
        # Adjust to make room for row labels & legend
        fig.subplots_adjust(left=0.15, top=0.86, right=0.82)  # extra space on the right for legend

        # Label rows with the celltype of the COI
        n_rows = len(clusters_of_interest)
        for row_i, coi in enumerate(clusters_of_interest):
            y_pos = 0.86 - (row_i + 0.5)*(0.86-0.1)/n_rows
            fig.text(0.05, y_pos, get_celltype_name(coi), va='center', ha='right', fontsize=10, color='black')

        # Create a single-column legend on the right
        fig.legend(
            handles=legend_patches,
            loc='upper left',
            bbox_to_anchor=(0.84, 0.95),
            ncol=1,
            fontsize=8,
            title="Target Celltypes"
        )

        pdf_dist.savefig(fig)
        plt.close(fig)

print("Visualization generation completed.")

Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Memory_Precursor_Effector_CD8 - control
Processing Memory_Precursor_Effector_CD8 - short_term
Processing Memory_Precursor_Effector_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)


Processing Central_Memory_CD8 - control
Processing Central_Memory_CD8 - short_term
Processing Central_Memory_CD8 - long_term
Processing Stem_Like_CD8 - control
Processing Stem_Like_CD8 - short_term
Processing Stem_Like_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing All_CD8 - control
Processing All_CD8 - short_term
Processing All_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot  # for optimal transport
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import chi2_contingency

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi_central_memory_cd8"

# For illustration, we focus on a single subpopulation,
# but you can loop over more if you wish.
subpop_name = "All_CD8"

# We only compare short_term vs long_term
cohorts = ["control", "long_term"]

# We are interested in timepoint C1 -> C2
# but feel free to adapt for other timepoints
source_tp = "C1"
target_tp = "C2"

# Clusters we are interested in as the "source" (COI):
# For instance, cluster 12 = central memory CD8
clusters_of_interest = [12]

# This file should map cluster numbers to colors/celltypes (from your snippet).
color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

# ---------------
# Main function
# ---------------
def run_optimal_transport_and_get_target_clusters(
    subpop_name,
    cohort_name,
    source_tp="C1",
    target_tp="C2",
    base_input_dir=base_input_dir
):
    """
    Runs OT from timepoint `source_tp` to `target_tp` for the given subpop & cohort.
    Returns a Pandas Series (indexed by source cell ID) of the *target cluster*
    that each source cell maps to.
    
    If no data or no target cells exist, returns an empty Series.
    """
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return pd.Series(dtype=int)
    
    # Path for timepoint folders
    source_folder = os.path.join(input_folder, f"Timepoint_{source_tp}")
    target_folder = os.path.join(input_folder, f"Timepoint_{target_tp}")

    if not (os.path.exists(source_folder) and os.path.exists(target_folder)):
        print(f"Either {source_folder} or {target_folder} not found. Skipping.")
        return pd.Series(dtype=int)
    
    source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
    gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
    
    target_cells_file = os.path.join(target_folder, 'source_cells_new.csv')
    gray_cells_target_file = os.path.join(target_folder, 'gray_cells_new.csv')
    
    if not all([os.path.exists(x) for x in [source_cells_file, gray_cells_file, 
                                            target_cells_file, gray_cells_target_file]]):
        print("Missing CSV files in the source/target timepoints.")
        return pd.Series(dtype=int)
    
    # Read data
    source_cells = pd.read_csv(source_cells_file)
    gray_cells_source = pd.read_csv(gray_cells_file)
    target_cells = pd.read_csv(target_cells_file)
    gray_cells_target = pd.read_csv(gray_cells_target_file)
    
    # Filter the source_cells to only keep those in the cluster(s) of interest
    if 'seurat_clusters' not in source_cells.columns:
        print("No 'seurat_clusters' column in source file.")
        return pd.Series(dtype=int)
    source_cells = source_cells[source_cells['seurat_clusters'].isin(clusters_of_interest)]
    
    if source_cells.empty:
        print("No source cells in clusters_of_interest. Skipping.")
        return pd.Series(dtype=int)
    
    # If target is empty, return
    if target_cells.empty:
        print("No target cells at timepoint. Skipping.")
        return pd.Series(dtype=int)

    # Some lumps (based on your snippet)
    # These lumps apparently combine certain cluster IDs
    for cfile in [source_cells, target_cells]:
        cfile['seurat_clusters'] = cfile['seurat_clusters'].replace({16:14,17:14})
        cfile['seurat_clusters'] = cfile['seurat_clusters'].replace({9:6,18:6})
    
    # Identify columns for PCs
    pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
    if not pc_cols:
        print("No PC_ columns found. Are you sure your data has them?")
        return pd.Series(dtype=int)
    
    full_source_coords_pc = source_cells[pc_cols].values
    full_target_coords_pc = target_cells[pc_cols].values
    
    # UMAP coords
    full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
    full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
    
    # Prepare uniform distributions for OT
    a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
    b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
    
    cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
    
    try:
        transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
    except Exception as e:
        print(f"OT computation failed: {e}")
        return pd.Series(dtype=int)
    
    # For each source cell, find the single best-matching target (argmax in each row)
    full_target_indices = np.argmax(transport_plan, axis=1)
    
    # Convert those indices into target clusters
    target_clusters = target_cells['seurat_clusters'].iloc[full_target_indices].values
    
    # Return as a Pandas Series keyed by the source cell's index
    out_series = pd.Series(data=target_clusters, index=source_cells.index)
    return out_series

# ------------------------------------------------------------
# Collect target‐cluster assignments for short‐term vs. long‐term
# ------------------------------------------------------------
short_term_targets = run_optimal_transport_and_get_target_clusters(
    subpop_name,
    "control",
    source_tp,
    target_tp
)

long_term_targets = run_optimal_transport_and_get_target_clusters(
    subpop_name,
    "long_term",
    source_tp,
    target_tp
)

# ------------------------------------------------------------
# Build contingency table of target clusters
# ------------------------------------------------------------
# We just count how many times each target cluster appears in short vs. long
short_counts = short_term_targets.value_counts().sort_index()
long_counts  = long_term_targets.value_counts().sort_index()

# Union of all target clusters that appear in either ST or LT
all_clusters = sorted(set(short_counts.index).union(long_counts.index))

# Make a table with row = cluster, columns = [ShortTerm, LongTerm]
# e.g. in a 2D array
rows = []
for clust_id in all_clusters:
    st_count = short_counts.get(clust_id, 0)
    lt_count = long_counts.get(clust_id, 0)
    rows.append([st_count, lt_count])

contingency = np.array(rows)

# Optionally put into a DataFrame with labels
contingency_df = pd.DataFrame(
    contingency,
    index=all_clusters,
    # columns=["ShortTerm","LongTerm"]
    columns=["Control","LongTerm"]
)

print("\n=== Contingency Table of Target Clusters ===")
print(contingency_df)

# ------------------------------------------------------------
# Run Chi-square test
# ------------------------------------------------------------
if contingency_df.shape[0] <= 1:
    print("\nNot enough distinct clusters to do a Chi-square test.")
else:
    chi2_stat, p_value, dof, expected = chi2_contingency(contingency_df.values)

    print("\n=== Chi-square Test Results ===")
    print(f"Chi-square statistic = {chi2_stat:.3f}")
    print(f"Degrees of freedom   = {dof}")
    print(f"p-value             = {p_value:.6g}")

    alpha = 0.05
    if p_value < alpha:
        print("Reject H₀ → The distribution of target clusters differs between Control and LT.")
    else:
        print("Fail to reject H₀ → No evidence of a difference in target cluster distribution.")

print("\nDone.")



=== Contingency Table of Target Clusters ===
    Control  LongTerm
1        17       357
2         2       250
3        23       445
8         9       163
10        9       160
12        2        20
14        5        71

=== Chi-square Test Results ===
Chi-square statistic = 10.807
Degrees of freedom   = 6
p-value             = 0.0945246
Fail to reject H₀ → No evidence of a difference in target cluster distribution.

Done.
