In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired arrow length for unit vectors
arrow_length = 0.5  # Adjust as needed

# ---------------------------
# Visualization Function
# ---------------------------

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    # Get list of available timepoint folders
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    
    # Ensure timepoints are in the correct order
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    # Check if there are any timepoints to process
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return
    
    # Initialize lists to store axes limits
    all_x_coords = []
    all_y_coords = []
    
    # First pass to collect coordinates for consistent axis limits
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
        
        # Check if source files exist
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for timepoint {tp} in {cohort_name} cohort. Skipping.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for timepoint {tp} in {cohort_name} cohort.")
            continue
        
        # Collect coordinates (flipped)
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
    
    # Check if any coordinates were collected
    if not all_x_coords or not all_y_coords:
        print(f"No coordinates found for {subpop_name} in {cohort_name} cohort.")
        return
    
    # Determine consistent axis limits
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Start a PDF to save all plots in one file
    output_file = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots.pdf")
    with PdfPages(output_file) as pdf:
        # Create a figure with subplots arranged in one row
        num_timepoints = len(cohort_timepoints)
        fig, axes = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        
        # Ensure axes is iterable
        if num_timepoints == 1:
            axes = [axes]
        
        for idx, ax in enumerate(axes):
            source_tp = cohort_timepoints[idx]
            source_folder = os.path.join(input_folder, f"Timepoint_{source_tp}")
            
            source_cells_file = os.path.join(source_folder, 'source_cells.csv')
            gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
            
            # Check if source files exist
            if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
                print(f"Missing files for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue  # Skip if any source file is missing
            
            source_cells = pd.read_csv(source_cells_file)
            gray_cells = pd.read_csv(gray_cells_file)
            
            if source_cells.empty:
                print(f"No source cells for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue
            
            # Flip UMAP coordinates by swapping UMAP_1 and UMAP_2
            gray_x = gray_cells['UMAP_2']
            gray_y = gray_cells['UMAP_1']
            source_x = source_cells['UMAP_2']
            source_y = source_cells['UMAP_1']
            
            # Plot gray background cells
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5, label='Background')
            # Plot source cells
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1, label='Highlighted')
            
            # For all but the last timepoint, compute optimal transport and plot arrows
            if idx < len(cohort_timepoints) - 1:
                target_tp = cohort_timepoints[idx + 1]
                target_folder = os.path.join(input_folder, f"Timepoint_{target_tp}")
                
                target_cells_file = os.path.join(target_folder, 'source_cells.csv')  # Target highlighted cells
                
                # Check if target_cells_file exists
                if not os.path.exists(target_cells_file):
                    print(f"Missing target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue  # Skip if target file is missing
                
                target_cells = pd.read_csv(target_cells_file)
                
                if target_cells.empty:
                    print(f"No target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue
                
                # Extract UMAP coordinates (flipped)
                source_coords = source_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                target_coords = target_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                
                # Compute cost matrix (Euclidean distance)
                cost_matrix = ot.dist(source_coords, target_coords, metric='euclidean')
                
                # Uniform weights for source and target
                a = np.ones((source_coords.shape[0],)) / source_coords.shape[0]
                b = np.ones((target_coords.shape[0],)) / target_coords.shape[0]
                
                # Compute optimal transport plan
                transport_plan = ot.emd(a, b, cost_matrix)
                
                # Assign each source cell to the target cell with the highest transport probability
                target_indices = np.argmax(transport_plan, axis=1)
                
                # Compute displacement vectors from source to assigned target
                displacement_vectors = target_coords[target_indices] - source_coords
                
                # Normalize vectors to unit length
                norms = np.linalg.norm(displacement_vectors, axis=1, keepdims=True)
                norms[norms == 0] = 1  # Avoid division by zero
                unit_vectors = displacement_vectors / norms
                
                # Scale unit vectors to desired length
                scaled_vectors = unit_vectors * arrow_length
                
                # Draw unit length arrows
                for j in range(len(source_coords)):
                    ax.arrow(source_coords[j, 0], source_coords[j, 1],
                             scaled_vectors[j, 0], scaled_vectors[j, 1],
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05, 
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min-1, x_max+1)
            ax.set_ylim(y_min-1, y_max+1)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        # Add a global title
        fig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        
        # Adjust layout to accommodate the global title
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
        # Save the figure to PDF
        pdf.savefig(fig)
        plt.close(fig)

# ---------------------------
# Main Execution Loop
# ---------------------------

# Loop over subpopulations and cohorts
for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue  # Skip if input subfolder does not exist
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term
Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term
Processing Hyperactivated_CD8 - control
Processing Hyperactivated_CD8 - short_term
Processing Hyperactivated_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing CD8 - control
Processing CD8 - short_term
Processing CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [3]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
min_arrow_length = 0.3  # Adjust as needed
max_arrow_length = 2  # Adjust as needed

# ---------------------------
# Visualization Function
# ---------------------------

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    # Get list of available timepoint folders
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    
    # Ensure timepoints are in the correct order
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    # Check if there are any timepoints to process
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return
    
    # Initialize lists to store axes limits
    all_x_coords = []
    all_y_coords = []
    
    # First pass to collect coordinates for consistent axis limits
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
        
        # Check if source files exist
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for timepoint {tp} in {cohort_name} cohort. Skipping.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for timepoint {tp} in {cohort_name} cohort.")
            continue
        
        # Collect coordinates (flipped)
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
    
    # Check if any coordinates were collected
    if not all_x_coords or not all_y_coords:
        print(f"No coordinates found for {subpop_name} in {cohort_name} cohort.")
        return
    
    # Determine consistent axis limits
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Start a PDF to save all plots in one file
    output_file = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths.pdf")
    with PdfPages(output_file) as pdf:
        # Create a figure with subplots arranged in one row
        num_timepoints = len(cohort_timepoints)
        fig, axes = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        
        # Ensure axes is iterable
        if num_timepoints == 1:
            axes = [axes]
        
        for idx, ax in enumerate(axes):
            source_tp = cohort_timepoints[idx]
            source_folder = os.path.join(input_folder, f"Timepoint_{source_tp}")
            
            source_cells_file = os.path.join(source_folder, 'source_cells.csv')
            gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
            
            # Check if source files exist
            if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
                print(f"Missing files for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue  # Skip if any source file is missing
            
            source_cells = pd.read_csv(source_cells_file)
            gray_cells = pd.read_csv(gray_cells_file)
            
            if source_cells.empty:
                print(f"No source cells for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue
            
            # Flip UMAP coordinates by swapping UMAP_1 and UMAP_2
            gray_x = gray_cells['UMAP_2']
            gray_y = gray_cells['UMAP_1']
            source_x = source_cells['UMAP_2']
            source_y = source_cells['UMAP_1']
            
            # Plot gray background cells with reduced size
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5, label='Background')  # Reduced size
            
            # Plot source cells with reduced size
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1, label='Highlighted')  # Reduced size
            
            # For all but the last timepoint, compute optimal transport and plot arrows
            if idx < len(cohort_timepoints) - 1:
                target_tp = cohort_timepoints[idx + 1]
                target_folder = os.path.join(input_folder, f"Timepoint_{target_tp}")
                
                target_cells_file = os.path.join(target_folder, 'source_cells.csv')  # Target highlighted cells
                
                # Check if target_cells_file exists
                if not os.path.exists(target_cells_file):
                    print(f"Missing target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue  # Skip if target file is missing
                
                target_cells = pd.read_csv(target_cells_file)
                
                if target_cells.empty:
                    print(f"No target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue
                
                # Extract UMAP coordinates (flipped)
                source_coords = source_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                target_coords = target_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                
                # Compute cost matrix (Euclidean distance)
                cost_matrix = ot.dist(source_coords, target_coords, metric='euclidean')
                
                # Uniform weights for source and target
                a = np.ones((source_coords.shape[0],)) / source_coords.shape[0]
                b = np.ones((target_coords.shape[0],)) / target_coords.shape[0]
                
                # Compute optimal transport plan with increased numItermax
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"Optimal transport computation failed for timepoint {source_tp} to {target_tp}: {e}")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue
                
                # Assign each source cell to the target cell with the highest transport probability
                target_indices = np.argmax(transport_plan, axis=1)
                
                # Compute displacement vectors from source to assigned target
                displacement_vectors = target_coords[target_indices] - source_coords
                
                # Compute norms (lengths) of displacement vectors
                norms = np.linalg.norm(displacement_vectors, axis=1)
                
                # Map norms to desired arrow length range [min_arrow_length, max_arrow_length]
                min_norm = np.min(norms)
                max_norm = np.max(norms)
                
                if max_norm - min_norm > 0:
                    arrow_lengths = ((norms - min_norm) / (max_norm - min_norm)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                else:
                    arrow_lengths = np.full_like(norms, min_arrow_length)
                
                # Compute unit vectors
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = displacement_vectors / norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0  # Handle divisions by zero
                
                # Compute scaled vectors
                scaled_vectors = unit_vectors * arrow_lengths[:, np.newaxis]
                
                # Draw arrows with lengths representing distances
                for j in range(len(source_coords)):
                    ax.arrow(source_coords[j, 0], source_coords[j, 1],
                             scaled_vectors[j, 0], scaled_vectors[j, 1],
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05, 
                             length_includes_head=True, linewidth=0.5)
                
            # Set titles and limits
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        # Add a global title
        fig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        
        # Adjust layout to accommodate the global title
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
        # Save the figure to PDF
        pdf.savefig(fig)
        plt.close(fig)

# ---------------------------
# Main Execution Loop
# ---------------------------

# Loop over subpopulations and cohorts
for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue  # Skip if input subfolder does not exist
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term
Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term
Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term
Processing Hyperactivated_CD8 - control
Processing Hyperactivated_CD8 - short_term
Processing Hyperactivated_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing CD8 - control
Processing CD8 - short_term
Processing CD8 - long_term


  result_code_string = check_result(result_code)


Visualization generation completed.


In [2]:
# equal number of cells at each timepoint

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
min_arrow_length = 0.3  # Adjust as needed
max_arrow_length = 2  # Adjust as needed

# ---------------------------
# Visualization Function
# ---------------------------

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    # Get list of available timepoint folders
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    
    # Ensure timepoints are in the correct order
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    # Check if there are any timepoints to process
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return
    
    # Initialize lists to store axes limits
    all_x_coords = []
    all_y_coords = []
    
    # Initialize a dictionary to store cell counts for each timepoint
    cell_counts = {'source': {}, 'gray': {}}
    
    # First pass to collect coordinates for consistent axis limits and cell counts
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
        
        # Check if source files exist
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for timepoint {tp} in {cohort_name} cohort. Skipping.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for timepoint {tp} in {cohort_name} cohort.")
            continue
        
        # Collect coordinates (flipped)
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        # Store cell counts
        cell_counts['source'][tp] = len(source_cells)
        cell_counts['gray'][tp] = len(gray_cells)
    
    # Determine the minimum number of source and gray cells across all timepoints
    if cell_counts['source']:
        min_source_cells = min(cell_counts['source'].values())
    else:
        min_source_cells = 0
    
    if cell_counts['gray']:
        min_gray_cells = min(cell_counts['gray'].values())
    else:
        min_gray_cells = 0
    
    # Handle cases where no cells are found
    if min_source_cells == 0 or min_gray_cells == 0:
        print(f"Insufficient cells for uniform sampling in {subpop_name} - {cohort_name}. Skipping visualization.")
        return
    
    print(f"Minimum number of source cells across timepoints: {min_source_cells}")
    print(f"Minimum number of gray cells across timepoints: {min_gray_cells}")
    
    # Determine consistent axis limits
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Start a PDF to save all plots in one file
    output_file = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells.pdf")
    with PdfPages(output_file) as pdf:
        # Create a figure with subplots arranged in one row
        num_timepoints = len(cohort_timepoints)
        fig, axes = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        
        # Ensure axes is iterable
        if num_timepoints == 1:
            axes = [axes]
        
        for idx, ax in enumerate(axes):
            source_tp = cohort_timepoints[idx]
            source_folder = os.path.join(input_folder, f"Timepoint_{source_tp}")
            
            source_cells_file = os.path.join(source_folder, 'source_cells.csv')
            gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
            
            # Check if source files exist
            if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
                print(f"Missing files for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue  # Skip if any source file is missing
            
            source_cells = pd.read_csv(source_cells_file)
            gray_cells = pd.read_csv(gray_cells_file)
            
            if source_cells.empty:
                print(f"No source cells for timepoint {source_tp}. Skipping subplot.")
                ax.axis('off')
                continue
            
            # Sample source_cells and gray_cells to have uniform counts
            if len(source_cells) >= min_source_cells:
                sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
            else:
                print(f"Not enough source cells in {source_tp} for uniform sampling. Skipping subplot.")
                ax.axis('off')
                continue
            
            if len(gray_cells) >= min_gray_cells:
                sampled_gray_cells = gray_cells.sample(n=min_gray_cells, random_state=42)
            else:
                print(f"Not enough gray cells in {source_tp} for uniform sampling. Skipping subplot.")
                ax.axis('off')
                continue
            
            # Flip UMAP coordinates by swapping UMAP_1 and UMAP_2
            gray_x = sampled_gray_cells['UMAP_2']
            gray_y = sampled_gray_cells['UMAP_1']
            source_x = sampled_source_cells['UMAP_2']
            source_y = sampled_source_cells['UMAP_1']
            
            # Plot gray background cells with reduced size
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5, label='Background')  # Reduced size
            
            # Plot source cells with reduced size
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1, label='Highlighted')  # Reduced size
            
            # For all but the last timepoint, compute optimal transport and plot arrows
            if idx < len(cohort_timepoints) - 1:
                target_tp = cohort_timepoints[idx + 1]
                target_folder = os.path.join(input_folder, f"Timepoint_{target_tp}")
                
                target_cells_file = os.path.join(target_folder, 'source_cells.csv')  # Target highlighted cells
                
                # Check if target_cells_file exists
                if not os.path.exists(target_cells_file):
                    print(f"Missing target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue  # Skip if target file is missing
                
                target_cells = pd.read_csv(target_cells_file)
                
                if target_cells.empty:
                    print(f"No target cells for timepoint {target_tp}. No arrows will be plotted for {source_tp}.")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue
                
                # Extract UMAP coordinates (flipped)
                source_coords = source_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                target_coords = target_cells[['UMAP_2', 'UMAP_1']].values  # [x, y]
                
                # Compute cost matrix (Euclidean distance)
                cost_matrix = ot.dist(source_coords, target_coords, metric='euclidean')
                
                # Uniform weights for source and target
                a = np.ones((source_coords.shape[0],)) / source_coords.shape[0]
                b = np.ones((target_coords.shape[0],)) / target_coords.shape[0]
                
                # Compute optimal transport plan with increased numItermax
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"Optimal transport computation failed for timepoint {source_tp} to {target_tp}: {e}")
                    ax.set_title(f"Timepoint: {source_tp}")
                    continue
                
                # Assign each source cell to the target cell with the highest transport probability
                target_indices = np.argmax(transport_plan, axis=1)
                
                # Compute displacement vectors from source to assigned target
                displacement_vectors = target_coords[target_indices] - source_coords
                
                # Compute norms (lengths) of displacement vectors
                norms = np.linalg.norm(displacement_vectors, axis=1)
                
                # Map norms to desired arrow length range [min_arrow_length, max_arrow_length]
                min_norm = np.min(norms)
                max_norm = np.max(norms)
                
                if max_norm - min_norm > 0:
                    arrow_lengths = ((norms - min_norm) / (max_norm - min_norm)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                else:
                    arrow_lengths = np.full_like(norms, min_arrow_length)
                
                # Compute unit vectors
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = displacement_vectors / norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0  # Handle divisions by zero
                
                # Compute scaled vectors
                scaled_vectors = unit_vectors * arrow_lengths[:, np.newaxis]
                
                # Draw arrows with lengths representing distances
                for j in range(len(source_coords)):
                    ax.arrow(source_coords[j, 0], source_coords[j, 1],
                             scaled_vectors[j, 0], scaled_vectors[j, 1],
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05, 
                             length_includes_head=True, linewidth=0.5)
                
            # Set titles and limits
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        # Add a global title
        fig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        
        # Adjust layout to accommodate the global title
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        
        # Save the figure to PDF
        pdf.savefig(fig)
        plt.close(fig)


# ---------------------------
# Main Execution Loop
# ---------------------------

# Loop over subpopulations and cohorts
for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue  # Skip if input subfolder does not exist
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Minimum number of source cells across timepoints: 143
Minimum number of gray cells across timepoints: 2165
Processing Activated_CD4 - short_term
Minimum number of source cells across timepoints: 445
Minimum number of gray cells across timepoints: 1315
Processing Activated_CD4 - long_term
Minimum number of source cells across timepoints: 206
Minimum number of gray cells across timepoints: 1474
Processing Effector_CD8 - control
Minimum number of source cells across timepoints: 449
Minimum number of gray cells across timepoints: 1939
Processing Effector_CD8 - short_term
Minimum number of source cells across timepoints: 285
Minimum number of gray cells across timepoints: 1479
Processing Effector_CD8 - long_term
Minimum number of source cells across timepoints: 571
Minimum number of gray cells across timepoints: 1428
Processing Effector_Memory_CD8 - control
Minimum number of source cells across timepoints: 339
Minimum number of gray cells across timepoints

  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Minimum number of source cells across timepoints: 248
Minimum number of gray cells across timepoints: 2140
Processing Gamma_Delta_T - short_term
Minimum number of source cells across timepoints: 102
Minimum number of gray cells across timepoints: 1639
Processing Gamma_Delta_T - long_term
Minimum number of source cells across timepoints: 438
Minimum number of gray cells across timepoints: 1561
Processing Active_CD4 - control
Minimum number of source cells across timepoints: 327
Minimum number of gray cells across timepoints: 2061
Processing Active_CD4 - short_term
Minimum number of source cells across timepoints: 310
Minimum number of gray cells across timepoints: 1530
Processing Active_CD4 - long_term
Minimum number of source cells across timepoints: 357
Minimum number of gray cells across timepoints: 1591
Processing Naive_CD4 - control
Minimum number of source cells across timepoints: 454
Minimum number of gray cells across timepoints: 1934
Processin

  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Minimum number of source cells across timepoints: 7
Minimum number of gray cells across timepoints: 2377
Processing Memory_CD4 - short_term
Minimum number of source cells across timepoints: 10
Minimum number of gray cells across timepoints: 1760
Processing Memory_CD4 - long_term
Minimum number of source cells across timepoints: 171
Minimum number of gray cells across timepoints: 1768


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Minimum number of source cells across timepoints: 404
Minimum number of gray cells across timepoints: 1982
Processing Memory_CD8 - short_term
Minimum number of source cells across timepoints: 308
Minimum number of gray cells across timepoints: 1370
Processing Memory_CD8 - long_term
Minimum number of source cells across timepoints: 536
Minimum number of gray cells across timepoints: 1463
Processing Anergic_CD8 - control
Minimum number of source cells across timepoints: 300
Minimum number of gray cells across timepoints: 2076
Processing Anergic_CD8 - short_term
Minimum number of source cells across timepoints: 224
Minimum number of gray cells across timepoints: 1475
Processing Anergic_CD8 - long_term
Minimum number of source cells across timepoints: 504
Minimum number of gray cells across timepoints: 1487
Processing Naive_CD8 - control
Minimum number of source cells across timepoints: 161
Minimum number of gray cells across timepoints: 2227
Processing Naiv

  result_code_string = check_result(result_code)


Visualization generation completed.


In [None]:
# Number of cells to sample for single-cell plotting (to reduce overcrowding)
# This can be the minimum number found across timepoints, or a fixed number.
# Here we will determine a uniform sampling size based on the smallest dataset
# found across timepoints, same as before, but only for plotting.

In [4]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
min_arrow_length = 0.3  # Adjust as needed
max_arrow_length = 2    # Adjust as needed

# Number of cells to sample for single-cell plotting (to reduce overcrowding)
# This can be the minimum number found across timepoints, or a fixed number.
# Here we will determine a uniform sampling size based on the smallest dataset
# found across timepoints, same as before, but only for plotting.
max_plot_cells = None  # will determine at runtime

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return

    # We'll store coordinates and also find global axis limits
    all_x_coords = []
    all_y_coords = []
    
    # Load all source and gray cells (full sets) for axis limits
    # We also determine the smallest number of source and gray cells across timepoints
    # for single-cell plotting (downsampling).
    cell_counts_source = []
    cell_counts_gray = []
    
    full_data = {}  # store full data for each timepoint
    
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        # Store full data
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        # For axis limits
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Determine number of cells to plot for single-cell arrows (downsample for plotting only)
    # We use the minimum across all timepoints to keep it consistent.
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return
    
    # We'll do OT and cluster arrow computations using the full sets,
    # but only plot a subsample of source & gray cells (and their arrows) in the single-cell arrow PDF.
    
    # We'll store results for both single-cell (subsampled) plotting and cluster-level arrows (full).
    timepoint_results = {}
    
    # We need to compute OT from each timepoint to the next (except the last one)
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Store initial data for plotting
        # Downsample for plotting single-cell arrows only
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # No target cells, no OT
                continue
            
            # Compute OT on full sets, not sampled
            full_source_coords = source_cells[['UMAP_2', 'UMAP_1']].values
            full_target_coords = target_cells[['UMAP_2', 'UMAP_1']].values
            
            # Uniform distributions
            a = np.ones((full_source_coords.shape[0],)) / full_source_coords.shape[0]
            b = np.ones((full_target_coords.shape[0],)) / full_target_coords.shape[0]
            
            # Compute cost matrix
            cost_matrix = ot.dist(full_source_coords, full_target_coords, metric='euclidean')
            
            try:
                transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
            except Exception as e:
                print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                continue
            
            # Assign each full source cell to a target cell
            full_target_indices = np.argmax(transport_plan, axis=1)
            displacement_vectors_full = full_target_coords[full_target_indices] - full_source_coords
            
            # Compute norms
            norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
            
            # Scale arrow lengths for cluster-level arrows
            min_norm_full = np.min(norms_full)
            max_norm_full = np.max(norms_full)
            
            if max_norm_full - min_norm_full > 0:
                arrow_lengths_full = ((norms_full - min_norm_full) / (max_norm_full - min_norm_full)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_full = np.full_like(norms_full, min_arrow_length)
            
            # For single-cell plotting, we also need displacement vectors of sampled sets
            # Assign each sampled source cell to target cell based on the full transport plan
            # We must find their indices in the full set
            # Matching sampled source cells to full source:
            sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
            sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
            sampled_norms = norms_full[sampled_source_indices]
            
            # Scale for sampled set plotting
            min_norm_sampled = np.min(sampled_norms)
            max_norm_sampled = np.max(sampled_norms)
            if max_norm_sampled - min_norm_sampled > 0:
                arrow_lengths_sampled = ((sampled_norms - min_norm_sampled) / (max_norm_sampled - min_norm_sampled)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_sampled = np.full_like(sampled_norms, min_arrow_length)
            
            # Store single-cell arrow info (for plotting)
            timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = arrow_lengths_sampled
            
            clusters_of_interest = {1, 2, 3, 8, 10, 12, 14}

            # Compute cluster-level arrows using full data (not sampled)
            if 'seurat_clusters' in source_cells.columns:
                # Replace 16 and 17 with 14 before computing cluster-level arrows
                source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16: 14, 17: 14})
                source_clusters_full = source_cells['seurat_clusters'].values
                df_cluster_full = pd.DataFrame({
                    'cluster': source_clusters_full,
                    'sx': full_source_coords[:,0],
                    'sy': full_source_coords[:,1],
                    'dx': displacement_vectors_full[:,0],
                    'dy': displacement_vectors_full[:,1],
                    'norm': norms_full
                })
                
                # Compute cluster aggregates
                group = df_cluster_full.groupby('cluster')
                centroids = group[['sx','sy']].mean()
                mean_disp = group[['dx','dy']].mean()
            
                cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                cn_min = cluster_norms.min()
                cn_max = cluster_norms.max()
                
                # Ensure cluster_arrow_lengths is a Series
                if cn_max - cn_min > 0:
                    cluster_arrow_lengths = ((cluster_norms - cn_min) / (cn_max - cn_min)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                else:
                    cluster_arrow_lengths = pd.Series(np.full_like(cluster_norms, min_arrow_length), index=cluster_norms.index)
                
                cluster_arrows = []
                for clust in centroids.index:
                    # Only show aggregated arrow if cluster is in clusters_of_interest
                    if clust in clusters_of_interest:
                        cx, cy = centroids.loc[clust, ['sx','sy']]
                        cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                        cnorm = np.sqrt(cdx**2 + cdy**2)
                        if cnorm > 0:
                            cdx = cdx / cnorm
                            cdy = cdy / cnorm
                        else:
                            cdx, cdy = 0, 0
            
                        # Safely access cluster_arrow_lengths
                        try:
                            length = cluster_arrow_lengths.loc[clust]
                        except AttributeError:
                            # If this happens, convert cluster_arrow_lengths to Series and retry
                            print("cluster_arrow_lengths is not a Series. Attempting to convert...")
                            if isinstance(cluster_arrow_lengths, np.ndarray):
                                cluster_arrow_lengths = pd.Series(cluster_arrow_lengths, index=cluster_norms.index)
                            try:
                                length = cluster_arrow_lengths.loc[clust]
                            except Exception as e:
                                print(f"Failed to access cluster_arrow_lengths for cluster {clust}: {e}")
                                length = min_arrow_length
            
                        cdx *= length
                        cdy *= length
                        cluster_arrows.append((cx, cy, cdx, cdy))
                
                timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                print(f"No 'seurat_clusters' column in {source_tp} source cells. Cannot compute cluster arrows.")

            
        else:
            # Last timepoint has no next timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([]) # no arrows
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            if 'seurat_clusters' in source_cells.columns:
                # Still record that cluster info was present
                # but no arrows since there's no next timepoint
                timepoint_results[source_tp]['cluster_arrows'] = []
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
    
    # Now plot the original single-cell arrows PDF
    output_file_original = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells.pdf")
    with PdfPages(output_file_original) as pdf_original:
        num_timepoints = len(cohort_timepoints)
        fig_orig, axes_orig = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                # Compute unit vectors for sampled
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(sx, sy, dx, dy,
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05,
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # Now plot aggregated cluster arrows in a new PDF
    output_file_cluster = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells.pdf")
    with PdfPages(output_file_cluster) as pdf_cluster:
        num_timepoints = len(cohort_timepoints)
        fig_clust, axes_clust = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            # For cluster-level plot, we still show the same sampled source/gray points as background
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            for (cx, cy, cdx, cdy) in cluster_arrows:
                ax.arrow(cx, cy, cdx, cdy,
                         color='black', alpha=0.7, head_width=0.1, head_length=0.1,
                         length_includes_head=True, linewidth=1.0)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

# ---------------------------
# Main Execution Loop
# ---------------------------

for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term
Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term
Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term
Processing Hyperactivated_CD8 - control
Processing Hyperactivated_CD8 - short_term
Processing Hyperactivated_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing CD8 - control
Processing CD8 - short_term
Processing CD8 - long_term


  result_code_string = check_result(result_code)


Visualization generation completed.


In [1]:
# In the following version we use PCs to calculate the distance instead of UMAP

In [7]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
min_arrow_length = 0.3  # Adjust as needed
max_arrow_length = 2    # Adjust as needed

# Number of cells to sample for single-cell plotting (to reduce overcrowding)
# This can be the minimum number found across timepoints, or a fixed number.
# Here we will determine a uniform sampling size based on the smallest dataset
# found across timepoints, same as before, but only for plotting.
max_plot_cells = None  # will determine at runtime

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return

    # We'll store coordinates and also find global axis limits
    all_x_coords = []
    all_y_coords = []
    
    # Load all source and gray cells (full sets) for axis limits
    # We also determine the smallest number of source and gray cells across timepoints
    # for single-cell plotting (downsampling).
    cell_counts_source = []
    cell_counts_gray = []
    
    full_data = {}  # store full data for each timepoint
    
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        # Store full data
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        # For axis limits
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Determine number of cells to plot for single-cell arrows (downsample for plotting only)
    # We use the minimum across all timepoints to keep it consistent.
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return
    
    # We'll do OT and cluster arrow computations using the full sets,
    # but only plot a subsample of source & gray cells (and their arrows) in the single-cell arrow PDF.
    
    # We'll store results for both single-cell (subsampled) plotting and cluster-level arrows (full).
    timepoint_results = {}
    
    # We need to compute OT from each timepoint to the next (except the last one)
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Store initial data for plotting
        # Downsample for plotting single-cell arrows only
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # No target cells, no OT
                continue
            
            pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
            full_source_coords = source_cells[pc_cols].values
            full_target_coords = target_cells[pc_cols].values
            
            full_source_coords_pc = source_cells[pc_cols].values
            full_target_coords_pc = target_cells[pc_cols].values
            
            # Also store UMAP coordinates for plotting and displacement vectors
            full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
            full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
            
            # Uniform distributions
            a = np.ones((full_source_coords.shape[0],)) / full_source_coords.shape[0]
            b = np.ones((full_target_coords.shape[0],)) / full_target_coords.shape[0]
            
            # Compute cost matrix
            cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
            
            try:
                transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
            except Exception as e:
                print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                continue
            
            # Assign each full source cell to a target cell
            full_target_indices = np.argmax(transport_plan, axis=1)
            displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
            
            # Compute norms
            norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
            
            # Scale arrow lengths for cluster-level arrows
            min_norm_full = np.min(norms_full)
            max_norm_full = np.max(norms_full)
            
            if max_norm_full - min_norm_full > 0:
                arrow_lengths_full = ((norms_full - min_norm_full) / (max_norm_full - min_norm_full)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_full = np.full_like(norms_full, min_arrow_length)
            
            # For single-cell plotting, we also need displacement vectors of sampled sets
            # Assign each sampled source cell to target cell based on the full transport plan
            # We must find their indices in the full set
            # Matching sampled source cells to full source:
            sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
            sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
            sampled_norms = norms_full[sampled_source_indices]
            
            # Scale for sampled set plotting
            min_norm_sampled = np.min(sampled_norms)
            max_norm_sampled = np.max(sampled_norms)
            if max_norm_sampled - min_norm_sampled > 0:
                arrow_lengths_sampled = ((sampled_norms - min_norm_sampled) / (max_norm_sampled - min_norm_sampled)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_sampled = np.full_like(sampled_norms, min_arrow_length)
            
            # Store single-cell arrow info (for plotting)
            timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = arrow_lengths_sampled
            
            clusters_of_interest = {1, 2, 3, 8, 10, 12, 14}

            # Compute cluster-level arrows using full data (not sampled)
            if 'seurat_clusters' in source_cells.columns:
                # Replace 16 and 17 with 14 before computing cluster-level arrows
                source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16: 14, 17: 14})
                source_clusters_full = source_cells['seurat_clusters'].values
                df_cluster_full = pd.DataFrame({
                    'cluster': source_clusters_full,
                    'sx': full_source_coords_umap[:,0],
                    'sy': full_source_coords_umap[:,1],
                    'dx': displacement_vectors_full[:,0],
                    'dy': displacement_vectors_full[:,1],
                    'norm': norms_full
                })
                
                # Combine source and gray cells
                all_cells_combined = pd.concat([source_cells, gray_cells])
                
                # Replace 16 and 17 with 14 in the combined set as well
                all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                
                # Extract UMAP coordinates for all cells
                all_clusters = all_cells_combined['seurat_clusters'].values
                all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                
                # Create a DataFrame for centroid calculation
                df_all = pd.DataFrame({
                    'cluster': all_clusters,
                    'sx': all_coords_umap[:,0],
                    'sy': all_coords_umap[:,1]
                })
                
                # Compute centroids from all cells (source + gray)
                all_group = df_all.groupby('cluster')
                # centroids = all_group[['sx','sy']].mean()
                centroids = all_group[['sx','sy']].median()

                
                # Mean displacement vectors still come from source cells only
                group = df_cluster_full.groupby('cluster')
                mean_disp = group[['dx','dy']].mean()
            
                cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                cn_min = cluster_norms.min()
                cn_max = cluster_norms.max()
                
                # Ensure cluster_arrow_lengths is a Series
                if cn_max - cn_min > 0:
                    cluster_arrow_lengths = ((cluster_norms - cn_min) / (cn_max - cn_min)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                else:
                    cluster_arrow_lengths = pd.Series(np.full_like(cluster_norms, min_arrow_length), index=cluster_norms.index)
                
                cluster_arrows = []

                # Only consider clusters present in both centroids and mean_disp
                common_clusters = centroids.index.intersection(mean_disp.index)
                for clust in common_clusters:
                    # Only show aggregated arrow if cluster is in clusters_of_interest
                    if clust in clusters_of_interest:
                        cx, cy = centroids.loc[clust, ['sx','sy']]
                        cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                        cnorm = np.sqrt(cdx**2 + cdy**2)
                        if cnorm > 0:
                            cdx = cdx / cnorm
                            cdy = cdy / cnorm
                        else:
                            cdx, cdy = 0, 0
            
                        # Safely access cluster_arrow_lengths
                        try:
                            length = cluster_arrow_lengths.loc[clust]
                        except AttributeError:
                            # If this happens, convert cluster_arrow_lengths to Series and retry
                            print("cluster_arrow_lengths is not a Series. Attempting to convert...")
                            if isinstance(cluster_arrow_lengths, np.ndarray):
                                cluster_arrow_lengths = pd.Series(cluster_arrow_lengths, index=cluster_norms.index)
                            try:
                                length = cluster_arrow_lengths.loc[clust]
                            except Exception as e:
                                print(f"Failed to access cluster_arrow_lengths for cluster {clust}: {e}")
                                length = min_arrow_length
            
                        cdx *= length
                        cdy *= length
                        cluster_arrows.append((cx, cy, cdx, cdy))
                
                timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                print(f"No 'seurat_clusters' column in {source_tp} source cells. Cannot compute cluster arrows.")

            
        else:
            # Last timepoint has no next timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([]) # no arrows
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            if 'seurat_clusters' in source_cells.columns:
                # Still record that cluster info was present
                # but no arrows since there's no next timepoint
                timepoint_results[source_tp]['cluster_arrows'] = []
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
    
    # Now plot the original single-cell arrows PDF
    output_file_original = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_original) as pdf_original:
        num_timepoints = len(cohort_timepoints)
        fig_orig, axes_orig = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                # Compute unit vectors for sampled
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(sx, sy, dx, dy,
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05,
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # Now plot aggregated cluster arrows in a new PDF
    output_file_cluster = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_cluster) as pdf_cluster:
        num_timepoints = len(cohort_timepoints)
        fig_clust, axes_clust = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            # For cluster-level plot, we still show the same sampled source/gray points as background
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            cluster_arrows = res.get('cluster_arrows', [])
            for (cx, cy, cdx, cdy) in cluster_arrows:
                ax.arrow(cx, cy, cdx, cdy,
                         color='black', alpha=0.7, head_width=0.1, head_length=0.1,
                         length_includes_head=True, linewidth=1.0)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

# ---------------------------
# Main Execution Loop
# ---------------------------

for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Hyperactivated_CD8 - control
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - short_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - long_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
Missing files for C6. Skipping this timepoint.
Missing files for C9. Skipping this timepoint.
Missing files for C18. Skipping this timepoint.
Missing files for C36. Skipping this timepoint.
No valid timepoints with data.
Processing Proliferating_Effector - control
Processing Proliferating_Eff

  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [4]:
# remove previous, if this works

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
min_arrow_length = 0.3  # Adjust as needed
max_arrow_length = 2    # Adjust as needed

# Number of cells to sample for single-cell plotting (to reduce overcrowding)
max_plot_cells = None  # will determine at runtime

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return

    # We'll store coordinates and also find global axis limits
    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}  # store full data for each timepoint
    
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        # Store full data
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        # For axis limits
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Determine number of cells to plot for single-cell arrows (downsample for plotting only)
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return
    
    timepoint_results = {}
    
    # Compute OT from each timepoint to the next (except the last one)
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample for plotting single-cell arrows only
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                continue
            
            pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
            full_source_coords_pc = source_cells[pc_cols].values
            full_target_coords_pc = target_cells[pc_cols].values
            
            # Also store UMAP coordinates
            full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
            full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
            
            # Uniform distributions
            a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
            b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
            
            # Compute cost matrix
            cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
            
            try:
                transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
            except Exception as e:
                print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                continue
            
            full_target_indices = np.argmax(transport_plan, axis=1)
            displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
            
            # Compute norms
            norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
            min_norm_full = np.min(norms_full)
            max_norm_full = np.max(norms_full)
            
            if max_norm_full - min_norm_full > 0:
                arrow_lengths_full = ((norms_full - min_norm_full) / (max_norm_full - min_norm_full)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_full = np.full_like(norms_full, min_arrow_length)
            
            # For single-cell plotting
            sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
            sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
            sampled_norms = norms_full[sampled_source_indices]
            
            min_norm_sampled = np.min(sampled_norms)
            max_norm_sampled = np.max(sampled_norms)
            if max_norm_sampled - min_norm_sampled > 0:
                arrow_lengths_sampled = ((sampled_norms - min_norm_sampled) / (max_norm_sampled - min_norm_sampled)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            else:
                arrow_lengths_sampled = np.full_like(sampled_norms, min_arrow_length)
            
            timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = arrow_lengths_sampled
            
            clusters_of_interest = {1, 2, 3, 8, 10, 12, 14}

            # Compute cluster-level arrows using full data
            if 'seurat_clusters' in source_cells.columns:
                source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16: 14, 17: 14})
                source_clusters_full = source_cells['seurat_clusters'].values
                df_cluster_full = pd.DataFrame({
                    'cluster': source_clusters_full,
                    'sx': full_source_coords_umap[:,0],
                    'sy': full_source_coords_umap[:,1],
                    'dx': displacement_vectors_full[:,0],
                    'dy': displacement_vectors_full[:,1],
                    'norm': norms_full
                })
                
                gray_cells = full_data[source_tp]['gray']
                all_cells_combined = pd.concat([source_cells, gray_cells])
                all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                
                all_clusters = all_cells_combined['seurat_clusters'].values
                all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                
                df_all = pd.DataFrame({
                    'cluster': all_clusters,
                    'sx': all_coords_umap[:,0],
                    'sy': all_coords_umap[:,1]
                })
                
                # Compute centroids from all cells (source + gray)
                all_group = df_all.groupby('cluster')
                centroids = all_group[['sx','sy']].median()
                
                # Mean displacement from source cells only
                group = df_cluster_full.groupby('cluster')
                mean_disp = group[['dx','dy']].mean()
            
                cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                cn_min = cluster_norms.min()
                cn_max = cluster_norms.max()
                
                if cn_max - cn_min > 0:
                    cluster_arrow_lengths = ((cluster_norms - cn_min) / (cn_max - cn_min)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                else:
                    cluster_arrow_lengths = pd.Series(np.full_like(cluster_norms, min_arrow_length), index=cluster_norms.index)
                
                cluster_arrows = []
                common_clusters = centroids.index.intersection(mean_disp.index)
                for clust in common_clusters:
                    if clust in clusters_of_interest:
                        cx, cy = centroids.loc[clust, ['sx','sy']]
                        cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                        cnorm = np.sqrt(cdx**2 + cdy**2)
                        if cnorm > 0:
                            cdx = cdx / cnorm
                            cdy = cdy / cnorm
                        else:
                            cdx, cdy = 0, 0
                        
                        try:
                            length = cluster_arrow_lengths.loc[clust]
                        except AttributeError:
                            if isinstance(cluster_arrow_lengths, np.ndarray):
                                cluster_arrow_lengths = pd.Series(cluster_arrow_lengths, index=cluster_norms.index)
                            try:
                                length = cluster_arrow_lengths.loc[clust]
                            except Exception as e:
                                print(f"Failed to access cluster_arrow_lengths for cluster {clust}: {e}")
                                length = min_arrow_length
            
                        cdx *= length
                        cdy *= length
                        # Store cluster ID alongside arrow coordinates for line width calculation later
                        cluster_arrows.append((clust, cx, cy, cdx, cdy))
                
                timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows

                if len(cluster_arrows) > 0:
                    # Instead of global centroid of all cells, use the median of source cells only:
                    source_median_x = source_cells['UMAP_2'].median()
                    source_median_y = source_cells['UMAP_1'].median()

                    # Sum displacement vectors of all cluster arrows
                    total_dx = sum([arrow[3] for arrow in cluster_arrows])
                    total_dy = sum([arrow[4] for arrow in cluster_arrows])

                    # Store the aggregated arrow
                    timepoint_results[source_tp]['aggregated_arrow'] = (source_median_x, source_median_y, total_dx, total_dy)
                else:
                    # If no cluster arrows, no aggregated arrow
                    timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
                print(f"No 'seurat_clusters' column in {source_tp} source cells. Cannot compute cluster arrows.")

            
        else:
            # Last timepoint has no next timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            if 'seurat_clusters' in source_cells.columns:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
    
    # Plot single-cell arrows PDF
    output_file_original = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_original) as pdf_original:
        num_timepoints = len(cohort_timepoints)
        fig_orig, axes_orig = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(sx, sy, dx, dy,
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05,
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # Plot aggregated cluster arrows in a new PDF
    output_file_cluster = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf")
    color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"  # Update with actual path
    color_mapping_df = pd.read_csv(color_mapping_file)
    cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
    with PdfPages(output_file_cluster) as pdf_cluster:
        num_timepoints = len(cohort_timepoints)
        fig_clust, axes_clust = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            # Compute source cluster proportions for line widths
            source_cells = full_data[source_tp]['source']
            source_cluster_counts = source_cells['seurat_clusters'].value_counts()
            total_source_cells_current = len(source_cells)
            
            # Now each entry in cluster_arrows is (clust, cx, cy, cdx, cdy)
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                # Get the color for this cluster, default to black if not found
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion  # adjust scaling as needed

                # Draw boundary arrow
                boundary_color = 'black'
                ax.arrow(cx, cy, cdx, cdy,
                         color=boundary_color,
                         alpha=1,
                         head_width=0.2 + 0.5 * proportion,
                         head_length=0.1 + 0.1 * proportion,
                         length_includes_head=True,
                         linewidth=line_width + 2)  # Thicker for boundary

                # Draw main arrow on top
                ax.arrow(cx, cy, cdx, cdy,
                         color=arrow_color, alpha=1,
                         head_width=0.2+0.5*proportion, head_length=0.1+0.1*proportion,   # bigger arrowhead
                         length_includes_head=True, linewidth=line_width)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # Plot single aggregated arrow in a new PDF
    output_file_single_agg = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_single_agg) as pdf_agg:
        num_timepoints = len(cohort_timepoints)
        fig_agg, axes_agg = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                # Just draw one large arrow
                # Use a fixed linewidth and arrowhead for clarity
                ax.arrow(global_cx, global_cy, total_dx, total_dy,
                         color='black', alpha=0.9,
                         head_width=0.3, head_length=0.3,
                         length_includes_head=True, linewidth=2.0)

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)

# ---------------------------
# Main Execution Loop
# ---------------------------

for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Hyperactivated_CD8 - control
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - short_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - long_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
Missing files for C6. Skipping this timepoint.
Missing files for C9. Skipping this timepoint.
Missing files for C18. Skipping this timepoint.
Missing files for C36. Skipping this timepoint.
No valid timepoints with data.
Processing Proliferating_Effector - control
Processing Proliferating_Eff

  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [1]:
# increasing the arrow lengths to actual displacement

In [2]:
# remove previous, if this works

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

# Base directory where the CSV files are stored
base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your actual path

# Base directory to save the plots
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/new/temp"  # Replace with your desired output path

# Subpopulations (should match the names used in R)
subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Effector_Memory_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Memory_CD8",
    "Anergic_CD8",
    "Naive_CD8",
    "Hyperactivated_CD8",
    "Proliferating_Effector",
    "CD8"
]

# Cohorts
cohorts = ["control", "short_term", "long_term"]

# Ordered timepoints
all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

# Color mapping for cohorts
cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# Desired minimum and maximum arrow lengths for visualization
# min_arrow_length = 0.3  # Adjust as needed
# max_arrow_length = 2    # Adjust as needed

# Number of cells to sample for single-cell plotting (to reduce overcrowding)
max_plot_cells = None  # will determine at runtime

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return

    # We'll store coordinates and also find global axis limits
    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}  # store full data for each timepoint
    
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        # Store full data
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        # For axis limits
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    # Determine number of cells to plot for single-cell arrows (downsample for plotting only)
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return
    
    timepoint_results = {}
    
    # Compute OT from each timepoint to the next (except the last one)
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample for plotting single-cell arrows only
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                continue
            
            pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
            full_source_coords_pc = source_cells[pc_cols].values
            full_target_coords_pc = target_cells[pc_cols].values
            
            # Also store UMAP coordinates
            full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
            full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
            
            # Uniform distributions
            a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
            b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
            
            # Compute cost matrix
            cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
            
            try:
                transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
            except Exception as e:
                print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                continue
            
            full_target_indices = np.argmax(transport_plan, axis=1)
            displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
            
            # Compute norms
            norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
            arrow_lengths_full = norms_full
            # min_norm_full = np.min(norms_full)
            # max_norm_full = np.max(norms_full)
            
            # if max_norm_full - min_norm_full > 0:
            #     arrow_lengths_full = ((norms_full - min_norm_full) / (max_norm_full - min_norm_full)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            # else:
            #     arrow_lengths_full = np.full_like(norms_full, min_arrow_length)
            
            # For single-cell plotting
            sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
            sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
            sampled_norms = norms_full[sampled_source_indices]
            arrow_lengths_sampled = sampled_norms
            
            # min_norm_sampled = np.min(sampled_norms)
            # max_norm_sampled = np.max(sampled_norms)
            # if max_norm_sampled - min_norm_sampled > 0:
            #     arrow_lengths_sampled = ((sampled_norms - min_norm_sampled) / (max_norm_sampled - min_norm_sampled)) * (max_arrow_length - min_arrow_length) + min_arrow_length
            # else:
            #     arrow_lengths_sampled = np.full_like(sampled_norms, min_arrow_length)
            
            timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = arrow_lengths_sampled
            
            clusters_of_interest = {1, 2, 3, 8, 10, 12, 14}

            # Compute cluster-level arrows using full data
            if 'seurat_clusters' in source_cells.columns:
                source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16: 14, 17: 14})
                source_clusters_full = source_cells['seurat_clusters'].values
                df_cluster_full = pd.DataFrame({
                    'cluster': source_clusters_full,
                    'sx': full_source_coords_umap[:,0],
                    'sy': full_source_coords_umap[:,1],
                    'dx': displacement_vectors_full[:,0],
                    'dy': displacement_vectors_full[:,1],
                    'norm': norms_full
                })
                
                gray_cells = full_data[source_tp]['gray']
                all_cells_combined = pd.concat([source_cells, gray_cells])
                all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                
                all_clusters = all_cells_combined['seurat_clusters'].values
                all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                
                df_all = pd.DataFrame({
                    'cluster': all_clusters,
                    'sx': all_coords_umap[:,0],
                    'sy': all_coords_umap[:,1]
                })
                
                # Compute centroids from all cells (source + gray)
                all_group = df_all.groupby('cluster')
                centroids = all_group[['sx','sy']].median()
                
                # Mean displacement from source cells only
                group = df_cluster_full.groupby('cluster')
                mean_disp = group[['dx','dy']].mean()
            
                cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                cluster_arrow_lengths = cluster_norms
                # cn_min = cluster_norms.min()
                # cn_max = cluster_norms.max()
                
                # if cn_max - cn_min > 0:
                #     cluster_arrow_lengths = ((cluster_norms - cn_min) / (cn_max - cn_min)) * (max_arrow_length - min_arrow_length) + min_arrow_length
                # else:
                #     cluster_arrow_lengths = pd.Series(np.full_like(cluster_norms, min_arrow_length), index=cluster_norms.index)
                
                cluster_arrows = []
                common_clusters = centroids.index.intersection(mean_disp.index)
                for clust in common_clusters:
                    if clust in clusters_of_interest:
                        cx, cy = centroids.loc[clust, ['sx','sy']]
                        cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                        cnorm = np.sqrt(cdx**2 + cdy**2)
                        if cnorm > 0:
                            cdx = cdx / cnorm
                            cdy = cdy / cnorm
                        else:
                            cdx, cdy = 0, 0
                        
                        try:
                            length = cluster_arrow_lengths.loc[clust]
                        except AttributeError:
                            if isinstance(cluster_arrow_lengths, np.ndarray):
                                cluster_arrow_lengths = pd.Series(cluster_arrow_lengths, index=cluster_norms.index)
                            try:
                                length = cluster_arrow_lengths.loc[clust]
                            except Exception as e:
                                print(f"Failed to access cluster_arrow_lengths for cluster {clust}: {e}")
                                length = min_arrow_length
            
                        cdx *= length
                        cdy *= length
                        # Store cluster ID alongside arrow coordinates for line width calculation later
                        cluster_arrows.append((clust, cx, cy, cdx, cdy))
                
                timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows

                if len(cluster_arrows) > 0:
                    # Instead of global centroid of all cells, use the median of source cells only:
                    source_median_x = source_cells['UMAP_2'].median()
                    source_median_y = source_cells['UMAP_1'].median()

                    # Sum displacement vectors of all cluster arrows
                    total_dx = sum([arrow[3] for arrow in cluster_arrows])
                    total_dy = sum([arrow[4] for arrow in cluster_arrows])

                    # Store the aggregated arrow
                    timepoint_results[source_tp]['aggregated_arrow'] = (source_median_x, source_median_y, total_dx, total_dy)
                else:
                    # If no cluster arrows, no aggregated arrow
                    timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
                print(f"No 'seurat_clusters' column in {source_tp} source cells. Cannot compute cluster arrows.")

            
        else:
            # Last timepoint has no next timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            if 'seurat_clusters' in source_cells.columns:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
    
    # Plot single-cell arrows PDF
    output_file_original = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_original) as pdf_original:
        num_timepoints = len(cohort_timepoints)
        fig_orig, axes_orig = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(sx, sy, dx, dy,
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05,
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # Plot aggregated cluster arrows in a new PDF
    output_file_cluster = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf")
    color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"  # Update with actual path
    color_mapping_df = pd.read_csv(color_mapping_file)
    cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
    with PdfPages(output_file_cluster) as pdf_cluster:
        num_timepoints = len(cohort_timepoints)
        fig_clust, axes_clust = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            # Compute source cluster proportions for line widths
            source_cells = full_data[source_tp]['source']
            source_cluster_counts = source_cells['seurat_clusters'].value_counts()
            total_source_cells_current = len(source_cells)
            
            # Now each entry in cluster_arrows is (clust, cx, cy, cdx, cdy)
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                # Get the color for this cluster, default to black if not found
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion  # adjust scaling as needed

                # Draw boundary arrow
                boundary_color = 'black'
                ax.arrow(cx, cy, cdx, cdy,
                         color=boundary_color,
                         alpha=1,
                         head_width=0.2 + 0.5 * proportion,
                         head_length=0.1 + 0.1 * proportion,
                         length_includes_head=True,
                         linewidth=line_width + 2)  # Thicker for boundary

                # Draw main arrow on top
                ax.arrow(cx, cy, cdx, cdy,
                         color=arrow_color, alpha=1,
                         head_width=0.2+0.5*proportion, head_length=0.1+0.1*proportion,   # bigger arrowhead
                         length_includes_head=True, linewidth=line_width)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # Plot single aggregated arrow in a new PDF
    output_file_single_agg = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_single_agg) as pdf_agg:
        num_timepoints = len(cohort_timepoints)
        fig_agg, axes_agg = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(cohort_timepoints):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                # Just draw one large arrow
                # Use a fixed linewidth and arrowhead for clarity
                ax.arrow(global_cx, global_cy, total_dx, total_dy,
                         color='black', alpha=0.9,
                         head_width=0.3, head_length=0.3,
                         length_includes_head=True, linewidth=2.0)

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)

