# GPA for transportation of gene expression datasets

This notebook aims to apply GPA (our baseline model) to recover the trajectory of an empirical (EMT) gene expression dataset with unknown dynamics.
We first reduce the dimensionality of the original data (175) to an accessible dimension through PCA, then apply GPA in the latent space, and reconstruct the result to the original high-dimensional space.

## Dataset

* 6 snapshots at different timepoints (day 0, 1, 2, 3, 4, 8)
* samplesize: differ by days
* dimension: 175
* The trajectory will be visualized in 2D principal component axes

## Functionality

* Given source and target days, GPA transports the source dataset toward the target dataset from a deterministic particle dynamics for the gradient flow of (Lipschitz regularized or limited transportation speed) KL divergergence.
$$D_{KL}^L(P\|Q) = \sup_{\| \nabla \phi \| \leq L} \left \{\mathbb{E}_P[\phi] - \log \mathbb{E}_Q [\exp(\phi)] \right \}$$
$$\partial_t P + \nabla \cdot \left(P v\right) = \partial_t P - \nabla \cdot \left(P \nabla \phi \right) = 0$$
$$\dot{X} = - \nabla \phi_t(X)$$

* Snapshots at all the intermediate points will be highlighted by default, but it can be specified.
* $W_2$ distance will be calculated between the particle trajectory and the snapshots.


In [None]:
import tensorflow as tf
result_dir = '%s/assets/Transport_genes/' % main_dir

def load_W(filename):
    with open(filename, "rb") as fr:
        W, b, p = pk.load(fr)

    W = [tf.Variable(w,  dtype=tf.float32) for w in W]
    b = [tf.Variable(b_,  dtype=tf.float32) for b_ in b]

    return W, b, p

def v(x, t, W, b):   # neural newtork for time-dependent vectorfield
    num_layers = len(W)
    activation_ftn = tf.nn.tanh
        
    h = tf.concat([x, t*tf.ones([x.shape[0], 1], dtype=tf.float32)], axis=1)
    for l in range(0,num_layers-1):
        h = activation_ftn(tf.add(tf.matmul(h, W[l]), b[l]))
    out=tf.add(tf.matmul(h, W[-1]), b[-1])

    return out

def time_integration(x0, T, dt):
    x = tf.constant(x0, dtype=tf.float32)
    xs = [x0]
    for i in range(int(T/dt)):
        vv = v(x, dt*i, W, b)
        x += dt * vv
        xs.append(x.numpy())
    return xs


def generate_animation(days, intermediate_days, X1_trpts, dt, physical_dt, img_src, d_red = 2, vs = None):
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    fig, ax = plt.subplots()
    ims = []

    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}    
    
    
    for i, X1_trpt in enumerate(X1_trpts):  # trajectories
        if np.isnan(X1_trpt).any():
            break
        X1_trpt_vis = X1_trpt
        
        if type(vs) != type(None) and i < len(X1_trpts)-1:
            X1_trpt_vis_next = X1_trpts[i+1]
            vs_vis = (X1_trpt_vis_next-X1_trpt_vis) / dt
            im = ax.quiver(X1_trpt_vis[:, 0], X1_trpt_vis[:, 1], vs_vis[:, 0], vs_vis[:, 1], 
                           width=0.003, headwidth=7, headlength=15, headaxislength=7, zorder=15)                            
        else:
            im = ax.scatter(X1_trpt_vis[:,0], X1_trpt_vis[:,1], color=colors[days[0]], 
                            alpha=1.0, s=0.7, zorder=10, label=f'day {days[0]}') # transported source 
            for t in days[1:]:
                X2_vis = pca.transform(mats[t])
                ax.scatter(X2_vis[:,0], X2_vis[:,1], color=colors[t], 
                   alpha=1.0, s=0.7, zorder=5, label=f'day {t}') # target 
            
            for t in intermediate_days:
                X1_intermediate_vis = pca.transform(mats[t])
                ax.scatter(X1_intermediate_vis[:,0], X1_intermediate_vis[:,1], color='lightgray', 
                        alpha=0.3, s=0.7, zorder=1, label=f'day {t}')
        #ax.set_xlim([-5,6])
        #ax.set_ylim([-5,6])
        
        #leg = ax.legend(loc='upper right')
        ttl = ax.text(0.5,1.05, "t = %.3f" % (physical_dt*i), \
                      bbox={'facecolor':'w', 'alpha':0.5, 'pad':5}, \
                      transform=ax.transAxes, ha="center")
        ims.append([im, ttl])#, leg])
        
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=200)
    writergif = animation.PillowWriter(fps=3)
    ani.save(img_src, writer=writergif)
    plt.clf()
    display(Image(filename = img_src))



    

In [None]:
## Static plot function for piecewise GPA (sample 3 - stem cell and sample 5 - synthetic)


## Static plot function for piecewise GPA

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_three_timepoints(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        intermediate_days[0]: '#2ca02c',  # Green
        middle_t: '#ff7f0e',  # Orange
        intermediate_days[1]: '#8c564b',  # Brown
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=12, zorder = 10, label=f'Time {source_t} (Training Data)')
    ax1.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=12, zorder = 10, label=f'Time {middle_t} (Training Data)')
    ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=12, zorder = 10, label=f'Time {target_t} (Training Data)')

    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax1.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.2, alpha=1.0, s=15, zorder = 20,  label=f'Time {t} (Test Data)')

    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=5, zorder = 1)

    
    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=24)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=24)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 24)
    ax1.set_ylabel("PC 2", fontsize = 24)
    ax1.tick_params(axis='both', which='major', labelsize=24)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=12, zorder = 10, label=f'Time {source_t} (Training Data)')
    ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=12, zorder = 10, label=f'Time {middle_t} (Training Data)')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=12, zorder = 10, label=f'Time {target_t} (Training Data)')

    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax2.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.2, alpha=1.0, s=15, zorder = 20,  label=f'Time {t} (Test Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 24)
    ax2.set_ylabel("PC 2", fontsize = 24)
    ax2.tick_params(axis='both', which='major', labelsize=24)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
    time_labels = []
    for label in labels:
        try:
            time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
            time_labels.append((time_value, label))  # Store (time, label) pairs
        except ValueError:
            time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
    # Sort legend by time values
    time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
    sorted_labels = [item[1] for item in time_labels]
    sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
    # **Increase marker size in legend**
    for handle in sorted_handles:
        if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
            handle.set_markersize(30)  # Adjust marker size
    
    # Create a separate figure for the legend
    # Create a separate figure for the legend
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
    ax_legend.axis("off")  # Remove axes
    
    # Create legend with smaller markers and tighter spacing
    legend = ax_legend.legend(
        sorted_handles,
        sorted_labels,
        fontsize=20,         # Font size of text
        loc='center',
        ncol=len(sorted_labels),
        markerscale=2,     # Scale down marker size in legend
        handlelength=1.5,    # Length of the marker line
        handletextpad=0.2    # Padding between marker and label text
    )    

    # Save the legend separately
    legend_path = os.path.join(result_dir, "legend_only.png")
    fig_legend.savefig(legend_path, bbox_inches="tight")
    plt.close(fig_legend)  # Close the legend figure
    
    print(f"Legend saved separately at: {legend_path}")



In [None]:
## Static plot function for simple GPA (Sample 1 - EMT data)


import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_two_timepoints(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None, output_file_snapshots_only=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    #Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    #ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], facecolors='none', edgecolors=color_map[source_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {source_t}')
    #ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], facecolors='none', edgecolors=color_map[target_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {target_t}')


    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax1.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.0, alpha=0.75, s=10, zorder = 20,  label=f'Day {t} (Test Data)')

    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder = 1)

    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=27)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=27)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 27)
    ax1.set_ylabel("PC 2", fontsize = 27)
    ax1.tick_params(axis='both', which='major', labelsize=30)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    # Plot only source, intermediates, and target
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8,  zorder = 15, label=f'Time {source_t} (Training Data)')
    #ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=10,  zorder = 10, label=f'Time {middle_t}')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8,  zorder = 10, label=f'Time {target_t} (Training Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 27)
    ax2.set_ylabel("PC 2", fontsize = 27)
    ax2.tick_params(axis='both', which='major', labelsize=27)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
    time_labels = []
    for label in labels:
        try:
            time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
            time_labels.append((time_value, label))  # Store (time, label) pairs
        except ValueError:
            time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
    # Sort legend by time values
    time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
    sorted_labels = [item[1] for item in time_labels]
    sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
    # **Increase marker size in legend**
    for handle in sorted_handles:
        if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
            handle.set_markersize(30)  # Adjust marker size
    
    # Create a separate figure for the legend
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
    ax_legend.axis("off")  # Remove axes
        
    # Create legend with larger markers for scatter plots
    legend = ax_legend.legend(
        sorted_handles, sorted_labels, fontsize=20, loc='center',
        ncol=len(sorted_labels), markerscale=2)
    
    # Save the legend separately
    legend_path = os.path.join(result_dir, "legend_only.png")
    fig_legend.savefig(legend_path, bbox_inches="tight")
    plt.close(fig_legend)  # Close the legend figure
    
    print(f"Legend saved separately at: {legend_path}")





    






In [None]:
## Static plot function for simple GPA (NDPR and Clinical data)


import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_two_timepoints_no_middle(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None, output_file_snapshots_only=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        #intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    #Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    #ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], facecolors='none', edgecolors=color_map[source_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {source_t}')
    #ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], facecolors='none', edgecolors=color_map[target_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {target_t}')



    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder = 1)

    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=20)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=20)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 20)
    ax1.set_ylabel("PC 2", fontsize = 20)
    ax1.tick_params(axis='both', which='major', labelsize=20)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    # Plot only source, intermediates, and target
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8,  zorder = 15, label=f'Time {source_t} (Training Data)')
    #ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=10,  zorder = 10, label=f'Time {middle_t}')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8,  zorder = 10, label=f'Time {target_t} (Training Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 20)
    ax2.set_ylabel("PC 2", fontsize = 20)
    ax2.tick_params(axis='both', which='major', labelsize=20)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Proceed only if legend elements exist
    if handles and labels:
        # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
        time_labels = []
        for label in labels:
            try:
                time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
                time_labels.append((time_value, label))  # Store (time, label) pairs
            except ValueError:
                time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
        # Sort legend by time values
        time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
        sorted_labels = [item[1] for item in time_labels]
        sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
        # **Increase marker size in legend**
        for handle in sorted_handles:
            if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
                handle.set_markersize(30)  # Adjust marker size
    
        # Create a separate figure for the legend
        fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
        ax_legend.axis("off")  # Remove axes
    
        # Create legend with larger markers for scatter plots
        legend = ax_legend.legend(
            sorted_handles, sorted_labels, fontsize=20, loc='center',
            ncol=len(sorted_labels), markerscale=6  # Increase scatter marker size
        )
    
        # Save the legend separately
        legend_path = os.path.join(result_dir, "legend_only.png")
        fig_legend.savefig(legend_path, bbox_inches="tight")
        plt.close(fig_legend)  # Close the legend figure
    
        print(f"Legend saved separately at: {legend_path}")
    else:
        print("No legend elements found — skipping separate legend plot.")
    





    






In [None]:
## Static plot function for simple GPA (NDPR and Clinical data) - separate legend


import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import os
import pickle as pk

def generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days, intermediate_days, X1_trpts, mats, d_red=26,
    output_file_with_snapshots=None,
    output_file_without_snapshots=None,
    output_file_snapshots_only=None
):
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    num_snapshots = len(X1_trpts)
    colormap = cm.viridis
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])

    source_t, _, target_t = days[0], days[1], days[-1]

    color_map = {
        source_t: 'magenta',  # Blue
        target_t: '#008080'   # Red
    }

    # --- Plot 1: WITH snapshots (No inline colorbar) ---
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    X1_vis = pca.transform(mats[source_t])
    X2_vis = pca.transform(mats[target_t])

    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder=1)

    ax1.set_xlabel("PC 1", fontsize=32)
    ax1.set_ylabel("PC 2", fontsize=32)
    ax1.tick_params(axis='both', labelsize=32)
    ax1.set_title("")

    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # --- Save colorbar separately ---
    fig_cb, ax_cb = plt.subplots(figsize=(10, 1))
    cb = plt.colorbar(sm, cax=ax_cb, orientation='horizontal')
    cb.set_ticks([0, 4])
    cb.set_ticklabels(["Pre-treatment", "Post-treatment"])
    cb.ax.tick_params(labelsize=24)
    cb.set_label("Time", fontsize=24)
    cb_path = os.path.join(result_dir, "trajectory_colorbar_only.png")
    fig_cb.savefig(cb_path, dpi=300, bbox_inches='tight')
    plt.close(fig_cb)
    print(f"Standalone colorbar saved to {cb_path}")

    # --- Plot 2: WITHOUT snapshots ---
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8, zorder=15, label='Pre-treatment')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8, zorder=10, label='Post-treatment')

    ax2.set_xlabel("PC 1", fontsize=32)
    ax2.set_ylabel("PC 2", fontsize=32)
    ax2.tick_params(axis='both', labelsize=32)
    ax2.set_title("")

    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()

    # --- Legend (just for pre/post-treatment) ---
    handles, labels = ax2.get_legend_handles_labels()
    if handles:
        fig_legend, ax_legend = plt.subplots(figsize=(6, 2))
        ax_legend.axis("off")
        ax_legend.legend(
            handles, labels, fontsize=16, loc='center',
            ncol=len(labels), markerscale=3, handletextpad=0.5
        )
        legend_path = os.path.join(result_dir, "legend_only.png")
        fig_legend.savefig(legend_path, bbox_inches="tight", dpi=300)
        plt.close(fig_legend)
        print(f"Legend saved separately at: {legend_path}")



In [None]:
## Two pieces GPA (Stem cell data)

exp_name = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")



In [None]:
## Two pieces GPA (Synthetic data)

exp_name = 'f_Lip=5e-2-t_size=50-network=64_64_64_26d' #'times_10_particles_200_3'
d_red = 26
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")



In [None]:
## One piece of GPA (EMT data)

exp_name = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 8
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 2, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 2, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")


In [None]:
## One piece of GPA (NDPR data)

exp_name = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (PA3)

exp_name = 'Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (Patient 862)

exp_name = 'Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (Patient 887)

exp_name = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:

## Full trajectories on static plot (samples 1 - EMT data)

exp_name = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 8
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints(
    days=[0, 4],
    intermediate_days=[2],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:

## Full trajectories on static plot (Sample 5 - synthetic data)

exp_name = 'f_Lip=5e-2-t_size=50-network=64_64_64_26d' #'times_10_particles_200_3'
d_red = 26
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_three_timepoints(
    days=[0, 2, 4],
    intermediate_days=[1, 3],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (samples 3 - stem cell data)

exp_name = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_three_timepoints(
    days=[0, 2, 4],
    intermediate_days=[1, 3],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (NDPR data)

exp_name = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (PA3)

exp_name = 'Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (Patient 862)

exp_name = 'Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (Patient 887)

exp_name = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_cell_types(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_cell_type_source=None, output_file_cell_type_target=None, output_file_cell_type_legend=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)


    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        #intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**

    # Same PCA transformation and cell type extraction as before
    X1_vis = pca.transform(mats[source_t])
    X2_vis = pca.transform(mats[target_t])
    cell_types_X1 = cell_types_by_day[source_t]
    cell_types_X2 = cell_types_by_day[target_t]
    unique_cell_types = np.unique(np.concatenate([cell_types_X1, cell_types_X2]))
    cell_type_palette = dict(zip(unique_cell_types, sns.color_palette("tab20", len(unique_cell_types))))
    
    # -----------------------
    # Plot 1: X1 colored, X2 gray (no legend)
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], color='lightgray', alpha=0.5, s=8)
    for cell_type in unique_cell_types:
        idx = cell_types_X1 == cell_type
        ax1.scatter(X1_vis[idx, 0], X1_vis[idx, 1], 
                    color=cell_type_palette[cell_type], s=8, alpha=1.0)
    ax1.set_xlabel("PC 1", fontsize=20)
    ax1.set_ylabel("PC 2", fontsize=20)
    ax1.tick_params(axis='both', which='major', labelsize=18)
    ax1.set_title(f"Untreated Samples colored by Cell Type", fontsize=18)
    plt.tight_layout()
    if output_file_cell_type_source:
        plt.savefig(output_file_cell_type_source, dpi=300, bbox_inches='tight')
        plt.close(fig1)
    else:
        plt.show()
    
    # -----------------------
    # Plot 2: X2 colored, X1 gray (no legend)
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color='lightgray', alpha=0.5, s=8)
    for cell_type in unique_cell_types:
        idx = cell_types_X2 == cell_type
        ax2.scatter(X2_vis[idx, 0], X2_vis[idx, 1], 
                    color=cell_type_palette[cell_type], s=8, alpha=1.0)
    ax2.set_xlabel("PC 1", fontsize=20)
    ax2.set_ylabel("PC 2", fontsize=20)
    ax2.tick_params(axis='both', which='major', labelsize=18)
    ax2.set_title(f"Treated Samples colored by Cell Type", fontsize=18)
    plt.tight_layout()
    if output_file_cell_type_target:
        plt.savefig(output_file_cell_type_target, dpi=300, bbox_inches='tight')
        plt.close(fig2)
    else:
        plt.show()

        

    # Use circle markers instead of patches for legend
    legend_elements = [
        mlines.Line2D(
            [], [], marker='o', color='w',
            markerfacecolor=cell_type_palette[cell_type],
            markersize=8, label=cell_type
        )
        for cell_type in unique_cell_types
    ]
    
    # Create circle markers for legend entries
    legend_elements = [
        mlines.Line2D(
            [], [], marker='o', color='w',
            markerfacecolor=cell_type_palette[cell_type],
            markersize=8, label=cell_type
        )
        for cell_type in unique_cell_types
    ]
    
    # Create figure and axis (just for the legend)
    fig_leg, ax_leg = plt.subplots()
    fig_leg.set_figwidth(8)  # Initial size; will be adjusted
    fig_leg.set_figheight(6)
    
    # Hide axes
    ax_leg.axis('off')
    
    # Add legend to axis (not directly to plt)
    legend = ax_leg.legend(
        handles=legend_elements,
        loc='center',
        frameon=True,
        fontsize=14,
        ncol=1,
        title='Cell Types',
        title_fontsize=14,
        borderpad=1
    )
    
    # Resize the figure to tightly fit the legend
    fig_leg.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig_leg.dpi_scale_trans.inverted())
    fig_leg.set_size_inches(bbox.width + 0.5, bbox.height + 0.5)  # Add a little padding
    
    # Save only, no display
    if output_file_cell_type_legend:
        plt.savefig(output_file_cell_type_legend, dpi=300, bbox_inches='tight')
        plt.close(fig_leg)
    


    