# ---------------------------
# Main Execution Loop
# ---------------------------

for subpop_name in subpopulations:
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        optimal_transport_visualization(subpop_name, cohort_name)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD8 - control
Processing Memory_CD8 - short_term
Processing Memory_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Anergic_CD8 - control
Processing Anergic_CD8 - short_term
Processing Anergic_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Naive_CD8 - control
Processing Naive_CD8 - short_term
Processing Naive_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Hyperactivated_CD8 - control
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - short_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
No valid timepoints with data.
Processing Hyperactivated_CD8 - long_term
Missing files for Pre. Skipping this timepoint.
Missing files for C1. Skipping this timepoint.
Missing files for C2. Skipping this timepoint.
Missing files for C4. Skipping this timepoint.
Missing files for C6. Skipping this timepoint.
Missing files for C9. Skipping this timepoint.
Missing files for C18. Skipping this timepoint.
Missing files for C36. Skipping this timepoint.
No valid timepoints with data.
Processing Proliferating_Effector - control
Processing Proliferating_Eff

  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [3]:
# including new visualization (target distribution)

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature"

subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Memory_Precursor_Effector_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Central_Memory_CD8",
    "Stem_Like_CD8",
    "Effector_Memory_CD8",
    "Proliferating_Effector",
    "CD8"
]

cohorts = ["control", "short_term", "long_term"]

all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

clusters_of_interest = [1, 2, 3, 8, 10, 12, 14]

color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

def optimal_transport_visualization(subpop_name, cohort_name):
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return None
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return None

    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}
    
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return None
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return None
    
    timepoint_results = {}
    distributions_coi = {tp: {cohort_name: {c: {} for c in clusters_of_interest}} for tp in cohort_timepoints}
    
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
                full_source_coords_pc = source_cells[pc_cols].values
                full_target_coords_pc = target_cells[pc_cols].values
                
                full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
                full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
                
                a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
                b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
                
                cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
                
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                    timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                    timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    # no distributions_coi update in case of failure
                    continue
                
                full_target_indices = np.argmax(transport_plan, axis=1)
                displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
                
                norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
                arrow_lengths_full = norms_full
                
                sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
                sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
                sampled_norms = norms_full[sampled_source_indices]
                
                arrow_lengths_sampled = sampled_norms
                
                timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = arrow_lengths_sampled
                
                if 'seurat_clusters' in source_cells.columns and 'seurat_clusters' in target_cells.columns:
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16:14,17:14})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({16:14,17:14})
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({9:6,18:6})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({9:6,18:6})
                    
                    # Cluster-level arrows
                    source_clusters_full = source_cells['seurat_clusters'].values
                    df_cluster_full = pd.DataFrame({
                        'cluster': source_clusters_full,
                        'sx': full_source_coords_umap[:,0],
                        'sy': full_source_coords_umap[:,1],
                        'dx': displacement_vectors_full[:,0],
                        'dy': displacement_vectors_full[:,1],
                        'norm': norms_full
                    })
                    
                    gray_cells = full_data[source_tp]['gray']
                    all_cells_combined = pd.concat([source_cells, gray_cells])
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({9:6,18:6})
                    
                    all_clusters = all_cells_combined['seurat_clusters'].values
                    all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                    
                    df_all = pd.DataFrame({
                        'cluster': all_clusters,
                        'sx': all_coords_umap[:,0],
                        'sy': all_coords_umap[:,1]
                    })
                    
                    all_group = df_all.groupby('cluster')
                    centroids = all_group[['sx','sy']].median()
                    
                    group = df_cluster_full.groupby('cluster')
                    mean_disp = group[['dx','dy']].mean()
                
                    cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                    cluster_arrow_lengths = cluster_norms
                    
                    cluster_arrows = []
                    common_clusters = centroids.index.intersection(mean_disp.index)
                    for clust in common_clusters:
                        if clust in clusters_of_interest:
                            cx, cy = centroids.loc[clust, ['sx','sy']]
                            cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                            cnorm = np.sqrt(cdx**2 + cdy**2)
                            if cnorm > 0:
                                cdx /= cnorm
                                cdy /= cnorm
                            
                            length = cluster_arrow_lengths.loc[clust]
                            cdx *= length
                            cdy *= length
                            cluster_arrows.append((clust, cx, cy, cdx, cdy))
                    
                    timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
    
                    if len(cluster_arrows) > 0:
                        source_median_x = source_cells['UMAP_2'].median()
                        source_median_y = source_cells['UMAP_1'].median()
                        total_dx = sum([arrow[3] for arrow in cluster_arrows])
                        total_dy = sum([arrow[4] for arrow in cluster_arrows])
                        timepoint_results[source_tp]['aggregated_arrow'] = (source_median_x, source_median_y, total_dx, total_dy)
                    else:
                        timepoint_results[source_tp]['aggregated_arrow'] = None
    
                    # Compute distributions for each cluster_of_interest
                    for coi in clusters_of_interest:
                        coi_mask = (source_cells['seurat_clusters'] == coi)
                        if np.any(coi_mask):
                            selected_target_indices = full_target_indices[coi_mask]
                            selected_target_clusters = target_cells['seurat_clusters'].iloc[selected_target_indices].values
                            unique_tclusters, counts = np.unique(selected_target_clusters, return_counts=True)
                            total_count = counts.sum()
                            if total_count > 0:
                                fraction_dict = {int(tc): (ct / total_count) for tc, ct in zip(unique_tclusters, counts)}
                            else:
                                fraction_dict = {}
                            distributions_coi[source_tp][cohort_name][coi] = fraction_dict
                        else:
                            distributions_coi[source_tp][cohort_name][coi] = {}
                else:
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    for coi in clusters_of_interest:
                        distributions_coi[source_tp][cohort_name][coi] = {}
        else:
            # Last timepoint has no next timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            for coi in clusters_of_interest:
                distributions_coi[source_tp][cohort_name][coi] = {}

    # ---------------------------
    # Generate Original PDFs
    # ---------------------------
    # Single-cell arrows PDF
    output_file_original = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_original) as pdf_original:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_orig, axes_orig = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(sx, sy, dx, dy,
                             color='black', alpha=0.7, head_width=0.05, head_length=0.05,
                             length_includes_head=True, linewidth=0.5)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # Cluster-level arrows PDF
    output_file_cluster = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_cluster) as pdf_cluster:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_clust, axes_clust = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            source_cells = full_data[source_tp]['source']
            if 'seurat_clusters' in source_cells.columns:
                source_cluster_counts = source_cells['seurat_clusters'].value_counts()
                total_source_cells_current = len(source_cells)
            else:
                source_cluster_counts = pd.Series()
                total_source_cells_current = 1
            
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion

                # Draw boundary arrow
                ax.arrow(cx, cy, cdx, cdy,
                         color='black',
                         alpha=1,
                         head_width=0.2 + 0.5 * proportion,
                         head_length=0.1 + 0.1 * proportion,
                         length_includes_head=True,
                         linewidth=line_width + 2)
                
                # Draw main arrow on top
                ax.arrow(cx, cy, cdx, cdy,
                         color=arrow_color, alpha=1,
                         head_width=0.2+0.5*proportion, head_length=0.1+0.1*proportion,
                         length_includes_head=True, linewidth=line_width)
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # Single aggregated arrow PDF
    output_file_single_agg = os.path.join(output_folder, f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf")
    with PdfPages(output_file_single_agg) as pdf_agg:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_agg, axes_agg = plt.subplots(1, num_timepoints, figsize=(4 * num_timepoints, 4), sharex=True, sharey=True)
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(plot_tps):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                ax.arrow(global_cx, global_cy, total_dx, total_dy,
                         color='black', alpha=0.9,
                         head_width=0.3, head_length=0.3,
                         length_includes_head=True, linewidth=2.0)

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)
    
    # Return the distributions for stacked bar chart plotting done outside
    return distributions_coi