In [None]:
## Cell types plots

exp_name = 'Palbo_BMC_nofibroblast_malignant_dim20-f_Lip=5e-3-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 20
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_cell_type_source = f"{result_dir}{exp_name}_cell_type_source.png"
output_file_cell_type_target = f"{result_dir}{exp_name}_cell_type_target.png"
output_file_cell_type_legend = f"{result_dir}{exp_name}_cell_type_legend.png"


# Generate both plots
generate_static_trajectory_plots_cell_types(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_cell_type_source=output_file_cell_type_source,
    output_file_cell_type_target=output_file_cell_type_target,
    output_file_cell_type_legend=output_file_cell_type_legend
)

## Downstream Analysis

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA

from sklearn.preprocessing import MinMaxScaler
import seaborn as sns
from scipy import stats
from matplotlib import animation

In [None]:
## For clinical data to identify cells with most change
import numpy as np
import matplotlib.pyplot as plt
import pickle as pk
import os
from scipy.stats import entropy

def plot_X1_hat_displacement_distribution(d_red=2, random_state=42, exp_memo='2'):
    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Compute trajectory
    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Extract first and last time point
    X1_hat_first = X1_trpts[0].astype(np.float32)
    X1_hat_last = X1_trpts[-1].astype(np.float32)

    # Compute Euclidean distances for each cell
    displacements = np.linalg.norm(X1_hat_last - X1_hat_first, axis=1)

    # Compute summary statistics
    mean_disp = np.mean(displacements)
    std_disp = np.std(displacements)
    iqr_disp = np.percentile(displacements, 75) - np.percentile(displacements, 25)
    cv_disp = std_disp / mean_disp if mean_disp > 0 else np.nan

    hist_counts, _ = np.histogram(displacements, bins=40)
    hist_probs = hist_counts / hist_counts.sum()
    entropy_disp = entropy(hist_probs)

    # Save stats to CSV
    stats_df = pd.DataFrame([{
        "exp_memo": exp_memo,
        "mean_displacement": mean_disp,
        "std_displacement": std_disp,
        "iqr_displacement": iqr_disp,
        "cv_displacement": cv_disp,
        "entropy": entropy_disp
    }])

    stats_output_path = f"{output_dir}{exp_memo}_X1_hat_displacement_stats.csv"
    stats_df.to_csv(stats_output_path, index=False)

    # Plot the distribution
    plt.figure(figsize=(10, 6))
    plt.hist(displacements, bins=40, color="skyblue", edgecolor="black")
    plt.xlabel("Displacement (Euclidean Distance)", fontsize=16)
    plt.ylabel("Number of Cells", fontsize=16)
    plt.title("Distribution of Cell Displacements\nX1_hat First vs Last Time Point", fontsize=18)
    plt.grid(alpha=0.3)
    plt.tight_layout()

    # Save the plot
    output_path = f"{output_dir}{exp_memo}_X1_hat_displacement_distribution.png"
    plt.savefig(output_path, dpi=300)
    plt.close()

    print(f"📊 Displacement plot saved at: {output_path}")
    print(f"📄 Stats CSV saved at: {stats_output_path}")

    # Save histogram data
    hist_counts, bin_edges = np.histogram(displacements, bins=40)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    
    hist_df = pd.DataFrame({
        "bin_center": bin_centers,
        "count": hist_counts
    })
    
    hist_output_path = f"{output_dir}{exp_memo}_X1_hat_displacement_histogram.csv"
    hist_df.to_csv(hist_output_path, index=False)



In [None]:
## finding the distribution plot of the cell transitions distances

source_t, target_t = 0, 4
exp_memo = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
X1_hat_labels = plot_X1_hat_displacement_distribution(d_red=2, random_state=40,
    exp_memo = exp_memo)

In [None]:
## Comparison of distributions across datasets (Histogram)

import pandas as pd
import matplotlib.pyplot as plt
import os

# Define path and file names
output_dir = result_dir + 'output/'


# Histogram filenames
histogram_files = {
    "In vitro": "Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "862": "Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "887": "Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "PA3": "Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",

}

# Define colors (all cyan as placeholder)
colors = {
    "In vitro": "gray",
    "PA3": "gray",
    "862": "gray",
    "887": "gray"
}

# Vertical lines for each
vlines_dict = {
    "862": (1.3, 2.0),
    "887": (0.8, 1.2),
    "PA3": (1.2, 2.4),
    "In vitro": (3.2, 3.8)
}

# Load histogram data
hist_data = {}
for label, file in histogram_files.items():
    fpath = os.path.join(output_dir, file)
    if os.path.exists(fpath):
        df = pd.read_csv(fpath)
        hist_data[label] = df
    else:
        print(f"[Warning] File not found: {fpath}")

# Plot histograms with vertical stacking
n = len(hist_data)
fig, axes = plt.subplots(nrows=n, figsize=(8, 4 * n), sharex=True, constrained_layout=True)

if n == 1:
    axes = [axes]

for ax, (label, df) in zip(axes, hist_data.items()):
    color = colors.get(label, "gray")
    bin_width = df["bin_center"][1] - df["bin_center"][0]

    ax.bar(df["bin_center"], df["count"], width=bin_width,
           color=color, edgecolor="black", alpha=0.8)
    ax.set_ylabel("Count", fontsize=20)
    #ax.set_title(label, fontsize=20)
    ax.grid(alpha=0.3)
    ax.tick_params(axis="y", labelsize=20)
    ax.set_xlim(left=0)

    # Add vertical lines
    #v1, v2 = vlines_dict.get(label, (None, None))
    #if v1 is not None and v2 is not None:
    #    ax.axvline(v1, color="black", linestyle="--", linewidth=1.5)
    #    ax.axvline(v2, color="black", linestyle="--", linewidth=1.5)

axes[-1].set_xlabel("Displacement (Euclidean Distance)", fontsize=20)
#axes[-1].tick_params(axis="x", labelsize=18)

# Ensure each subplot shows x-ticks even with shared x-axis
for ax in axes:
    ax.tick_params(axis="x", labelsize=20, which='both', labelbottom=True)

# Save figure
output_path = os.path.join(output_dir, "side_by_side_displacement_histograms_without_vlines.pdf")
plt.savefig(output_path, dpi=300)
plt.show()

print(f"[✓] Plot saved to: {output_path}")


In [None]:
## Comparison of distributions across datasets (KDE)


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks
import os


#result_dir = '%s/assets/Transport_genes/' % main_dir
#output_dir = result_dir + 'output/'

# === Files ===
histogram_files = {
    "In vitro": "Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "PA3": "Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "862": "Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
    "887": "Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64_X1_hat_displacement_histogram.csv",
}

colors = {key: "gray" for key in histogram_files.keys()}
hist_data_kde = {}

# === Load and smooth ===
y_max_global = 0
for label, file in histogram_files.items():
    fpath = os.path.join(output_dir, file)
    if os.path.exists(fpath):
        df = pd.read_csv(fpath)
        data = np.repeat(df["bin_center"].values, df["count"].astype(int))
        kde = gaussian_kde(data, bw_method='scott')
        x_eval = np.linspace(min(data), max(data), 1000)
        density = kde(x_eval)
        hist_data_kde[label] = (x_eval, density)
        y_max_global = max(y_max_global, max(density))
    else:
        print(f"[Warning] File not found: {fpath}")

# === Plot ===
n = len(hist_data_kde)
fig, axes = plt.subplots(nrows=n, figsize=(8, 4 * n), sharex=True, constrained_layout=True)
if n == 1:
    axes = [axes]

for ax, (label, (x_eval, density)) in zip(axes, hist_data_kde.items()):
    color = colors.get(label, "gray")
    ax.plot(x_eval, density, color=color, lw=2)
    ax.fill_between(x_eval, density, alpha=0.3, color=color)

    # === Find valleys (local minima) ===
    valleys, _ = find_peaks(-density)
    x_vals = x_eval[valleys]
    y_vals = density[valleys]

    # Keep the 2 smallest minima in x-value
    sorted_idx = np.argsort(x_vals)
    top_two_idx = sorted_idx[:2]
    x_vals_top2 = x_vals[top_two_idx]
    y_vals_top2 = y_vals[top_two_idx]

    # === Plot only local minima ===
    ax.plot(x_vals_top2, y_vals_top2, "go", label="Local Minima", color = 'blue', markersize=10)
    for xv in x_vals_top2:
        ax.axvline(x=xv, color="black", linestyle="--", linewidth=1.5)

    # Print x-axis values of local minima
    print(f"{label} local minima x-values: {np.round(x_vals_top2, 3)}")

    ax.set_ylabel("Density", fontsize=23)
    ax.grid(alpha=0.3)
    ax.tick_params(axis="y", labelsize=23)
    ax.set_xlim(left=0)
    ax.set_ylim(top=y_max_global * 1.05)  # Ensure consistent y-limits across plots
    ax.set_title("", fontsize=23)

axes[-1].set_xlabel("Displacement (Euclidean Distance)", fontsize=24)
for ax in axes:
    ax.tick_params(axis="x", labelsize=22, which='both', labelbottom=True)

# === Save KDE figure ===
output_path = os.path.join(output_dir, "KDE_displacement_local_minima.pdf")
plt.savefig(output_path, dpi=300)
plt.show()
print(f"[✓] KDE plot saved to: {output_path}")

# === Save legend separately ===
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

legend_elements = [
    Line2D([0], [0], marker='o', color='blue', linestyle='None', label='Local Minimum', markersize=5),
    #Line2D([0], [0], color='black', linestyle='--', label='Threshold Cutoff')
]

fig_legend, ax_legend = plt.subplots(figsize=(4.5, 2))
ax_legend.axis("off")
legend = ax_legend.legend(handles=legend_elements, loc="center", fontsize=14)
legend_path = os.path.join(output_dir, "KDE_legend_only.pdf")
fig_legend.savefig(legend_path, bbox_inches='tight')
print(f"[✓] Legend saved to: {legend_path}")

In [None]:


import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as mpatches
from matplotlib import colormaps