# ---------------------------
# Main Execution Loop
# ---------------------------

for subpop_name in subpopulations:
    distributions_coi_all = {tp: {c: {coi: {} for coi in clusters_of_interest} for c in cohorts} for tp in all_timepoints}
    
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        
        dist_coi = optimal_transport_visualization(subpop_name, cohort_name)
        if dist_coi is not None:
            for tp in dist_coi:
                for c in dist_coi[tp]:
                    for coi in dist_coi[tp][c]:
                        distributions_coi_all[tp][c][coi] = dist_coi[tp][c][coi]

    # Determine all target clusters that appear
    all_target_clusters = set()
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                all_target_clusters.update(distributions_coi_all[tp][c][coi].keys())
    all_target_clusters = sorted(all_target_clusters)

    def get_celltype_name(clust_id):
        return cluster_celltype_map.get(clust_id, f"Cluster {clust_id}")

    def get_cluster_color(clust_id):
        return cluster_colors_map.get(clust_id, 'gray')

    # Generate the stacked bar chart PDF
    output_file_stacked = os.path.join(base_output_dir, f"{subpop_name}_target_cluster_distribution_coi.pdf")
    with PdfPages(output_file_stacked) as pdf_dist:
        fig, axes = plt.subplots(len(clusters_of_interest), len(all_timepoints), 
                                 figsize=(4*len(all_timepoints), 3*len(clusters_of_interest)), 
                                 sharex=False, sharey=True)
        
        if len(clusters_of_interest) == 1 and len(all_timepoints) == 1:
            axes = np.array([[axes]])
        elif len(clusters_of_interest) == 1:
            axes = axes[np.newaxis, :]
        elif len(all_timepoints) == 1:
            axes = axes[:, np.newaxis]

        # Plot bars
        for row_i, coi in enumerate(clusters_of_interest):
            for col_i, tp in enumerate(all_timepoints):
                ax = axes[row_i, col_i]
                bar_positions = np.arange(len(cohorts))
                cohort_distributions = [distributions_coi_all[tp][c][coi] for c in cohorts]
                bottoms = np.zeros(len(cohorts))

                stack_data = []
                for tc in all_target_clusters:
                    heights = [d.get(tc,0) for d in cohort_distributions]
                    stack_data.append((tc, heights))
        
                # Sort by sum of fractions
                stack_data.sort(key=lambda x: sum(x[1]), reverse=True)
                
                for (tc, h) in stack_data:
                    color = get_cluster_color(tc)
                    ax.bar(bar_positions, h, bottom=bottoms, color=color, edgecolor='black')
                    # Add labels
                    for idx, val in enumerate(h):
                        if val > 0.05:
                            mid_y = bottoms[idx] + val/2
                            ax.text(bar_positions[idx], mid_y, get_celltype_name(tc),
                                    ha='center', va='center', fontsize=6, color='white')
                    bottoms += h

                if row_i == 0:
                    ax.set_title(f"{tp}", fontsize=10)

                ax.set_xticks(bar_positions)
                ax.set_xticklabels(cohorts, rotation=45, ha='right', fontsize=8)
                ax.set_ylim(0,1)

        fig.suptitle(f"Distribution of Target Clusters per COI - {subpop_name}", fontsize=16)

        # Apply tight layout first
        plt.tight_layout()

        # Add space for legend and row labels
        fig.subplots_adjust(left=0.15, top=0.88)  
        
        # Label rows with celltype of the COI
        n_rows = len(clusters_of_interest)
        for row_i, coi in enumerate(clusters_of_interest):
            y_pos = 0.88 - (row_i + 0.5)*(0.88-0.1)/n_rows
            fig.text(0.05, y_pos, get_celltype_name(coi), va='center', ha='right', fontsize=10, color='black')

        # Create legend patches with celltype names for target clusters
        legend_patches = []
        for tc in all_target_clusters:
            legend_patches.append(plt.Rectangle((0,0),1,1,
                                                facecolor=get_cluster_color(tc),
                                                edgecolor='black',
                                                label=get_celltype_name(tc)))
        
        # Place legend at top center
        fig.legend(handles=legend_patches, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
                   ncol=len(all_target_clusters), fontsize=8, title="Target Celltypes")

        pdf_dist.savefig(fig)
        plt.close(fig)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Memory_Precursor_Effector_CD8 - control
Processing Memory_Precursor_Effector_CD8 - short_term
Processing Memory_Precursor_Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Central_Memory_CD8 - control
Processing Central_Memory_CD8 - short_term
Processing Central_Memory_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Stem_Like_CD8 - control
Processing Stem_Like_CD8 - short_term
Processing Stem_Like_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing CD8 - control
Processing CD8 - short_term
Processing CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [1]:
# December 23

In [4]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature"

subpopulations = [
    "Activated_CD4",
    "Effector_CD8",
    "Memory_Precursor_Effector_CD8",
    "Exhausted_T",
    "Gamma_Delta_T",
    "Active_CD4",
    "Naive_CD4",
    "Memory_CD4",
    "Central_Memory_CD8",
    "Stem_Like_CD8",
    "Effector_Memory_CD8",
    "Proliferating_Effector",
    "CD8"
]

cohorts = ["control", "short_term", "long_term"]

all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# --------------------------------------------------
# Clusters we are interested in as the "source" (COI)
# --------------------------------------------------
clusters_of_interest = [1, 2, 3, 8, 10, 12, 14]

# Provide the path to your CSV that maps each cluster to a color and celltype
color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

def optimal_transport_visualization(subpop_name, cohort_name):
    """
    For a single subpopulation and cohort, run the full OT pipeline:
    1) Load data at each timepoint.
    2) Perform OT to map from source to next timepoint.
    3) Collect single-cell and cluster-level arrows, plus aggregated arrows.
    4) Generate PDFs of the movement plots (single-cell, cluster-level, aggregated).
    5) Return distributions_coi for stacked-bar plotting.
    """
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return None
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return None

    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}
    
    # Gather data across all valid timepoints
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return None
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return None
    
    # Prepare data structures for results
    timepoint_results = {}
    distributions_coi = {tp: {cohort_name: {c: {} for c in clusters_of_interest}} for tp in cohort_timepoints}
    
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample so each timepoint has the same # of source & gray cells for plotting
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        # If not the last timepoint, compute OT from source_tp -> target_tp
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # If there's no data in the next timepoint, skip
                timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                # Identify columns for PCs
                pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
                full_source_coords_pc = source_cells[pc_cols].values
                full_target_coords_pc = target_cells[pc_cols].values
                
                # UMAP coords
                full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
                full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
                
                # Uniform distribution over source / target
                a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
                b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
                
                cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
                
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                    timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                    timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    continue
                
                full_target_indices = np.argmax(transport_plan, axis=1)
                displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
                norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
                
                # Displacements for the downsampled subset
                sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
                sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
                sampled_norms = norms_full[sampled_source_indices]
                
                timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = sampled_norms
                
                if 'seurat_clusters' in source_cells.columns and 'seurat_clusters' in target_cells.columns:
                    # Lump some clusters if needed
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16:14,17:14})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({16:14,17:14})
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({9:6,18:6})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({9:6,18:6})
                    
                    # For cluster-level arrows, compute average displacement per cluster
                    source_clusters_full = source_cells['seurat_clusters'].values
                    df_cluster_full = pd.DataFrame({
                        'cluster': source_clusters_full,
                        'sx': full_source_coords_umap[:,0],
                        'sy': full_source_coords_umap[:,1],
                        'dx': displacement_vectors_full[:,0],
                        'dy': displacement_vectors_full[:,1],
                        'norm': norms_full
                    })
                    
                    gray_cells_tp = full_data[source_tp]['gray']
                    all_cells_combined = pd.concat([source_cells, gray_cells_tp])
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({9:6,18:6})
                    
                    all_clusters = all_cells_combined['seurat_clusters'].values
                    all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                    
                    df_all = pd.DataFrame({
                        'cluster': all_clusters,
                        'sx': all_coords_umap[:,0],
                        'sy': all_coords_umap[:,1]
                    })
                    
                    centroids = df_all.groupby('cluster')[['sx','sy']].median()
                    mean_disp = df_cluster_full.groupby('cluster')[['dx','dy']].mean()
                    
                    cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                    
                    # Build arrow info for each cluster of interest
                    cluster_arrows = []
                    common_clusters = centroids.index.intersection(mean_disp.index)
                    for clust in common_clusters:
                        if clust in clusters_of_interest:
                            cx, cy = centroids.loc[clust, ['sx','sy']]
                            cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                            cnorm = np.sqrt(cdx**2 + cdy**2)
                            if cnorm > 0:
                                cdx /= cnorm
                                cdy /= cnorm
                            length = cluster_norms.loc[clust]
                            cdx *= length
                            cdy *= length
                            cluster_arrows.append((clust, cx, cy, cdx, cdy))
                    
                    timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
    
                    if len(cluster_arrows) > 0:
                        # "Aggregated" arrow by summation
                        source_median_x = source_cells['UMAP_2'].median()
                        source_median_y = source_cells['UMAP_1'].median()
                        total_dx = sum([arrow[3] for arrow in cluster_arrows])
                        total_dy = sum([arrow[4] for arrow in cluster_arrows])
                        timepoint_results[source_tp]['aggregated_arrow'] = (
                            source_median_x, 
                            source_median_y, 
                            total_dx, 
                            total_dy
                        )
                    else:
                        timepoint_results[source_tp]['aggregated_arrow'] = None
    
                    # Compute target cluster distributions for each COI (source cluster)
                    for coi in clusters_of_interest:
                        coi_mask = (source_cells['seurat_clusters'] == coi)
                        if np.any(coi_mask):
                            # Which target clusters do these COI cells map to?
                            selected_target_indices = full_target_indices[coi_mask]
                            selected_target_clusters = target_cells['seurat_clusters'].iloc[selected_target_indices].values
                            unique_tclusters, counts = np.unique(selected_target_clusters, return_counts=True)
                            total_count = counts.sum()
                            if total_count > 0:
                                fraction_dict = {int(tc): (ct / total_count) for tc, ct in zip(unique_tclusters, counts)}
                            else:
                                fraction_dict = {}
                            distributions_coi[source_tp][cohort_name][coi] = fraction_dict
                        else:
                            distributions_coi[source_tp][cohort_name][coi] = {}
                else:
                    # If cluster info is missing
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    for coi in clusters_of_interest:
                        distributions_coi[source_tp][cohort_name][coi] = {}
        else:
            # Last timepoint has no "next" timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            for coi in clusters_of_interest:
                distributions_coi[source_tp][cohort_name][coi] = {}

    # ---------------------------
    # Generate PDFs of movement
    # ---------------------------

    # 1) Single-cell arrows
    output_file_original = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_original) as pdf_original:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_orig, axes_orig = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(
                        sx, sy, 
                        dx, dy,
                        color='black', 
                        alpha=0.7, 
                        head_width=0.05, 
                        head_length=0.05,
                        length_includes_head=True, 
                        linewidth=0.5
                    )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # 2) Cluster-level arrows
    output_file_cluster = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_cluster) as pdf_cluster:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_clust, axes_clust = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            source_cells = full_data[source_tp]['source']
            if 'seurat_clusters' in source_cells.columns:
                source_cluster_counts = source_cells['seurat_clusters'].value_counts()
                total_source_cells_current = len(source_cells)
            else:
                source_cluster_counts = pd.Series()
                total_source_cells_current = 1
            
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion

                # Outline in black for clarity
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color='black',
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width + 2
                )
                
                # Main arrow in cluster color
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color=arrow_color,
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width
                )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # 3) Single aggregated arrow
    output_file_single_agg = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_single_agg) as pdf_agg:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_agg, axes_agg = plt.subplots(
            1,
            num_timepoints,
            figsize=(4 * num_timepoints, 4),
            sharex=True,
            sharey=True
        )
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(plot_tps):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                ax.arrow(
                    global_cx, global_cy,
                    total_dx, total_dy,
                    color='black', 
                    alpha=0.9,
                    head_width=0.3, 
                    head_length=0.3,
                    length_includes_head=True, 
                    linewidth=2.0
                )

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)
    
    # Return the distributions for stacked-bar plotting
    return distributions_coi


# ---------------------------
# Main Execution Loop
# ---------------------------
for subpop_name in subpopulations:
    # Prepare a nested dict so we can collect distributions
    distributions_coi_all = {
        tp: {
            c: {coi: {} for coi in clusters_of_interest} 
            for c in cohorts
        }
        for tp in all_timepoints
    }
    
    # Run for each cohort
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        
        dist_coi = optimal_transport_visualization(subpop_name, cohort_name)
        if dist_coi is not None:
            for tp in dist_coi:
                for c in dist_coi[tp]:
                    for coi in dist_coi[tp][c]:
                        distributions_coi_all[tp][c][coi] = dist_coi[tp][c][coi]

    # --------------------------------------------------
    # 1) Group non-COI target clusters into "Others"
    # --------------------------------------------------
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                fraction_dict = distributions_coi_all[tp][c][coi]
                if fraction_dict:
                    new_dict = {}
                    others_sum = 0.0
                    for tclust, frac in fraction_dict.items():
                        # If the target cluster is not in our "source" list, 
                        # put it in the "Others" bin
                        if tclust not in clusters_of_interest:
                            others_sum += frac
                        else:
                            new_dict[tclust] = frac
                    if others_sum > 0:
                        new_dict['Others'] = others_sum
                    distributions_coi_all[tp][c][coi] = new_dict

    # Gather all target clusters (including "Others")
    all_target_clusters = []
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                all_target_clusters.extend(distributions_coi_all[tp][c][coi].keys())
    all_target_clusters = set(all_target_clusters)

    # Handle "Others" separately for sorting
    others_present = ("Others" in all_target_clusters)
    if others_present:
        all_target_clusters.remove("Others")

    # Now only numeric cluster IDs remain
    # Sort numeric clusters in ascending order
    all_target_clusters = sorted(all_target_clusters)

    # Put "Others" at the end if present
    if others_present:
        all_target_clusters.append("Others")

    def get_celltype_name(clust_id):
        if clust_id == 'Others':
            return "Others"
        return cluster_celltype_map.get(clust_id, f"Cluster {clust_id}")

    def get_cluster_color(clust_id):
        if clust_id == 'Others':
            return 'gray'
        return cluster_colors_map.get(clust_id, 'gray')

    # ---------------------------
    # Generate the stacked bar chart PDF
    # ---------------------------
    output_file_stacked = os.path.join(base_output_dir, f"{subpop_name}_target_cluster_distribution_coi_without_stack_labels.pdf")
    with PdfPages(output_file_stacked) as pdf_dist:
        fig, axes = plt.subplots(
            len(clusters_of_interest), 
            len(all_timepoints), 
            figsize=(4 * len(all_timepoints), 3 * len(clusters_of_interest)), 
            sharex=False, 
            sharey=True
        )
        
        # Handle shape for single row/col
        if len(clusters_of_interest) == 1 and len(all_timepoints) == 1:
            axes = np.array([[axes]])
        elif len(clusters_of_interest) == 1:
            axes = axes[np.newaxis, :]
        elif len(all_timepoints) == 1:
            axes = axes[:, np.newaxis]

        # Prepare legend patches
        legend_patches = []
        for tc in all_target_clusters:
            legend_patches.append(
                plt.Rectangle((0,0),1,1,
                              facecolor=get_cluster_color(tc),
                              edgecolor='black',
                              label=get_celltype_name(tc))
            )
        
        # Plot bars
        for row_i, coi in enumerate(clusters_of_interest):
            for col_i, tp in enumerate(all_timepoints):
                ax = axes[row_i, col_i]
                bar_positions = np.arange(len(cohorts))
                bottoms = np.zeros(len(cohorts))

                # Retrieve distribution data for each cohort at (tp, coi)
                cohort_distributions = [distributions_coi_all[tp][c][coi] for c in cohorts]

                # Build up data for stacked bars: (target_cluster, [fractions for each cohort])
                stack_data = []
                for tc in all_target_clusters:
                    h = [dist.get(tc, 0.0) for dist in cohort_distributions]
                    stack_data.append((tc, h))

                # Sort so biggest fraction (summed across cohorts) is at the bottom
                stack_data.sort(key=lambda x: sum(x[1]), reverse=True)

                # Plot
                for (tc, heights) in stack_data:
                    color = get_cluster_color(tc)
                    ax.bar(
                        bar_positions, 
                        heights, 
                        bottom=bottoms, 
                        color=color, 
                        edgecolor='black'
                    )
                    # Label each segment if fraction is big enough
                    # for idx, val in enumerate(heights):
                    #     if val > 0.05:
                    #         mid_y = bottoms[idx] + val / 2
                    #         ax.text(
                    #             bar_positions[idx], 
                    #             mid_y, 
                    #             get_celltype_name(tc),
                    #             ha='center', 
                    #             va='center',
                    #             fontsize=6, 
                    #             color='white'
                    #         )
                    bottoms += heights

                if row_i == 0:
                    ax.set_title(f"{tp}", fontsize=10)

                ax.set_xticks(bar_positions)
                ax.set_xticklabels(cohorts, rotation=45, ha='right', fontsize=8)
                ax.set_ylim(0, 1)

        fig.suptitle(f"Distribution of Target Clusters per COI - {subpop_name}", fontsize=16)
        plt.tight_layout()
        
        # Adjust to make room for row labels & legend
        fig.subplots_adjust(left=0.15, top=0.86, right=0.82)  # extra space on the right for legend

        # Label rows with the celltype of the COI
        n_rows = len(clusters_of_interest)
        for row_i, coi in enumerate(clusters_of_interest):
            y_pos = 0.86 - (row_i + 0.5)*(0.86-0.1)/n_rows
            fig.text(0.05, y_pos, get_celltype_name(coi), va='center', ha='right', fontsize=10, color='black')

        # Create a single-column legend on the right
        fig.legend(
            handles=legend_patches,
            loc='upper left',
            bbox_to_anchor=(0.84, 0.95),
            ncol=1,
            fontsize=8,
            title="Target Celltypes"
        )

        pdf_dist.savefig(fig)
        plt.close(fig)

print("Visualization generation completed.")


Processing Activated_CD4 - control
Processing Activated_CD4 - short_term
Processing Activated_CD4 - long_term
Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Memory_Precursor_Effector_CD8 - control
Processing Memory_Precursor_Effector_CD8 - short_term
Processing Memory_Precursor_Effector_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Gamma_Delta_T - control
Processing Gamma_Delta_T - short_term
Processing Gamma_Delta_T - long_term
Processing Active_CD4 - control
Processing Active_CD4 - short_term
Processing Active_CD4 - long_term
Processing Naive_CD4 - control
Processing Naive_CD4 - short_term
Processing Naive_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Memory_CD4 - control
Processing Memory_CD4 - short_term
Processing Memory_CD4 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Central_Memory_CD8 - control
Processing Central_Memory_CD8 - short_term
Processing Central_Memory_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Stem_Like_CD8 - control
Processing Stem_Like_CD8 - short_term
Processing Stem_Like_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing CD8 - control
Processing CD8 - short_term
Processing CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [None]:
# optimal transport only for clusters of interest

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"

subpopulations = [
    "Effector_CD8",
    "Memory_Precursor_Effector_CD8",
    "Exhausted_T",
    "Central_Memory_CD8",
    "Stem_Like_CD8",
    "Effector_Memory_CD8",
    "Proliferating_Effector",
    "All_CD8"
]

cohorts = ["control", "short_term", "long_term"]

all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# --------------------------------------------------
# Clusters we are interested in as the "source" (COI)
# --------------------------------------------------
clusters_of_interest = [1, 2, 3, 8, 10, 12, 14]

# Provide the path to your CSV that maps each cluster to a color and celltype
color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