def generate_static_cluster_plot_deviation_colormap(
    source_t, target_t, optimal_k, start_i, index, reverse=False, intermediate_t=[1, 2, 3],
    d_red=2, random_state=42, exp_memo='experiment'
):
    """
    Generate a static plot of all snapshots from X1_trpts, colored by sub-trajectories with gradient coloring.
    Also generates a separate legend figure and individual plots per subgroup.
    """

    # Load model parameters
    filename = f"{result_dir}{exp_memo}.pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Compute trajectory integration
    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[source_t]), T=p['numerical_ts'][-1], dt=dt)

    # Step 1: Reduce real data and predictions to PCA space
    last_day = mats[target_t]
    last_day_reduced = pca.transform(last_day).astype(np.float32)
    X1_hat_last = X1_trpts[-1].astype(np.float32)
    X1_hat_first = X1_trpts[0].astype(np.float32)
    displacements = np.linalg.norm(X1_hat_last - X1_hat_first, axis=1)

    # 862, 3 and 5 (ER)
    # 887 0.8 and 1.3 (ER)
    # BMC 1.0 and 3.0 (ER)
    # Rinath 2.89 and 2.90 (ER)

    # 862, 1.3 and 2 (R)
    # 887 0.8 and 1.2 (R)
    # BMC 1.2 and 2.4 (R)
    # Rinath 3.2 and 3.8 (R)

    ## R genes
    #862 local minima x-values: [1.288 2.365]
    #887 local minima x-values: [0.699 1.149]
    #BMC local minima x-values: [1.145 1.937]
    #In vitro local minima x-values: [3.017 3.853]


    # Assign labels based on displacement
    X1_hat_labels = np.full(displacements.shape, 'low', dtype=object)
    X1_hat_labels[(displacements > 0.699) & (displacements <= 1.149)] = 'medium'
    X1_hat_labels[displacements > 1.149] = 'high'

    # Define colormaps per subgroup (avoid lightest tones by clipping range)
    label_to_cmap = {
        'low': colormaps['Oranges'],
        'medium': colormaps['Purples'],
        'high': colormaps['Greens']
    }
    cmap_clip = slice(75, 256)
    color_range = np.linspace(0, 1, 256)[cmap_clip]  # consistent use

    # Set file path for main plot
    output_file = f"{result_dir}{exp_memo}_static_celltypes_plot_deviation_colormap.png"
    fig, ax = plt.subplots(figsize=(8, 6))
    X2_vis = pca.transform(mats[target_t])
    X1_vis = pca.transform(mats[source_t])

    total_steps = len([i for i in range(len(X1_trpts)) if i % index == 0 and i >= start_i])

    for label in np.unique(X1_hat_labels):
        cmap = label_to_cmap[label]
        idx = (X1_hat_labels == label)

        for step_idx, i in enumerate(range(start_i, len(X1_trpts), index)):
            if np.isnan(X1_trpts[i]).any():
                continue
            X1_hat_vis = X1_trpts[i]
            norm_val = step_idx / max(total_steps - 1, 1)
            color_idx = int(norm_val * (len(color_range) - 1))
            color = cmap(color_range[color_idx])
            ax.scatter(X1_hat_vis[idx, 0], X1_hat_vis[idx, 1], color=color, alpha=0.9, s=3, zorder=1)

            if i > start_i:
                prev_X1_hat_vis = X1_trpts[i - index]
                prev_idx = idx
                ax.plot([
                    prev_X1_hat_vis[prev_idx, 0], X1_hat_vis[idx, 0]
                ], [
                    prev_X1_hat_vis[prev_idx, 1], X1_hat_vis[idx, 1]
                ], color=color, alpha=0.6, linewidth=1.2, zorder=0)

    for t in intermediate_t:
        X_intermediate_vis = pca.transform(mats[t])
        ax.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1],
                   color='lightgray', alpha=0.7, s=10, zorder=10)

    ax.set_xlabel("PC 1", fontsize=32)
    ax.set_ylabel("PC 2", fontsize=32)
    ax.tick_params(axis='both', which='major', labelsize=32)
    ax.set_title("")

    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Static cluster plot saved to {output_file}")

    # --- Save colormap legend with thinner bars side by side ---
    fig_legend, axs = plt.subplots(1, len(label_to_cmap), figsize=(20, 1.5))
    if len(label_to_cmap) == 1:
        axs = [axs]

    for ax, (label, base_cmap) in zip(axs, label_to_cmap.items()):
        # Create a new clipped colormap
        clipped_cmap = base_cmap(np.linspace(0, 1, 256)[cmap_clip])
        custom_cmap = plt.matplotlib.colors.ListedColormap(clipped_cmap)
    
        # Use that in the legend
        gradient = np.linspace(0, 1, cmap_clip.stop - cmap_clip.start).reshape(1, -1)
        ax.imshow(gradient, aspect='auto', cmap=custom_cmap, extent=[0, 1, 0, 0.03])
        ax.set_title(f"{label.capitalize()} phenotypic shift \nPre-treatment → Post-treatment", fontsize=20)
        ax.axis('off')


    legend_path = output_file.replace('.png', '_legend.png')
    plt.tight_layout()
    plt.savefig(legend_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Legend saved to {legend_path}")

    # --- Additional per-label plots ---
    for current_label in np.unique(X1_hat_labels):
        fig_lbl, ax_lbl = plt.subplots(figsize=(8, 6))
        cmap = label_to_cmap[current_label]
        idx = (X1_hat_labels == current_label)

        for step_idx, i in enumerate(range(start_i, len(X1_trpts), index)):
            if np.isnan(X1_trpts[i]).any():
                continue
            X1_hat_vis = X1_trpts[i]
            norm_val = step_idx / max(total_steps - 1, 1)
            color_idx = int(norm_val * (len(color_range) - 1))
            color = cmap(color_range[color_idx])
            ax_lbl.scatter(X1_hat_vis[idx, 0], X1_hat_vis[idx, 1],
                           color=color, alpha=0.8, s=3, zorder=2)

            if i > start_i:
                prev_X1_hat_vis = X1_trpts[i - index]
                prev_idx = idx
                ax_lbl.plot([
                    prev_X1_hat_vis[prev_idx, 0], X1_hat_vis[idx, 0]
                ], [
                    prev_X1_hat_vis[prev_idx, 1], X1_hat_vis[idx, 1]
                ], color=color, alpha=0.6, linewidth=1.2)

        ax_lbl.scatter(X1_vis[:, 0], X1_vis[:, 1], color='lightgray', alpha=0.7, s=10, zorder=1)
        ax_lbl.scatter(X2_vis[:, 0], X2_vis[:, 1], color='lightgray', alpha=0.7, s=10, zorder=1)

        ax_lbl.set_xlabel("PC 1", fontsize=32)
        ax_lbl.set_ylabel("PC 2", fontsize=32)
        ax_lbl.tick_params(axis='both', labelsize=32)
        ax_lbl.set_title("", fontsize=32)
        plt.tight_layout()
        plt.savefig(f"{output_file.replace('.png', f'_label_{current_label}.png')}", dpi=300, bbox_inches='tight')
        plt.close(fig_lbl)

    return X1_hat_labels



In [None]:
## Clinical data (deviated cell vs non-deviated cells)

source_t, target_t = 0, 4
optimal_k = 2
start_i = 0
index = 1
exp_memo = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
X1_hat_labels = generate_static_cluster_plot_deviation_colormap(
    source_t, target_t, optimal_k, start_i, index, reverse = False, intermediate_t = [], d_red=2, random_state=42,
    exp_memo = exp_memo)

# Convert the labels into a DataFrame for saving
df_clusters = pd.DataFrame({
    "Cell_Index": np.arange(len(X1_hat_labels)),  # Assuming each cell has an index
    "Cluster_Label": X1_hat_labels
})

# Define the filename
cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_deviation.csv"

# Save as CSV
df_clusters.to_csv(cluster_save_path, index=False)
print(f"Cluster labels saved to {cluster_save_path}")

# Return the saved filename for future use
cluster_save_path


In [None]:
## Get all cell labels for each time point

df_cls_indexed = df_cls.set_index("id")

cell_ids_by_day = {}
for c in cls:
    idx = df_cls['day'] == c
    cell_ids_by_day[c] = df_cls_indexed.index[idx].tolist()  # extract index values for each day


In [None]:
# New code without celltype annoation

## Save the labels at a specific time point for each class of subtrajectory (this is an example for time 0)

# Step 1: Load saved file
df_loaded = pd.read_csv(cluster_save_path)

# Step 2: Verify lengths match
if len(df_loaded) != len(cell_ids_by_day[0]):
    raise ValueError("❌ Length mismatch: Cannot assign new cell IDs.")

# Step 3: Replace index or create new column
df_loaded["Cell_ID"] = cell_ids_by_day[0]  # Add as a column
df_loaded = df_loaded[["Cell_ID"] + [col for col in df_loaded.columns if col != "Cell_ID"]]  # Optional: move to front
df_loaded = df_loaded.drop(columns=["Cell_Index"])

# Step 4: Save to new file
new_path = cluster_save_path.replace(".csv", "_with_cell_ids.csv")
df_loaded.to_csv(new_path, index=False)
print(f"✅ Updated file saved to: {new_path}")

In [None]:
## Old code
# Step 1: Load saved file
df_loaded = pd.read_csv(cluster_save_path)

# Step 2: Verify lengths match
if len(df_loaded) != len(cell_ids_by_day[0]):
    raise ValueError("❌ Length mismatch: Cannot assign new cell IDs.")

# Step 3: Replace index or create new column
df_loaded["Cell_ID"] = cell_ids_by_day[0]  # Add as a column
df_loaded = df_loaded[["Cell_ID"] + [col for col in df_loaded.columns if col != "Cell_ID"]]  # Optional: move to front
df_loaded = df_loaded.drop(columns=["Cell_Index"])

# Step 4: Save to new file
new_path = cluster_save_path.replace(".csv", "_with_cell_ids.csv")
df_loaded.to_csv(new_path, index=False)
print(f"✅ Updated file saved to: {new_path}")


## Downstream analysis for single gene dynamcs

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA

from sklearn.preprocessing import MinMaxScaler
import seaborn as sns
from scipy import stats
from matplotlib import animation

In [None]:
def load_W(filename):
    with open(filename, "rb") as fr:
        W, b, p = pk.load(fr)

    W = [tf.Variable(w,  dtype=tf.float32) for w in W]
    b = [tf.Variable(b_,  dtype=tf.float32) for b_ in b]

    return W, b, p

def v(x, t, W, b):   # neural newtork for time-dependent vectorfield
    num_layers = len(W)
    activation_ftn = tf.nn.tanh
        
    h = tf.concat([x, t*tf.ones([x.shape[0], 1], dtype=tf.float32)], axis=1)
    for l in range(0,num_layers-1):
        h = activation_ftn(tf.add(tf.matmul(h, W[l]), b[l]))
    out=tf.add(tf.matmul(h, W[-1]), b[-1])

    return out

def time_integration(x0, T, dt):
    x = tf.constant(x0, dtype=tf.float32)
    xs = [x0]
    for i in range(int(T/dt)):
        vv = v(x, dt*i, W, b)
        x += dt * vv
        xs.append(x.numpy())
    return xs


def gene_dynamics_whole_saveonly(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day2]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering with the optimal number of clusters
    kmeans = KMeans(n_clusters=optimal_k, random_state=40)
    kmeans.fit(last_day_reduced)
    last_day_labels = kmeans.labels_
    
    #X1_hat_last_reduced = pca.transform(X1_hat_last)

    X1_hat_last = X1_trpts[-1].astype(np.float32) 
    X1_hat_labels = kmeans.predict(X1_hat_last)
    


    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(last_day_labels)
    print(f"Number of unique labels in last_day_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(last_day_labels, red_colors)

    #mask = last_day_labels == 0
    
    # Visualization in the original space 
    img_src = f"{output_dir}{exp_memo}-movie-{d_red}D-{gene_of_interest}-dynamics-trajectory-source-only-nonormalized-mix-40000-034.gif"
    
    fig, ax = plt.subplots()
    ims = []
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    for i in indices:
        if i % index == 0 and i <= max_i:
            X1_trpt = X1_trpts[i]
            if np.isnan(X1_trpt).any():
                break
            X1_hat = pca.inverse_transform(X1_trpt)
            X1_hat_vis = reducer.transform(X1_hat)
            #mask_2 = X1_hat_labels == 0
            #X1_hat_vis = X1_hat_vis[mask_2]

            # Plot all points in X1_hat_vis with colormap based on precomputed normalized gene expression values
            gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
            all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update the list to exclude the used values
            im = ax.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], c=gene_expression_values, cmap='viridis', alpha=1.0, s=0.5, zorder=10, vmin=vmin, vmax=vmax)
            #ax.scatter(X1_intermediate_vis[:, 0], X1_intermediate_vis[:, 1], color=colors[t], alpha=0.5, s=0.5, zorder=5)
            ax.scatter(vis_all_days[:, 0], vis_all_days[:, 1], color='lightgray', alpha=0.3, s=0.5, zorder=1)
    
            ttl = ax.text(0.5, 1.05, "t = %.3f" % (physical_dt*i), bbox={'facecolor': 'w', 'alpha': 0.5, 'pad': 5}, transform=ax.transAxes, ha="center")
            ims.append([im, ttl])

    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=200)
    writergif = animation.PillowWriter(fps=3)
    ani.save(img_src, writer=writergif)
    # Close the figure to free up memory
    plt.close(fig)

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




    output_file = f"{output_dir}/_average_gene_expression_{gene_of_interest}.png"

    
    # Plot averaged gene expressions with confidence intervals
    plt.figure(figsize=(10, 6))
    
    # Plot the averaged gene expressions as a line
    plt.plot(
        extended_indices,
        avg_gene_expressions,
        label='Average Gene Expression',
        color='green',
        linestyle='-',  # Use a line instead of dots
    )
    
    # Fill the confidence intervals
    plt.fill_between(
        extended_indices,
        np.array(avg_gene_expressions) - np.array(ci_gene_expressions),
        np.array(avg_gene_expressions) + np.array(ci_gene_expressions),
        alpha=0.2,
        color='lightgreen',
        label='95% CI'
    )
    
    # Ensure rescaled_indices and all_avg_expressions have the same length
    assert len(rescaled_indices) == len(all_avg_expressions), (
        f"Length mismatch: rescaled_indices ({len(rescaled_indices)}) != all_avg_expressions ({len(all_avg_expressions)})"
    )
    
    # Plot the intermediate and boundary time points
    plt.errorbar(
        rescaled_indices,  # Use rescaled indices for the x-axis
        all_avg_expressions,
        yerr=all_ci_expressions,
        fmt='o',
        color='blue',
        label='Discrete Points'
    )
    
    # Update the x-axis ticks and labels
    plt.xticks(
        ticks=rescaled_indices,  # Tick positions based on rescaled indices
        labels=combined_indices,  # Relabel using combined_indices
        rotation=0,  # Optional: Rotate labels for better visibility
        fontsize=10    # Adjust font size for readability
    )
    
    plt.xlabel('Time Point (Day)')
    plt.ylabel('Gene Expression')
    plt.title(f'Average {gene_of_interest} Expression Over Time')
    plt.legend()
    
    # Save the plot to a file
    plt.savefig(output_file, dpi=300, bbox_inches='tight')  # Save with high resolution
    
    # Close the figure to free up memory
    plt.close()



    # (2) Plot averaged gene expression and confidence intervals for subgroups at each time point based on X1_hat_labels
    # Perform KMeans clustering with the optimal number of clusters
    X1_hat_last = X1_trpts[-1].astype(np.float32)
    X1_hat_labels = kmeans.predict(X1_hat_last)
    
    # Initialize dictionaries to store subgroup averages and confidence intervals
    subgroup_avg_gene_expressions = {label: [] for label in np.unique(X1_hat_labels)}
    subgroup_ci_gene_expressions = {label: [] for label in np.unique(X1_hat_labels)}
    
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Compute averages and confidence intervals for subgroups
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]
    
        # Compute subgroup-specific averages and confidence intervals
        for label in np.unique(X1_hat_labels):
            subgroup_values = gene_expression_values[X1_hat_labels == label]
            subgroup_avg_gene_expressions[label].append(np.mean(subgroup_values))
            ci = stats.sem(subgroup_values) * stats.t.ppf((1 + 0.95) / 2., len(subgroup_values) - 1)
            subgroup_ci_gene_expressions[label].append(ci)

    
    # Define the filename for saving the subgroup plot
    subgroup_output_file = f"{output_dir}/_average_gene_expression_{gene_of_interest}_subgroups.png"
    
    # Plot subgroup averages and confidence intervals
    plt.figure(figsize=(10, 6))
    for label in np.unique(X1_hat_labels):
        plt.plot(
            extended_indices,
            subgroup_avg_gene_expressions[label],
            label=f'Subgroup {label} Average',
            linestyle='-'
        )
        plt.fill_between(
            extended_indices,
            np.array(subgroup_avg_gene_expressions[label]) - np.array(subgroup_ci_gene_expressions[label]),
            np.array(subgroup_avg_gene_expressions[label]) + np.array(subgroup_ci_gene_expressions[label]),
            alpha=0.2,
            label=f'Subgroup {label} 95% CI'
        )
    
    
    # Update the x-axis ticks and labels
    plt.xticks(
        ticks=rescaled_indices,  # Tick positions based on rescaled indices
        labels=combined_indices,  # Relabel using combined_indices
        rotation=0,  # Optional: Rotate labels for better visibility
        fontsize=10  # Adjust font size for readability
    )
    
    plt.xlabel('Time Point (Day)')
    plt.ylabel('Gene Expression')
    plt.title(f'Average {gene_of_interest} Expression Over Time by Subgroup')
    plt.legend()
    
    # Save the subgroup plot
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    
    # Close the figure to free up memory
    plt.close()
    
    print(f"Subgroup plot saved at: {subgroup_output_file}")

    


## Downstream Analysis of Average Gene Dynamics for each Gene with confidence intervals

In [None]:

def load_W(filename):
    with open(filename, "rb") as fr:
        W, b, p = pk.load(fr)

    W = [tf.Variable(w,  dtype=tf.float32) for w in W]
    b = [tf.Variable(b_,  dtype=tf.float32) for b_ in b]

    return W, b, p

def v(x, t, W, b):   # neural newtork for time-dependent vectorfield
    num_layers = len(W)
    activation_ftn = tf.nn.tanh
        
    h = tf.concat([x, t*tf.ones([x.shape[0], 1], dtype=tf.float32)], axis=1)
    for l in range(0,num_layers-1):
        h = activation_ftn(tf.add(tf.matmul(h, W[l]), b[l]))
    out=tf.add(tf.matmul(h, W[-1]), b[-1])

    return out

def time_integration(x0, T, dt):
    x = tf.constant(x0, dtype=tf.float32)
    xs = [x0]
    for i in range(int(T/dt)):
        vv = v(x, dt*i, W, b)
        x += dt * vv
        xs.append(x.numpy())
    return xs


def Average_gene_dynamics_whole_saveonly(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day1]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    

    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    

    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




    output_file = f"{output_dir}/_average_gene_expression_{gene_of_interest}.png"

    
    # Plot averaged gene expressions with confidence intervals
    plt.figure(figsize=(10, 6))
    
    # Plot the averaged gene expressions as a line
    plt.plot(
        extended_indices,
        avg_gene_expressions,
        label='Average Gene Expression',
        color='orange',
        linestyle='-',  # Use a line instead of dots
    )
    
    # Fill the confidence intervals
    plt.fill_between(
        extended_indices,
        np.array(avg_gene_expressions) - np.array(ci_gene_expressions),
        np.array(avg_gene_expressions) + np.array(ci_gene_expressions),
        alpha=0.2,
        color='lightsalmon',
        label='95% CI'
    )
    
    # Ensure rescaled_indices and all_avg_expressions have the same length
    assert len(rescaled_indices) == len(all_avg_expressions), (
        f"Length mismatch: rescaled_indices ({len(rescaled_indices)}) != all_avg_expressions ({len(all_avg_expressions)})"
    )
    
    # Plot the intermediate and boundary time points
    plt.errorbar(
        rescaled_indices,  # Use rescaled indices for the x-axis
        all_avg_expressions,
        yerr=all_ci_expressions,
        fmt='o',
        color='blue',
        label='Discrete Points'
    )
    
    # Update the x-axis ticks and labels
    plt.xticks(
        ticks=rescaled_indices,  # Tick positions based on rescaled indices
        labels=combined_indices,  # Relabel using combined_indices
        rotation=0,  # Optional: Rotate labels for better visibility
        fontsize=32    # Adjust font size for readability
    )

    plt.yticks(fontsize=32)  # Increase y-axis tick font size

    
    plt.xlabel('Time', fontsize=32)
    plt.ylabel('Gene Expression', fontsize=32)
    plt.title(f'Average {gene_of_interest} Dynamics', fontsize=32)
    #plt.legend(fontsize=16)  # Adjust font size as needed
    
    # Save the plot to a file
    plt.savefig(output_file, dpi=300, bbox_inches='tight')  # Save with high resolution
    
    # Close the figure to free up memory
    plt.close()




    # **Save Separate Legend**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))
    ax_legend.axis("off")

    legend_elements = [
        mlines.Line2D([], [], color='orange', linestyle='-', linewidth=3, label='Average Gene Dynamics'),
        mpatches.Patch(color='lightsalmon', alpha=0.7, label='95% Confidence Interval'),
        mlines.Line2D([], [], color='blue', marker='o', linestyle='-', markersize=8, linewidth=2, label='Data Points')
    ]

    ax_legend.legend(handles=legend_elements, loc="center", fontsize=32, title="", title_fontsize=32, 
                     frameon=True, ncol=len(legend_elements), handletextpad=2, columnspacing=2)

    legend_output_file = output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()

    print(f"Subgroup trajectory plot saved at: {output_file}")
    print(f"Legend plot saved separately at: {legend_output_file}")