def optimal_transport_visualization(subpop_name, cohort_name):
    """
    For a single subpopulation and cohort, run the full OT pipeline:
    1) Load data at each timepoint.
    2) Perform OT to map from source to next timepoint.
    3) Collect single-cell and cluster-level arrows, plus aggregated arrows.
    4) Generate PDFs of the movement plots (single-cell, cluster-level, aggregated).
    5) Return distributions_coi for stacked-bar plotting.
    """
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return None
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return None

    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}
    
    # Gather data across all valid timepoints
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return None
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return None
    
    # Prepare data structures for results
    timepoint_results = {}
    distributions_coi = {tp: {cohort_name: {c: {} for c in clusters_of_interest}} for tp in cohort_timepoints}
    
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample so each timepoint has the same # of source & gray cells for plotting
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        # If not the last timepoint, compute OT from source_tp -> target_tp
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # If there's no data in the next timepoint, skip
                timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                # Identify columns for PCs
                pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
                full_source_coords_pc = source_cells[pc_cols].values
                full_target_coords_pc = target_cells[pc_cols].values
                
                # UMAP coords
                full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
                full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
                
                # Uniform distribution over source / target
                a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
                b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
                
                cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
                
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                    timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                    timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    continue
                
                full_target_indices = np.argmax(transport_plan, axis=1)
                displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
                norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
                
                # Displacements for the downsampled subset
                sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
                sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
                sampled_norms = norms_full[sampled_source_indices]
                
                timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = sampled_norms
                
                if 'seurat_clusters' in source_cells.columns and 'seurat_clusters' in target_cells.columns:
                    # Lump some clusters if needed
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16:14,17:14})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({16:14,17:14})
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({9:6,18:6})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({9:6,18:6})
                    
                    # For cluster-level arrows, compute average displacement per cluster
                    source_clusters_full = source_cells['seurat_clusters'].values
                    df_cluster_full = pd.DataFrame({
                        'cluster': source_clusters_full,
                        'sx': full_source_coords_umap[:,0],
                        'sy': full_source_coords_umap[:,1],
                        'dx': displacement_vectors_full[:,0],
                        'dy': displacement_vectors_full[:,1],
                        'norm': norms_full
                    })
                    
                    gray_cells_tp = full_data[source_tp]['gray']
                    all_cells_combined = pd.concat([source_cells, gray_cells_tp])
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({9:6,18:6})
                    
                    all_clusters = all_cells_combined['seurat_clusters'].values
                    all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                    
                    df_all = pd.DataFrame({
                        'cluster': all_clusters,
                        'sx': all_coords_umap[:,0],
                        'sy': all_coords_umap[:,1]
                    })
                    
                    centroids = df_all.groupby('cluster')[['sx','sy']].median()
                    mean_disp = df_cluster_full.groupby('cluster')[['dx','dy']].mean()
                    
                    cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                    
                    # Build arrow info for each cluster of interest
                    cluster_arrows = []
                    common_clusters = centroids.index.intersection(mean_disp.index)
                    for clust in common_clusters:
                        if clust in clusters_of_interest:
                            cx, cy = centroids.loc[clust, ['sx','sy']]
                            cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                            cnorm = np.sqrt(cdx**2 + cdy**2)
                            if cnorm > 0:
                                cdx /= cnorm
                                cdy /= cnorm
                            length = cluster_norms.loc[clust]
                            cdx *= length
                            cdy *= length
                            cluster_arrows.append((clust, cx, cy, cdx, cdy))
                    
                    timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
    
                    if len(cluster_arrows) > 0:
                        # "Aggregated" arrow by summation
                        source_median_x = source_cells['UMAP_2'].median()
                        source_median_y = source_cells['UMAP_1'].median()
                        total_dx = sum([arrow[3] for arrow in cluster_arrows])
                        total_dy = sum([arrow[4] for arrow in cluster_arrows])
                        timepoint_results[source_tp]['aggregated_arrow'] = (
                            source_median_x, 
                            source_median_y, 
                            total_dx, 
                            total_dy
                        )
                    else:
                        timepoint_results[source_tp]['aggregated_arrow'] = None
    
                    # Compute target cluster distributions for each COI (source cluster)
                    for coi in clusters_of_interest:
                        coi_mask = (source_cells['seurat_clusters'] == coi)
                        if np.any(coi_mask):
                            # Which target clusters do these COI cells map to?
                            selected_target_indices = full_target_indices[coi_mask]
                            selected_target_clusters = target_cells['seurat_clusters'].iloc[selected_target_indices].values
                            unique_tclusters, counts = np.unique(selected_target_clusters, return_counts=True)
                            total_count = counts.sum()
                            if total_count > 0:
                                fraction_dict = {int(tc): (ct / total_count) for tc, ct in zip(unique_tclusters, counts)}
                            else:
                                fraction_dict = {}
                            distributions_coi[source_tp][cohort_name][coi] = fraction_dict
                        else:
                            distributions_coi[source_tp][cohort_name][coi] = {}
                else:
                    # If cluster info is missing
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    for coi in clusters_of_interest:
                        distributions_coi[source_tp][cohort_name][coi] = {}
        else:
            # Last timepoint has no "next" timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            for coi in clusters_of_interest:
                distributions_coi[source_tp][cohort_name][coi] = {}

    # ---------------------------
    # Generate PDFs of movement
    # ---------------------------

    # 1) Single-cell arrows
    output_file_original = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_original) as pdf_original:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_orig, axes_orig = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            if len(single_cell_displacements) > 0:
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                for j in range(len(source_coords_sampled)):
                    sx, sy = source_coords_sampled[j]
                    dx, dy = scaled_vectors[j]
                    ax.arrow(
                        sx, sy, 
                        dx, dy,
                        color='black', 
                        alpha=0.7, 
                        head_width=0.05, 
                        head_length=0.05,
                        length_includes_head=True, 
                        linewidth=0.5
                    )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # 2) Cluster-level arrows
    output_file_cluster = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_cluster) as pdf_cluster:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_clust, axes_clust = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            source_cells = full_data[source_tp]['source']
            if 'seurat_clusters' in source_cells.columns:
                source_cluster_counts = source_cells['seurat_clusters'].value_counts()
                total_source_cells_current = len(source_cells)
            else:
                source_cluster_counts = pd.Series()
                total_source_cells_current = 1
            
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion

                # Outline in black for clarity
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color='black',
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width + 2
                )
                
                # Main arrow in cluster color
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    # color=arrow_color,
                    color='black',
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width
                )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # 3) Single aggregated arrow
    output_file_single_agg = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_single_agg) as pdf_agg:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_agg, axes_agg = plt.subplots(
            1,
            num_timepoints,
            figsize=(4 * num_timepoints, 4),
            sharex=True,
            sharey=True
        )
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(plot_tps):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                ax.arrow(
                    global_cx, global_cy,
                    total_dx, total_dy,
                    color='black', 
                    alpha=0.9,
                    head_width=0.3, 
                    head_length=0.3,
                    length_includes_head=True, 
                    linewidth=2.0
                )

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)
    
    # Return the distributions for stacked-bar plotting
    return distributions_coi


# ---------------------------
# Main Execution Loop
# ---------------------------
for subpop_name in subpopulations:
    # Prepare a nested dict so we can collect distributions
    distributions_coi_all = {
        tp: {
            c: {coi: {} for coi in clusters_of_interest} 
            for c in cohorts
        }
        for tp in all_timepoints
    }
    
    # Run for each cohort
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        
        dist_coi = optimal_transport_visualization(subpop_name, cohort_name)
        if dist_coi is not None:
            for tp in dist_coi:
                for c in dist_coi[tp]:
                    for coi in dist_coi[tp][c]:
                        distributions_coi_all[tp][c][coi] = dist_coi[tp][c][coi]

    # --------------------------------------------------
    # 1) Group non-COI target clusters into "Others"
    # --------------------------------------------------
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                fraction_dict = distributions_coi_all[tp][c][coi]
                if fraction_dict:
                    new_dict = {}
                    others_sum = 0.0
                    for tclust, frac in fraction_dict.items():
                        # If the target cluster is not in our "source" list, 
                        # put it in the "Others" bin
                        if tclust not in clusters_of_interest:
                            others_sum += frac
                        else:
                            new_dict[tclust] = frac
                    if others_sum > 0:
                        new_dict['Others'] = others_sum
                    distributions_coi_all[tp][c][coi] = new_dict

    # Gather all target clusters (including "Others")
    all_target_clusters = []
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                all_target_clusters.extend(distributions_coi_all[tp][c][coi].keys())
    all_target_clusters = set(all_target_clusters)

    # Handle "Others" separately for sorting
    others_present = ("Others" in all_target_clusters)
    if others_present:
        all_target_clusters.remove("Others")

    # Now only numeric cluster IDs remain
    # Sort numeric clusters in ascending order
    all_target_clusters = sorted(all_target_clusters)

    # Put "Others" at the end if present
    if others_present:
        all_target_clusters.append("Others")

    def get_celltype_name(clust_id):
        if clust_id == 'Others':
            return "Others"
        return cluster_celltype_map.get(clust_id, f"Cluster {clust_id}")

    def get_cluster_color(clust_id):
        if clust_id == 'Others':
            return 'gray'
        return cluster_colors_map.get(clust_id, 'gray')

    # ---------------------------
    # Generate the stacked bar chart PDF
    # ---------------------------
    output_file_stacked = os.path.join(base_output_dir, f"{subpop_name}_target_cluster_distribution_coi_with_stack_labels.pdf")
    with PdfPages(output_file_stacked) as pdf_dist:
        fig, axes = plt.subplots(
            len(clusters_of_interest), 
            len(all_timepoints), 
            figsize=(4 * len(all_timepoints), 3 * len(clusters_of_interest)), 
            sharex=False, 
            sharey=True
        )
        
        # Handle shape for single row/col
        if len(clusters_of_interest) == 1 and len(all_timepoints) == 1:
            axes = np.array([[axes]])
        elif len(clusters_of_interest) == 1:
            axes = axes[np.newaxis, :]
        elif len(all_timepoints) == 1:
            axes = axes[:, np.newaxis]

        # Prepare legend patches
        legend_patches = []
        for tc in all_target_clusters:
            legend_patches.append(
                plt.Rectangle((0,0),1,1,
                              facecolor=get_cluster_color(tc),
                              edgecolor='black',
                              label=get_celltype_name(tc))
            )
        
        # Plot bars
        for row_i, coi in enumerate(clusters_of_interest):
            for col_i, tp in enumerate(all_timepoints):
                ax = axes[row_i, col_i]
                bar_positions = np.arange(len(cohorts))
                bottoms = np.zeros(len(cohorts))

                # Retrieve distribution data for each cohort at (tp, coi)
                cohort_distributions = [distributions_coi_all[tp][c][coi] for c in cohorts]

                # Build up data for stacked bars: (target_cluster, [fractions for each cohort])
                stack_data = []
                for tc in all_target_clusters:
                    h = [dist.get(tc, 0.0) for dist in cohort_distributions]
                    stack_data.append((tc, h))

                # Sort so biggest fraction (summed across cohorts) is at the bottom
                stack_data.sort(key=lambda x: sum(x[1]), reverse=True)

                # Plot
                for (tc, heights) in stack_data:
                    color = get_cluster_color(tc)
                    ax.bar(
                        bar_positions, 
                        heights, 
                        bottom=bottoms, 
                        color=color, 
                        edgecolor='black'
                    )
                    # Label each segment if fraction is big enough
                    for idx, val in enumerate(heights):
                        if val > 0.05:
                            mid_y = bottoms[idx] + val / 2
                            ax.text(
                                bar_positions[idx], 
                                mid_y, 
                                get_celltype_name(tc),
                                ha='center', 
                                va='center',
                                fontsize=6, 
                                color='white'
                            )
                    bottoms += heights

                if row_i == 0:
                    ax.set_title(f"{tp}", fontsize=10)

                ax.set_xticks(bar_positions)
                ax.set_xticklabels(cohorts, rotation=45, ha='right', fontsize=8)
                ax.set_ylim(0, 1)

        fig.suptitle(f"Distribution of Target Clusters per COI - {subpop_name}", fontsize=16)
        plt.tight_layout()
        
        # Adjust to make room for row labels & legend
        fig.subplots_adjust(left=0.15, top=0.86, right=0.82)  # extra space on the right for legend

        # Label rows with the celltype of the COI
        n_rows = len(clusters_of_interest)
        for row_i, coi in enumerate(clusters_of_interest):
            y_pos = 0.86 - (row_i + 0.5)*(0.86-0.1)/n_rows
            fig.text(0.05, y_pos, get_celltype_name(coi), va='center', ha='right', fontsize=10, color='black')

        # Create a single-column legend on the right
        fig.legend(
            handles=legend_patches,
            loc='upper left',
            bbox_to_anchor=(0.84, 0.95),
            ncol=1,
            fontsize=8,
            title="Target Celltypes"
        )

        pdf_dist.savefig(fig)
        plt.close(fig)

print("Visualization generation completed.")


Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Memory_Precursor_Effector_CD8 - control
Processing Memory_Precursor_Effector_CD8 - short_term
Processing Memory_Precursor_Effector_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)


Processing Central_Memory_CD8 - control
Processing Central_Memory_CD8 - short_term
Processing Central_Memory_CD8 - long_term
Processing Stem_Like_CD8 - control
Processing Stem_Like_CD8 - short_term
Processing Stem_Like_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing All_CD8 - control
Processing All_CD8 - short_term
Processing All_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.


In [11]:
import os
import numpy as np
import pandas as pd
import ot  # Make sure you have installed POT or similar library for optimal transport

def compute_shannon_diversity(counts_dict):
    """
    Compute the Shannon diversity index of a distribution given as a dictionary of counts.
    H = -∑ (p_i * log(p_i)).
    """
    import math
    total = sum(counts_dict.values())
    if total == 0:
        return 0.0
    # Convert counts to proportions
    proportions = [count / total for count in counts_dict.values()]
    # Calculate Shannon entropy
    shannon_entropy = -sum([p * math.log(p + 1e-12) for p in proportions])
    return shannon_entropy

def compute_simpson_diversity(counts_dict):
    """
    Compute Simpson’s diversity index of a distribution given as a dictionary of counts.
    Simpson’s index = 1 - ∑(p_i²), where p_i = count_i / total_counts.
    """
    total = sum(counts_dict.values())
    if total == 0:
        return 0.0
    
    # Convert counts to proportions
    proportions = [count / total for count in counts_dict.values()]
    # Calculate sum of squares of proportions
    sum_of_squares = sum(p ** 2 for p in proportions)
    # Simpson’s index
    simpson_index = 1 - sum_of_squares
    return simpson_index


def compute_target_diversity_for_group_ot(
    base_directory,
    subpop_name,
    group_name,
    timepoint,
    mapping,
    source_subpop
):
    """
    For a single group (e.g., "short_term"), compute target diversity index *per patient*
    using optimal transport.

    Steps (per patient):
    1) Read source_cells_new.csv (source) and target_cells_new.csv (target).
    2) Filter the source cells to include only the cluster IDs belonging to `source_subpop`.
    3) Run optimal transport (OT) from those filtered source cells to all target cells 
       (for that patient). 
       - We'll use PC columns to compute the cost matrix, as in your snippet.
    4) For each source cell, pick the target cell that has the highest transport weight.
    5) Look at the distribution of the matched target cells’ clusters -> compute diversity.
    6) Return {patient_id -> diversity_index}.
    """
    source_path = os.path.join(
        base_directory, subpop_name, group_name, f"Timepoint_{timepoint}", "source_cells_new.csv"
    )
    target_path = os.path.join(
        base_directory, subpop_name, group_name, f"Timepoint_{timepoint}", "target_cells_new.csv"
    )
    
    if not os.path.exists(source_path):
        raise FileNotFoundError(f"File not found: {source_path}")
    if not os.path.exists(target_path):
        raise FileNotFoundError(f"File not found: {target_path}")
    
    source_df = pd.read_csv(source_path)
    target_df = pd.read_csv(target_path)

    # Check that the CSVs contain the necessary columns.
    # Adapt these column names to your actual data.
    required_cols = {'cell_id', 'Patient', 'seurat_clusters'}
    for colset, df in zip(['source', 'target'], [source_df, target_df]):
        if not required_cols.issubset(df.columns):
            raise ValueError(f"Dataframe {colset} is missing required columns: {required_cols}")
    
    # We also need PC columns (or whatever you plan to use for OT).
    # For example, let's look for 'PC_' columns.
    pc_cols = [c for c in source_df.columns if c.startswith("PC_")]
    if not pc_cols:
        raise ValueError("No PC_ columns found for computing cost matrix. Check your data.")
    
    # Identify which clusters represent `source_subpop`
    source_cluster_ids = mapping[source_subpop]
    
    # We'll compute a dictionary of {patient_id: diversity}
    patient_diversity = {}
    all_patients = source_df['Patient'].unique()

    for patient_id in all_patients:
        # Filter source cells for this patient
        patient_source = source_df[
            (source_df['Patient'] == patient_id) &
            (source_df['seurat_clusters'].isin(source_cluster_ids))
        ].copy()
        
        # If there are no source cells of interest, skip or store 0
        if patient_source.empty:
            patient_diversity[patient_id] = 0.0
            continue
        
        # Filter target cells for this patient
        patient_target = target_df[target_df['Patient'] == patient_id].copy()
        
        # If no target cells, skip
        if patient_target.empty:
            patient_diversity[patient_id] = 0.0
            continue

        # Prepare data for OT
        # We'll use PC columns for cost
        source_coords = patient_source[pc_cols].values
        target_coords = patient_target[pc_cols].values
        
        # Probability distributions (uniform over source and target)
        a = np.ones((source_coords.shape[0],)) / source_coords.shape[0]
        b = np.ones((target_coords.shape[0],)) / target_coords.shape[0]
        
        # Euclidean cost matrix
        cost_matrix = ot.dist(source_coords, target_coords, metric='euclidean')
        
        try:
            transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
        except Exception as e:
            print(f"OT computation failed for patient {patient_id}: {e}")
            patient_diversity[patient_id] = 0.0
            continue
        
        # transport_plan has shape (num_source_cells, num_target_cells).
        # For each source cell (row), pick the target cell (col) that has max weight.
        matched_target_indices = np.argmax(transport_plan, axis=1)
        
        # Now find the target cell clusters for these matched indices.
        # matched_target_indices[i] is the column index in patient_target
        target_clusters_matched = patient_target.iloc[matched_target_indices]['seurat_clusters'].values
        
        # Count how many times each cluster appears
        cluster_counts = pd.Series(target_clusters_matched).value_counts().to_dict()
        
        # Compute diversity
        diversity_val = compute_shannon_diversity(cluster_counts)
        # diversity_val = compute_simpson_diversity(cluster_counts)
        patient_diversity[patient_id] = diversity_val
    
    return patient_diversity