In [None]:
## Average gene dynamics for each single gene plot (Stem Cell data)

import matplotlib.patches as mpatches

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,2,3]
#intermediate_t = [2]

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

output_dir = os.path.join(result_dir, 'output', exp_memo)
    
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")

In [None]:
## save the gene expression dynamics png as pdf (Stem Cell data)


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/_average_gene_expression_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/_average_gene_expression.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))


## Downstream Analysis for Average gene dynamics on each subtrajectory

In [None]:
## ## This is for Stem cell data (Five time points: Time [0, 1, 2, 3, 4])

## Subtrajectroies with_violin_plot

def Average_gene_dynamics_whole_saveonly_with_violin_plot_sample_3(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day2]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering with the optimal number of clusters
    kmeans = KMeans(n_clusters=optimal_k, random_state=40)
    kmeans.fit(last_day_reduced)
    last_day_labels = kmeans.labels_
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)

    #mask = last_day_labels == 0
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )




    
    # (1) Perform clustering on the last day's cell states from `mats`
    last_day = mats[day2]
    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering
    kmeans = KMeans(n_clusters=optimal_k, random_state=40)
    kmeans.fit(last_day_reduced)
    last_day_labels = kmeans.labels_
    

    
    # Define colors for subgroups
    subgroup_colors = ['red', 'blue', '#ffe119', '#f58231', '#3cb44b']
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subgroup_colors[i % len(subgroup_colors)] for i, label in enumerate(unique_labels)}
    
    # Define filename
    subgroup_output_file = f"{output_dir}/subtrajectories_violin_plots_{gene_of_interest}.png"
    
    # (2) Initialize Storage for Mean and CI
    subgroup_avg_gene_expressions = {label: [] for label in unique_labels}
    subgroup_ci_gene_expressions = {label: [] for label in unique_labels}
    
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized.copy()
    
    # (3) Compute Mean & Confidence Intervals
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:  # Apply truncation
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract gene expression values
        X1_hat = pca.inverse_transform(X1_trpt)
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]
    
        # Compute subgroup averages & CI
        for label in unique_labels:
            mask = (X1_hat_labels == label)  # Use labels **only from step 1**
            subgroup_values = np.array(gene_expression_values)[mask]
    
            if len(subgroup_values) > 0:
                subgroup_avg_gene_expressions[label].append(np.mean(subgroup_values))
                ci = stats.sem(subgroup_values) * stats.t.ppf((1 + 0.95) / 2., len(subgroup_values) - 1)
                subgroup_ci_gene_expressions[label].append(ci)
            else:
                subgroup_avg_gene_expressions[label].append(np.nan)
                subgroup_ci_gene_expressions[label].append(np.nan)
    

            
    
    # (4) **Plot**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # **Get x-axis positions for the line plot (scale to [0, 4])**
    num_points = len(next(iter(subgroup_avg_gene_expressions.values())))  # Number of time points
    x_positions = np.linspace(0, 4, num_points)  # Ensure correct x-spacing for trajectories
    
    # **Plot Predicted Trajectories & Confidence Intervals**
    subgroup_legend_handles = []  # Store for separate legend


    # **Plot Subgroup Averages & Confidence Intervals**
    for i, label in enumerate(unique_labels):
        # **Plot the Mean Trajectory Line**
        line, = ax1.plot(
            x_positions, subgroup_avg_gene_expressions[label], zorder=10,
            linestyle='-', color=subgroup_color_map[label], linewidth=2,
            label=f'Predicted Trajectory {i+1}'
        )
    
        # **Plot the Confidence Interval (Shaded Region)**
        ax1.fill_between(
            x_positions,
            np.array(subgroup_avg_gene_expressions[label]) - np.array(subgroup_ci_gene_expressions[label]),
            np.array(subgroup_avg_gene_expressions[label]) + np.array(subgroup_ci_gene_expressions[label]),
            alpha=0.2, zorder=5, color=subgroup_color_map[label],
            label=f'95% CI of Trajectory {i+1}'
        )
    
        # **Legend entry for Mean + Confidence Interval**
        ci_patch = mpatches.Patch(
            color=subgroup_color_map[label], alpha=0.2, label=f'95% CI of Trajectory {i+1}'
        )
    
        # **Store in Legend Handles**
        subgroup_legend_handles.append(ci_patch)
        subgroup_legend_handles.append(line)
        
    # (5) **Ensure Violin Plots are at `[0, 2, 4]`**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    # **Manually set violin plot positions to `[0, 2, 4]`**
    violin_x_positions = np.array([0, 1, 2, 3, 4])  # Explicitly define positions
    violin_colors = ["black", "gray", "black", "gray", "black"]  # Set distinct colors
    
    # 🎻 **Plot Violin Plots One-by-One to Force Correct Positioning**
    for i, (x_pos, data, color) in enumerate(zip(violin_x_positions, violin_data, violin_colors)):
        violin_parts = sns.violinplot(
            data=[data],  # Must be wrapped in a list to avoid merging violins
            ax=ax1,
            inner=None,
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=color,  # ✅ Assign distinct colors
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  # Move to correct x-location
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  # ✅ Extend range
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 1, 2, 3, 4])  # ✅ Force labels at `[0, 2, 4]`
    ax1.set_xticklabels([0, 1, 2, 3, 4], fontsize=35)
    ax1.tick_params(axis='y', labelsize=35)
    
    ax1.set_xlabel('Time', fontsize=35)
    ax1.set_ylabel('Gene Expression', fontsize=35)
    ax1.set_title(f'Subtrajectory {gene_of_interest} Expression', fontsize=34)
    
    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (VERTICAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(4, 8))  # Tall aspect ratio for vertical layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = subgroup_legend_handles + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=18, title="Trajectories & Violin Plots",
        title_fontsize=18, ncol=1, frameon=True, handletextpad=1.5, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Subgroup trajectory plot saved at: {subgroup_output_file}")
    print(f"Legend plot saved separately at: {legend_output_file}")


In [None]:
## Plot the results for Stem cell data

import matplotlib.patches as mpatches


## HR+ cancer gene markders
genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,2,3]
#intermediate_t = [0]

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

output_dir = os.path.join(result_dir, 'output', exp_memo)
    
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_with_violin_plot_sample_3(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf for Stem Cell data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/subtrajectories_violin_plots_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/subtrajectories_violin_plots.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))


In [None]:
## This is for EMT data (Three time points: Time [0, 2, 4])

## Subtrajectroies with_violin_plot

def Average_gene_dynamics_whole_saveonly_with_violin_plot_sample1(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day1]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    

    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X2_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)

    #mask = last_day_labels == 0
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




    


    
        

   
    
    # (1) Perform clustering on the last day's cell states from `mats`
    last_day = mats[day1]
    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering

    
    # Define colors for subgroups
    subgroup_colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown']
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subgroup_colors[i % len(subgroup_colors)] for i, label in enumerate(unique_labels)}
    
    # Define filename
    subgroup_output_file = f"{output_dir}/subtrajectories_violin_plots_{gene_of_interest}.png"
    
    # (2) Initialize Storage for Mean and CI
    subgroup_avg_gene_expressions = {label: [] for label in unique_labels}
    subgroup_ci_gene_expressions = {label: [] for label in unique_labels}
    
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized.copy()
    
    # (3) Compute Mean & Confidence Intervals
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:  # Apply truncation
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract gene expression values
        X1_hat = pca.inverse_transform(X1_trpt)
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]
    
        # Compute subgroup averages & CI
        for label in unique_labels:
            mask = (X1_hat_labels == label)  # Use labels **only from step 1**
            subgroup_values = np.array(gene_expression_values)[mask]
    
            if len(subgroup_values) > 0:
                subgroup_avg_gene_expressions[label].append(np.mean(subgroup_values))
                ci = stats.sem(subgroup_values) * stats.t.ppf((1 + 0.95) / 2., len(subgroup_values) - 1)
                subgroup_ci_gene_expressions[label].append(ci)
            else:
                subgroup_avg_gene_expressions[label].append(np.nan)
                subgroup_ci_gene_expressions[label].append(np.nan)
    

    
    # (4) **Plot**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # **Get x-axis positions for the line plot (scale to [0, 4])**
    num_points = len(next(iter(subgroup_avg_gene_expressions.values())))  # Number of time points
    x_positions = np.linspace(0, 4, num_points)  # Ensure correct x-spacing for trajectories
    
    # **Plot Predicted Trajectories & Confidence Intervals**
    subgroup_legend_handles = []  # Store for separate legend


    # **Plot Subgroup Averages & Confidence Intervals**
    for i, label in enumerate(unique_labels):
        # **Plot the Mean Trajectory Line**
        line, = ax1.plot(
            x_positions, subgroup_avg_gene_expressions[label], zorder=10,
            linestyle='-', color=subgroup_color_map[label], linewidth=2,
            label=f'Predicted Trajectory {i+1}'
        )
    
        # **Plot the Confidence Interval (Shaded Region)**
        ax1.fill_between(
            x_positions,
            np.array(subgroup_avg_gene_expressions[label]) - np.array(subgroup_ci_gene_expressions[label]),
            np.array(subgroup_avg_gene_expressions[label]) + np.array(subgroup_ci_gene_expressions[label]),
            alpha=0.2, zorder=5, color=subgroup_color_map[label],
            label=f'95% CI of Trajectory {i+1}'
        )
    
        # **Legend entry for Mean + Confidence Interval**
        ci_patch = mpatches.Patch(
            color=subgroup_color_map[label], alpha=0.2, label=f'95% CI of Trajectory {i+1}'
        )
    
        # **Store in Legend Handles**
        subgroup_legend_handles.append(line)
        subgroup_legend_handles.append(ci_patch)

        
    # (5) **Ensure Violin Plots are at `[0, 2, 4]`**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    # **Manually set violin plot positions to `[0, 2, 4]`**
    violin_x_positions = np.array([0, 2, 4])  # Explicitly define positions
    violin_colors = ["black", "gray", "black", "gray", "black"]  # Set distinct colors
    
    # 🎻 **Plot Violin Plots One-by-One to Force Correct Positioning**
    for i, (x_pos, data, color) in enumerate(zip(violin_x_positions, violin_data, violin_colors)):
        violin_parts = sns.violinplot(
            data=[data],  # Must be wrapped in a list to avoid merging violins
            ax=ax1,
            inner=None,
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=color,  # ✅ Assign distinct colors
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  # Move to correct x-location
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  # ✅ Extend range
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 2, 4])  # ✅ Force labels at `[0, 2, 4]`
    ax1.set_xticklabels([0, 2, 4], fontsize=35)
    ax1.tick_params(axis='y', labelsize=35)
    
    ax1.set_xlabel('Time', fontsize=35)
    ax1.set_ylabel('Gene Expression', fontsize=35)
    ax1.set_title(f'Subtrajectory {gene_of_interest} Expression', fontsize=35)
    
    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (VERTICAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(8, 2))  # Tall aspect ratio for vertical layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = subgroup_legend_handles + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=18, title="",
        title_fontsize=18, ncol=6, frameon=True, handletextpad=1.5, columnspacing=1.5
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Subgroup trajectory plot saved at: {subgroup_output_file}")
    print(f"Legend plot saved separately at: {legend_output_file}")


In [None]:
## Plot the results for EMT data

import matplotlib.patches as mpatches


## HR+ cancer gene markders
genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [2]
#intermediate_t = [0]

d_red= 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

output_dir = os.path.join(result_dir, 'output', exp_memo)
    
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_with_violin_plot_sample1(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")

In [None]:
## save the gene expression dynamics png as pdf for EMT data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/subtrajectories_violin_plots_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/subtrajectories_violin_plots.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))


## Analysis of divergence of convergence of subtrajectories

In [None]:
## Dynamics of p-values and fold change arross subtrajectories (with csv files and visulization results) 

import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import pandas as pd
import os

def Compute_and_Plot_FoldChange_MeanDiff_PValues(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                                 intermediate_t=[1], d_red=2, random_state=42, exp_memo='2'):

    # Load PCA and clustering results
    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA transformation
    pca_filename = "emt_pca_%d.pkl" % d_red if dim_red_method == 'EMT_PCA' else "pca_%d.pkl"
    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)



    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")

    # Extract gene expression values
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    all_gene_expression_values_normalized_X1 = np.concatenate(
        [pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i]
    )

    # Initialize lists for storing results
    fold_change_results = []
    mean_diff_results = []
    p_value_results = []

    indices = range(0, len(X1_trpts), index)
    eps = 1e-6  # Small constant to prevent division by zero

    # Compute mean difference, fold-change, and p-values over time
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            continue

        X1_hat = pca.inverse_transform(X1_trpt)
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]

        for label1 in unique_labels:
            for label2 in unique_labels:
                if label1 >= label2:
                    continue

                mask1 = (X1_hat_labels == label1)
                mask2 = (X1_hat_labels == label2)

                expr_values1 = np.array(gene_expression_values)[mask1]
                expr_values2 = np.array(gene_expression_values)[mask2]

                if len(expr_values1) > 1 and len(expr_values2) > 1:
                    mean1, mean2 = np.mean(expr_values1), np.mean(expr_values2)
                    
                    # Compute Mean Difference
                    mean_diff = mean1 - mean2
                    
                    # Compute Fold Change
                    if mean1 <= 0 or mean2 <= 0 or np.isnan(mean1) or np.isnan(mean2):
                        log2_fc = np.nan
                    else:
                        log2_fc = np.log2(mean1 / mean2)

                    # Compute p-value
                    t_stat, p_val = stats.ttest_ind(expr_values1, expr_values2, equal_var=False, nan_policy='omit')
                else:
                    mean_diff = np.nan
                    log2_fc = np.nan
                    p_val = np.nan

                fold_change_results.append({
                    "Time": time_idx,
                    "Cluster 1": label1,
                    "Cluster 2": label2,
                    "Log2 Fold Change": log2_fc
                })

                mean_diff_results.append({
                    "Time": time_idx,
                    "Cluster 1": label1,
                    "Cluster 2": label2,
                    "Mean Difference": mean_diff
                })

                p_value_results.append({
                    "Time": time_idx,
                    "Cluster 1": label1,
                    "Cluster 2": label2,
                    "p-value": p_val
                })

    # Convert results to DataFrame
    df_fc = pd.DataFrame(fold_change_results)
    df_md = pd.DataFrame(mean_diff_results)
    df_pval = pd.DataFrame(p_value_results)

    # Save as CSV files
    output_dir = os.path.join(result_dir, "output", exp_memo)
    os.makedirs(output_dir, exist_ok=True)

    # Function to propagate NaNs forward **only if the last finite value was negative**
    def propagate_nans_for_negative_fc(df, column="Log2 Fold Change"):
        """Once NaN appears after a negative `column` value, all subsequent values become NaN."""
        df = df.copy()
        log2_fc_values = df[column].values  # Extract column values as an array
    
        # Identify first NaN index
        nan_mask = np.isnan(log2_fc_values)
        if nan_mask.any():
            first_nan_idx = np.where(nan_mask)[0][0]  # Find first NaN index
    
            # Check the last valid value before NaN
            last_valid_idx = first_nan_idx - 1 if first_nan_idx > 0 else None
            if last_valid_idx is not None and log2_fc_values[last_valid_idx] < 0:
                # If last valid log2 fold change was negative, set all subsequent values to NaN
                log2_fc_values[first_nan_idx:] = np.nan
    
        df[column] = log2_fc_values  # Update the DataFrame
        return df
    
    
    # **Create and Save the Fold Change CSV with NaN Propagation for Negative Values**
    df_fc_nan_propagated = df_fc.groupby(["Cluster 1", "Cluster 2"]).apply(propagate_nans_for_negative_fc)
    df_fc_nan_propagated.to_csv(os.path.join(subtraj_dir, f"fold_change_nan_propagated_{gene_of_interest}.csv"), index=False)
    
    print(f"Saved fold change CSV with NaN propagation for negative values: fold_change_nan_propagated_{gene_of_interest}.csv")


    df_fc.to_csv(os.path.join(subtraj_dir, f"fold_change_{gene_of_interest}.csv"), index=False)
    df_md.to_csv(os.path.join(subtraj_dir, f"mean_difference_{gene_of_interest}.csv"), index=False)
    df_pval.to_csv(os.path.join(subtraj_dir, f"p_values_{gene_of_interest}.csv"), index=False)

    print(f"Results saved to {output_dir}")

    # -------- Visualization --------

    ## **(1) Line Plot for Mean Difference (Each Cluster Pair)**
    for (cluster1, cluster2), group in df_md.groupby(["Cluster 1", "Cluster 2"]):
        plt.figure(figsize=(8, 5))
        plt.plot(group["Time"], group["Mean Difference"], label=f"Clusters {cluster1} vs {cluster2}", marker='o', linestyle='-')
        plt.axhline(y=0, color="black", linestyle="--")
        plt.xlabel("Time")
        plt.ylabel("Mean Difference")
        plt.title(f"Mean Difference Over Time: {gene_of_interest}\nCluster {cluster1} vs {cluster2}")
        plt.legend()
        plt.grid()
        plt.tight_layout()
        plt.savefig(os.path.join(subtraj_dir, f"mean_difference_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
        plt.close()

    ## **(2) Line Plot for Fold-Change (All Cluster Pairs)**
    plt.figure(figsize=(10, 6))
    for (cluster1, cluster2), group in df_fc.groupby(["Cluster 1", "Cluster 2"]):
        plt.plot(group["Time"], group["Log2 Fold Change"], label=f"Clusters {cluster1} vs {cluster2}", marker='o')

    plt.axhline(y=0, color="black", linestyle="--")
    plt.xlabel("Time")
    plt.ylabel("Log2 Fold Change")
    plt.title(f"Log2 Fold Change Over Time - {gene_of_interest}")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"fold_change_all_clusters_{gene_of_interest}.png"), dpi=300)
    plt.close()

    ## **(3) Line Plot of p-values (All Cluster Pairs)**
    plt.figure(figsize=(10, 6))
    for (cluster1, cluster2), group in df_pval.groupby(["Cluster 1", "Cluster 2"]):
        plt.plot(group["Time"], group["p-value"], label=f"Clusters {cluster1} vs {cluster2}", marker='o')

    plt.axhline(y=0.05, color="r", linestyle="--", label="Significance Threshold (p=0.05)")
    plt.yscale("log")
    plt.ylim(1e-6, 1)
    plt.xlabel("Time")
    plt.ylabel("p-value (log scale)")
    plt.title(f"P-values for Differential Expression - {gene_of_interest}")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"p_values_all_clusters_{gene_of_interest}.png"), dpi=300)
    plt.close()



In [None]:
## Generate the statistical results for the comparison of subtrajectories (Stem Cell data)

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 3
index = 1
max_i = 200
intermediate_t = [2]

#intermediate_t = []

d_red= 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

subtraj_dir = os.path.join(output_dir, 'subtraj')

# Create necessary directories
os.makedirs(subtraj_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Compute_and_Plot_FoldChange_MeanDiff_PValues(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## Generate the statistical results for the comparison of subtrajectories (Stem Cell data)

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 3
index = 1
max_i = 200
intermediate_t = [1,2,3]

#intermediate_t = []

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=6e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

subtraj_dir = os.path.join(result_dir, 'output', exp_memo, 'subtraj')
    
# Create the directory if it doesn't exist
os.makedirs(subtraj_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Compute_and_Plot_FoldChange_MeanDiff_PValues(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## Plot the difference of means across subtrajectories (EMT data)

import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np

# Define paths
genes_of_interest = gene_names  # List of all genes
subtraj_dir = os.path.join(result_dir, "output", exp_memo, "subtraj")

# Load all CSVs for fold change, mean difference, and p-values
df_fc_list, df_md_list, df_pval_list = [], [], []

for gene in genes_of_interest:
    fc_file_path = os.path.join(subtraj_dir, f"fold_change_nan_propagated_{gene}.csv")
    md_file_path = os.path.join(subtraj_dir, f"mean_difference_{gene}.csv")
    pval_file_path = os.path.join(subtraj_dir, f"p_values_{gene}.csv")

    if os.path.exists(fc_file_path):
        df_fc = pd.read_csv(fc_file_path)
        df_fc["Gene"] = gene  
        df_fc_list.append(df_fc)

    if os.path.exists(md_file_path):
        df_md = pd.read_csv(md_file_path)
        df_md["Gene"] = gene  
        df_md_list.append(df_md)

    if os.path.exists(pval_file_path):
        df_pval = pd.read_csv(pval_file_path)
        df_pval["Gene"] = gene  
        df_pval_list.append(df_pval)

# Merge all DataFrames
df_fc_all = pd.concat(df_fc_list, ignore_index=True)
df_md_all = pd.concat(df_md_list, ignore_index=True)
df_pval_all = pd.concat(df_pval_list, ignore_index=True)

# Merge p-values into fold change and mean difference DataFrames
df_fc_all = df_fc_all.merge(df_pval_all, on=["Time", "Cluster 1", "Cluster 2", "Gene"], how="left")
df_md_all = df_md_all.merge(df_pval_all, on=["Time", "Cluster 1", "Cluster 2", "Gene"], how="left")

# **Rescale Time by dividing by 50**
df_fc_all["Time"] = df_fc_all["Time"] / 50
df_md_all["Time"] = df_md_all["Time"] / 50

# Define thresholds
pos_threshold_fc = 5.0  
neg_threshold_fc = -5.0  
pos_threshold_md = 0.11  
neg_threshold_md = -0.15  
p_value_threshold = 1e-3  # Significance threshold

# **Identify genes passing p-value threshold**
significant_genes = df_pval_all.groupby("Gene")["p-value"].apply(lambda x: x.min(skipna=True) < p_value_threshold)
significant_genes = significant_genes[significant_genes].index  # Only genes passing p-value

# **Filter fold-change and mean-difference genes, but retain all passing p-value**
def filter_significant_genes(df, metric_col, threshold_high, threshold_low):
    """Find genes that exceed thresholds (colored) and others that stay gray (but pass p-value)."""
    gene_criteria = df.groupby("Gene")[metric_col].apply(lambda x: ((x > threshold_high) | (x < threshold_low)).any())
    
    highlighted_genes = gene_criteria[gene_criteria].index  # Genes exceeding threshold
    retained_genes = significant_genes.intersection(gene_criteria.index)  # Only keep genes passing p-value
    
    return retained_genes, highlighted_genes  # Return both full set and colored ones

# **Apply filtering**
all_genes_fc, highlighted_genes_fc = filter_significant_genes(df_fc_all, "Log2 Fold Change", pos_threshold_fc, neg_threshold_fc)
all_genes_md, highlighted_genes_md = filter_significant_genes(df_md_all, "Mean Difference", pos_threshold_md, neg_threshold_md)

# Compute per-gene normalization
df_md_all["Normalized Mean Difference"] = df_md_all.groupby("Gene")["Mean Difference"].transform(lambda x: x / x.abs().max())

# Identify genes with significant normalized mean difference changes
gene_change_norm = df_md_all.groupby("Gene")["Normalized Mean Difference"].apply(lambda x: x.max() - x.min())
all_genes_norm = significant_genes.intersection(gene_change_norm.index)
highlighted_genes_norm = gene_change_norm[gene_change_norm > 0.45].index

# Define colormaps
cmap1 = plt.colormaps["tab20b"]
cmap2 = plt.colormaps["tab20c"]
cmap3 = plt.colormaps["Set1"]
cmap4 = plt.colormaps["Set3"]
cmap5 = plt.colormaps["Paired"]

# Generate color lists
color_list = (
    [cmap3(i) for i in range(min(9, len(cmap3.colors)))] +  
    [cmap4(i) for i in range(min(12, len(cmap4.colors)))] +  
    [cmap1(i) for i in range(20)] +
    [cmap2(i) for i in range(20)] +
    [cmap5(i) for i in range(min(12, len(cmap5.colors)))] +  
    list(plt.cm.hsv(np.linspace(0, 1, 15)))  
)

# Assign colors
all_highlighted_genes = sorted(set(highlighted_genes_fc).union(set(highlighted_genes_md)).union(set(highlighted_genes_norm)))
gene_colors = {gene: color_list[i % len(color_list)] for i, gene in enumerate(all_highlighted_genes)}

# Group by cluster pairs and plot
for (cluster1, cluster2), group_fc in df_fc_all.groupby(["Cluster 1", "Cluster 2"]):

    # **(1) Fold Change Plot**
    fig, ax = plt.subplots(figsize=(10, 6))
    legend_handles = []

    for gene in all_genes_fc:
        sub_group = group_fc[group_fc["Gene"] == gene]
        if gene in highlighted_genes_fc:
            color = gene_colors.get(gene)
            alpha_value, linestyle = 1.0, "-"
        else:
            color = "gray"
            alpha_value, linestyle = 0.2, "--"

        line, = ax.plot(sub_group["Time"], sub_group["Log2 Fold Change"], marker="o", markersize=4, linestyle=linestyle, color=color, alpha=alpha_value, label=gene)
        if gene in highlighted_genes_fc:
            legend_handles.append(line)

    ax.axhline(y=0, color="black", linestyle="--", alpha=0.7)
    ax.set_xlabel("Time")
    ax.set_ylabel("Log2 Fold Change")
    ax.set_title(f"Fold Change Over Time (Cluster {cluster1} vs {cluster2})")

    if legend_handles:
        ax.legend(handles=legend_handles, title="Significant Genes", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"fold_change_comparison_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
    plt.close()

    # **(2) Mean Difference Plot**
    fig, ax = plt.subplots(figsize=(10, 6))
    legend_handles = []
    group_md = df_md_all[(df_md_all["Cluster 1"] == cluster1) & (df_md_all["Cluster 2"] == cluster2)]

    for gene in all_genes_md:
        sub_group = group_md[group_md["Gene"] == gene]
        if gene in highlighted_genes_md:
            color = gene_colors.get(gene)
            alpha_value, linestyle = 1.0, "-"
        else:
            color = "gray"
            alpha_value, linestyle = 0.2, "--"

        line, = ax.plot(sub_group["Time"], sub_group["Mean Difference"], marker="o", markersize=4, linestyle=linestyle, color=color, alpha=alpha_value, label=gene)
        if gene in highlighted_genes_md:
            legend_handles.append(line)

    ax.axhline(y=0, color="black", linestyle="--", alpha=0.7)
    ax.set_xlabel("Time")
    ax.set_ylabel("Mean Difference")
    ax.set_title(f"Mean Difference Over Time (Cluster {cluster1} vs {cluster2})")

    if legend_handles:
        ax.legend(handles=legend_handles, title="Significant Genes", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"mean_difference_comparison_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
    plt.close()




    
    # **(3) Normalized Mean Difference Plot**
    fig, ax = plt.subplots(figsize=(10, 8))
    legend_handles = []
    
    # **Filter data for the given cluster pair**
    group_md = df_md_all[(df_md_all["Cluster 1"] == cluster1) & (df_md_all["Cluster 2"] == cluster2)]
    
    # **Only keep genes passing the p-value threshold**
    genes_to_plot = significant_genes.intersection(group_md["Gene"].unique())
    
    # **Get Last Time Point Values**
    last_time_values = group_md[group_md["Time"] == group_md["Time"].max()].set_index("Gene")["Normalized Mean Difference"]
    
    # **Categorize Highlighted Genes into Four Groups Based on Their Last Time Point Values**
    group_green = [gene for gene in highlighted_genes_norm if gene in last_time_values and 1.0 >= last_time_values[gene] > 0.8]
    group_red = [gene for gene in highlighted_genes_norm if gene in last_time_values and 0.8 >= last_time_values[gene] > -0.3]
    group_blue = [gene for gene in highlighted_genes_norm if gene in last_time_values and -0.3 >= last_time_values[gene] > -0.8]
    group_orange = [gene for gene in highlighted_genes_norm if gene in last_time_values and -0.8 >= last_time_values[gene] >= -1.0]
    
    # **Define function to get only bright colors from a colormap (skip dark colors)**
    def get_bright_colormap(cmap_name, num_colors):
        cmap = plt.get_cmap(cmap_name)
        if num_colors == 1:
            return [cmap(0.75)]  # Single color case, select a bright shade
        return [cmap(0.5 + (i / (2 * (num_colors - 1)))) for i in range(num_colors)]  # Use only upper 50% of colormap
    
    # **Generate Bright Color Maps for Each Group**
    colors_green = get_bright_colormap("Greens", max(len(group_green), 1))
    colors_blue = get_bright_colormap("Purples", max(len(group_blue), 1))
    colors_red = get_bright_colormap("Reds", max(len(group_red), 1))
    colors_orange = get_bright_colormap("Oranges", max(len(group_orange), 1))
    
    # **Assign Colors to Highlighted Genes Based on Group**
    color_mapped_genes = {}
    
    for i, gene in enumerate(group_green):
        color_mapped_genes[gene] = colors_green[i]
    for i, gene in enumerate(group_blue):
        color_mapped_genes[gene] = colors_blue[i]
    for i, gene in enumerate(group_red):
        color_mapped_genes[gene] = colors_red[i]
    for i, gene in enumerate(group_orange):
        color_mapped_genes[gene] = colors_orange[i]
    
    # **Plot Data (Exclude group_green and group_red)**
    for gene in genes_to_plot:  # **Only plot genes that satisfy p-value threshold**
        if gene in group_green or gene in group_orange:  # Skip plotting genes in group_green and group_red
            continue  
    
        sub_group = group_md[group_md["Gene"] == gene]
    
        if gene in highlighted_genes_norm:  # **Highlight genes that satisfy both p-value & change criteria**
            color = color_mapped_genes.get(gene, "black")
            alpha_value, linestyle = 1.0, "-"
            line, = ax.plot(sub_group["Time"], sub_group["Normalized Mean Difference"], 
                            marker="o", markersize=4, linestyle=linestyle, 
                            color=color, alpha=alpha_value, label=gene)
            legend_handles.append((line, gene))  # Save for grouped legend
        else:  # **Gray for genes that satisfy p-value but not change threshold**
            color = "gray"
            alpha_value, linestyle = 0.2, "--"
            ax.plot(sub_group["Time"], sub_group["Normalized Mean Difference"], 
                    marker="o", markersize=4, linestyle=linestyle, 
                    color=color, alpha=alpha_value, label=gene)

    # **Labeling and Formatting**
    ax.set_xlabel("Time", fontsize=30)
    ax.set_ylabel("Normalized Mean Difference", fontsize=30)
    ax.set_title(f"Trajectory {cluster1+1} vs {cluster2+1}", fontsize=30)
    ax.tick_params(axis="both", labelsize=30)
    
    # **Save Main Plot Without Legend**
    plot_output_path = os.path.join(subtraj_dir, f"normalized_mean_difference_cluster_{cluster1}_vs_{cluster2}.png")
    plt.tight_layout()
    plt.savefig(plot_output_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"✅ Normalized Mean Difference plot saved: {plot_output_path}")
    
    # **Create and Save Separate Legend**
    if legend_handles:
        fig_legend, ax_legend = plt.subplots(figsize=(10, 10))  # Wider figure to fit grouped layout
        ax_legend.axis("off")
    
        # **Group Legend by Color**
        grouped_legend_handles = []
        
        for gene_list, title in zip(
            #[group_green, group_blue, group_red, group_orange],
            [group_red, group_blue],
            ["", "", "", ""]  # No explicit labels, but groups remain visually separate
        ):
            if gene_list:
                handles = [mpatches.Patch(color=color_mapped_genes[gene], label=gene) for gene in gene_list]
                grouped_legend_handles.append(handles)
    
        # **Flatten List of Legends**
        flattened_handles = [item for sublist in grouped_legend_handles for item in sublist]
    
        # **Plot the Grouped Legend**
        ax_legend.legend(
            handles=flattened_handles, 
            title="Significant Genes", 
            loc="center", 
            fontsize=26, 
            title_fontsize=26, 
            frameon=True, 
            ncol=min(len(flattened_handles), 2)  # Adjust to avoid overflow
        )
    
        # **Save Legend as Separate PNG**
        legend_output_path = os.path.join(subtraj_dir, f"legend_normalized_mean_difference_cluster_{cluster1}_vs_{cluster2}.png")
        plt.savefig(legend_output_path, dpi=300, bbox_inches="tight")
        plt.close()
        
        print(f"✅ Legend saved separately: {legend_output_path}")


In [None]:
## Plot the difference of means across subtrajectories (Stem cell data)

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import matplotlib.patches as mpatches

# Define function to get evenly spaced colors from a colormap
def get_colormap_colors(cmap_name, num_colors):
    if num_colors == 1:
        return [plt.get_cmap(cmap_name)(0.5)]  # Select the middle of the colormap
    cmap = plt.get_cmap(cmap_name)
    return [cmap(i / (num_colors - 1)) for i in range(num_colors)]

# Define paths
genes_of_interest = gene_names  
subtraj_dir = os.path.join(result_dir, "output", exp_memo, "subtraj")

# Load all CSVs for fold change, mean difference, and p-values
df_fc_list, df_md_list, df_pval_list = [], [], []

for gene in genes_of_interest:
    fc_file_path = os.path.join(subtraj_dir, f"fold_change_nan_propagated_{gene}.csv")
    md_file_path = os.path.join(subtraj_dir, f"mean_difference_{gene}.csv")
    pval_file_path = os.path.join(subtraj_dir, f"p_values_{gene}.csv")

    if os.path.exists(fc_file_path):
        df_fc = pd.read_csv(fc_file_path)
        df_fc["Gene"] = gene  
        df_fc_list.append(df_fc)

    if os.path.exists(md_file_path):
        df_md = pd.read_csv(md_file_path)
        df_md["Gene"] = gene  
        df_md_list.append(df_md)

    if os.path.exists(pval_file_path):
        df_pval = pd.read_csv(pval_file_path)
        df_pval["Gene"] = gene  
        df_pval_list.append(df_pval)

# Merge all DataFrames
df_fc_all = pd.concat(df_fc_list, ignore_index=True)
df_md_all = pd.concat(df_md_list, ignore_index=True)
df_pval_all = pd.concat(df_pval_list, ignore_index=True)

# Merge p-values into fold change and mean difference DataFrames
df_fc_all = df_fc_all.merge(df_pval_all, on=["Time", "Cluster 1", "Cluster 2", "Gene"], how="left")
df_md_all = df_md_all.merge(df_pval_all, on=["Time", "Cluster 1", "Cluster 2", "Gene"], how="left")

# **Rescale Time**
df_fc_all["Time"] = df_fc_all["Time"] / 50
df_md_all["Time"] = df_md_all["Time"] / 50

# Define thresholds
pos_threshold_fc = 5.0  
neg_threshold_fc = -5.0  
pos_threshold_md = 0.11  
neg_threshold_md = -0.15  
p_value_threshold = 1e-4  

# **Identify significant genes based on p-value threshold**
significant_genes = df_pval_all.groupby("Gene")["p-value"].apply(lambda x: x.min(skipna=True) < p_value_threshold)
significant_genes = significant_genes[significant_genes].index  

# Compute per-gene normalization
df_md_all["Normalized Mean Difference"] = df_md_all.groupby("Gene")["Mean Difference"].transform(lambda x: x / x.abs().max())

# **Filter fold-change and mean-difference genes, but retain all passing p-value**
def filter_significant_genes(df, metric_col, threshold_high, threshold_low):
    """Find genes that exceed thresholds (colored) and others that stay gray (but pass p-value)."""
    gene_criteria = df.groupby("Gene")[metric_col].apply(lambda x: ((x > threshold_high) | (x < threshold_low)).any())
    
    highlighted_genes = gene_criteria[gene_criteria].index  # Genes exceeding threshold
    retained_genes = significant_genes.intersection(gene_criteria.index)  # Only keep genes passing p-value
    
    return retained_genes, highlighted_genes  # Return both full set and colored ones

# **Apply filtering**
all_genes_fc, highlighted_genes_fc = filter_significant_genes(df_fc_all, "Log2 Fold Change", pos_threshold_fc, neg_threshold_fc)
all_genes_md, highlighted_genes_md = filter_significant_genes(df_md_all, "Mean Difference", pos_threshold_md, neg_threshold_md)

# Identify genes with significant normalized mean difference changes
gene_max_diff = df_md_all.groupby("Gene")["Normalized Mean Difference"].apply(lambda x: x.max() - x.min())

# **Determine which genes to highlight**
highlighted_genes_norm = significant_genes.intersection(gene_max_diff[gene_max_diff > 0.62].index)
gray_genes = significant_genes.difference(highlighted_genes_norm)  # Genes passing p-value but not gene_max_diff

# **Split Highlighted Genes into Positive and Negative Groups**
positive_genes = []
negative_genes = []

for gene in highlighted_genes_norm:
    initial_value = df_md_all[df_md_all["Gene"] == gene]["Normalized Mean Difference"].iloc[0]  
    if initial_value >= 0:
        positive_genes.append(gene)
    else:
        negative_genes.append(gene)



# **Assign Colors Using Custom Colormaps (Avoiding White)**
num_positive = max(len(positive_genes), 1)  
num_negative = max(len(negative_genes), 1)  

# Load colormaps
cmap_reds = plt.get_cmap("Reds")  # Red to White
cmap_blues = plt.get_cmap("Blues_r")  # Blue to White (Reversed)

# Clip the colormap range to avoid very light colors (too close to white)
colors_positive = [cmap_reds(0.3 + 0.7 * (i / (num_positive - 1))) for i in range(num_positive)]  # Avoid very light red
colors_negative = [cmap_blues(0.1 + 0.6 * (i / max(1, num_negative - 1))) for i in range(num_negative)]

gene_colors_scaled = {}

for i, gene in enumerate(positive_genes):
    gene_colors_scaled[gene] = colors_positive[i % len(colors_positive)]

for i, gene in enumerate(negative_genes):
    gene_colors_scaled[gene] = colors_negative[i % len(colors_negative)]




# **Separate Legends for Positive & Negative Genes**
legend_handles_pos = [mpatches.Patch(color=gene_colors_scaled[gene], label=gene) for gene in positive_genes]
legend_handles_neg = [mpatches.Patch(color=gene_colors_scaled[gene], label=gene) for gene in negative_genes]

# **Save Separate Legend for Positive Genes**
if legend_handles_pos:
    fig_legend_pos, ax_legend_pos = plt.subplots(figsize=(6, max(1, len(legend_handles_pos) // 2)))
    ax_legend_pos.axis("off")
    ax_legend_pos.legend(handles=legend_handles_pos, loc="center", title="Positive  Difference",
                         title_fontsize=36, fontsize=36, frameon=True, ncol=1)
    plt.savefig(os.path.join(subtraj_dir, "legend_positive_genes.png"), dpi=300, bbox_inches="tight")
    plt.close()

# **Save Separate Legend for Negative Genes**
if legend_handles_neg:
    fig_legend_neg, ax_legend_neg = plt.subplots(figsize=(6, max(1, len(legend_handles_neg) // 2)))
    ax_legend_neg.axis("off")
    ax_legend_neg.legend(handles=legend_handles_neg, loc="center", title="Negative Difference",
                         title_fontsize=36, fontsize=36, frameon=True, ncol=1)
    plt.savefig(os.path.join(subtraj_dir, "legend_negative_genes.png"), dpi=300, bbox_inches="tight")
    plt.close()

# **Generate Main Plot (Without Legend)**
for (cluster1, cluster2), sub_group in df_md_all.groupby(["Cluster 1", "Cluster 2"]):

    fig, ax = plt.subplots(figsize=(10, 9))

    for gene in significant_genes:  # Only genes passing p-value threshold are plotted
        gene_data = sub_group[sub_group["Gene"] == gene]

        if gene in highlighted_genes_norm:
            color = gene_colors_scaled.get(gene, "gray")  
            alpha_value, linestyle = 1.0, "-"
        else:  # Genes passing p-value but NOT max diff
            color = "gray"
            alpha_value, linestyle = 0.2, "--"

        ax.plot(
            gene_data["Time"], gene_data["Normalized Mean Difference"], 
            marker="o", markersize=3, linestyle=linestyle, color=color, alpha=alpha_value
        )

    # Labels and Title
    ax.set_xlabel("Time", fontsize=30)
    ax.set_ylabel("Normalized Mean Difference", fontsize=30)
    ax.set_title(f"Trajectory {cluster1 + 1} vs {cluster2 + 1}", fontsize=30)
    ax.tick_params(axis="both", labelsize=30)

    # Save Plot
    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"normalized_mean_difference_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
    plt.close()



     # **(1) Fold Change Plot**
    group_fc = df_fc_all[(df_fc_all["Cluster 1"] == cluster1) & (df_fc_all["Cluster 2"] == cluster2)]
    fig, ax = plt.subplots(figsize=(10, 6))
    legend_handles = []

    for gene in all_genes_fc:
        sub_group = group_fc[group_fc["Gene"] == gene]
        if gene in highlighted_genes_fc:
            color = gene_colors.get(gene)
            alpha_value, linestyle = 1.0, "-"
        else:
            color = "gray"
            alpha_value, linestyle = 0.2, "--"

        line, = ax.plot(sub_group["Time"], sub_group["Log2 Fold Change"], marker="o", markersize=4, linestyle=linestyle, color=color, alpha=alpha_value, label=gene)
        if gene in highlighted_genes_fc:
            legend_handles.append(line)

    ax.axhline(y=0, color="black", linestyle="--", alpha=0.7)
    ax.set_xlabel("Time", fontsize = 30)
    ax.set_ylabel("Log2 Fold Change", fontsize = 30)
    ax.set_title(f"Fold Change Over Time (Cluster {cluster1} vs {cluster2})", fontsize = 30)

    if legend_handles:
        ax.legend(handles=legend_handles, title="Significant Genes", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"fold_change_comparison_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
    plt.close()

    # **(2) Mean Difference Plot**
    fig, ax = plt.subplots(figsize=(10, 6))
    legend_handles = []
    group_md = df_md_all[(df_md_all["Cluster 1"] == cluster1) & (df_md_all["Cluster 2"] == cluster2)]

    for gene in all_genes_md:
        sub_group = group_md[group_md["Gene"] == gene]
        if gene in highlighted_genes_md:
            color = gene_colors.get(gene)
            alpha_value, linestyle = 1.0, "-"
        else:
            color = "gray"
            alpha_value, linestyle = 0.2, "--"

        line, = ax.plot(sub_group["Time"], sub_group["Mean Difference"], marker="o", markersize=4, linestyle=linestyle, color=color, alpha=alpha_value, label=gene)
        if gene in highlighted_genes_md:
            legend_handles.append(line)

    ax.axhline(y=0, color="black", linestyle="--", alpha=0.7)
    ax.set_xlabel("Time", fontsize = 30)
    ax.set_ylabel("Mean Difference", fontsize = 30)
    ax.set_title(f"Mean Difference Over Time (Cluster {cluster1} vs {cluster2})", fontsize = 30)

    if legend_handles:
        ax.legend(handles=legend_handles, title="Significant Genes", bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.tight_layout()
    plt.savefig(os.path.join(subtraj_dir, f"mean_difference_comparison_cluster_{cluster1}_vs_{cluster2}.png"), dpi=300)
    plt.close()

print("✅ Main plots saved.")
print("✅ Separate legend for positive genes saved as 'legend_positive_genes.png'.")
print("✅ Separate legend for negative genes saved as 'legend_negative_genes.png'.")




## Plot single gene dynamics for every single cell

In [None]:
## ## This is for EMT data, Time [0 , 2, 4]
## Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_EMT(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    #subtrajectory_colors = ['red', 'blue', 'brown']
    subtrajectory_colors = ['green']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "gray", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Individual_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory {label_mapping[label]}'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 2, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0,2, 4])  
    ax1.set_xticklabels([0, 2, 4], fontsize=32)
    ax1.tick_params(axis='y', labelsize=32)
    
    ax1.set_xlabel('Time', fontsize=32)
    ax1.set_ylabel('Gene Expression', fontsize=32)
    ax1.set_title(f'Single Cell {gene_of_interest} Expression Dynamics', fontsize=32)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    legend_patches = [
        mlines.Line2D([], [], color="green", linestyle="-", linewidth=3, 
                      label="Gene dynamics of each single cell")
    ]

    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    


In [None]:
# Plot for EMT data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [2]
#intermediate_t = [4]

d_red= 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_EMT(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf for EMT data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Individual_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Individual_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for Stem Cell data, Time [0 , 1,  2, 3,  4]
## Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_mESC(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day2]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering with the optimal number of clusters
    kmeans = KMeans(n_clusters=optimal_k, random_state=40)
    kmeans.fit(last_day_reduced)
    last_day_labels = kmeans.labels_
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")

    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)

    #mask = last_day_labels == 0
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    
    
    # Define **subtrajectory colors** (for cell trajectories)
    #subtrajectory_colors = ['red', 'blue']
    subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    #violin_colors = ["#3cb44b", "#f58231", "#3cb44b", "#f58231", "#3cb44b"]  # Green, Orange, Purple
    violin_colors = ["black", "gray", "black", "gray", "black"] 
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Individual_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.7, linewidth=1.0 
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory {label_mapping[label]}'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 1, 2, 3, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 1, 2, 3, 4])  
    ax1.set_xticklabels([0, 1, 2, 3, 4], fontsize=35)
    ax1.tick_params(axis='y', labelsize=35)
    
    ax1.set_xlabel('Time', fontsize=35)
    ax1.set_ylabel('Gene Expression', fontsize=35)
    ax1.set_title(f'Single Cell {gene_of_interest} Expression Dynamics', fontsize=35)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    legend_patches = [
        mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
                      label="Gene dynamics of each single cell")
    ]

    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    



    
    

In [None]:
# Plot for Stem cell data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,2,3]
#intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_mESC(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf for stem cell data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Individual_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Individual_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for breast cancer cell line's data, Time [0 , 4]
# Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches




## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_NDPR(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_deviation.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    real_cell_types = np.array(cell_ids_by_day[day2])
    #unique_cell_types = np.unique(real_cell_types)
    unique_cell_types = unique_labels
    #subtrajectory_colors = list(sns.color_palette("tab20", len(unique_cell_types)))
    #subtrajectory_colors = ['green', 'orange', 'purple', 'blue', 'red', 'brown']
    subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory of {label_mapping[label]} phenotypic shift'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 4])  
    #ax1.set_xticklabels([0, 4], fontsize=32)
    ax1.set_xticklabels(["Pre-treatment", "Post-treatment"], fontsize=46)
    ax1.tick_params(axis='y', labelsize=46)
    
    ax1.set_xlabel('Time', fontsize=46)
    ax1.set_ylabel('Gene Expression', fontsize=46)
    ax1.set_title(f'{gene_of_interest}', fontsize=46)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    
    legend_patches = [
        mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
                      label="Hallmark dynamics of each single cell")
    ]




    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
## Plot the results for breast cancer cell line data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
#intermediate_t = [1,2,3]
intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_NDPR(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")

In [None]:
## save the gene expression dynamics png as pdf - Breast cancer cell line data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for clinical data, Time [0 , 4]
## Plot gene dynamis for each trajectory (deviation cell types)

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_clinical(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_deviation.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    real_cell_types = np.array(cell_ids_by_day[day2])
    #unique_cell_types = np.unique(real_cell_types)
    unique_cell_types = unique_labels
    #subtrajectory_colors = list(sns.color_palette("tab20", len(unique_cell_types)))
    subtrajectory_colors = ['green', 'orange', 'purple', 'blue', 'red', 'brown']
    #subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory of {label_mapping[label]} phenotypic shift'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 4])  
    #ax1.set_xticklabels([0, 4], fontsize=32)
    ax1.set_xticklabels(["Pre-treatment", "Post-treatment"], fontsize=46)
    ax1.tick_params(axis='y', labelsize=46)
    
    ax1.set_xlabel('Time', fontsize=46)
    ax1.set_ylabel('Gene Expression', fontsize=46)
    ax1.set_title(f'{gene_of_interest}', fontsize=46)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    
    #legend_patches = [
    #    mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
    #                  label="Hallmark dynamics of each single cell")
    #]

    label_descriptions = {
    "low": "Trajectory of low phenotypic shift",
    "medium": "Trajectory of medium phenotypic shift",
    "high": "Trajectory of high phenotypic shift"}


    # Thicker lines using `linewidth`
    legend_patches = [
        mlines.Line2D(
            [], [], color=color, linestyle='-', linewidth=3,  # ← thicker line here
            markersize=10,
            label=f"{label_descriptions.get(ctype, '')}"
        )
        for ctype, color in zip(unique_cell_types, subtrajectory_colors)
    ]


    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
## Plot the results for clinical data


genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
#intermediate_t = [1,2,3]
intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_clinical(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf - clincial data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## Quantify the errors by using W2 distance (permutation test)

import numpy as np
from scipy.stats import wasserstein_distance

def Compare_Distribution_Permutation_Test(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                          intermediate_t=None, d_red=2, random_state=42, exp_memo='2', 
                                          num_permutations=1000):
    """
    Performs a permutation test to evaluate whether the predicted and test gene expression 
    distributions are significantly different.

    Parameters:
    - source_t: Start time point (not included in the plot)
    - target_t: End time point (not included in the plot)
    - optimal_k: Number of clusters for KMeans
    - gene_of_interest: The gene whose expression is analyzed
    - index: Step size for trajectory extraction
    - max_i: Maximum index for trajectory extraction
    - intermediate_t: List of intermediate time points (defaults to [1] if not provided)
    - d_red: Dimensionality reduction method
    - random_state: Random seed
    - exp_memo: Experiment identifier
    - num_permutations: Number of permutations for significance testing.

    Returns:
    - A dictionary containing W2 distances, permutation test results, and p-values.
    """

    # Ensure intermediate_t has a default value
    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Define intermediate time points
    intermediate_only_points = intermediate_t  

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract gene expression values for test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract trajectory-based predicted distributions at each intermediate time point
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break

        # Extract gene expression at this time step
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]

        # Assign to the corresponding intermediate time point
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    permutation_results = {}

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # Compute the observed W2 distance
        observed_w2 = wasserstein_distance(test_vals, predicted_vals)

        # Perform permutation test
        combined_vals = np.concatenate([test_vals, predicted_vals])
        permuted_w2_distances = []

        for _ in range(num_permutations):
            np.random.shuffle(combined_vals)  # Shuffle data
            perm_test_sample = combined_vals[:len(test_vals)]
            perm_pred_sample = combined_vals[len(test_vals):]

            permuted_w2 = wasserstein_distance(perm_test_sample, perm_pred_sample)
            permuted_w2_distances.append(permuted_w2)

        # Compute p-value (proportion of permuted distances ≥ observed W2)
        p_value = np.mean(np.array(permuted_w2_distances) >= observed_w2)

        # Store results
        permutation_results[time] = {
            "Observed W2": observed_w2,
            "Permutation Mean W2": np.mean(permuted_w2_distances),
            "Permutation Std W2": np.std(permuted_w2_distances),
            "p-value": p_value
        }

    # Print summary
    print("\n--- Permutation Test Summary ---")
    for time, results in permutation_results.items():
        print(f"Time {time}:")
        print(f"  Observed W2: {results['Observed W2']:.4f}")
        print(f"  Mean Permutation W2: {results['Permutation Mean W2']:.4f} ± {results['Permutation Std W2']:.4f}")
        print(f"  p-value: {results['p-value']:.4f}")
        if results["p-value"] < 0.05:
            print("  ** Significant Difference (Reject Null Hypothesis) **")
        else:
            print("  No Significant Difference (Cannot Reject Null Hypothesis)")

    return permutation_results



In [None]:
## Full metrics but No Sinkhorn (permutation test)

import numpy as np
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import rbf_kernel

def maximum_mean_discrepancy(X, Y, gamma=1.0):
    """
    Compute Maximum Mean Discrepancy (MMD) between two distributions using an RBF kernel.
    """
    K_xx = rbf_kernel(X[:, None], X[:, None], gamma=gamma)
    K_yy = rbf_kernel(Y[:, None], Y[:, None], gamma=gamma)
    K_xy = rbf_kernel(X[:, None], Y[:, None], gamma=gamma)

    return K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()

def Compare_Distribution_Permutation_Test(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                          intermediate_t=None, d_red=2, random_state=42, exp_memo='2', 
                                          num_permutations=1000, mmd_gamma=1.0):
    """
    Performs a permutation test to evaluate whether the predicted and test gene expression 
    distributions are significantly different using W2 and MMD.

    Returns:
    - A dictionary containing W2, MMD distances, permutation test results, and p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Define intermediate time points
    intermediate_only_points = intermediate_t  

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract gene expression values for test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract trajectory-based predicted distributions at each intermediate time point
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break

        # Extract gene expression at this time step
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]

        # Assign to the corresponding intermediate time point
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    permutation_results = {}

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # Compute observed metrics
        observed_w2 = wasserstein_distance(test_vals, predicted_vals)
        observed_mmd = maximum_mean_discrepancy(test_vals, predicted_vals, gamma=mmd_gamma)

        # Perform permutation test
        combined_vals = np.concatenate([test_vals, predicted_vals])
        permuted_w2, permuted_mmd = [], []

        for _ in range(num_permutations):
            np.random.shuffle(combined_vals)  
            perm_test_sample = combined_vals[:len(test_vals)]
            perm_pred_sample = combined_vals[len(test_vals):]

            permuted_w2.append(wasserstein_distance(perm_test_sample, perm_pred_sample))
            permuted_mmd.append(maximum_mean_discrepancy(perm_test_sample, perm_pred_sample, gamma=mmd_gamma))

        # Compute p-values
        p_w2 = np.mean(np.array(permuted_w2) >= observed_w2)
        p_mmd = np.mean(np.array(permuted_mmd) >= observed_mmd)

        # Store results
        permutation_results[time] = {
            "Observed W2": observed_w2, "p_W2": p_w2,
            "Observed MMD": observed_mmd, "p_MMD": p_mmd
        }

    return permutation_results


In [None]:
## Full metrics With Sinkhorn (permutation test)


import numpy as np
import torch
import os
from geomloss import SamplesLoss  # Sinkhorn divergence
from sklearn.utils import resample


def sinkhorn_divergence(X, Y, epsilon=0.5):
    """Compute Sinkhorn divergence with entropy regularization."""
    X = torch.from_numpy(X.reshape(-1, 1)).float()
    Y = torch.from_numpy(Y.reshape(-1, 1)).float()

    min_samples = min(X.shape[0], Y.shape[0])
    X, Y = X[:min_samples], Y[:min_samples]

    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=epsilon)

    return (
        sinkhorn_loss(X, Y).item()
        - 0.5 * sinkhorn_loss(X, X).item()
        - 0.5 * sinkhorn_loss(Y, Y).item()
    )


def permutation_test_sinkhorn(test_vals, pred_vals, num_permutations=1000, epsilon=0.5):
    """Permutation test for Sinkhorn divergence."""
    observed_stat = sinkhorn_divergence(test_vals, pred_vals, epsilon)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        perm_stat = sinkhorn_divergence(perm_test_sample, perm_pred_sample, epsilon)
        permuted_stats.append(perm_stat)

    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, p_value


def Compare_Distribution_Sinkhorn(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                  intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                  num_permutations=1000, sinkhorn_epsilon=0.5):
    """Compute Sinkhorn divergence with correct scaling and permutation test."""

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping method not available.")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Correct scaling
    num_snapshots = len(X1_trpts)
    scaling_factor = num_snapshots / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    predicted_data = []
    for snapshot_idx in scaled_intermediate_indices:
        if snapshot_idx > max_i:
            break
        X1_trpt = X1_trpts[snapshot_idx]
        if np.isnan(X1_trpt).any():
            continue
        predicted_vals = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_data.append(predicted_vals)

    results = []
    for time, test_vals, pred_vals in zip(intermediate_t, test_data, predicted_data):

        # Match sample sizes
        min_size = min(len(test_vals), len(pred_vals))
        test_vals, pred_vals = resample(test_vals, n_samples=min_size, random_state=42), \
                               resample(pred_vals, n_samples=min_size, random_state=42)

        # Compute Sinkhorn and permutation test
        sinkhorn_stat, p_sinkhorn = permutation_test_sinkhorn(
            test_vals, pred_vals, num_permutations, sinkhorn_epsilon
        )

        results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "Sinkhorn Divergence": sinkhorn_stat,
            "Sinkhorn p-value": p_sinkhorn
        })

    output_dir = os.path.join(result_dir, 'output', exp_memo)
    os.makedirs(output_dir, exist_ok=True)

    df_results = pd.DataFrame(results)
    csv_path = os.path.join(output_dir, f"Sinkhorn_metrics_{gene_of_interest}.csv")
    df_results.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    return results

In [None]:
## Sample 1

## Permuation by other statistical tests (permutation test)

import numpy as np
import torch
import pandas as pd
import warnings
from scipy.stats import ks_2samp, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample  # Resampling for matching sizes

warnings.simplefilter("ignore")  # Suppress warnings


def total_variation_distance(p, q):
    """Compute Total Variation (TV) distance between two probability distributions."""
    return 0.5 * np.abs(p - q).sum()


def kl_divergence(p, q):
    """Compute the Kullback-Leibler (KL) divergence."""
    p = np.clip(p, 1e-10, None)  # Avoid zero division
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))


def match_sample_sizes(X, Y):
    """Resample the larger array to match the size of the smaller one."""
    min_size = min(len(X), len(Y))
    X_resampled = resample(X, n_samples=min_size, replace=False, random_state=42)
    Y_resampled = resample(Y, n_samples=min_size, replace=False, random_state=42)
    return X_resampled, Y_resampled


def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    """Perform a permutation test for a given statistic and return mean ± std."""
    observed_stat = stat_func(test_vals, pred_vals)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        permuted_stat = stat_func(perm_test_sample, perm_pred_sample)
        permuted_stats.append(permuted_stat)

    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, mean_perm, std_perm, p_value


def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    """
    Computes multiple statistical metrics (KS test, TV, KL, and Jensen-Shannon).
    Includes permutation tests to assess significance.
    Saves results to CSV files.

    Returns:
    - Dictionary containing statistics and permutation test p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    intermediate_only_points = intermediate_t  
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract predicted distributions
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    metric_results = []

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # **Ensure test and predicted values have the same size**
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        # **Compute different statistics**
        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)  # Kolmogorov-Smirnov test
        
        # Compute distribution distances
        tv_distance = total_variation_distance(np.histogram(test_vals, bins=50, density=True)[0],
                                               np.histogram(predicted_vals, bins=50, density=True)[0])
        kl_div = kl_divergence(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])
        js_div = jensenshannon(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])

        # **Permutation Tests**
        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)

        # **Store results**
        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "Total Variation Distance": tv_distance,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl_div,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js_div,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js
        })

    # Convert results to DataFrame and save as CSV
    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)  # Ensure directory exists

        df_results = pd.DataFrame(metric_results)
        output_csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(output_csv_path, index=False)
        print(f"Results saved to {output_csv_path}")

    return metric_results



In [None]:
## Sample 3

## Permuation by other statistical tests (permutation test)

import numpy as np
import torch
import pandas as pd
import warnings
from scipy.stats import ks_2samp, wasserstein_distance, ttest_ind
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample  # Resampling for matching sizes

warnings.simplefilter("ignore")  # Suppress warnings


def total_variation_distance(p, q):
    """Compute Total Variation (TV) distance between two probability distributions."""
    return 0.5 * np.abs(p - q).sum()


def kl_divergence(p, q):
    """Compute the Kullback-Leibler (KL) divergence."""
    p = np.clip(p, 1e-10, None)  # Avoid zero division
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))


def match_sample_sizes(X, Y):
    """Resample the larger array to match the size of the smaller one."""
    min_size = min(len(X), len(Y))
    X_resampled = resample(X, n_samples=min_size, replace=False, random_state=42)
    Y_resampled = resample(Y, n_samples=min_size, replace=False, random_state=42)
    return X_resampled, Y_resampled


def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    """Perform a permutation test for a given statistic and return mean ± std."""
    observed_stat = stat_func(test_vals, pred_vals)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        permuted_stat = stat_func(perm_test_sample, perm_pred_sample)
        permuted_stats.append(permuted_stat)

    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, mean_perm, std_perm, p_value


def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    """
    Computes multiple statistical metrics (KS test, TV, KL, JS, and Mean Comparison).
    Includes permutation tests to assess significance.
    Saves results to CSV files.

    Returns:
    - Dictionary containing statistics and permutation test p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Correctly scale the intermediate time points
    num_snapshots = len(X1_trpts)
    scaling_factor = num_snapshots / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    # Extract predicted distributions using scaled indices
    predicted_distributions = {t: [] for t in intermediate_t}
    for i, time_idx in enumerate(scaled_intermediate_indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_distributions[intermediate_t[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_t]

    # Store results
    metric_results = []

    for i, time in enumerate(intermediate_t):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # **Ensure test and predicted values have the same size**
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        # **Compute different statistics**
        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)  # Kolmogorov-Smirnov test

        # Compute distribution distances
        tv_distance = total_variation_distance(np.histogram(test_vals, bins=50, density=True)[0],
                                               np.histogram(predicted_vals, bins=50, density=True)[0])
        kl_div = kl_divergence(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])
        js_div = jensenshannon(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])

        # **Permutation Tests**
        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)

        # **Mean and Standard Deviation Comparison**
        mean_test = np.mean(test_vals)
        mean_pred = np.mean(predicted_vals)
        std_test = np.std(test_vals)
        std_pred = np.std(predicted_vals)
        mean_diff = mean_test - mean_pred
        std_diff = std_test - std_pred

        # Perform a t-test to compare means
        t_stat, t_pval = ttest_ind(test_vals, predicted_vals, equal_var=False)

        # **Store results**
        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "Total Variation Distance": tv_distance,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl_div,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js_div,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js,
            "Mean Test": mean_test,
            "Mean Predicted": mean_pred,
            "Mean Difference": mean_diff,
            "Standard Deviation Test": std_test,
            "Standard Deviation Predicted": std_pred,
            "Standard Deviation Difference": std_diff,
            "T-test Statistic": t_stat,
            "T-test p-value": t_pval
        })

    # Convert results to DataFrame and save as CSV
    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)

        df_results = pd.DataFrame(metric_results)
        output_csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(output_csv_path, index=False)
        print(f"Results saved to {output_csv_path}")

    return metric_results


In [None]:
## Permuation by statistical tests (permutation test) - stem cell data

from scipy.stats import ks_2samp, wasserstein_distance, ttest_ind
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample
import numpy as np
import pandas as pd
import torch
from geomloss import SamplesLoss
import os
import warnings

warnings.simplefilter("ignore")

def total_variation_distance(p, q):
    return 0.5 * np.abs(p - q).sum()

def kl_divergence(p, q):
    p = np.clip(p, 1e-10, None)
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))

def match_sample_sizes(X, Y):
    min_size = min(len(X), len(Y))
    return resample(X, n_samples=min_size, random_state=42), resample(Y, n_samples=min_size, random_state=42)

def sinkhorn_divergence(X, Y, epsilon=0.5):
    """Compute Sinkhorn divergence using PyTorch and geomloss."""
    X = torch.from_numpy(X.reshape(-1, 1)).float()
    Y = torch.from_numpy(Y.reshape(-1, 1)).float()

    min_samples = min(X.shape[0], Y.shape[0])
    X, Y = X[:min_samples], Y[:min_samples]

    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=epsilon)
    return (
        sinkhorn_loss(X, Y).item()
        - 0.5 * sinkhorn_loss(X, X).item()
        - 0.5 * sinkhorn_loss(Y, Y).item()
    )

def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    observed_stat = stat_func(test_vals, pred_vals)
    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []
    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]
        permuted_stats.append(stat_func(perm_test_sample, perm_pred_sample))
    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)
    return observed_stat, mean_perm, std_perm, p_value

def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("❌ PCA method not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    scaling_factor = len(X1_trpts) / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    kde_test_data = [pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t]

    predicted_distributions = {t: [] for t in intermediate_t}
    for i, time_idx in enumerate(scaled_intermediate_indices):
        if time_idx > max_i:
            continue
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            continue
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_distributions[intermediate_t[i]].extend(gene_expression_values)

    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_t]

    metric_results = []
    for i, time in enumerate(intermediate_t):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)

        hist_test = np.histogram(test_vals, bins=50, density=True)[0]
        hist_pred = np.histogram(predicted_vals, bins=50, density=True)[0]

        tv = total_variation_distance(hist_test, hist_pred)
        kl = kl_divergence(hist_test, hist_pred)
        js = jensenshannon(hist_test, hist_pred)
        w2 = wasserstein_distance(test_vals, predicted_vals)
        sink = sinkhorn_divergence(test_vals, predicted_vals)

        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)
        perm_w2, mean_perm_w2, std_perm_w2, p_w2 = permutation_test(wasserstein_distance, test_vals, predicted_vals, num_permutations)
        perm_sink, mean_perm_sink, std_perm_sink, p_sink = permutation_test(sinkhorn_divergence, test_vals, predicted_vals, num_permutations)

        mean_diff = np.mean(test_vals) - np.mean(predicted_vals)
        std_diff = np.std(test_vals) - np.std(predicted_vals)
        t_stat, t_pval = ttest_ind(test_vals, predicted_vals, equal_var=False)

        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "W2 Distance": w2,
            "Permutation W2 ± Std": f"{mean_perm_w2:.4f} ± {std_perm_w2:.4f}",
            "p-value W2": p_w2,
            "Sinkhorn Distance": sink,
            "Permutation Sinkhorn ± Std": f"{mean_perm_sink:.4f} ± {std_perm_sink:.4f}",
            "p-value Sinkhorn": p_sink,
            "Total Variation Distance": tv,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js,
            "Mean Difference": mean_diff,
            "Standard Deviation Difference": std_diff,
            "T-test Statistic": t_stat,
            "T-test p-value": t_pval
        })

    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)
        df_results = pd.DataFrame(metric_results)
        csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(csv_path, index=False)
        print(f"✅ Results saved to: {csv_path}")

    return metric_results




In [None]:
## CSV files for all the genes - Stem Cell data

import os
import numpy as np
import pandas as pd

# Define parameters
genes_of_interest = gene_names # Set of genes
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,3]  # Intermediate time points
d_red = 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
result_dir = '%s/assets/Transport_genes/' % main_dir
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Create directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Store results in a list
all_gene_results = []

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Compute statistical metrics
        results = Compare_Distribution_Statistics(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo, num_permutations=100,
            save_csv=False  # Prevent saving individual CSVs for each gene
        )

        # Convert to DataFrame and append to the list
        df_results = pd.DataFrame(results)
        df_results["Gene"] = gene  # Add gene column

        # **Add Reject Null Hypothesis column (Yes/No)**
        df_results["Reject Null Hypothesis (TV)"] = df_results["p-value TV"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (KL)"] = df_results["p-value KL"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (T-test)"] = df_results["T-test p-value"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (W2)"] = df_results["p-value W2"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (Sinkhorn)"] = df_results["p-value Sinkhorn"].apply(lambda p: "No" if p > 0.05 else "Yes")
        all_gene_results.append(df_results)

    except Exception as e:
        print(f"Error processing gene {gene}: {e}")

# Combine results for all genes into one DataFrame
if all_gene_results:
    combined_df = pd.concat(all_gene_results, ignore_index=True)

    # **Save combined results for all genes**
    combined_csv_path = os.path.join(output_dir, "all_genes_statistical_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"All genes' results saved to {combined_csv_path}")

    # **Save each metric separately**
    metrics = {
        "TV": ["Gene", "Time", "Total Variation Distance", "Permutation TV ± Std", "p-value TV", "Reject Null Hypothesis (TV)"],
        "KL": ["Gene", "Time", "KL Divergence", "Permutation KL ± Std", "p-value KL", "Reject Null Hypothesis (KL)"],
        "W2": ["Gene", "Time", "W2 Distance", "Permutation W2 ± Std", "p-value W2", "Reject Null Hypothesis (W2)"],
        "Sinkhorn": ["Gene", "Time", "Sinkhorn Distance", "Permutation Sinkhorn ± Std", "p-value Sinkhorn", "Reject Null Hypothesis (Sinkhorn)"],
        "T-test": ["Gene", "Time", "Mean Test", "T-test Statistic", "T-test p-value", "Reject Null Hypothesis (T-test)"],

    }

    for metric, cols in metrics.items():
        if all(col in combined_df.columns for col in cols):  # Ensure columns exist
            metric_df = combined_df[cols]
            metric_csv_path = os.path.join(output_dir, f"{metric}_metrics.csv")
            metric_df.to_csv(metric_csv_path, index=False)
            print(f"{metric} results saved to {metric_csv_path}")
        else:
            print(f"Warning: Some columns missing for {metric} metric.")

    # **Print out genes that did not reject the null hypothesis**
    
    # Genes where "Reject Null Hypothesis (TV)" is "No"
    genes_no_TV = combined_df[combined_df["Reject Null Hypothesis (TV)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in TV:", genes_no_TV)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_KL = combined_df[combined_df["Reject Null Hypothesis (KL)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in KL:", genes_no_KL)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_W2 = combined_df[combined_df["Reject Null Hypothesis (W2)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in W2:", genes_no_W2)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_Sinkhorn = combined_df[combined_df["Reject Null Hypothesis (Sinkhorn)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in Sinkhorn:", genes_no_Sinkhorn)

    # Genes where both TV and KL are "No"
    genes_no_TV_KL = combined_df[(combined_df["Reject Null Hypothesis (TV)"] == "No") &
                                 (combined_df["Reject Null Hypothesis (KL)"] == "No")]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in BOTH TV and KL:", genes_no_TV_KL)

else:
    print("No valid results generated.")



In [None]:
## Venn Diagram of null hypothesis results for Stem cell data
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"


# Define output directory
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Load existing metric CSV files
tv_df = pd.read_csv(os.path.join(output_dir, "TV_metrics.csv"))
kl_df = pd.read_csv(os.path.join(output_dir, "KL_metrics.csv"))
sinkhorn_df = pd.read_csv(os.path.join(output_dir, "Sinkhorn_metrics.csv"))

# Ensure consistent column names
tv_col = "Reject Null Hypothesis (TV)"
kl_col = "Reject Null Hypothesis (KL)"
sinkhorn_col = "Reject Null Hypothesis (Sinkhorn)"

# Identify unique time points
time_points = sorted(tv_df["Time"].unique())

# Iterate over each time point
for time_point in time_points:
    # Filter DataFrames for the current time point
    tv_time_df = tv_df[tv_df["Time"] == time_point]
    kl_time_df = kl_df[kl_df["Time"] == time_point]
    sinkhorn_time_df = sinkhorn_df[sinkhorn_df["Time"] == time_point]

    # Identify genes with "No" in each metric
    genes_no_tv = set(tv_time_df.loc[tv_time_df[tv_col] == "No", "Gene"].unique())
    genes_no_kl = set(kl_time_df.loc[kl_time_df[kl_col] == "No", "Gene"].unique())
    genes_no_sinkhorn = set(sinkhorn_time_df.loc[sinkhorn_time_df[sinkhorn_col] == "No", "Gene"].unique())

    # Genes with at least one metric showing "No"
    genes_at_least_one_no = genes_no_tv | genes_no_kl | genes_no_sinkhorn

    # Genes with at least two metrics showing "No"
    genes_at_least_two_no = (
        (genes_no_tv & genes_no_kl) |
        (genes_no_tv & genes_no_sinkhorn) |
        (genes_no_kl & genes_no_sinkhorn)
    )

    # Genes with all three metrics showing "No"
    genes_all_three_no = genes_no_tv & genes_no_kl & genes_no_sinkhorn

    # Prepare summary DataFrame
    summary_df = pd.DataFrame({
        "Criteria": [
            "At least one metric (TV, KL, or Sinkhorn) showing No",
            "At least two metrics (TV, KL, or Sinkhorn) showing No",
            "All three metrics (TV, KL, and Sinkhorn) showing No"
        ],
        "Genes": [
            ", ".join(sorted(genes_at_least_one_no)),
            ", ".join(sorted(genes_at_least_two_no)),
            ", ".join(sorted(genes_all_three_no))
        ]
    })

    # Print results for current time point
    print(f"\n🔹 Summary of Genes by Null Hypothesis Rejection (Time {time_point}):")
    print(summary_df.to_string(index=False))

    # Save summary to CSV for current time point
    summary_csv_path = os.path.join(output_dir, f"genes_null_hypothesis_summary_time_{time_point}.csv")
    summary_df.to_csv(summary_csv_path, index=False)
    print(f"✅ Summary for Time {time_point} saved to {summary_csv_path}")

    # Create Venn diagram for current time point
    plt.figure(figsize=(10, 8))
    venn = venn3(
        [genes_no_tv, genes_no_kl, genes_no_sinkhorn],
        set_labels=('TV', 'KL', 'Sinkhorn')
    )

    plt.title(f"Genes NOT Rejecting Null Hypothesis (Time {time_point})", fontsize=16)

    # Add summary annotations
    x_pos = 0.6
    y_pos = 0.6
    step = 0.07

    plt.text(x_pos, y_pos, f"Genes in ≥1 metric: {len(genes_at_least_one_no)}", fontsize=12)
    plt.text(x_pos, y_pos - step, f"Genes in ≥2 metrics: {len(genes_at_least_two_no)}", fontsize=12)
    plt.text(x_pos, y_pos - 2*step, f"Genes in all 3 metrics: {len(genes_all_three_no)}", fontsize=12)

    plt.tight_layout()

    # Save Venn diagram
    venn_path = os.path.join(output_dir, f"genes_venn_diagram_time_{time_point}.png")
    plt.savefig(venn_path, dpi=300)
    plt.show()

    print(f"✅ Venn diagram for Time {time_point} saved to {venn_path}")


In [None]:
## CSV files for all the genes - Stem Cell data

import os
import numpy as np
import pandas as pd

# Define parameters
genes_of_interest = ['DSP'] # Set of genes
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [2]  # Intermediate time points
d_red = 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'
result_dir = '%s/assets/Transport_genes/' % main_dir
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Create directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Store results in a list
all_gene_results = []

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Compute statistical metrics
        results = Compare_Distribution_Statistics(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo, num_permutations=100,
            save_csv=False  # Prevent saving individual CSVs for each gene
        )

        # Convert to DataFrame and append to the list
        df_results = pd.DataFrame(results)
        df_results["Gene"] = gene  # Add gene column

        # **Add Reject Null Hypothesis column (Yes/No)**
        df_results["Reject Null Hypothesis (TV)"] = df_results["p-value TV"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (KL)"] = df_results["p-value KL"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (T-test)"] = df_results["T-test p-value"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (W2)"] = df_results["p-value W2"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (Sinkhorn)"] = df_results["p-value Sinkhorn"].apply(lambda p: "No" if p > 0.05 else "Yes")
        all_gene_results.append(df_results)

    except Exception as e:
        print(f"Error processing gene {gene}: {e}")

# Combine results for all genes into one DataFrame
if all_gene_results:
    combined_df = pd.concat(all_gene_results, ignore_index=True)

    # **Save combined results for all genes**
    combined_csv_path = os.path.join(output_dir, "all_genes_statistical_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"All genes' results saved to {combined_csv_path}")

    # **Save each metric separately**
    metrics = {
        "TV": ["Gene", "Time", "Total Variation Distance", "Permutation TV ± Std", "p-value TV", "Reject Null Hypothesis (TV)"],
        "KL": ["Gene", "Time", "KL Divergence", "Permutation KL ± Std", "p-value KL", "Reject Null Hypothesis (KL)"],
        "W2": ["Gene", "Time", "W2 Distance", "Permutation W2 ± Std", "p-value W2", "Reject Null Hypothesis (W2)"],
        "Sinkhorn": ["Gene", "Time", "Sinkhorn Distance", "Permutation Sinkhorn ± Std", "p-value Sinkhorn", "Reject Null Hypothesis (Sinkhorn)"],
        "T-test": ["Gene", "Time", "Mean Test", "T-test Statistic", "T-test p-value", "Reject Null Hypothesis (T-test)"],

    }

    for metric, cols in metrics.items():
        if all(col in combined_df.columns for col in cols):  # Ensure columns exist
            metric_df = combined_df[cols]
            metric_csv_path = os.path.join(output_dir, f"{metric}_metrics.csv")
            metric_df.to_csv(metric_csv_path, index=False)
            print(f"{metric} results saved to {metric_csv_path}")
        else:
            print(f"Warning: Some columns missing for {metric} metric.")

    # **Print out genes that did not reject the null hypothesis**
    
    # Genes where "Reject Null Hypothesis (TV)" is "No"
    genes_no_TV = combined_df[combined_df["Reject Null Hypothesis (TV)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in TV:", genes_no_TV)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_KL = combined_df[combined_df["Reject Null Hypothesis (KL)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in KL:", genes_no_KL)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_W2 = combined_df[combined_df["Reject Null Hypothesis (W2)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in W2:", genes_no_W2)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_Sinkhorn = combined_df[combined_df["Reject Null Hypothesis (Sinkhorn)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in Sinkhorn:", genes_no_Sinkhorn)

    # Genes where both TV and KL are "No"
    genes_no_TV_KL = combined_df[(combined_df["Reject Null Hypothesis (TV)"] == "No") &
                                 (combined_df["Reject Null Hypothesis (KL)"] == "No")]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in BOTH TV and KL:", genes_no_TV_KL)

else:
    print("No valid results generated.")



In [None]:
## Combining genes (preprocess)
import os
import pandas as pd
import glob

result_dir = '%s/assets/Transport_genes/' % main_dir
csv_dir = os.path.join(result_dir, 'Sample 3')


# List of prefixes
prefixes = ['sinkhorn']

# Mapping prefixes to their corresponding p-value column names
pval_columns = {
    'sinkhorn': 'p_Sinkhorn_1',
}


# Loop through each prefix
for prefix in prefixes:
    # Use glob to find matching files
    matching_files = glob.glob(os.path.join(csv_dir, f"{prefix}*.csv"))
    
    if not matching_files:
        print(f"No files found for prefix '{prefix}'")
        continue
    
    # Read and concatenate files
    combined_df = pd.concat([pd.read_csv(f) for f in matching_files], ignore_index=True)

    # Check if the p-value column exists
    p_col = pval_columns[prefix]
    if p_col in combined_df.columns:
        combined_df[f'Reject Null Hypothesis ({prefix[:-1].upper()})'] = combined_df[p_col].apply(
            lambda p: 'No' if p > 0.05 else 'Yes'
        )
    else:
        print(f"⚠️ Warning: Column '{p_col}' not found in files for prefix '{prefix}'.")
    
    # Save the combined dataframe to CSV
    output_filename = os.path.join(csv_dir, f"{prefix}_metric.csv")
    combined_df.to_csv(output_filename, index=False)
    
    print(f"✅ Combined {len(matching_files)} files into '{output_filename}' with hypothesis test results.")


In [None]:
## Sample 1 (only one time point)

import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

# Define output directory
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Load existing metric CSV files
tv_df = pd.read_csv(os.path.join(output_dir, "TV_metrics.csv"))
kl_df = pd.read_csv(os.path.join(output_dir, "KL_metrics.csv"))
sinkhorn_df = pd.read_csv(os.path.join(output_dir, "sinkhorn_metrics.csv"))

# Identify genes with "No" in each metric
genes_no_tv = set(tv_df.loc[tv_df["Reject Null Hypothesis (TV)"] == "No", "Gene"].unique())
genes_no_kl = set(kl_df.loc[kl_df["Reject Null Hypothesis (KL)"] == "No", "Gene"].unique())
genes_no_sinkhorn = set(sinkhorn_df.loc[sinkhorn_df["Reject Null Hypothesis (SINKHOR)"] == "No", "Gene"].unique())

# Genes with at least one metric showing "No"
genes_at_least_one_no = genes_no_tv.union(genes_no_kl).union(genes_no_sinkhorn)

# Genes with at least two metrics showing "No"
genes_at_least_two_no = (
    (genes_no_tv & genes_no_kl) | (genes_no_tv & genes_no_sinkhorn) | (genes_no_kl & genes_no_sinkhorn)
)

# Genes with all three metrics showing "No"
genes_all_three_no = genes_no_tv & genes_no_kl & genes_no_sinkhorn

# Prepare dataframes for easy viewing and saving
summary_df = pd.DataFrame({
    "Criteria": [
        "At least one metric (TV, KL, or Sinkhorn) showing No",
        "At least two metrics (TV, KL, or Sinkhorn) showing No",
        "All three metrics (TV, KL, and Sinkhorn) showing No"
    ],
    "Genes": [
        ", ".join(sorted(genes_at_least_one_no)),
        ", ".join(sorted(genes_at_least_two_no)),
        ", ".join(sorted(genes_all_three_no))
    ]
})

# Print results
print("\nSummary of Genes by Null Hypothesis Rejection:")
print(summary_df.to_string(index=False))

# Save to CSV
summary_csv_path = os.path.join(output_dir, "genes_null_hypothesis_summary.csv")
summary_df.to_csv(summary_csv_path, index=False)
print(f"\nSummary saved to {summary_csv_path}")



# Venn diagram
plt.figure(figsize=(12, 8))
venn = venn3(
    [genes_no_tv, genes_no_kl, genes_no_sinkhorn],
    set_labels=('TV', 'KL', 'SINKHORN')
)

plt.title("Genes NOT Rejecting Null Hypothesis", fontsize=16)

# Additional summaries (positions adjusted)
x_pos = 0.6  # Move further to the right
y_pos = 0.6
step = 0.07

plt.text(x_pos, y_pos, f"Genes in at least 1 metric: {len(genes_no_tv | genes_no_kl | genes_no_sinkhorn)}", fontsize=12)
plt.text(x_pos, y_pos - step, f"Genes in at least 2 metrics: {len((genes_no_tv & genes_no_kl) | (genes_no_tv & genes_no_sinkhorn) | (genes_no_kl & genes_no_sinkhorn))}", fontsize=12)
plt.text(x_pos, y_pos - 2*step, f"Genes in all 3 metrics: {len(genes_no_tv & genes_no_kl & genes_no_sinkhorn)}", fontsize=12)

plt.tight_layout()

# Save figure
venn_path = os.path.join(output_dir, "genes_venn_diagram.png")
plt.savefig(venn_path, dpi=300)
plt.show()

print(f"✅ Venn diagram saved to {venn_path}")



In [None]:
## Showing distribution plots based on the gene sets which have similar distributions

## PDF combination for comparison distributions 
## save the gene expression dynamics png as pdf


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/KDE_Intermediate_Only_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = ['DSP', 'ENPP5', 'EPB41L5', 'KRTCAP3', 'MMP2', 'RAB25', 'SERINC2', 'TMEM45B']  # List of genes 
## Selected genes (no difference by TV)  ['AXL', 'HNMT', 'TMEM45B', 'SSH3', 'SHROOM3', 'PRSS22', 'SERINC2', 'EVPL', 'GALNT3', 'DSP', 'ELMO3', 'KRTCAP3', 'KRT19', 'C1orf116', 'CDS1', 'INADL']
## Selected genes (no difference by TV and KL) ['HNMT', 'TMEM45B', 'SHROOM3', 'PRSS22', 'SERINC2', 'KRTCAP3', 'C1orf116', 'CDS1']  
pdf_path = f"{output_dir}/selected_gene_expression_KDE_Intermediate_Only(two metrics).pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))