def compute_p_value(diversity_group1, diversity_group2, n_boot=100000):
    """
    Given two dictionaries { patient_id: diversity_val } for group1 and group2,
    compute the p-value by permutation test (two-tailed).
    """
    group1_vals = np.array(list(diversity_group1.values()))
    group2_vals = np.array(list(diversity_group2.values()))
    
    # A small constant to prevent division by zero
    EPS = 1e-20
    true_diff = group2_vals.mean() / (group1_vals.mean() + EPS)
    # true_diff = group2_vals.mean() - group1_vals.mean()
    # true_diff = group2_vals.mean() / group1_vals.mean()
    
    combined = np.concatenate([group1_vals, group2_vals])
    n1 = len(group1_vals)
    n2 = len(group2_vals)
    
    count_extreme = 0
    for _ in range(n_boot):
        np.random.shuffle(combined)
        perm_g1 = combined[:n1]
        perm_g2 = combined[n1:n1 + n2]
        # perm_diff = perm_g2.mean() - perm_g1.mean()
        perm_diff = perm_g2.mean() / (perm_g1.mean() + EPS)
        # perm_diff = perm_g2.mean() / perm_g1.mean()
        if abs(perm_diff) >= abs(true_diff):
            count_extreme += 1
    
    pval = count_extreme / n_boot
    return pval


def compare_groups_diversity_ot(
    base_directory,
    subpop_name,
    mapping,
    timepoint,
    group1_names,
    group2_names,
    source_subpop,
    n_boot=100000
):
    """
    Main function to:
    1) Compute target diversity distributions for two sets of groups (e.g., short_term vs long_term),
       but using OT-based mapping from source to target cells.
    2) Perform a permutation test on the difference of means.
    """
    # For group1 (can have multiple group names combined)
    diversity_group1 = {}
    for g_name in group1_names:
        group_div = compute_target_diversity_for_group_ot(
            base_directory=base_directory,
            subpop_name=subpop_name,
            group_name=g_name,
            timepoint=timepoint,
            mapping=mapping,
            source_subpop=source_subpop
        )
        # Merge them (disambiguate with the group name + patient ID, or just patient ID if unique)
        for pid, val in group_div.items():
            # If a patient can appear in multiple subgroups, you might need special logic.
            # For now, we assume each patient is unique to that group.
            diversity_group1[(g_name, pid)] = val
    
    # For group2
    diversity_group2 = {}
    for g_name in group2_names:
        group_div = compute_target_diversity_for_group_ot(
            base_directory=base_directory,
            subpop_name=subpop_name,
            group_name=g_name,
            timepoint=timepoint,
            mapping=mapping,
            source_subpop=source_subpop
        )
        for pid, val in group_div.items():
            diversity_group2[(g_name, pid)] = val
    
    # Compute p-value
    pval = compute_p_value(diversity_group1, diversity_group2, n_boot=n_boot)

    # Print results
    mean1 = np.mean(list(diversity_group1.values()))
    mean2 = np.mean(list(diversity_group2.values()))
    print(f"Mean diversity (Group1): {mean1:.4f}")
    print(f"Mean diversity (Group2): {mean2:.4f}")
    print(f"p-value: {pval:.6f}")

    return pval, diversity_group1, diversity_group2


if __name__ == "__main__":
    # Example usage
    celltype_to_cluster = {
        "Effector_CD8": [1],
        "Memory_Precursor_Effector_CD8": [2],
        "Exhausted_T": [3],
        "Stem_Like_CD8": [8],
        "Effector_Memory_CD8": [10],
        "Central_Memory_CD8": [12],
        "Proliferating_Effector": [14, 16, 17]
    }
    
    base_directory = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"
    subpop_name = "Central_Memory_CD8"
    mapping = celltype_to_cluster
    timepoint = "C1"
    group1_names = ["short_term"]
    group2_names = ["long_term"]
    source_subpop = "Central_Memory_CD8"
    n_boot = 100000

    p_val, div_g1, div_g2 = compare_groups_diversity_ot(
        base_directory=base_directory,
        subpop_name=subpop_name,
        mapping=mapping,
        timepoint=timepoint,
        group1_names=group1_names,
        group2_names=group2_names,
        source_subpop=source_subpop,
        n_boot=n_boot
    )


Mean diversity (Group1): 0.2626
Mean diversity (Group2): 0.5767
p-value: 0.177660


#### Generating plots where cluster of interest is only Central Memory CD8 T Cells

In [13]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ot
from matplotlib.backends.backend_pdf import PdfPages

# ---------------------------
# Configuration Parameters
# ---------------------------

base_input_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi"
base_output_dir = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/new_sequencing_data/ITT/Clonal_Tracking_Results/final_nomenclature/coi_central_memory_cd8"

subpopulations = [
    "Effector_CD8",
    "Memory_Precursor_Effector_CD8",
    "Exhausted_T",
    "Central_Memory_CD8",
    "Stem_Like_CD8",
    "Effector_Memory_CD8",
    "Proliferating_Effector",
    "All_CD8"
]

cohorts = ["control", "short_term", "long_term"]

all_timepoints = ["Pre", "C1", "C2", "C4", "C6", "C9", "C18", "C36"]

cohort_colors = {
    'control': 'yellow',
    'short_term': 'blue',
    'long_term': 'red'
}

# --------------------------------------------------
# Clusters we are interested in as the "source" (COI)
# --------------------------------------------------
clusters_of_interest = [12]

# Provide the path to your CSV that maps each cluster to a color and celltype
color_mapping_file = "/project/dtran642_927/Data/Collaborator/DJ/scRNAseq/latest_big_sequencing_projects/MK/Publication_Material/T_Cell_cluster_colors.csv"
color_mapping_df = pd.read_csv(color_mapping_file)
cluster_colors_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Color']))
cluster_celltype_map = dict(zip(color_mapping_df['Cluster'], color_mapping_df['Celltype']))

def optimal_transport_visualization(subpop_name, cohort_name):
    """
    For a single subpopulation and cohort, run the full OT pipeline:
    1) Load data at each timepoint.
    2) Perform OT to map from source to next timepoint.
    3) Collect single-cell and cluster-level arrows, plus aggregated arrows.
    4) Generate PDFs of the movement plots (single-cell, cluster-level, aggregated).
    5) Return distributions_coi for stacked-bar plotting.
    """
    input_folder = os.path.join(base_input_dir, subpop_name, cohort_name)
    output_folder = os.path.join(base_output_dir, subpop_name)
    os.makedirs(output_folder, exist_ok=True)
    
    if not os.path.exists(input_folder):
        print(f"Input folder does not exist: {input_folder}")
        return None
    
    timepoint_folders = sorted([f for f in os.listdir(input_folder) if f.startswith("Timepoint_")])
    available_timepoints = [tp.split("_")[1] for tp in timepoint_folders]
    cohort_timepoints = [tp for tp in all_timepoints if tp in available_timepoints]
    
    if not cohort_timepoints:
        print(f"No timepoints available for {subpop_name} in {cohort_name} cohort.")
        return None

    all_x_coords = []
    all_y_coords = []
    
    cell_counts_source = []
    cell_counts_gray = []
    full_data = {}
    
    # Gather data across all valid timepoints
    for tp in cohort_timepoints:
        source_folder = os.path.join(input_folder, f"Timepoint_{tp}")
        source_cells_file = os.path.join(source_folder, 'source_cells_new.csv')
        gray_cells_file = os.path.join(source_folder, 'gray_cells_new.csv')
        
        if not (os.path.exists(source_cells_file) and os.path.exists(gray_cells_file)):
            print(f"Missing files for {tp}. Skipping this timepoint.")
            continue
        
        source_cells = pd.read_csv(source_cells_file)
        gray_cells = pd.read_csv(gray_cells_file)
        
        if source_cells.empty:
            print(f"No source cells for {tp}. Skipping.")
            continue
        
        full_data[tp] = {
            'source': source_cells,
            'gray': gray_cells
        }
        
        all_x_coords.extend(gray_cells['UMAP_2'])
        all_x_coords.extend(source_cells['UMAP_2'])
        all_y_coords.extend(gray_cells['UMAP_1'])
        all_y_coords.extend(source_cells['UMAP_1'])
        
        cell_counts_source.append(len(source_cells))
        cell_counts_gray.append(len(gray_cells))
    
    if not full_data:
        print("No valid timepoints with data.")
        return None
    
    x_min, x_max = min(all_x_coords), max(all_x_coords)
    y_min, y_max = min(all_y_coords), max(all_y_coords)
    
    min_source_cells = min(cell_counts_source) if cell_counts_source else 0
    min_gray_cells = min(cell_counts_gray) if cell_counts_gray else 0
    if min_source_cells == 0 or min_gray_cells == 0:
        print("Insufficient cells. Skipping visualization.")
        return None
    
    # Prepare data structures for results
    timepoint_results = {}
    distributions_coi = {tp: {cohort_name: {c: {} for c in clusters_of_interest}} for tp in cohort_timepoints}
    
    for i, source_tp in enumerate(cohort_timepoints):
        source_data = full_data[source_tp]
        source_cells = source_data['source']
        
        # Downsample so each timepoint has the same # of source & gray cells for plotting
        sampled_source_cells = source_cells.sample(n=min_source_cells, random_state=42)
        sampled_gray_cells = full_data[source_tp]['gray'].sample(n=min_gray_cells, random_state=42)
        
        timepoint_results[source_tp] = {
            'sampled_source': sampled_source_cells,
            'sampled_gray': sampled_gray_cells
        }
        
        # If not the last timepoint, compute OT from source_tp -> target_tp
        if i < len(cohort_timepoints) - 1:
            target_tp = cohort_timepoints[i+1]
            target_data = full_data[target_tp]
            target_cells = target_data['source']
            
            if target_cells.empty:
                # If there's no data in the next timepoint, skip
                timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                timepoint_results[source_tp]['cluster_arrows'] = []
                timepoint_results[source_tp]['aggregated_arrow'] = None
            else:
                # Identify columns for PCs
                pc_cols = [c for c in source_cells.columns if c.startswith('PC_')]
                full_source_coords_pc = source_cells[pc_cols].values
                full_target_coords_pc = target_cells[pc_cols].values
                
                # UMAP coords
                full_source_coords_umap = source_cells[['UMAP_2', 'UMAP_1']].values
                full_target_coords_umap = target_cells[['UMAP_2', 'UMAP_1']].values
                
                # Uniform distribution over source / target
                a = np.ones((full_source_coords_pc.shape[0],)) / full_source_coords_pc.shape[0]
                b = np.ones((full_target_coords_pc.shape[0],)) / full_target_coords_pc.shape[0]
                
                cost_matrix = ot.dist(full_source_coords_pc, full_target_coords_pc, metric='euclidean')
                
                try:
                    transport_plan = ot.emd(a, b, cost_matrix, numItermax=100000)
                except Exception as e:
                    print(f"OT computation failed for {source_tp} to {target_tp}: {e}")
                    timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
                    timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    continue
                
                full_target_indices = np.argmax(transport_plan, axis=1)
                displacement_vectors_full = full_target_coords_umap[full_target_indices] - full_source_coords_umap
                norms_full = np.linalg.norm(displacement_vectors_full, axis=1)
                
                # Displacements for the downsampled subset
                sampled_source_indices = source_cells.index.get_indexer_for(sampled_source_cells.index)
                sampled_displacements = displacement_vectors_full[sampled_source_indices, :]
                sampled_norms = norms_full[sampled_source_indices]
                
                timepoint_results[source_tp]['single_cell_displacements'] = sampled_displacements
                timepoint_results[source_tp]['single_cell_arrow_lengths'] = sampled_norms
                
                if 'seurat_clusters' in source_cells.columns and 'seurat_clusters' in target_cells.columns:
                    # Lump some clusters if needed
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({16:14,17:14})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({16:14,17:14})
                    source_cells['seurat_clusters'] = source_cells['seurat_clusters'].replace({9:6,18:6})
                    target_cells['seurat_clusters'] = target_cells['seurat_clusters'].replace({9:6,18:6})
                    
                    # For cluster-level arrows, compute average displacement per cluster
                    source_clusters_full = source_cells['seurat_clusters'].values
                    df_cluster_full = pd.DataFrame({
                        'cluster': source_clusters_full,
                        'sx': full_source_coords_umap[:,0],
                        'sy': full_source_coords_umap[:,1],
                        'dx': displacement_vectors_full[:,0],
                        'dy': displacement_vectors_full[:,1],
                        'norm': norms_full
                    })
                    
                    gray_cells_tp = full_data[source_tp]['gray']
                    all_cells_combined = pd.concat([source_cells, gray_cells_tp])
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({16:14,17:14})
                    all_cells_combined['seurat_clusters'] = all_cells_combined['seurat_clusters'].replace({9:6,18:6})
                    
                    all_clusters = all_cells_combined['seurat_clusters'].values
                    all_coords_umap = all_cells_combined[['UMAP_2', 'UMAP_1']].values
                    
                    df_all = pd.DataFrame({
                        'cluster': all_clusters,
                        'sx': all_coords_umap[:,0],
                        'sy': all_coords_umap[:,1]
                    })
                    
                    centroids = df_all.groupby('cluster')[['sx','sy']].median()
                    mean_disp = df_cluster_full.groupby('cluster')[['dx','dy']].mean()
                    
                    cluster_norms = np.sqrt(mean_disp['dx']**2 + mean_disp['dy']**2)
                    
                    # Build arrow info for each cluster of interest
                    cluster_arrows = []
                    common_clusters = centroids.index.intersection(mean_disp.index)
                    for clust in common_clusters:
                        if clust in clusters_of_interest:
                            cx, cy = centroids.loc[clust, ['sx','sy']]
                            cdx, cdy = mean_disp.loc[clust, ['dx','dy']]
                            cnorm = np.sqrt(cdx**2 + cdy**2)
                            if cnorm > 0:
                                cdx /= cnorm
                                cdy /= cnorm
                            length = cluster_norms.loc[clust]
                            cdx *= length
                            cdy *= length
                            cluster_arrows.append((clust, cx, cy, cdx, cdy))
                    
                    timepoint_results[source_tp]['cluster_arrows'] = cluster_arrows
    
                    if len(cluster_arrows) > 0:
                        # "Aggregated" arrow by summation
                        source_median_x = source_cells['UMAP_2'].median()
                        source_median_y = source_cells['UMAP_1'].median()
                        total_dx = sum([arrow[3] for arrow in cluster_arrows])
                        total_dy = sum([arrow[4] for arrow in cluster_arrows])
                        timepoint_results[source_tp]['aggregated_arrow'] = (
                            source_median_x, 
                            source_median_y, 
                            total_dx, 
                            total_dy
                        )
                    else:
                        timepoint_results[source_tp]['aggregated_arrow'] = None
    
                    # Compute target cluster distributions for each COI (source cluster)
                    for coi in clusters_of_interest:
                        coi_mask = (source_cells['seurat_clusters'] == coi)
                        if np.any(coi_mask):
                            # Which target clusters do these COI cells map to?
                            selected_target_indices = full_target_indices[coi_mask]
                            selected_target_clusters = target_cells['seurat_clusters'].iloc[selected_target_indices].values
                            unique_tclusters, counts = np.unique(selected_target_clusters, return_counts=True)
                            total_count = counts.sum()
                            if total_count > 0:
                                fraction_dict = {int(tc): (ct / total_count) for tc, ct in zip(unique_tclusters, counts)}
                            else:
                                fraction_dict = {}
                            distributions_coi[source_tp][cohort_name][coi] = fraction_dict
                        else:
                            distributions_coi[source_tp][cohort_name][coi] = {}
                else:
                    # If cluster info is missing
                    timepoint_results[source_tp]['cluster_arrows'] = []
                    timepoint_results[source_tp]['aggregated_arrow'] = None
                    for coi in clusters_of_interest:
                        distributions_coi[source_tp][cohort_name][coi] = {}
        else:
            # Last timepoint has no "next" timepoint
            timepoint_results[source_tp]['single_cell_displacements'] = np.array([])
            timepoint_results[source_tp]['single_cell_arrow_lengths'] = np.array([])
            timepoint_results[source_tp]['cluster_arrows'] = []
            timepoint_results[source_tp]['aggregated_arrow'] = None
            for coi in clusters_of_interest:
                distributions_coi[source_tp][cohort_name][coi] = {}

    # ---------------------------
    # Generate PDFs of movement
    # ---------------------------

    # 1) Single-cell arrows
    output_file_original = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_differential_arrow_lengths_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_original) as pdf_original:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_orig, axes_orig = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_orig = [axes_orig]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_orig[i]
            res = timepoint_results[source_tp]
            sampled_source_cells = res['sampled_source']
            sampled_gray_cells = res['sampled_gray']
            
            gray_x = sampled_gray_cells['UMAP_2'].values
            gray_y = sampled_gray_cells['UMAP_1'].values
            source_x = sampled_source_cells['UMAP_2'].values
            source_y = sampled_source_cells['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            single_cell_displacements = res['single_cell_displacements']
            single_cell_arrow_lengths = res['single_cell_arrow_lengths']
            
            # ---------------------------
            # Only draw arrows for COI
            # ---------------------------
            if len(single_cell_displacements) > 0:
                # Coordinates for all sampled source cells
                source_coords_sampled = sampled_source_cells[['UMAP_2','UMAP_1']].values
                
                # Compute the unit vectors for each displacement
                sampled_norms = np.linalg.norm(single_cell_displacements, axis=1)
                with np.errstate(divide='ignore', invalid='ignore'):
                    unit_vectors = single_cell_displacements / sampled_norms[:, np.newaxis]
                    unit_vectors[~np.isfinite(unit_vectors)] = 0
                scaled_vectors = unit_vectors * single_cell_arrow_lengths[:, np.newaxis]
                
                # If you have cluster info, filter by clusters_of_interest
                if 'seurat_clusters' in sampled_source_cells.columns:
                    for j in range(len(source_coords_sampled)):
                        # Check if this cell is in one of your clusters of interest
                        current_cluster = sampled_source_cells.iloc[j]['seurat_clusters']
                        if current_cluster not in clusters_of_interest:
                            # Skip drawing arrows if not in COI
                            continue
                        
                        sx, sy = source_coords_sampled[j]
                        dx, dy = scaled_vectors[j]
                        ax.arrow(
                            sx, sy, 
                            dx, dy,
                            color='black', 
                            alpha=0.7, 
                            head_width=0.2, 
                            head_length=0.2,
                            length_includes_head=True, 
                            linewidth=0.5
                        )
                else:
                    # If no cluster column, you can either skip or draw all
                    for j in range(len(source_coords_sampled)):
                        sx, sy = source_coords_sampled[j]
                        dx, dy = scaled_vectors[j]
                        ax.arrow(
                            sx, sy, 
                            dx, dy,
                            color='black', 
                            alpha=0.7, 
                            head_width=0.2, 
                            head_length=0.2,
                            length_includes_head=True, 
                            linewidth=0.5
                        )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_orig.suptitle(f"{subpop_name} - {cohort_name}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_original.savefig(fig_orig)
        plt.close(fig_orig)
    
    # 2) Cluster-level arrows
    output_file_cluster = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_cluster_arrows_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_cluster) as pdf_cluster:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_clust, axes_clust = plt.subplots(
            1, 
            num_timepoints, 
            figsize=(4 * num_timepoints, 4), 
            sharex=True, 
            sharey=True
        )
        if num_timepoints == 1:
            axes_clust = [axes_clust]
        
        for i, source_tp in enumerate(plot_tps):
            ax = axes_clust[i]
            res = timepoint_results[source_tp]
            
            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)
            
            cluster_arrows = res.get('cluster_arrows', [])
            
            source_cells = full_data[source_tp]['source']
            if 'seurat_clusters' in source_cells.columns:
                source_cluster_counts = source_cells['seurat_clusters'].value_counts()
                total_source_cells_current = len(source_cells)
            else:
                source_cluster_counts = pd.Series()
                total_source_cells_current = 1
            
            for (clust, cx, cy, cdx, cdy) in cluster_arrows:
                arrow_color = cluster_colors_map.get(clust, 'black')
                proportion = source_cluster_counts.get(clust, 0) / total_source_cells_current
                line_width = 1.0 + 8.0 * proportion

                # Outline in black for clarity
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color='black',
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width + 2
                )
                
                # Main arrow in cluster color
                ax.arrow(
                    cx, cy,
                    cdx, cdy,
                    color=arrow_color,
                    alpha=1,
                    head_width=0.2 + 0.5 * proportion,
                    head_length=0.1 + 0.1 * proportion,
                    length_includes_head=True,
                    linewidth=line_width
                )
            
            ax.set_title(f"Timepoint: {source_tp}")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])
        
        fig_clust.suptitle(f"{subpop_name} - {cohort_name} (Cluster-Level Arrows)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_cluster.savefig(fig_clust)
        plt.close(fig_clust)

    # 3) Single aggregated arrow
    output_file_single_agg = os.path.join(
        output_folder, 
        f"{subpop_name}_{cohort_name}_movement_plots_aggregated_single_arrow_equal_cells_using_PCs.pdf"
    )
    with PdfPages(output_file_single_agg) as pdf_agg:
        plot_tps = list(timepoint_results.keys())
        num_timepoints = len(plot_tps)
        fig_agg, axes_agg = plt.subplots(
            1,
            num_timepoints,
            figsize=(4 * num_timepoints, 4),
            sharex=True,
            sharey=True
        )
        if num_timepoints == 1:
            axes_agg = [axes_agg]

        for i, source_tp in enumerate(plot_tps):
            ax = axes_agg[i]
            res = timepoint_results[source_tp]

            gray_x = res['sampled_gray']['UMAP_2'].values
            gray_y = res['sampled_gray']['UMAP_1'].values
            source_x = res['sampled_source']['UMAP_2'].values
            source_y = res['sampled_source']['UMAP_1'].values
            
            ax.scatter(gray_x, gray_y, color='gray', s=5, alpha=0.5)
            cohort_color = cohort_colors.get(cohort_name, 'red')
            ax.scatter(source_x, source_y, color=cohort_color, s=5, alpha=1)

            aggregated_arrow = res.get('aggregated_arrow', None)
            if aggregated_arrow is not None:
                global_cx, global_cy, total_dx, total_dy = aggregated_arrow
                ax.arrow(
                    global_cx, global_cy,
                    total_dx, total_dy,
                    color='black', 
                    alpha=0.9,
                    head_width=0.3, 
                    head_length=0.3,
                    length_includes_head=True, 
                    linewidth=2.0
                )

            ax.set_title(f"Timepoint: {source_tp} (Single Aggregated Arrow)")
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        fig_agg.suptitle(f"{subpop_name} - {cohort_name} (Single Aggregated Arrow)", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        pdf_agg.savefig(fig_agg)
        plt.close(fig_agg)
    
    # Return the distributions for stacked-bar plotting
    return distributions_coi


# ---------------------------
# Main Execution Loop
# ---------------------------
for subpop_name in subpopulations:
    # Prepare a nested dict so we can collect distributions
    distributions_coi_all = {
        tp: {
            c: {coi: {} for coi in clusters_of_interest} 
            for c in cohorts
        }
        for tp in all_timepoints
    }
    
    # Run for each cohort
    for cohort_name in cohorts:
        print(f"Processing {subpop_name} - {cohort_name}")
        input_subfolder = os.path.join(base_input_dir, subpop_name, cohort_name)
        if not os.path.exists(input_subfolder):
            print(f"Input subfolder does not exist: {input_subfolder}. Skipping.")
            continue
        
        dist_coi = optimal_transport_visualization(subpop_name, cohort_name)
        if dist_coi is not None:
            for tp in dist_coi:
                for c in dist_coi[tp]:
                    for coi in dist_coi[tp][c]:
                        distributions_coi_all[tp][c][coi] = dist_coi[tp][c][coi]

    # --------------------------------------------------
    # 1) Group non-COI target clusters into "Others"
    # --------------------------------------------------
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                fraction_dict = distributions_coi_all[tp][c][coi]
                if fraction_dict:
                    new_dict = {}
                    others_sum = 0.0
                    for tclust, frac in fraction_dict.items():
                        # If the target cluster is not in our "source" list, 
                        # put it in the "Others" bin
                        if tclust not in clusters_of_interest:
                            others_sum += frac
                        else:
                            new_dict[tclust] = frac
                    if others_sum > 0:
                        new_dict['Others'] = others_sum
                    distributions_coi_all[tp][c][coi] = new_dict

    # Gather all target clusters (including "Others")
    all_target_clusters = []
    for tp in distributions_coi_all:
        for c in cohorts:
            for coi in clusters_of_interest:
                all_target_clusters.extend(distributions_coi_all[tp][c][coi].keys())
    all_target_clusters = set(all_target_clusters)

    # Handle "Others" separately for sorting
    others_present = ("Others" in all_target_clusters)
    if others_present:
        all_target_clusters.remove("Others")

    # Now only numeric cluster IDs remain
    # Sort numeric clusters in ascending order
    all_target_clusters = sorted(all_target_clusters)

    # Put "Others" at the end if present
    if others_present:
        all_target_clusters.append("Others")

    def get_celltype_name(clust_id):
        if clust_id == 'Others':
            return "Others"
        return cluster_celltype_map.get(clust_id, f"Cluster {clust_id}")

    def get_cluster_color(clust_id):
        if clust_id == 'Others':
            return 'gray'
        return cluster_colors_map.get(clust_id, 'gray')

    # ---------------------------
    # Generate the stacked bar chart PDF
    # ---------------------------
    output_file_stacked = os.path.join(base_output_dir, f"{subpop_name}_target_cluster_distribution_coi_with_stack_labels.pdf")
    with PdfPages(output_file_stacked) as pdf_dist:
        fig, axes = plt.subplots(
            len(clusters_of_interest), 
            len(all_timepoints), 
            figsize=(4 * len(all_timepoints), 3 * len(clusters_of_interest)), 
            sharex=False, 
            sharey=True
        )
        
        # Handle shape for single row/col
        if len(clusters_of_interest) == 1 and len(all_timepoints) == 1:
            axes = np.array([[axes]])
        elif len(clusters_of_interest) == 1:
            axes = axes[np.newaxis, :]
        elif len(all_timepoints) == 1:
            axes = axes[:, np.newaxis]

        # Prepare legend patches
        legend_patches = []
        for tc in all_target_clusters:
            legend_patches.append(
                plt.Rectangle((0,0),1,1,
                              facecolor=get_cluster_color(tc),
                              edgecolor='black',
                              label=get_celltype_name(tc))
            )
        
        # Plot bars
        for row_i, coi in enumerate(clusters_of_interest):
            for col_i, tp in enumerate(all_timepoints):
                ax = axes[row_i, col_i]
                bar_positions = np.arange(len(cohorts))
                bottoms = np.zeros(len(cohorts))

                # Retrieve distribution data for each cohort at (tp, coi)
                cohort_distributions = [distributions_coi_all[tp][c][coi] for c in cohorts]

                # Build up data for stacked bars: (target_cluster, [fractions for each cohort])
                stack_data = []
                for tc in all_target_clusters:
                    h = [dist.get(tc, 0.0) for dist in cohort_distributions]
                    stack_data.append((tc, h))

                # Sort so biggest fraction (summed across cohorts) is at the bottom
                stack_data.sort(key=lambda x: sum(x[1]), reverse=True)

                # Plot
                for (tc, heights) in stack_data:
                    color = get_cluster_color(tc)
                    ax.bar(
                        bar_positions, 
                        heights, 
                        bottom=bottoms, 
                        color=color, 
                        edgecolor='black'
                    )
                    # Label each segment if fraction is big enough
                    for idx, val in enumerate(heights):
                        if val > 0.05:
                            mid_y = bottoms[idx] + val / 2
                            ax.text(
                                bar_positions[idx], 
                                mid_y, 
                                get_celltype_name(tc),
                                ha='center', 
                                va='center',
                                fontsize=6, 
                                color='white'
                            )
                    bottoms += heights

                if row_i == 0:
                    ax.set_title(f"{tp}", fontsize=10)

                ax.set_xticks(bar_positions)
                ax.set_xticklabels(cohorts, rotation=45, ha='right', fontsize=8)
                ax.set_ylim(0, 1)

        fig.suptitle(f"Distribution of Target Clusters per COI - {subpop_name}", fontsize=16)
        plt.tight_layout()
        
        # Adjust to make room for row labels & legend
        fig.subplots_adjust(left=0.15, top=0.86, right=0.82)  # extra space on the right for legend

        # Label rows with the celltype of the COI
        n_rows = len(clusters_of_interest)
        for row_i, coi in enumerate(clusters_of_interest):
            y_pos = 0.86 - (row_i + 0.5)*(0.86-0.1)/n_rows
            fig.text(0.05, y_pos, get_celltype_name(coi), va='center', ha='right', fontsize=10, color='black')

        # Create a single-column legend on the right
        fig.legend(
            handles=legend_patches,
            loc='upper left',
            bbox_to_anchor=(0.84, 0.95),
            ncol=1,
            fontsize=8,
            title="Target Celltypes"
        )

        pdf_dist.savefig(fig)
        plt.close(fig)

print("Visualization generation completed.")

Processing Effector_CD8 - control
Processing Effector_CD8 - short_term
Processing Effector_CD8 - long_term
Processing Memory_Precursor_Effector_CD8 - control
Processing Memory_Precursor_Effector_CD8 - short_term
Processing Memory_Precursor_Effector_CD8 - long_term
Processing Exhausted_T - control
Processing Exhausted_T - short_term
Processing Exhausted_T - long_term


  result_code_string = check_result(result_code)


Processing Central_Memory_CD8 - control
Processing Central_Memory_CD8 - short_term
Processing Central_Memory_CD8 - long_term
Processing Stem_Like_CD8 - control
Processing Stem_Like_CD8 - short_term
Processing Stem_Like_CD8 - long_term


  result_code_string = check_result(result_code)


Processing Effector_Memory_CD8 - control
Processing Effector_Memory_CD8 - short_term
Processing Effector_Memory_CD8 - long_term
Processing Proliferating_Effector - control
Processing Proliferating_Effector - short_term
Processing Proliferating_Effector - long_term
Processing All_CD8 - control
Processing All_CD8 - short_term
Processing All_CD8 - long_term


  result_code_string = check_result(result_code)
  result_code_string = check_result(result_code)


Visualization generation completed.
