# GPA for transportation of gene expression datasets

This notebook aims to apply GPA (our baseline model) to recover the trajectory of an empirical (EMT) gene expression dataset with unknown dynamics.
We first reduce the dimensionality of the original data (175) to an accessible dimension through PCA, then apply GPA in the latent space, and reconstruct the result to the original high-dimensional space.

## Dataset

* 6 snapshots at different timepoints (day 0, 1, 2, 3, 4, 8)
* samplesize: differ by days
* dimension: 175
* The trajectory will be visualized in 2D principal component axes

## Functionality

* Given source and target days, GPA transports the source dataset toward the target dataset from a deterministic particle dynamics for the gradient flow of (Lipschitz regularized or limited transportation speed) KL divergergence.
$$D_{KL}^L(P\|Q) = \sup_{\| \nabla \phi \| \leq L} \left \{\mathbb{E}_P[\phi] - \log \mathbb{E}_Q [\exp(\phi)] \right \}$$
$$\partial_t P + \nabla \cdot \left(P v\right) = \partial_t P - \nabla \cdot \left(P \nabla \phi \right) = 0$$
$$\dot{X} = - \nabla \phi_t(X)$$

* Snapshots at all the intermediate points will be highlighted by default, but it can be specified.
* $W_2$ distance will be calculated between the particle trajectory and the snapshots.


In [None]:
import tensorflow as tf
result_dir = '%s/assets/Transport_genes/' % main_dir

def load_W(filename):
    with open(filename, "rb") as fr:
        W, b, p = pk.load(fr)

    W = [tf.Variable(w,  dtype=tf.float32) for w in W]
    b = [tf.Variable(b_,  dtype=tf.float32) for b_ in b]

    return W, b, p

def v(x, t, W, b):   # neural newtork for time-dependent vectorfield
    num_layers = len(W)
    activation_ftn = tf.nn.tanh
        
    h = tf.concat([x, t*tf.ones([x.shape[0], 1], dtype=tf.float32)], axis=1)
    for l in range(0,num_layers-1):
        h = activation_ftn(tf.add(tf.matmul(h, W[l]), b[l]))
    out=tf.add(tf.matmul(h, W[-1]), b[-1])

    return out

def time_integration(x0, T, dt):
    x = tf.constant(x0, dtype=tf.float32)
    xs = [x0]
    for i in range(int(T/dt)):
        vv = v(x, dt*i, W, b)
        x += dt * vv
        xs.append(x.numpy())
    return xs


def generate_animation(days, intermediate_days, X1_trpts, dt, physical_dt, img_src, d_red = 2, vs = None):
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    fig, ax = plt.subplots()
    ims = []

    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}    
    
    
    for i, X1_trpt in enumerate(X1_trpts):  # trajectories
        if np.isnan(X1_trpt).any():
            break
        X1_trpt_vis = X1_trpt
        
        if type(vs) != type(None) and i < len(X1_trpts)-1:
            X1_trpt_vis_next = X1_trpts[i+1]
            vs_vis = (X1_trpt_vis_next-X1_trpt_vis) / dt
            im = ax.quiver(X1_trpt_vis[:, 0], X1_trpt_vis[:, 1], vs_vis[:, 0], vs_vis[:, 1], 
                           width=0.003, headwidth=7, headlength=15, headaxislength=7, zorder=15)                            
        else:
            im = ax.scatter(X1_trpt_vis[:,0], X1_trpt_vis[:,1], color=colors[days[0]], 
                            alpha=1.0, s=0.7, zorder=10, label=f'day {days[0]}') # transported source 
            for t in days[1:]:
                X2_vis = pca.transform(mats[t])
                ax.scatter(X2_vis[:,0], X2_vis[:,1], color=colors[t], 
                   alpha=1.0, s=0.7, zorder=5, label=f'day {t}') # target 
            
            for t in intermediate_days:
                X1_intermediate_vis = pca.transform(mats[t])
                ax.scatter(X1_intermediate_vis[:,0], X1_intermediate_vis[:,1], color='lightgray', 
                        alpha=0.3, s=0.7, zorder=1, label=f'day {t}')
        #ax.set_xlim([-5,6])
        #ax.set_ylim([-5,6])
        
        #leg = ax.legend(loc='upper right')
        ttl = ax.text(0.5,1.05, "t = %.3f" % (physical_dt*i), \
                      bbox={'facecolor':'w', 'alpha':0.5, 'pad':5}, \
                      transform=ax.transAxes, ha="center")
        ims.append([im, ttl])#, leg])
        
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=200)
    writergif = animation.PillowWriter(fps=3)
    ani.save(img_src, writer=writergif)
    plt.clf()
    display(Image(filename = img_src))



    

In [None]:
## Static plot function for piecewise GPA (sample 3 - stem cell and sample 5 - synthetic)


## Static plot function for piecewise GPA

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_three_timepoints(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        intermediate_days[0]: '#2ca02c',  # Green
        middle_t: '#ff7f0e',  # Orange
        intermediate_days[1]: '#8c564b',  # Brown
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=12, zorder = 10, label=f'Time {source_t} (Training Data)')
    ax1.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=12, zorder = 10, label=f'Time {middle_t} (Training Data)')
    ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=12, zorder = 10, label=f'Time {target_t} (Training Data)')

    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax1.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.2, alpha=1.0, s=15, zorder = 20,  label=f'Time {t} (Test Data)')

    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=5, zorder = 1)

    
    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=24)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=24)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 24)
    ax1.set_ylabel("PC 2", fontsize = 24)
    ax1.tick_params(axis='both', which='major', labelsize=24)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=12, zorder = 10, label=f'Time {source_t} (Training Data)')
    ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=12, zorder = 10, label=f'Time {middle_t} (Training Data)')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=12, zorder = 10, label=f'Time {target_t} (Training Data)')

    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax2.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.2, alpha=1.0, s=15, zorder = 20,  label=f'Time {t} (Test Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 24)
    ax2.set_ylabel("PC 2", fontsize = 24)
    ax2.tick_params(axis='both', which='major', labelsize=24)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
    time_labels = []
    for label in labels:
        try:
            time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
            time_labels.append((time_value, label))  # Store (time, label) pairs
        except ValueError:
            time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
    # Sort legend by time values
    time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
    sorted_labels = [item[1] for item in time_labels]
    sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
    # **Increase marker size in legend**
    for handle in sorted_handles:
        if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
            handle.set_markersize(30)  # Adjust marker size
    
    # Create a separate figure for the legend
    # Create a separate figure for the legend
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
    ax_legend.axis("off")  # Remove axes
    
    # Create legend with smaller markers and tighter spacing
    legend = ax_legend.legend(
        sorted_handles,
        sorted_labels,
        fontsize=20,         # Font size of text
        loc='center',
        ncol=len(sorted_labels),
        markerscale=2,     # Scale down marker size in legend
        handlelength=1.5,    # Length of the marker line
        handletextpad=0.2    # Padding between marker and label text
    )    

    # Save the legend separately
    legend_path = os.path.join(result_dir, "legend_only.png")
    fig_legend.savefig(legend_path, bbox_inches="tight")
    plt.close(fig_legend)  # Close the legend figure
    
    print(f"Legend saved separately at: {legend_path}")



In [None]:
## Static plot function for simple GPA (Sample 1 - EMT data)


import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_two_timepoints(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None, output_file_snapshots_only=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    #Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    #ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], facecolors='none', edgecolors=color_map[source_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {source_t}')
    #ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], facecolors='none', edgecolors=color_map[target_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {target_t}')


    # Plot intermediate time points
    for t in intermediate_days:
        X_intermediate_vis = pca.transform(mats[t])
        ax1.scatter(X_intermediate_vis[:, 0], X_intermediate_vis[:, 1], color=color_map[t], facecolors='none', edgecolors=color_map[t], linewidths=1.0, alpha=0.75, s=10, zorder = 20,  label=f'Day {t} (Test Data)')

    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder = 1)

    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=27)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=27)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 27)
    ax1.set_ylabel("PC 2", fontsize = 27)
    ax1.tick_params(axis='both', which='major', labelsize=30)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    # Plot only source, intermediates, and target
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8,  zorder = 15, label=f'Time {source_t} (Training Data)')
    #ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=10,  zorder = 10, label=f'Time {middle_t}')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8,  zorder = 10, label=f'Time {target_t} (Training Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 27)
    ax2.set_ylabel("PC 2", fontsize = 27)
    ax2.tick_params(axis='both', which='major', labelsize=27)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
    time_labels = []
    for label in labels:
        try:
            time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
            time_labels.append((time_value, label))  # Store (time, label) pairs
        except ValueError:
            time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
    # Sort legend by time values
    time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
    sorted_labels = [item[1] for item in time_labels]
    sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
    # **Increase marker size in legend**
    for handle in sorted_handles:
        if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
            handle.set_markersize(30)  # Adjust marker size
    
    # Create a separate figure for the legend
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
    ax_legend.axis("off")  # Remove axes
        
    # Create legend with larger markers for scatter plots
    legend = ax_legend.legend(
        sorted_handles, sorted_labels, fontsize=20, loc='center',
        ncol=len(sorted_labels), markerscale=2)
    
    # Save the legend separately
    legend_path = os.path.join(result_dir, "legend_only.png")
    fig_legend.savefig(legend_path, bbox_inches="tight")
    plt.close(fig_legend)  # Close the legend figure
    
    print(f"Legend saved separately at: {legend_path}")





    






In [None]:
## Static plot function for simple GPA (NDPR and Clinical data)


import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_two_timepoints_no_middle(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_with_snapshots=None, output_file_without_snapshots=None, output_file_snapshots_only=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    # Define color gradient for snapshots
    num_snapshots = len(X1_trpts)
    colormap = cm.viridis  # Can change to "plasma", "inferno", etc.
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)

    # Create a normalization object for the color mapping
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])  # Needed for color bar

    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        #intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**
    fig1, ax1 = plt.subplots(figsize=(8, 6))

    # Plot source, intermediates, and target
    X1_vis = pca.transform(mats[source_t])
    #Xm_vis = pca.transform(mats[middle_t])
    X2_vis = pca.transform(mats[target_t])
    #ax1.scatter(X1_vis[:, 0], X1_vis[:, 1], facecolors='none', edgecolors=color_map[source_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {source_t}')
    #ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], facecolors='none', edgecolors=color_map[target_t], linewidths=0.5, alpha=1.0, s=20, zorder=10, label=f'Time {target_t}')



    # Plot snapshots from X1_trpts with a color gradient
    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder = 1)

    # Add a small color bar inside the plot
    cax = ax1.inset_axes([1.02, 0.2, 0.03, 0.6])  # [x, y, width, height] (relative position)
    
    # Create the colorbar with increased size
    cbar = plt.colorbar(sm, cax=cax)
    
    # Set manual tick positions
    cbar.set_ticks(np.linspace(0, 4, 5))  # Ensures ticks at 0, 1, 2, 3, 4
    
    # Optional: Explicitly set tick labels if needed
    cbar.set_ticklabels([0, 1, 2, 3, 4])  
    
    # Increase colorbar label font size
    cbar.set_label("Time", fontsize=20)  
    
    # Increase colorbar tick font size
    cbar.ax.tick_params(labelsize=20)

    # Adjust colorbar thickness
    #cbar.ax.set_aspect(20)  # Increase aspect ratio to make it thicker
   
    # Set labels and title
    ax1.set_xlabel("PC 1", fontsize = 20)
    ax1.set_ylabel("PC 2", fontsize = 20)
    ax1.tick_params(axis='both', which='major', labelsize=20)  # Increase tick sizes
    #ax1.legend(loc='upper right', fontsize= 24)
    ax1.set_title("")

    # Save or show the plot
    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # **Plot 2: Without Snapshots**
    fig2, ax2 = plt.subplots(figsize=(8, 6))

    # Plot only source, intermediates, and target
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8,  zorder = 15, label=f'Time {source_t} (Training Data)')
    #ax2.scatter(Xm_vis[:, 0], Xm_vis[:, 1], color=color_map[middle_t], alpha=1.0, s=10,  zorder = 10, label=f'Time {middle_t}')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8,  zorder = 10, label=f'Time {target_t} (Training Data)')

    # Set labels and title
    ax2.set_xlabel("PC 1", fontsize = 20)
    ax2.set_ylabel("PC 2", fontsize = 20)
    ax2.tick_params(axis='both', which='major', labelsize=20)  # Increase tick sizes
    #ax2.legend(loc='upper right', fontsize='small')
    ax2.set_title("")

    # Save or show the plot
    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()


    
    # Extract legend elements
    handles, labels = ax1.get_legend_handles_labels()
    
    # Proceed only if legend elements exist
    if handles and labels:
        # Extract numeric values from "Time X (Input Data)" and "Time X (Test Data)"
        time_labels = []
        for label in labels:
            try:
                time_value = int(label.split(" ")[1])  # Extract the numerical value after "Time"
                time_labels.append((time_value, label))  # Store (time, label) pairs
            except ValueError:
                time_labels.append((float('inf'), label))  # Place non-time labels at the end
    
        # Sort legend by time values
        time_labels.sort(key=lambda x: x[0])  # Sort by the extracted numeric value
        sorted_labels = [item[1] for item in time_labels]
        sorted_handles = [handles[labels.index(label)] for label in sorted_labels]
    
        # **Increase marker size in legend**
        for handle in sorted_handles:
            if isinstance(handle, plt.Line2D):  # Ensure we're modifying scatter markers
                handle.set_markersize(30)  # Adjust marker size
    
        # Create a separate figure for the legend
        fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Adjust size as needed
        ax_legend.axis("off")  # Remove axes
    
        # Create legend with larger markers for scatter plots
        legend = ax_legend.legend(
            sorted_handles, sorted_labels, fontsize=20, loc='center',
            ncol=len(sorted_labels), markerscale=6  # Increase scatter marker size
        )
    
        # Save the legend separately
        legend_path = os.path.join(result_dir, "legend_only.png")
        fig_legend.savefig(legend_path, bbox_inches="tight")
        plt.close(fig_legend)  # Close the legend figure
    
        print(f"Legend saved separately at: {legend_path}")
    else:
        print("No legend elements found — skipping separate legend plot.")
    





    






In [None]:
## Static plot function for simple GPA (NDPR and Clinical data) - separate legend


import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import os
import pickle as pk

def generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days, intermediate_days, X1_trpts, mats, d_red=26,
    output_file_with_snapshots=None,
    output_file_without_snapshots=None,
    output_file_snapshots_only=None
):
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    num_snapshots = len(X1_trpts)
    colormap = cm.viridis
    snapshot_colors = [colormap(i / num_snapshots) for i in range(num_snapshots)]

    # Rescale time values for the color bar
    time_values = np.linspace(0, physical_dt * num_snapshots, num_snapshots)
    norm = mcolors.Normalize(vmin=time_values.min(), vmax=time_values.max())
    sm = cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])

    source_t, _, target_t = days[0], days[1], days[-1]

    color_map = {
        source_t: 'magenta',  # Blue
        target_t: '#008080'   # Red
    }

    # --- Plot 1: WITH snapshots (No inline colorbar) ---
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    X1_vis = pca.transform(mats[source_t])
    X2_vis = pca.transform(mats[target_t])

    for i, X1_trpt in enumerate(X1_trpts):
        if np.isnan(X1_trpt).any():
            continue
        X1_hat_vis = X1_trpt
        ax1.scatter(X1_hat_vis[:, 0], X1_hat_vis[:, 1], color=snapshot_colors[i], alpha=0.75, s=2, zorder=1)

    ax1.set_xlabel("PC 1", fontsize=32)
    ax1.set_ylabel("PC 2", fontsize=32)
    ax1.tick_params(axis='both', labelsize=32)
    ax1.set_title("")

    if output_file_with_snapshots:
        plt.savefig(output_file_with_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITH snapshots saved to {output_file_with_snapshots}")
        plt.close(fig1)
    else:
        plt.show()

    # --- Save colorbar separately ---
    fig_cb, ax_cb = plt.subplots(figsize=(10, 1))
    cb = plt.colorbar(sm, cax=ax_cb, orientation='horizontal')
    cb.set_ticks([0, 4])
    cb.set_ticklabels(["Pre-treatment", "Post-treatment"])
    cb.ax.tick_params(labelsize=24)
    cb.set_label("Time", fontsize=24)
    cb_path = os.path.join(result_dir, "trajectory_colorbar_only.png")
    fig_cb.savefig(cb_path, dpi=300, bbox_inches='tight')
    plt.close(fig_cb)
    print(f"Standalone colorbar saved to {cb_path}")

    # --- Plot 2: WITHOUT snapshots ---
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color=color_map[source_t], alpha=1.0, s=8, zorder=15, label='Pre-treatment')
    ax2.scatter(X2_vis[:, 0], X2_vis[:, 1], color=color_map[target_t], alpha=1.0, s=8, zorder=10, label='Post-treatment')

    ax2.set_xlabel("PC 1", fontsize=32)
    ax2.set_ylabel("PC 2", fontsize=32)
    ax2.tick_params(axis='both', labelsize=32)
    ax2.set_title("")

    if output_file_without_snapshots:
        plt.savefig(output_file_without_snapshots, dpi=300, bbox_inches='tight')
        print(f"Static trajectory plot WITHOUT snapshots saved to {output_file_without_snapshots}")
        plt.close(fig2)
    else:
        plt.show()

    # --- Legend (just for pre/post-treatment) ---
    handles, labels = ax2.get_legend_handles_labels()
    if handles:
        fig_legend, ax_legend = plt.subplots(figsize=(6, 2))
        ax_legend.axis("off")
        ax_legend.legend(
            handles, labels, fontsize=16, loc='center',
            ncol=len(labels), markerscale=3, handletextpad=0.5
        )
        legend_path = os.path.join(result_dir, "legend_only.png")
        fig_legend.savefig(legend_path, bbox_inches="tight", dpi=300)
        plt.close(fig_legend)
        print(f"Legend saved separately at: {legend_path}")



In [None]:
## Two pieces GPA (Stem cell data)

exp_name = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")



In [None]:
## Two pieces GPA (Synthetic data)

exp_name = 'f_Lip=5e-2-t_size=50-network=64_64_64_26d' #'times_10_particles_200_3'
d_red = 26
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0,2,4], [1,3], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")



In [None]:
## One piece of GPA (EMT data)

exp_name = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 8
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 2, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 2, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")


In [None]:
## One piece of GPA (NDPR data)

exp_name = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (PA3)

exp_name = 'Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (Patient 862)

exp_name = 'Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:
## One piece of GPA (Patient 887)

exp_name = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

img_src = result_dir + exp_name + '-movie-original.gif' 
img_src1 = result_dir + exp_name + '-movie-original-with-arrows.gif'
if os.path.exists(img_src): # try loading saved movie
    display(Image(filename = img_src))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src, d_red = d_red)
if os.path.exists(img_src1): # try loading saved movie
    display(Image(filename = img_src1))
else:
    generate_animation([0, 4], [], X1_trpts, dt, physical_dt, img_src1, d_red = d_red, vs = "vecotorfield")

In [None]:

## Full trajectories on static plot (samples 1 - EMT data)

exp_name = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 8
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints(
    days=[0, 4],
    intermediate_days=[2],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:

## Full trajectories on static plot (Sample 5 - synthetic data)

exp_name = 'f_Lip=5e-2-t_size=50-network=64_64_64_26d' #'times_10_particles_200_3'
d_red = 26
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_three_timepoints(
    days=[0, 2, 4],
    intermediate_days=[1, 3],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (samples 3 - stem cell data)

exp_name = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_three_timepoints(
    days=[0, 2, 4],
    intermediate_days=[1, 3],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (NDPR data)

exp_name = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (PA3)

exp_name = 'Palbo_BMC_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (Patient 862)

exp_name = 'Palbo_862_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
## Full trajectories on static plot (Patient 887)

exp_name = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 2
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_with_snapshots = f"{result_dir}{exp_name}_static_trajectory_with_snapshots_circle.png"
output_file_without_snapshots = f"{result_dir}{exp_name}_static_trajectory_without_snapshots.png"
output_file_snapshots_only = f"{result_dir}{exp_name}_static_trajectory_snapshots_only.png"

# Generate both plots
generate_static_trajectory_plots_two_timepoints_no_middle_legend(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_with_snapshots=output_file_with_snapshots,
    output_file_without_snapshots=output_file_without_snapshots
)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import matplotlib.colors as mcolors

def generate_static_trajectory_plots_cell_types(days, intermediate_days, X1_trpts, mats, d_red=26, output_file_cell_type_source=None, output_file_cell_type_target=None, output_file_cell_type_legend=None):
    """
    Generate two static trajectory plots:
    1. With snapshots from X1_trpts using a color gradient.
    2. Without snapshots, showing only main time points.
    """
    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)


    source_t, middle_t, target_t = days[0], days[1], days[-1]
    
    # Define colors for time points
    color_map = {
        source_t: '#1f77b4',  # Blue
        #intermediate_days[0]: '#ff7f0e',  # Orange
        target_t: '#d62728'  # Red
    }

    # **Plot 1: With Snapshots**

    # Same PCA transformation and cell type extraction as before
    X1_vis = pca.transform(mats[source_t])
    X2_vis = pca.transform(mats[target_t])
    cell_types_X1 = cell_types_by_day[source_t]
    cell_types_X2 = cell_types_by_day[target_t]
    unique_cell_types = np.unique(np.concatenate([cell_types_X1, cell_types_X2]))
    cell_type_palette = dict(zip(unique_cell_types, sns.color_palette("tab20", len(unique_cell_types))))
    
    # -----------------------
    # Plot 1: X1 colored, X2 gray (no legend)
    fig1, ax1 = plt.subplots(figsize=(8, 6))
    ax1.scatter(X2_vis[:, 0], X2_vis[:, 1], color='lightgray', alpha=0.5, s=8)
    for cell_type in unique_cell_types:
        idx = cell_types_X1 == cell_type
        ax1.scatter(X1_vis[idx, 0], X1_vis[idx, 1], 
                    color=cell_type_palette[cell_type], s=8, alpha=1.0)
    ax1.set_xlabel("PC 1", fontsize=20)
    ax1.set_ylabel("PC 2", fontsize=20)
    ax1.tick_params(axis='both', which='major', labelsize=18)
    ax1.set_title(f"Untreated Samples colored by Cell Type", fontsize=18)
    plt.tight_layout()
    if output_file_cell_type_source:
        plt.savefig(output_file_cell_type_source, dpi=300, bbox_inches='tight')
        plt.close(fig1)
    else:
        plt.show()
    
    # -----------------------
    # Plot 2: X2 colored, X1 gray (no legend)
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    ax2.scatter(X1_vis[:, 0], X1_vis[:, 1], color='lightgray', alpha=0.5, s=8)
    for cell_type in unique_cell_types:
        idx = cell_types_X2 == cell_type
        ax2.scatter(X2_vis[idx, 0], X2_vis[idx, 1], 
                    color=cell_type_palette[cell_type], s=8, alpha=1.0)
    ax2.set_xlabel("PC 1", fontsize=20)
    ax2.set_ylabel("PC 2", fontsize=20)
    ax2.tick_params(axis='both', which='major', labelsize=18)
    ax2.set_title(f"Treated Samples colored by Cell Type", fontsize=18)
    plt.tight_layout()
    if output_file_cell_type_target:
        plt.savefig(output_file_cell_type_target, dpi=300, bbox_inches='tight')
        plt.close(fig2)
    else:
        plt.show()

        

    # Use circle markers instead of patches for legend
    legend_elements = [
        mlines.Line2D(
            [], [], marker='o', color='w',
            markerfacecolor=cell_type_palette[cell_type],
            markersize=8, label=cell_type
        )
        for cell_type in unique_cell_types
    ]
    
    # Create circle markers for legend entries
    legend_elements = [
        mlines.Line2D(
            [], [], marker='o', color='w',
            markerfacecolor=cell_type_palette[cell_type],
            markersize=8, label=cell_type
        )
        for cell_type in unique_cell_types
    ]
    
    # Create figure and axis (just for the legend)
    fig_leg, ax_leg = plt.subplots()
    fig_leg.set_figwidth(8)  # Initial size; will be adjusted
    fig_leg.set_figheight(6)
    
    # Hide axes
    ax_leg.axis('off')
    
    # Add legend to axis (not directly to plt)
    legend = ax_leg.legend(
        handles=legend_elements,
        loc='center',
        frameon=True,
        fontsize=14,
        ncol=1,
        title='Cell Types',
        title_fontsize=14,
        borderpad=1
    )
    
    # Resize the figure to tightly fit the legend
    fig_leg.canvas.draw()
    bbox = legend.get_window_extent().transformed(fig_leg.dpi_scale_trans.inverted())
    fig_leg.set_size_inches(bbox.width + 0.5, bbox.height + 0.5)  # Add a little padding
    
    # Save only, no display
    if output_file_cell_type_legend:
        plt.savefig(output_file_cell_type_legend, dpi=300, bbox_inches='tight')
        plt.close(fig_leg)
    


    






In [None]:
## Cell types plots

exp_name = 'Palbo_BMC_nofibroblast_malignant_dim20-f_Lip=5e-3-t_size=50-network=64_64_64' #'times_10_particles_200_3'
d_red = 20
filename = result_dir + exp_name + ".pickle"
W, b, p = load_W(filename)

# load PCA
if dim_red_method == 'EMT_PCA':
    pca_filename = "emt_pca_%d.pkl" % d_red
elif dim_red_method == 'PCA':
    pca_filename = "pca_%d.pkl" % d_red
else:
    print("PCA mapping for the reduction method and dimension is not available")

with open(data_dir + pca_filename,"rb") as fr:
    [pca] = pk.load(fr)

dt = p['numerical_ts'][-1]/200
X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)

physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]

# Define output filenames
output_file_cell_type_source = f"{result_dir}{exp_name}_cell_type_source.png"
output_file_cell_type_target = f"{result_dir}{exp_name}_cell_type_target.png"
output_file_cell_type_legend = f"{result_dir}{exp_name}_cell_type_legend.png"


# Generate both plots
generate_static_trajectory_plots_cell_types(
    days=[0, 4],
    intermediate_days=[ ],
    X1_trpts=X1_trpts,
    mats=mats,
    d_red=d_red,
    output_file_cell_type_source=output_file_cell_type_source,
    output_file_cell_type_target=output_file_cell_type_target,
    output_file_cell_type_legend=output_file_cell_type_legend
)

In [None]:
## ## This is for EMT data, Time [0 , 2, 4]
## Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_EMT(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    #subtrajectory_colors = ['red', 'blue', 'brown']
    subtrajectory_colors = ['green']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "gray", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Individual_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory {label_mapping[label]}'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 2, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0,2, 4])  
    ax1.set_xticklabels([0, 2, 4], fontsize=32)
    ax1.tick_params(axis='y', labelsize=32)
    
    ax1.set_xlabel('Time', fontsize=32)
    ax1.set_ylabel('Gene Expression', fontsize=32)
    ax1.set_title(f'Single Cell {gene_of_interest} Expression Dynamics', fontsize=32)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    legend_patches = [
        mlines.Line2D([], [], color="green", linestyle="-", linewidth=3, 
                      label="Gene dynamics of each single cell")
    ]

    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    


In [None]:
# Plot for EMT data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [2]
#intermediate_t = [4]

d_red= 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_EMT(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf for EMT data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Individual_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Individual_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for Stem Cell data, Time [0 , 1,  2, 3,  4]
## Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_mESC(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    last_day = mats[day2]

    last_day_reduced = pca.transform(last_day).astype(np.float32)
    
    # Perform KMeans clustering with the optimal number of clusters
    kmeans = KMeans(n_clusters=optimal_k, random_state=40)
    kmeans.fit(last_day_reduced)
    last_day_labels = kmeans.labels_
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_clusters.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")

    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)

    #mask = last_day_labels == 0
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    
    
    # Define **subtrajectory colors** (for cell trajectories)
    #subtrajectory_colors = ['red', 'blue']
    subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    #violin_colors = ["#3cb44b", "#f58231", "#3cb44b", "#f58231", "#3cb44b"]  # Green, Orange, Purple
    violin_colors = ["black", "gray", "black", "gray", "black"] 
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Individual_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.7, linewidth=1.0 
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory {label_mapping[label]}'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 1, 2, 3, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 1, 2, 3, 4])  
    ax1.set_xticklabels([0, 1, 2, 3, 4], fontsize=35)
    ax1.tick_params(axis='y', labelsize=35)
    
    ax1.set_xlabel('Time', fontsize=35)
    ax1.set_ylabel('Gene Expression', fontsize=35)
    ax1.set_title(f'Single Cell {gene_of_interest} Expression Dynamics', fontsize=35)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    legend_patches = [
        mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
                      label="Gene dynamics of each single cell")
    ]

    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data"),
        mpatches.Patch(color="gray", label="Test Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    



    
    

In [None]:
# Plot for Stem cell data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,2,3]
#intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_mESC(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf for stem cell data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Individual_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Individual_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for breast cancer cell line's data, Time [0 , 4]
# Plot gene dynamis for each trajectory

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches




## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_NDPR(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_deviation.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    real_cell_types = np.array(cell_ids_by_day[day2])
    #unique_cell_types = np.unique(real_cell_types)
    unique_cell_types = unique_labels
    #subtrajectory_colors = list(sns.color_palette("tab20", len(unique_cell_types)))
    #subtrajectory_colors = ['green', 'orange', 'purple', 'blue', 'red', 'brown']
    subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory of {label_mapping[label]} phenotypic shift'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 4])  
    #ax1.set_xticklabels([0, 4], fontsize=32)
    ax1.set_xticklabels(["Pre-treatment", "Post-treatment"], fontsize=46)
    ax1.tick_params(axis='y', labelsize=46)
    
    ax1.set_xlabel('Time', fontsize=46)
    ax1.set_ylabel('Gene Expression', fontsize=46)
    ax1.set_title(f'{gene_of_interest}', fontsize=46)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    
    legend_patches = [
        mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
                      label="Hallmark dynamics of each single cell")
    ]




    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()


In [None]:
## Plot the results for breast cancer cell line data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
#intermediate_t = [1,2,3]
intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_NDPR(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")

In [None]:
## save the gene expression dynamics png as pdf - Breast cancer cell line data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "Palbo_NDPR_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

In [None]:
## ## This is for clinical data, Time [0 , 4]
## Plot gene dynamis for each trajectory (deviation cell types)

import seaborn as sns  # Required for violin plots
import numpy as np
import matplotlib.patches as mpatches



## Subtrajectroies defined by source
def Average_gene_dynamics_whole_saveonly_single_trajectory_clinical(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                              intermediate_t = [1], 
                              d_red=2, random_state=42, exp_memo = '2'):

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)
    
    # load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = "emt_pca_%d.pkl" % d_red
    elif dim_red_method == 'PCA':
        pca_filename = "pca_%d.pkl" % d_red
    else:
        print("PCA mapping for the reduction method and dimension is not available")
    
    with open(data_dir + pca_filename,"rb") as fr:
        [pca] = pk.load(fr)
    
    dt = p['numerical_ts'][-1]/200
    X1_trpts = time_integration(pca.transform(mats[0]), T = p['numerical_ts'][-1], dt = dt)
    
    physical_dt = dt * p['ts'][-1] / p['numerical_ts'][-1]
    
    intermediate_t = np.array(intermediate_t)
    
    if len(intermediate_t) == 0:
        intermediate_t = range(source_t+1, target_t)
        
    # data parameters
    day1, day2 = source_t, target_t


    # --------
    N_source = N_samples_cls[day1]
    N_target = N_samples_cls[day2]
        

    X1_trpt = X1_trpts[-1]
    
    
    contrast_colors = [
    '#1f77b4',  # blue
    '#2ca02c',  # green
    '#ff7f0e',  # orange
    '#8c564b',  # brown
    '#d62728',  # red 
    '#9467bd'  # purple (to be used for index 8)
    ]

    # Create a color mapping for the specific indices
    colors = {0: contrast_colors[0], 1: contrast_colors[1], 2: contrast_colors[2], 3: contrast_colors[3], 4: contrast_colors[4], 8: contrast_colors[5]}

    
    # Step 1: Perform clustering analysis on the last day's cell states from mats
    
    # Load previously saved cluster labels
    cluster_save_path = f"{result_dir}{exp_memo}_X1_hat_deviation.csv"
    if not os.path.exists(cluster_save_path):
        raise FileNotFoundError(f"Cluster labels file not found: {cluster_save_path}")
    
    df_clusters = pd.read_csv(cluster_save_path)
    X1_hat_labels = df_clusters["Cluster_Label"].values  # Load saved labels

    # Print the number of unique labels in last_day_labels
    unique_labels = np.unique(X1_hat_labels)
    print(f"Number of unique labels in X1_hat_labels: {len(unique_labels)}")
    print(f"Unique labels: {unique_labels}")
    
    # Define a function to create colors for the subgroups using a predefined set of colors
    def get_subgroup_colors(labels, colors):
        unique_labels = np.unique(labels)
        if len(colors) < len(unique_labels):
            raise ValueError("Not enough colors for the number of unique labels.")
        subgroup_colors = {label: colors[i] for i, label in enumerate(unique_labels)}
        return subgroup_colors

    # Define specific sets of colors for the blue and red subgroups
    blue_colors = ['#1f77b4', '#878ceb', '#104E8B', '#87CEEB', '#4682B4', '#6495ED', '#5F9EA0']  # Add more shades of blue as needed
    red_colors = ['#d62728',  '#eb8787', '#FF4500', '#DC143C', '#FF6347', '#B22222', '#8B0000']  # Add more shades of red as needed
    light_red_colors = ['#f99fa1', '#ffb1b1', '#ffaf86', '#f48585', '#ffb5a5', '#ff9c9c', '#ff5f5f']
    
    # Get the subgroup colors based on the labels
    subgroup_colors_blue = get_subgroup_colors(X1_hat_labels, blue_colors)
    subgroup_colors_red = get_subgroup_colors(X1_hat_labels, red_colors)
    
    
    # Extract the gene index for the gene of interest
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    
    # Extract gene expression values from mats[day1], intermediate time points, and mats[day2]
    X1_vis_pca = pca.transform(mats[source_t])
    X1_vis_i_pca = pca.inverse_transform(X1_vis_pca)
    X2_vis_pca = pca.transform(mats[target_t])
    X2_vis_i_pca = pca.inverse_transform(X2_vis_pca)

    gene_expression_X1 = X1_vis_i_pca[:, gene_index]
    gene_expression_X2 = X2_vis_i_pca[:, gene_index]

    gene_expression_intermediates = []
    for t in intermediate_t:
        X1_intermediate_vis_pca = pca.transform(mats[t])
        X1_intermediate_vis_i_pca = pca.inverse_transform(X1_intermediate_vis_pca)
        gene_expression_intermediates.append(X1_intermediate_vis_i_pca[:, gene_index])

    # Extract gene expression values from X1_trpts based on the given condition
    
    gene_expression_X1_trpts = np.concatenate([pca.inverse_transform(X1_trpt)[:, gene_index] for i, X1_trpt in enumerate(X1_trpts) if i % index == 0 and i <= max_i])
    
    # Combine all gene expression values
    all_gene_expression_values = np.concatenate([gene_expression_X1, *gene_expression_intermediates, gene_expression_X2, gene_expression_X1_trpts])

    gene_expression_X1_normalized = gene_expression_X1
    gene_expression_intermediates_normalized = gene_expression_intermediates
    gene_expression_X2_normalized = gene_expression_X2
    gene_expression_X1_trpts_normalized = gene_expression_X1_trpts
    
    vmin = all_gene_expression_values.min()
    vmax = all_gene_expression_values.max()
    
    # Plot dynamics for X1_trpts with subgroup colors
    indices = range(len(X1_trpts))

    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    

    
    # (1) Plot the averaged gene expressions across X1_trpt at each time point with confidence intervals
    
    # Compute the average gene expression and confidence intervals
    avg_gene_expressions = []
    ci_gene_expressions = []
    
    # Reset normalized gene expression values for X1_trpts
    all_gene_expression_values_normalized_X1 = gene_expression_X1_trpts_normalized
    
    # Use indices with the specified step size defined by `index`
    indices = range(0, len(X1_trpts), index)

    
    # Iterate through indices to compute averages and confidence intervals
    for i in indices:
        if i > max_i:  # Apply truncation based on max_i
            break
        X1_trpt = X1_trpts[i]
        if np.isnan(X1_trpt).any():
            break
    
        # Inverse transform the current trajectory
        X1_hat = pca.inverse_transform(X1_trpt)
    
        # Extract gene expression values for the current step
        gene_expression_values = all_gene_expression_values_normalized_X1[:len(X1_hat)]
        all_gene_expression_values_normalized_X1 = all_gene_expression_values_normalized_X1[len(X1_hat):]  # Update to exclude used values
    
        # Compute average and confidence interval
        avg_gene_expressions.append(np.mean(gene_expression_values))
        ci = stats.sem(gene_expression_values) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_values) - 1)
        ci_gene_expressions.append(ci)
    
    # Process intermediate time points
    intermediate_avg_expressions = []
    intermediate_ci_expressions = []
    intermediate_indices = []


    for idx, t in enumerate(intermediate_t):
        gene_expression_intermediate = gene_expression_intermediates_normalized[idx]
        intermediate_avg_expressions.append(np.mean(gene_expression_intermediate))
        ci = stats.sem(gene_expression_intermediate) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_intermediate) - 1)
        intermediate_ci_expressions.append(ci)
    
        # Rescale the intermediate time points to align with `index`
        shifted_value_1 = intermediate_t - 1
        shifted_value_2 = intermediate_t[0] - 1
        shifted_t_1 = t - shifted_value_1
        shifted_t_2 = t - shifted_value_2
        time_index = int((float(shifted_t_2) / (float(max(shifted_t_1)) + 1)) * len(indices))
        intermediate_indices.append(time_index)

    
    # Include first and last time points
    all_avg_expressions = [np.mean(gene_expression_X1_normalized)] + intermediate_avg_expressions + [np.mean(gene_expression_X2_normalized)]
    all_ci_expressions = [
        stats.sem(gene_expression_X1_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X1_normalized) - 1)
    ] + intermediate_ci_expressions + [
        stats.sem(gene_expression_X2_normalized) * stats.t.ppf((1 + 0.95) / 2., len(gene_expression_X2_normalized) - 1)
    ]

        
    all_indices = [0] + intermediate_indices + [len(indices)]
    combined_indices = sorted([day1] + intermediate_t.tolist() + [day2])

    print(combined_indices)

    
    # Ensure extended_indices align with avg_gene_expressions
    extended_indices = np.array([x * index for x in range(len(avg_gene_expressions))])
    
    # Ensure all_indices and extended_indices are NumPy arrays
    combined_indices = np.array(combined_indices)
    extended_indices = np.array(extended_indices)
    
    # Linearly rescale all_indices to be equally distributed in extended_indices
    rescaled_indices = np.interp(
        combined_indices,  # Original indices
        [combined_indices[0], combined_indices[-1]],  # Range of all_indices
        [extended_indices[0], extended_indices[-1]]  # Range of extended_indices
    )

    # Define the filename for saving the plot




 
    # (1) **Assign Labels for Subgroups Based on Step 1**

    
    # Define **subtrajectory colors** (for cell trajectories)
    real_cell_types = np.array(cell_ids_by_day[day2])
    #unique_cell_types = np.unique(real_cell_types)
    unique_cell_types = unique_labels
    #subtrajectory_colors = list(sns.color_palette("tab20", len(unique_cell_types)))
    subtrajectory_colors = ['green', 'orange', 'purple', 'blue', 'red', 'brown']
    #subtrajectory_colors = ['violet']
    
    # Define **violin plot colors** for the three time points
    violin_colors = ["black", "black"]  # Green, Orange, Purple
    
    # Map each subgroup label to a **trajectory color** and shift labels from 0,1 → 1,2
    unique_labels = np.unique(X1_hat_labels)
    subgroup_color_map = {label: subtrajectory_colors[i % len(subtrajectory_colors)] for i, label in enumerate(unique_labels)}
    label_mapping = {old_label: new_label + 1 for new_label, old_label in enumerate(unique_labels)}
    
    # Define filename for saving
    subgroup_output_file = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene_of_interest}.png"
    
    # (2) **Create Figure**
    fig, ax1 = plt.subplots(figsize=(12, 7))
    
    # (3) **Ensure Proper x-axis Scaling**
    num_points = len(indices)
    x_positions = np.linspace(0, 4, num_points)  # Scale to match `[0, 2, 4]`
    
    # (4) **Extract Cell Trajectories for Each Gene**
    cell_trajectories = {cell_idx: [] for cell_idx in range(X1_trpts[0].shape[0])}
    
    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
    
        # Extract **expression values of the gene of interest** from each cell at this time point
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
    
        # Append the expression value at this time to each cell’s trajectory
        for cell_idx, expr_value in enumerate(gene_expression_values):
            cell_trajectories[cell_idx].append(expr_value)
    
    # (5) **Plot Individual Trajectories per Subgroup**
    legend_patches = []  # Store legend handles
    for label in unique_labels:
        first_plotted = False  # Track if we added a legend entry for this subgroup
        
        for cell_idx, traj in cell_trajectories.items():
            if len(traj) != len(x_positions):
                continue  # Ensure trajectories align with time points
    
            if X1_hat_labels[cell_idx] == label:  # Match subgroup label from step 1
                ax1.plot(
                    x_positions, traj,  
                    color=subgroup_color_map[label],  # ✅ Use the **subtrajectory colors**
                    alpha=0.1, linewidth=0.8  
                )
                
                # Add a single legend entry for each subgroup (renaming from 0,1 → 1,2)
                if not first_plotted:
                    legend_patches.append(mpatches.Patch(color=subgroup_color_map[label], label=f'Trajectory of {label_mapping[label]} phenotypic shift'))
                    first_plotted = True
    
    # (6) **Ensure Violin Plots are at `[0, 2, 4]` & Appear in Front**
    violin_data = [
        gene_expression_X1_normalized,
        *gene_expression_intermediates_normalized,
        gene_expression_X2_normalized
    ]
    
    violin_x_positions = np.array([0, 4])  # Ensure correct positions
    
    # 🎻 **Plot Violin Plots with Correct Colors and Transparency**
    for i, (x_pos, data) in enumerate(zip(violin_x_positions, violin_data)):
        violin_parts = sns.violinplot(
            data=[data],  
            ax=ax1,
            inner=None,  # ✅ REMOVE QUARTILE LINES
            linewidth=1.2,
            width=0.7,
            cut=0,
            scale="width",
            color=violin_colors[i],  # ✅ Assign correct color
            alpha=0.8,  # ✅ MAKE TRANSPARENT
            zorder=3  # ✅ BRINGS VIOLINS TO THE FRONT
        )
        
        # **Manually Adjust X-Position of Each Violin**
        for violin in ax1.collections[-1:]:  # Only adjust the last added violin
            for path in violin.get_paths():
                path.vertices[:, 0] += x_pos - path.vertices[:, 0].mean()  
    
    # **Expand x-axis limits to prevent cutting off last violin plot**
    ax1.set_xlim(-0.5, 4.5)  
    
    # 🛠 **Fix x-axis labels and ensure proper alignment**
    ax1.set_xticks([0, 4])  
    #ax1.set_xticklabels([0, 4], fontsize=32)
    ax1.set_xticklabels(["Pre-treatment", "Post-treatment"], fontsize=46)
    ax1.tick_params(axis='y', labelsize=46)
    
    ax1.set_xlabel('Time', fontsize=46)
    ax1.set_ylabel('Gene Expression', fontsize=46)
    ax1.set_title(f'{gene_of_interest}', fontsize=46)


    # 🎨 **Save the main figure without a legend**
    plt.savefig(subgroup_output_file, dpi=300, bbox_inches='tight')
    plt.close()
    

    # 🎨 **Redefine `legend_patches` to Include a Green Bar**
    
    #legend_patches = [
    #    mlines.Line2D([], [], color="violet", linestyle="-", linewidth=3, 
    #                  label="Hallmark dynamics of each single cell")
    #]

    label_descriptions = {
    "low": "Trajectory of low phenotypic shift",
    "medium": "Trajectory of medium phenotypic shift",
    "high": "Trajectory of high phenotypic shift"}


    # Thicker lines using `linewidth`
    legend_patches = [
        mlines.Line2D(
            [], [], color=color, linestyle='-', linewidth=3,  # ← thicker line here
            markersize=10,
            label=f"{label_descriptions.get(ctype, '')}"
        )
        for ctype, color in zip(unique_cell_types, subtrajectory_colors)
    ]


    # 🎨 **Violin Plot Legend**
    violin_legend_patches = [
        mpatches.Patch(color="black", label="Input Data")
    ]
    
    # 🎨 **Create Separate Legend Figure (HORIZONTAL LAYOUT)**
    fig_legend, ax_legend = plt.subplots(figsize=(10, 2))  # Wider aspect ratio for horizontal layout
    ax_legend.axis("off")  # Hide axes
    
    # **Combine both legends**
    combined_legend = legend_patches + violin_legend_patches
    
    ax_legend.legend(
        handles=combined_legend,
        loc="center", fontsize=24, title="",
        title_fontsize=24, ncol=len(combined_legend),  # Horizontal layout
        frameon=True, handletextpad=2, columnspacing=2
    )
    
    # Save the separate legend
    legend_output_file = subgroup_output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
## Plot the results for clinical data


genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
#intermediate_t = [1,2,3]
intermediate_t = [4]

d_red= 2
random_state = 40
exp_memo = 'Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

output_dir = os.path.join(result_dir, 'output', exp_memo)
    


# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Average_gene_dynamics_whole_saveonly_single_trajectory_clinical(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## save the gene expression dynamics png as pdf - clincial data


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/Celltypes_deviated_trajectories_violin_plot_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "Palbo_887_nofibroblast_malignant_Rgene_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes
pdf_path = f"{output_dir}/Celltypes_deviated_trajectories_violin_plot.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))

## Comparsion of Predicted and Test distributions

In [None]:
## Distribution of single genes comparions (Intermediate time points only) - Stem Cell data
from scipy.stats import gaussian_kde
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def Compare_Distribution_Trajectories_Intermediate_mESC(
    source_t, target_t, optimal_k, gene_of_interest, index, max_i,
    intermediate_t=None, d_red=2, random_state=42, exp_memo='2'):

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Extract gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    # Correct snapshot extraction based on scaling
    snapshots_per_day = len(X1_trpts) / (target_t - source_t)
    scaled_intermediate_indices = [int(day * snapshots_per_day) for day in intermediate_t]

    kde_predicted_data = []
    for idx in scaled_intermediate_indices:
        if idx >= len(X1_trpts):
            idx = len(X1_trpts) - 1
        gene_expr_predicted = pca.inverse_transform(X1_trpts[idx])[:, gene_index]
        kde_predicted_data.append(gene_expr_predicted)

    # Visualization setup
    num_plots = len(intermediate_t)
    fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5), sharey=True)

    if num_plots == 1:
        axes = [axes]

    #test_data_colors = ["#2ca02c", "#8c564b"] #Sample 3
    #predicted_colors = ["#1b6420", "#5c3930"] #Sample 3

    test_data_colors = ["#2ca02c", "#8c564b", "#3cb44b"]  #Sample 1
    predicted_colors = ["#1b6420", "#5c3930", "#228B22"]  #Sample 1


    

        

    # Initialize list to store legend handles per intermediate time
    legend_patches_list = []
    
    # Generate KDE plots
    for i, (ax, t, test_vals, pred_vals) in enumerate(zip(axes, intermediate_t, kde_test_data, kde_predicted_data)):
    
        all_vals = np.concatenate([test_vals, pred_vals])
        x_min, x_max = np.min(all_vals), np.max(all_vals)
        x_margin = (x_max - x_min) * 0.2
        x_range = np.linspace(x_min - x_margin, x_max + x_margin, 300)
    
        kde_test = gaussian_kde(test_vals)
        kde_pred = gaussian_kde(pred_vals)
    
        test_density = kde_test(x_range)
        pred_density = kde_pred(x_range)
    
        y_max = max(test_density.max(), pred_density.max()) * 2
    
        ax.fill_between(x_range, test_density, color=test_data_colors[i % len(test_data_colors)], alpha=0.5)
        ax.plot(x_range, test_density, color=test_data_colors[i % len(test_data_colors)], linewidth=2)
    
        ax.fill_between(x_range, pred_density, color=predicted_colors[i % len(predicted_colors)], alpha=0.5)
        ax.plot(x_range, pred_density, color=predicted_colors[i % len(predicted_colors)], linestyle="dashed", linewidth=2)
    
        ax.set_title(f"Time {t}", fontsize=26)
        ax.set_xlabel("Gene Expression", fontsize=26)
        ax.set_ylim(0, y_max)
        ax.set_ylabel("Density", fontsize=26)
        ax.tick_params(axis='both', which='major', labelsize=26)
    
        plt.suptitle(f"KDE for {gene_of_interest}", fontsize=26)
    

        # **Legend Entry for This Time Point**
        legend_patches_list.append([
            # **Test Data: Dashed Line**
            mlines.Line2D([], [], color=test_data_colors[i % len(test_data_colors)], linestyle="solid", linewidth=3,
                          label=f"Test Data time {t}"),
            
            # **Predicted Data: Solid Line**
            mlines.Line2D([], [], color=predicted_colors[i % len(predicted_colors)], linestyle="dashed", linewidth=3,
                          label=f"Predicted time {t}")
        ])
        
    # **Save KDE plot (without legend)**
    output_file = f"{output_dir}/KDE_Intermediate_Only_updated_{gene_of_interest}.png"
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"KDE plot saved: {output_file}")
    
    # **(2) Create a Separate Figure for the Legend**
    fig_legend, ax_legend = plt.subplots(figsize=(len(intermediate_t) * 3, 2))  # Adjust width dynamically
    ax_legend.axis("off")  # Hide axes
    
    # **Flatten legend handles into a single row-style list**
    flattened_legend_patches = []
    for group in legend_patches_list:
        for entry in group:
            flattened_legend_patches.append(entry)
    
    # **Create a Row-Style Legend with Box + Line for Each Entry**
    ax_legend.legend(
        handles=flattened_legend_patches,
        loc="center",
        fontsize=22,
        ncol=2,  # Ensures (Test, Predicted) pairs stay together
        frameon=True,
        handletextpad=1.5,
        columnspacing=2
    )
    
    # **Save the separate legend**
    legend_output_file = output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"Legend plot saved separately at: {legend_output_file}")




In [None]:
## Plot restuls for Stem cell data

genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,3]
#intermediate_t = [2]

d_red= 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

output_dir = os.path.join(result_dir, 'output', exp_memo)
    
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Compare_Distribution_Trajectories_Intermediate_mESC(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")

In [None]:
## PDF combination for comparison distributions  (Stem cell data)
## save the gene expression dynamics png as pdf


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/KDE_Intermediate_Only_updated_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes 
## Selected genes (no difference by TV)  ['AXL', 'HNMT', 'TMEM45B', 'SSH3', 'SHROOM3', 'PRSS22', 'SERINC2', 'EVPL', 'GALNT3', 'DSP', 'ELMO3', 'KRTCAP3', 'KRT19', 'C1orf116', 'CDS1', 'INADL']
## Selected genes (no difference by TV and KL) ['HNMT', 'TMEM45B', 'SHROOM3', 'PRSS22', 'SERINC2', 'KRTCAP3', 'C1orf116', 'CDS1']  
pdf_path = f"{output_dir}/combined_gene_expression_KDE_Intermediate_Only_updated.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))


In [None]:
## Distribution of single genes comparions (Intermediate time points only) - EMT data
from scipy.stats import gaussian_kde
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def Compare_Distribution_Trajectories_Intermediate_EMT(
    source_t, target_t, optimal_k, gene_of_interest, index, max_i,
    intermediate_t=None, d_red=2, random_state=42, exp_memo='2'):

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Extract gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    # Correct snapshot extraction based on scaling
    snapshots_per_day = len(X1_trpts) / (target_t - source_t)
    scaled_intermediate_indices = [int(day * snapshots_per_day) for day in intermediate_t]

    kde_predicted_data = []
    for idx in scaled_intermediate_indices:
        if idx >= len(X1_trpts):
            idx = len(X1_trpts) - 1
        gene_expr_predicted = pca.inverse_transform(X1_trpts[idx])[:, gene_index]
        kde_predicted_data.append(gene_expr_predicted)

    # Visualization setup
    num_plots = len(intermediate_t)
    fig, axes = plt.subplots(1, num_plots, figsize=(6 * num_plots, 5), sharey=True)

    if num_plots == 1:
        axes = [axes]

    #test_data_colors = ["#2ca02c", "#8c564b"] #Sample 3
    #predicted_colors = ["#1b6420", "#5c3930"] #Sample 3

    test_data_colors = ["#f58231", "#911eb4", "#3cb44b"]  #Sample 1
    predicted_colors = ["#D2691E", "#800080", "#228B22"]  #Sample 1

    

        

    # Initialize list to store legend handles per intermediate time
    legend_patches_list = []
    
    # Generate KDE plots
    for i, (ax, t, test_vals, pred_vals) in enumerate(zip(axes, intermediate_t, kde_test_data, kde_predicted_data)):
    
        all_vals = np.concatenate([test_vals, pred_vals])
        x_min, x_max = np.min(all_vals), np.max(all_vals)
        x_margin = (x_max - x_min) * 0.2
        x_range = np.linspace(x_min - x_margin, x_max + x_margin, 300)
    
        kde_test = gaussian_kde(test_vals)
        kde_pred = gaussian_kde(pred_vals)
    
        test_density = kde_test(x_range)
        pred_density = kde_pred(x_range)
    
        y_max = max(test_density.max(), pred_density.max()) * 2
    
        ax.fill_between(x_range, test_density, color=test_data_colors[i % len(test_data_colors)], alpha=0.5)
        ax.plot(x_range, test_density, color=test_data_colors[i % len(test_data_colors)], linewidth=2)
    
        ax.fill_between(x_range, pred_density, color=predicted_colors[i % len(predicted_colors)], alpha=0.5)
        ax.plot(x_range, pred_density, color=predicted_colors[i % len(predicted_colors)], linestyle="dashed", linewidth=2)
    
        ax.set_title(f"Day {t}", fontsize=26)
        ax.set_xlabel("Gene Expression", fontsize=26)
        ax.set_ylim(0, y_max)
        ax.set_ylabel("Density", fontsize=26)
        ax.tick_params(axis='both', which='major', labelsize=26)
    
        plt.suptitle(f"KDE for {gene_of_interest}", fontsize=26)
    

        # **Legend Entry for This Time Point**
        legend_patches_list.append([
            # **Test Data: Dashed Line**
            mlines.Line2D([], [], color=test_data_colors[i % len(test_data_colors)], linestyle="solid", linewidth=3,
                          label=f"Test Data day {t}"),
            
            # **Predicted Data: Solid Line**
            mlines.Line2D([], [], color=predicted_colors[i % len(predicted_colors)], linestyle="dashed", linewidth=3,
                          label=f"Predicted day {t}")
        ])
        
    # **Save KDE plot (without legend)**
    output_file = f"{output_dir}/KDE_Intermediate_Only_updated_{gene_of_interest}.png"
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"KDE plot saved: {output_file}")
    
    # **(2) Create a Separate Figure for the Legend**
    fig_legend, ax_legend = plt.subplots(figsize=(len(intermediate_t) * 3, 2))  # Adjust width dynamically
    ax_legend.axis("off")  # Hide axes
    
    # **Flatten legend handles into a single row-style list**
    flattened_legend_patches = []
    for group in legend_patches_list:
        for entry in group:
            flattened_legend_patches.append(entry)
    
    # **Create a Row-Style Legend with Box + Line for Each Entry**
    ax_legend.legend(
        handles=flattened_legend_patches,
        loc="center",
        fontsize=22,
        ncol=2,  # Ensures (Test, Predicted) pairs stay together
        frameon=True,
        handletextpad=1.5,
        columnspacing=2
    )
    
    # **Save the separate legend**
    legend_output_file = output_file.replace(".png", "_legend.png")
    plt.savefig(legend_output_file, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"Legend plot saved separately at: {legend_output_file}")


In [None]:
## Plot restuls for EMT data


genes_of_interest = gene_names # NANOG, SOX2
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
#intermediate_t = [1,3]
intermediate_t = [2]

d_red= 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'

result_dir = '%s/assets/Transport_genes/' % main_dir

output_dir = os.path.join(result_dir, 'output', exp_memo)
    
# Create the directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Call the function with the current gene
        Compare_Distribution_Trajectories_Intermediate_EMT(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo
        )
    except Exception as e:
        # Handle errors gracefully
        print(f"Error processing gene {gene}: {e}")


In [None]:
## PDF combination for comparison distributions  (EMT data)
## save the gene expression dynamics png as pdf


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/KDE_Intermediate_Only_updated_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = gene_names  # List of genes 
## Selected genes (no difference by TV)  ['AXL', 'HNMT', 'TMEM45B', 'SSH3', 'SHROOM3', 'PRSS22', 'SERINC2', 'EVPL', 'GALNT3', 'DSP', 'ELMO3', 'KRTCAP3', 'KRT19', 'C1orf116', 'CDS1', 'INADL']
## Selected genes (no difference by TV and KL) ['HNMT', 'TMEM45B', 'SHROOM3', 'PRSS22', 'SERINC2', 'KRTCAP3', 'C1orf116', 'CDS1']  
pdf_path = f"{output_dir}/combined_gene_expression_KDE_Intermediate_Only_updated.pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))



## Permutation test for predicted distributions to test data distributions

In [None]:
## Quantify the errors by using W2 distance (permutation test)

import numpy as np
from scipy.stats import wasserstein_distance

def Compare_Distribution_Permutation_Test(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                          intermediate_t=None, d_red=2, random_state=42, exp_memo='2', 
                                          num_permutations=1000):
    """
    Performs a permutation test to evaluate whether the predicted and test gene expression 
    distributions are significantly different.

    Parameters:
    - source_t: Start time point (not included in the plot)
    - target_t: End time point (not included in the plot)
    - optimal_k: Number of clusters for KMeans
    - gene_of_interest: The gene whose expression is analyzed
    - index: Step size for trajectory extraction
    - max_i: Maximum index for trajectory extraction
    - intermediate_t: List of intermediate time points (defaults to [1] if not provided)
    - d_red: Dimensionality reduction method
    - random_state: Random seed
    - exp_memo: Experiment identifier
    - num_permutations: Number of permutations for significance testing.

    Returns:
    - A dictionary containing W2 distances, permutation test results, and p-values.
    """

    # Ensure intermediate_t has a default value
    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Define intermediate time points
    intermediate_only_points = intermediate_t  

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract gene expression values for test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract trajectory-based predicted distributions at each intermediate time point
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break

        # Extract gene expression at this time step
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]

        # Assign to the corresponding intermediate time point
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    permutation_results = {}

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # Compute the observed W2 distance
        observed_w2 = wasserstein_distance(test_vals, predicted_vals)

        # Perform permutation test
        combined_vals = np.concatenate([test_vals, predicted_vals])
        permuted_w2_distances = []

        for _ in range(num_permutations):
            np.random.shuffle(combined_vals)  # Shuffle data
            perm_test_sample = combined_vals[:len(test_vals)]
            perm_pred_sample = combined_vals[len(test_vals):]

            permuted_w2 = wasserstein_distance(perm_test_sample, perm_pred_sample)
            permuted_w2_distances.append(permuted_w2)

        # Compute p-value (proportion of permuted distances ≥ observed W2)
        p_value = np.mean(np.array(permuted_w2_distances) >= observed_w2)

        # Store results
        permutation_results[time] = {
            "Observed W2": observed_w2,
            "Permutation Mean W2": np.mean(permuted_w2_distances),
            "Permutation Std W2": np.std(permuted_w2_distances),
            "p-value": p_value
        }

    # Print summary
    print("\n--- Permutation Test Summary ---")
    for time, results in permutation_results.items():
        print(f"Time {time}:")
        print(f"  Observed W2: {results['Observed W2']:.4f}")
        print(f"  Mean Permutation W2: {results['Permutation Mean W2']:.4f} ± {results['Permutation Std W2']:.4f}")
        print(f"  p-value: {results['p-value']:.4f}")
        if results["p-value"] < 0.05:
            print("  ** Significant Difference (Reject Null Hypothesis) **")
        else:
            print("  No Significant Difference (Cannot Reject Null Hypothesis)")

    return permutation_results



In [None]:
## Full metrics but No Sinkhorn (permutation test)

import numpy as np
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import rbf_kernel

def maximum_mean_discrepancy(X, Y, gamma=1.0):
    """
    Compute Maximum Mean Discrepancy (MMD) between two distributions using an RBF kernel.
    """
    K_xx = rbf_kernel(X[:, None], X[:, None], gamma=gamma)
    K_yy = rbf_kernel(Y[:, None], Y[:, None], gamma=gamma)
    K_xy = rbf_kernel(X[:, None], Y[:, None], gamma=gamma)

    return K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()

def Compare_Distribution_Permutation_Test(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                          intermediate_t=None, d_red=2, random_state=42, exp_memo='2', 
                                          num_permutations=1000, mmd_gamma=1.0):
    """
    Performs a permutation test to evaluate whether the predicted and test gene expression 
    distributions are significantly different using W2 and MMD.

    Returns:
    - A dictionary containing W2, MMD distances, permutation test results, and p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method and dimension is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Define intermediate time points
    intermediate_only_points = intermediate_t  

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract gene expression values for test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract trajectory-based predicted distributions at each intermediate time point
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break

        # Extract gene expression at this time step
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]

        # Assign to the corresponding intermediate time point
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    permutation_results = {}

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # Compute observed metrics
        observed_w2 = wasserstein_distance(test_vals, predicted_vals)
        observed_mmd = maximum_mean_discrepancy(test_vals, predicted_vals, gamma=mmd_gamma)

        # Perform permutation test
        combined_vals = np.concatenate([test_vals, predicted_vals])
        permuted_w2, permuted_mmd = [], []

        for _ in range(num_permutations):
            np.random.shuffle(combined_vals)  
            perm_test_sample = combined_vals[:len(test_vals)]
            perm_pred_sample = combined_vals[len(test_vals):]

            permuted_w2.append(wasserstein_distance(perm_test_sample, perm_pred_sample))
            permuted_mmd.append(maximum_mean_discrepancy(perm_test_sample, perm_pred_sample, gamma=mmd_gamma))

        # Compute p-values
        p_w2 = np.mean(np.array(permuted_w2) >= observed_w2)
        p_mmd = np.mean(np.array(permuted_mmd) >= observed_mmd)

        # Store results
        permutation_results[time] = {
            "Observed W2": observed_w2, "p_W2": p_w2,
            "Observed MMD": observed_mmd, "p_MMD": p_mmd
        }

    return permutation_results


In [None]:
## Full metrics With Sinkhorn (permutation test)


import numpy as np
import torch
import os
from geomloss import SamplesLoss  # Sinkhorn divergence
from sklearn.utils import resample


def sinkhorn_divergence(X, Y, epsilon=0.5):
    """Compute Sinkhorn divergence with entropy regularization."""
    X = torch.from_numpy(X.reshape(-1, 1)).float()
    Y = torch.from_numpy(Y.reshape(-1, 1)).float()

    min_samples = min(X.shape[0], Y.shape[0])
    X, Y = X[:min_samples], Y[:min_samples]

    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=epsilon)

    return (
        sinkhorn_loss(X, Y).item()
        - 0.5 * sinkhorn_loss(X, X).item()
        - 0.5 * sinkhorn_loss(Y, Y).item()
    )


def permutation_test_sinkhorn(test_vals, pred_vals, num_permutations=1000, epsilon=0.5):
    """Permutation test for Sinkhorn divergence."""
    observed_stat = sinkhorn_divergence(test_vals, pred_vals, epsilon)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        perm_stat = sinkhorn_divergence(perm_test_sample, perm_pred_sample, epsilon)
        permuted_stats.append(perm_stat)

    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, p_value


def Compare_Distribution_Sinkhorn(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                  intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                  num_permutations=1000, sinkhorn_epsilon=0.5):
    """Compute Sinkhorn divergence with correct scaling and permutation test."""

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        raise ValueError("PCA mapping method not available.")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Correct scaling
    num_snapshots = len(X1_trpts)
    scaling_factor = num_snapshots / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    predicted_data = []
    for snapshot_idx in scaled_intermediate_indices:
        if snapshot_idx > max_i:
            break
        X1_trpt = X1_trpts[snapshot_idx]
        if np.isnan(X1_trpt).any():
            continue
        predicted_vals = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_data.append(predicted_vals)

    results = []
    for time, test_vals, pred_vals in zip(intermediate_t, test_data, predicted_data):

        # Match sample sizes
        min_size = min(len(test_vals), len(pred_vals))
        test_vals, pred_vals = resample(test_vals, n_samples=min_size, random_state=42), \
                               resample(pred_vals, n_samples=min_size, random_state=42)

        # Compute Sinkhorn and permutation test
        sinkhorn_stat, p_sinkhorn = permutation_test_sinkhorn(
            test_vals, pred_vals, num_permutations, sinkhorn_epsilon
        )

        results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "Sinkhorn Divergence": sinkhorn_stat,
            "Sinkhorn p-value": p_sinkhorn
        })

    output_dir = os.path.join(result_dir, 'output', exp_memo)
    os.makedirs(output_dir, exist_ok=True)

    df_results = pd.DataFrame(results)
    csv_path = os.path.join(output_dir, f"Sinkhorn_metrics_{gene_of_interest}.csv")
    df_results.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    return results

In [None]:
## Sample 1

## Permuation by other statistical tests (permutation test)

import numpy as np
import torch
import pandas as pd
import warnings
from scipy.stats import ks_2samp, wasserstein_distance
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample  # Resampling for matching sizes

warnings.simplefilter("ignore")  # Suppress warnings


def total_variation_distance(p, q):
    """Compute Total Variation (TV) distance between two probability distributions."""
    return 0.5 * np.abs(p - q).sum()


def kl_divergence(p, q):
    """Compute the Kullback-Leibler (KL) divergence."""
    p = np.clip(p, 1e-10, None)  # Avoid zero division
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))


def match_sample_sizes(X, Y):
    """Resample the larger array to match the size of the smaller one."""
    min_size = min(len(X), len(Y))
    X_resampled = resample(X, n_samples=min_size, replace=False, random_state=42)
    Y_resampled = resample(Y, n_samples=min_size, replace=False, random_state=42)
    return X_resampled, Y_resampled


def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    """Perform a permutation test for a given statistic and return mean ± std."""
    observed_stat = stat_func(test_vals, pred_vals)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        permuted_stat = stat_func(perm_test_sample, perm_pred_sample)
        permuted_stats.append(permuted_stat)

    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, mean_perm, std_perm, p_value


def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    """
    Computes multiple statistical metrics (KS test, TV, KL, and Jensen-Shannon).
    Includes permutation tests to assess significance.
    Saves results to CSV files.

    Returns:
    - Dictionary containing statistics and permutation test p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    intermediate_only_points = intermediate_t  
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_only_points
    ]

    # Extract predicted distributions
    predicted_distributions = {t: [] for t in intermediate_only_points}
    indices = range(0, len(X1_trpts), index)

    for i, time_idx in enumerate(indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        if i < len(intermediate_only_points):  
            predicted_distributions[intermediate_only_points[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_only_points]

    # Store results
    metric_results = []

    for i, time in enumerate(intermediate_only_points):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # **Ensure test and predicted values have the same size**
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        # **Compute different statistics**
        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)  # Kolmogorov-Smirnov test
        
        # Compute distribution distances
        tv_distance = total_variation_distance(np.histogram(test_vals, bins=50, density=True)[0],
                                               np.histogram(predicted_vals, bins=50, density=True)[0])
        kl_div = kl_divergence(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])
        js_div = jensenshannon(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])

        # **Permutation Tests**
        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)

        # **Store results**
        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "Total Variation Distance": tv_distance,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl_div,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js_div,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js
        })

    # Convert results to DataFrame and save as CSV
    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)  # Ensure directory exists

        df_results = pd.DataFrame(metric_results)
        output_csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(output_csv_path, index=False)
        print(f"Results saved to {output_csv_path}")

    return metric_results



In [None]:
## Sample 3

## Permuation by other statistical tests (permutation test)

import numpy as np
import torch
import pandas as pd
import warnings
from scipy.stats import ks_2samp, wasserstein_distance, ttest_ind
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample  # Resampling for matching sizes

warnings.simplefilter("ignore")  # Suppress warnings


def total_variation_distance(p, q):
    """Compute Total Variation (TV) distance between two probability distributions."""
    return 0.5 * np.abs(p - q).sum()


def kl_divergence(p, q):
    """Compute the Kullback-Leibler (KL) divergence."""
    p = np.clip(p, 1e-10, None)  # Avoid zero division
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))


def match_sample_sizes(X, Y):
    """Resample the larger array to match the size of the smaller one."""
    min_size = min(len(X), len(Y))
    X_resampled = resample(X, n_samples=min_size, replace=False, random_state=42)
    Y_resampled = resample(Y, n_samples=min_size, replace=False, random_state=42)
    return X_resampled, Y_resampled


def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    """Perform a permutation test for a given statistic and return mean ± std."""
    observed_stat = stat_func(test_vals, pred_vals)

    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []

    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]

        permuted_stat = stat_func(perm_test_sample, perm_pred_sample)
        permuted_stats.append(permuted_stat)

    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)

    return observed_stat, mean_perm, std_perm, p_value


def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    """
    Computes multiple statistical metrics (KS test, TV, KL, JS, and Mean Comparison).
    Includes permutation tests to assess significance.
    Saves results to CSV files.

    Returns:
    - Dictionary containing statistics and permutation test p-values.
    """

    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    # Load PCA
    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("PCA mapping for the reduction method is not available")

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    # Correctly scale the intermediate time points
    num_snapshots = len(X1_trpts)
    scaling_factor = num_snapshots / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    # Extract the gene index
    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1

    # Extract test data distributions
    kde_test_data = [
        pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t
    ]

    # Extract predicted distributions using scaled indices
    predicted_distributions = {t: [] for t in intermediate_t}
    for i, time_idx in enumerate(scaled_intermediate_indices):
        if time_idx > max_i:
            break
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            break
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_distributions[intermediate_t[i]].extend(gene_expression_values)

    # Convert trajectory distributions into a list format
    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_t]

    # Store results
    metric_results = []

    for i, time in enumerate(intermediate_t):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]

        # **Ensure test and predicted values have the same size**
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        # **Compute different statistics**
        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)  # Kolmogorov-Smirnov test

        # Compute distribution distances
        tv_distance = total_variation_distance(np.histogram(test_vals, bins=50, density=True)[0],
                                               np.histogram(predicted_vals, bins=50, density=True)[0])
        kl_div = kl_divergence(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])
        js_div = jensenshannon(np.histogram(test_vals, bins=50, density=True)[0],
                               np.histogram(predicted_vals, bins=50, density=True)[0])

        # **Permutation Tests**
        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)

        # **Mean and Standard Deviation Comparison**
        mean_test = np.mean(test_vals)
        mean_pred = np.mean(predicted_vals)
        std_test = np.std(test_vals)
        std_pred = np.std(predicted_vals)
        mean_diff = mean_test - mean_pred
        std_diff = std_test - std_pred

        # Perform a t-test to compare means
        t_stat, t_pval = ttest_ind(test_vals, predicted_vals, equal_var=False)

        # **Store results**
        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "Total Variation Distance": tv_distance,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl_div,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js_div,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js,
            "Mean Test": mean_test,
            "Mean Predicted": mean_pred,
            "Mean Difference": mean_diff,
            "Standard Deviation Test": std_test,
            "Standard Deviation Predicted": std_pred,
            "Standard Deviation Difference": std_diff,
            "T-test Statistic": t_stat,
            "T-test p-value": t_pval
        })

    # Convert results to DataFrame and save as CSV
    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)

        df_results = pd.DataFrame(metric_results)
        output_csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(output_csv_path, index=False)
        print(f"Results saved to {output_csv_path}")

    return metric_results


In [None]:
## Permuation by statistical tests (permutation test) - stem cell data

from scipy.stats import ks_2samp, wasserstein_distance, ttest_ind
from scipy.spatial.distance import jensenshannon
from scipy.special import rel_entr
from sklearn.utils import resample
import numpy as np
import pandas as pd
import torch
from geomloss import SamplesLoss
import os
import warnings

warnings.simplefilter("ignore")

def total_variation_distance(p, q):
    return 0.5 * np.abs(p - q).sum()

def kl_divergence(p, q):
    p = np.clip(p, 1e-10, None)
    q = np.clip(q, 1e-10, None)
    return np.sum(rel_entr(p, q))

def match_sample_sizes(X, Y):
    min_size = min(len(X), len(Y))
    return resample(X, n_samples=min_size, random_state=42), resample(Y, n_samples=min_size, random_state=42)

def sinkhorn_divergence(X, Y, epsilon=0.5):
    """Compute Sinkhorn divergence using PyTorch and geomloss."""
    X = torch.from_numpy(X.reshape(-1, 1)).float()
    Y = torch.from_numpy(Y.reshape(-1, 1)).float()

    min_samples = min(X.shape[0], Y.shape[0])
    X, Y = X[:min_samples], Y[:min_samples]

    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=2, blur=epsilon)
    return (
        sinkhorn_loss(X, Y).item()
        - 0.5 * sinkhorn_loss(X, X).item()
        - 0.5 * sinkhorn_loss(Y, Y).item()
    )

def permutation_test(stat_func, test_vals, pred_vals, num_permutations=1000):
    observed_stat = stat_func(test_vals, pred_vals)
    combined_vals = np.concatenate([test_vals, pred_vals])
    permuted_stats = []
    for _ in range(num_permutations):
        np.random.shuffle(combined_vals)
        perm_test_sample = combined_vals[:len(test_vals)]
        perm_pred_sample = combined_vals[len(test_vals):]
        permuted_stats.append(stat_func(perm_test_sample, perm_pred_sample))
    mean_perm = np.mean(permuted_stats)
    std_perm = np.std(permuted_stats)
    p_value = np.mean(np.array(permuted_stats) >= observed_stat)
    return observed_stat, mean_perm, std_perm, p_value

def Compare_Distribution_Statistics(source_t, target_t, optimal_k, gene_of_interest, index, max_i,
                                    intermediate_t=None, d_red=2, random_state=42, exp_memo='2',
                                    num_permutations=1000, save_csv=True):
    if intermediate_t is None:
        intermediate_t = [1]

    filename = result_dir + exp_memo + ".pickle"
    W, b, p = load_W(filename)

    if dim_red_method == 'EMT_PCA':
        pca_filename = f"emt_pca_{d_red}.pkl"
    elif dim_red_method == 'PCA':
        pca_filename = f"pca_{d_red}.pkl"
    else:
        print("❌ PCA method not available")
        return

    with open(data_dir + pca_filename, "rb") as fr:
        [pca] = pk.load(fr)

    dt = p['numerical_ts'][-1] / 200
    X1_trpts = time_integration(pca.transform(mats[0]), T=p['numerical_ts'][-1], dt=dt)

    scaling_factor = len(X1_trpts) / (target_t - source_t)
    scaled_intermediate_indices = [int(t * scaling_factor) for t in intermediate_t]

    gene_index = df_reduced_emt.columns.get_loc(gene_of_interest) - 1
    kde_test_data = [pca.inverse_transform(pca.transform(mats[t]))[:, gene_index] for t in intermediate_t]

    predicted_distributions = {t: [] for t in intermediate_t}
    for i, time_idx in enumerate(scaled_intermediate_indices):
        if time_idx > max_i:
            continue
        X1_trpt = X1_trpts[time_idx]
        if np.isnan(X1_trpt).any():
            continue
        gene_expression_values = pca.inverse_transform(X1_trpt)[:, gene_index]
        predicted_distributions[intermediate_t[i]].extend(gene_expression_values)

    kde_predicted_data = [np.array(predicted_distributions[t]) for t in intermediate_t]

    metric_results = []
    for i, time in enumerate(intermediate_t):
        test_vals = kde_test_data[i]
        predicted_vals = kde_predicted_data[i]
        test_vals, predicted_vals = match_sample_sizes(test_vals, predicted_vals)

        ks_stat, ks_pval = ks_2samp(test_vals, predicted_vals)

        hist_test = np.histogram(test_vals, bins=50, density=True)[0]
        hist_pred = np.histogram(predicted_vals, bins=50, density=True)[0]

        tv = total_variation_distance(hist_test, hist_pred)
        kl = kl_divergence(hist_test, hist_pred)
        js = jensenshannon(hist_test, hist_pred)
        w2 = wasserstein_distance(test_vals, predicted_vals)
        sink = sinkhorn_divergence(test_vals, predicted_vals)

        perm_tv, mean_perm_tv, std_perm_tv, p_tv = permutation_test(total_variation_distance, test_vals, predicted_vals, num_permutations)
        perm_kl, mean_perm_kl, std_perm_kl, p_kl = permutation_test(kl_divergence, test_vals, predicted_vals, num_permutations)
        perm_js, mean_perm_js, std_perm_js, p_js = permutation_test(jensenshannon, test_vals, predicted_vals, num_permutations)
        perm_w2, mean_perm_w2, std_perm_w2, p_w2 = permutation_test(wasserstein_distance, test_vals, predicted_vals, num_permutations)
        perm_sink, mean_perm_sink, std_perm_sink, p_sink = permutation_test(sinkhorn_divergence, test_vals, predicted_vals, num_permutations)

        mean_diff = np.mean(test_vals) - np.mean(predicted_vals)
        std_diff = np.std(test_vals) - np.std(predicted_vals)
        t_stat, t_pval = ttest_ind(test_vals, predicted_vals, equal_var=False)

        metric_results.append({
            "Time": time,
            "Gene": gene_of_interest,
            "KS Test Statistic": ks_stat,
            "KS p-value": ks_pval,
            "W2 Distance": w2,
            "Permutation W2 ± Std": f"{mean_perm_w2:.4f} ± {std_perm_w2:.4f}",
            "p-value W2": p_w2,
            "Sinkhorn Distance": sink,
            "Permutation Sinkhorn ± Std": f"{mean_perm_sink:.4f} ± {std_perm_sink:.4f}",
            "p-value Sinkhorn": p_sink,
            "Total Variation Distance": tv,
            "Permutation TV ± Std": f"{mean_perm_tv:.4f} ± {std_perm_tv:.4f}",
            "p-value TV": p_tv,
            "KL Divergence": kl,
            "Permutation KL ± Std": f"{mean_perm_kl:.4f} ± {std_perm_kl:.4f}",
            "p-value KL": p_kl,
            "Jensen-Shannon Divergence": js,
            "Permutation JS ± Std": f"{mean_perm_js:.4f} ± {std_perm_js:.4f}",
            "p-value JS": p_js,
            "Mean Difference": mean_diff,
            "Standard Deviation Difference": std_diff,
            "T-test Statistic": t_stat,
            "T-test p-value": t_pval
        })

    if save_csv:
        output_dir = os.path.join(result_dir, 'output', exp_memo)
        os.makedirs(output_dir, exist_ok=True)
        df_results = pd.DataFrame(metric_results)
        csv_path = os.path.join(output_dir, f"statistical_metrics_{gene_of_interest}.csv")
        df_results.to_csv(csv_path, index=False)
        print(f"✅ Results saved to: {csv_path}")

    return metric_results




In [None]:
## CSV files for all the genes - Stem Cell data

import os
import numpy as np
import pandas as pd

# Define parameters
genes_of_interest = gene_names # Set of genes
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [1,3]  # Intermediate time points
d_red = 2
random_state = 40
exp_memo = 'EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64'
result_dir = '%s/assets/Transport_genes/' % main_dir
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Create directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Store results in a list
all_gene_results = []

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Compute statistical metrics
        results = Compare_Distribution_Statistics(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo, num_permutations=100,
            save_csv=False  # Prevent saving individual CSVs for each gene
        )

        # Convert to DataFrame and append to the list
        df_results = pd.DataFrame(results)
        df_results["Gene"] = gene  # Add gene column

        # **Add Reject Null Hypothesis column (Yes/No)**
        df_results["Reject Null Hypothesis (TV)"] = df_results["p-value TV"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (KL)"] = df_results["p-value KL"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (T-test)"] = df_results["T-test p-value"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (W2)"] = df_results["p-value W2"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (Sinkhorn)"] = df_results["p-value Sinkhorn"].apply(lambda p: "No" if p > 0.05 else "Yes")
        all_gene_results.append(df_results)

    except Exception as e:
        print(f"Error processing gene {gene}: {e}")

# Combine results for all genes into one DataFrame
if all_gene_results:
    combined_df = pd.concat(all_gene_results, ignore_index=True)

    # **Save combined results for all genes**
    combined_csv_path = os.path.join(output_dir, "all_genes_statistical_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"All genes' results saved to {combined_csv_path}")

    # **Save each metric separately**
    metrics = {
        "TV": ["Gene", "Time", "Total Variation Distance", "Permutation TV ± Std", "p-value TV", "Reject Null Hypothesis (TV)"],
        "KL": ["Gene", "Time", "KL Divergence", "Permutation KL ± Std", "p-value KL", "Reject Null Hypothesis (KL)"],
        "W2": ["Gene", "Time", "W2 Distance", "Permutation W2 ± Std", "p-value W2", "Reject Null Hypothesis (W2)"],
        "Sinkhorn": ["Gene", "Time", "Sinkhorn Distance", "Permutation Sinkhorn ± Std", "p-value Sinkhorn", "Reject Null Hypothesis (Sinkhorn)"],
        "T-test": ["Gene", "Time", "Mean Test", "T-test Statistic", "T-test p-value", "Reject Null Hypothesis (T-test)"],

    }

    for metric, cols in metrics.items():
        if all(col in combined_df.columns for col in cols):  # Ensure columns exist
            metric_df = combined_df[cols]
            metric_csv_path = os.path.join(output_dir, f"{metric}_metrics.csv")
            metric_df.to_csv(metric_csv_path, index=False)
            print(f"{metric} results saved to {metric_csv_path}")
        else:
            print(f"Warning: Some columns missing for {metric} metric.")

    # **Print out genes that did not reject the null hypothesis**
    
    # Genes where "Reject Null Hypothesis (TV)" is "No"
    genes_no_TV = combined_df[combined_df["Reject Null Hypothesis (TV)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in TV:", genes_no_TV)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_KL = combined_df[combined_df["Reject Null Hypothesis (KL)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in KL:", genes_no_KL)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_W2 = combined_df[combined_df["Reject Null Hypothesis (W2)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in W2:", genes_no_W2)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_Sinkhorn = combined_df[combined_df["Reject Null Hypothesis (Sinkhorn)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in Sinkhorn:", genes_no_Sinkhorn)

    # Genes where both TV and KL are "No"
    genes_no_TV_KL = combined_df[(combined_df["Reject Null Hypothesis (TV)"] == "No") &
                                 (combined_df["Reject Null Hypothesis (KL)"] == "No")]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in BOTH TV and KL:", genes_no_TV_KL)

else:
    print("No valid results generated.")



In [None]:
## Venn Diagram of null hypothesis results for Stem cell data
import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

exp_memo = "EMT_dim2-f_Lip=5e-2-t_size=50-network=64_64_64"


# Define output directory
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Load existing metric CSV files
tv_df = pd.read_csv(os.path.join(output_dir, "TV_metrics.csv"))
kl_df = pd.read_csv(os.path.join(output_dir, "KL_metrics.csv"))
sinkhorn_df = pd.read_csv(os.path.join(output_dir, "Sinkhorn_metrics.csv"))

# Ensure consistent column names
tv_col = "Reject Null Hypothesis (TV)"
kl_col = "Reject Null Hypothesis (KL)"
sinkhorn_col = "Reject Null Hypothesis (Sinkhorn)"

# Identify unique time points
time_points = sorted(tv_df["Time"].unique())

# Iterate over each time point
for time_point in time_points:
    # Filter DataFrames for the current time point
    tv_time_df = tv_df[tv_df["Time"] == time_point]
    kl_time_df = kl_df[kl_df["Time"] == time_point]
    sinkhorn_time_df = sinkhorn_df[sinkhorn_df["Time"] == time_point]

    # Identify genes with "No" in each metric
    genes_no_tv = set(tv_time_df.loc[tv_time_df[tv_col] == "No", "Gene"].unique())
    genes_no_kl = set(kl_time_df.loc[kl_time_df[kl_col] == "No", "Gene"].unique())
    genes_no_sinkhorn = set(sinkhorn_time_df.loc[sinkhorn_time_df[sinkhorn_col] == "No", "Gene"].unique())

    # Genes with at least one metric showing "No"
    genes_at_least_one_no = genes_no_tv | genes_no_kl | genes_no_sinkhorn

    # Genes with at least two metrics showing "No"
    genes_at_least_two_no = (
        (genes_no_tv & genes_no_kl) |
        (genes_no_tv & genes_no_sinkhorn) |
        (genes_no_kl & genes_no_sinkhorn)
    )

    # Genes with all three metrics showing "No"
    genes_all_three_no = genes_no_tv & genes_no_kl & genes_no_sinkhorn

    # Prepare summary DataFrame
    summary_df = pd.DataFrame({
        "Criteria": [
            "At least one metric (TV, KL, or Sinkhorn) showing No",
            "At least two metrics (TV, KL, or Sinkhorn) showing No",
            "All three metrics (TV, KL, and Sinkhorn) showing No"
        ],
        "Genes": [
            ", ".join(sorted(genes_at_least_one_no)),
            ", ".join(sorted(genes_at_least_two_no)),
            ", ".join(sorted(genes_all_three_no))
        ]
    })

    # Print results for current time point
    print(f"\n🔹 Summary of Genes by Null Hypothesis Rejection (Time {time_point}):")
    print(summary_df.to_string(index=False))

    # Save summary to CSV for current time point
    summary_csv_path = os.path.join(output_dir, f"genes_null_hypothesis_summary_time_{time_point}.csv")
    summary_df.to_csv(summary_csv_path, index=False)
    print(f"✅ Summary for Time {time_point} saved to {summary_csv_path}")

    # Create Venn diagram for current time point
    plt.figure(figsize=(10, 8))
    venn = venn3(
        [genes_no_tv, genes_no_kl, genes_no_sinkhorn],
        set_labels=('TV', 'KL', 'Sinkhorn')
    )

    plt.title(f"Genes NOT Rejecting Null Hypothesis (Time {time_point})", fontsize=16)

    # Add summary annotations
    x_pos = 0.6
    y_pos = 0.6
    step = 0.07

    plt.text(x_pos, y_pos, f"Genes in ≥1 metric: {len(genes_at_least_one_no)}", fontsize=12)
    plt.text(x_pos, y_pos - step, f"Genes in ≥2 metrics: {len(genes_at_least_two_no)}", fontsize=12)
    plt.text(x_pos, y_pos - 2*step, f"Genes in all 3 metrics: {len(genes_all_three_no)}", fontsize=12)

    plt.tight_layout()

    # Save Venn diagram
    venn_path = os.path.join(output_dir, f"genes_venn_diagram_time_{time_point}.png")
    plt.savefig(venn_path, dpi=300)
    plt.show()

    print(f"✅ Venn diagram for Time {time_point} saved to {venn_path}")


In [None]:
## CSV files for all the genes - Stem Cell data

import os
import numpy as np
import pandas as pd

# Define parameters
genes_of_interest = ['DSP'] # Set of genes
source_t, target_t = 0, 4
optimal_k = 2
index = 1
max_i = 200
intermediate_t = [2]  # Intermediate time points
d_red = 8
random_state = 40
exp_memo = '72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64'
result_dir = '%s/assets/Transport_genes/' % main_dir
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Create directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Store results in a list
all_gene_results = []

# Iterate over each gene in the list
for gene in genes_of_interest:
    print(f"Processing gene: {gene}")
    try:
        # Compute statistical metrics
        results = Compare_Distribution_Statistics(
            source_t, target_t, optimal_k, gene, index, max_i,
            intermediate_t=intermediate_t, d_red=d_red,
            random_state=random_state, exp_memo=exp_memo, num_permutations=100,
            save_csv=False  # Prevent saving individual CSVs for each gene
        )

        # Convert to DataFrame and append to the list
        df_results = pd.DataFrame(results)
        df_results["Gene"] = gene  # Add gene column

        # **Add Reject Null Hypothesis column (Yes/No)**
        df_results["Reject Null Hypothesis (TV)"] = df_results["p-value TV"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (KL)"] = df_results["p-value KL"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (T-test)"] = df_results["T-test p-value"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (W2)"] = df_results["p-value W2"].apply(lambda p: "No" if p > 0.05 else "Yes")
        df_results["Reject Null Hypothesis (Sinkhorn)"] = df_results["p-value Sinkhorn"].apply(lambda p: "No" if p > 0.05 else "Yes")
        all_gene_results.append(df_results)

    except Exception as e:
        print(f"Error processing gene {gene}: {e}")

# Combine results for all genes into one DataFrame
if all_gene_results:
    combined_df = pd.concat(all_gene_results, ignore_index=True)

    # **Save combined results for all genes**
    combined_csv_path = os.path.join(output_dir, "all_genes_statistical_metrics.csv")
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"All genes' results saved to {combined_csv_path}")

    # **Save each metric separately**
    metrics = {
        "TV": ["Gene", "Time", "Total Variation Distance", "Permutation TV ± Std", "p-value TV", "Reject Null Hypothesis (TV)"],
        "KL": ["Gene", "Time", "KL Divergence", "Permutation KL ± Std", "p-value KL", "Reject Null Hypothesis (KL)"],
        "W2": ["Gene", "Time", "W2 Distance", "Permutation W2 ± Std", "p-value W2", "Reject Null Hypothesis (W2)"],
        "Sinkhorn": ["Gene", "Time", "Sinkhorn Distance", "Permutation Sinkhorn ± Std", "p-value Sinkhorn", "Reject Null Hypothesis (Sinkhorn)"],
        "T-test": ["Gene", "Time", "Mean Test", "T-test Statistic", "T-test p-value", "Reject Null Hypothesis (T-test)"],

    }

    for metric, cols in metrics.items():
        if all(col in combined_df.columns for col in cols):  # Ensure columns exist
            metric_df = combined_df[cols]
            metric_csv_path = os.path.join(output_dir, f"{metric}_metrics.csv")
            metric_df.to_csv(metric_csv_path, index=False)
            print(f"{metric} results saved to {metric_csv_path}")
        else:
            print(f"Warning: Some columns missing for {metric} metric.")

    # **Print out genes that did not reject the null hypothesis**
    
    # Genes where "Reject Null Hypothesis (TV)" is "No"
    genes_no_TV = combined_df[combined_df["Reject Null Hypothesis (TV)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in TV:", genes_no_TV)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_KL = combined_df[combined_df["Reject Null Hypothesis (KL)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in KL:", genes_no_KL)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_W2 = combined_df[combined_df["Reject Null Hypothesis (W2)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in W2:", genes_no_W2)

    # Genes where "Reject Null Hypothesis (KL)" is "No"
    genes_no_Sinkhorn = combined_df[combined_df["Reject Null Hypothesis (Sinkhorn)"] == "No"]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in Sinkhorn:", genes_no_Sinkhorn)

    # Genes where both TV and KL are "No"
    genes_no_TV_KL = combined_df[(combined_df["Reject Null Hypothesis (TV)"] == "No") &
                                 (combined_df["Reject Null Hypothesis (KL)"] == "No")]["Gene"].unique()
    print("\nGenes that did NOT reject null hypothesis in BOTH TV and KL:", genes_no_TV_KL)

else:
    print("No valid results generated.")



In [None]:
## Combining genes (preprocess)
import os
import pandas as pd
import glob

result_dir = '%s/assets/Transport_genes/' % main_dir
csv_dir = os.path.join(result_dir, 'Sample 3')


# List of prefixes
prefixes = ['sinkhorn']

# Mapping prefixes to their corresponding p-value column names
pval_columns = {
    'sinkhorn': 'p_Sinkhorn_1',
}


# Loop through each prefix
for prefix in prefixes:
    # Use glob to find matching files
    matching_files = glob.glob(os.path.join(csv_dir, f"{prefix}*.csv"))
    
    if not matching_files:
        print(f"No files found for prefix '{prefix}'")
        continue
    
    # Read and concatenate files
    combined_df = pd.concat([pd.read_csv(f) for f in matching_files], ignore_index=True)

    # Check if the p-value column exists
    p_col = pval_columns[prefix]
    if p_col in combined_df.columns:
        combined_df[f'Reject Null Hypothesis ({prefix[:-1].upper()})'] = combined_df[p_col].apply(
            lambda p: 'No' if p > 0.05 else 'Yes'
        )
    else:
        print(f"⚠️ Warning: Column '{p_col}' not found in files for prefix '{prefix}'.")
    
    # Save the combined dataframe to CSV
    output_filename = os.path.join(csv_dir, f"{prefix}_metric.csv")
    combined_df.to_csv(output_filename, index=False)
    
    print(f"✅ Combined {len(matching_files)} files into '{output_filename}' with hypothesis test results.")


In [None]:
## Sample 1 (only one time point)

import os
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn3

# Define output directory
output_dir = os.path.join(result_dir, 'output', exp_memo)

# Load existing metric CSV files
tv_df = pd.read_csv(os.path.join(output_dir, "TV_metrics.csv"))
kl_df = pd.read_csv(os.path.join(output_dir, "KL_metrics.csv"))
sinkhorn_df = pd.read_csv(os.path.join(output_dir, "sinkhorn_metrics.csv"))

# Identify genes with "No" in each metric
genes_no_tv = set(tv_df.loc[tv_df["Reject Null Hypothesis (TV)"] == "No", "Gene"].unique())
genes_no_kl = set(kl_df.loc[kl_df["Reject Null Hypothesis (KL)"] == "No", "Gene"].unique())
genes_no_sinkhorn = set(sinkhorn_df.loc[sinkhorn_df["Reject Null Hypothesis (SINKHOR)"] == "No", "Gene"].unique())

# Genes with at least one metric showing "No"
genes_at_least_one_no = genes_no_tv.union(genes_no_kl).union(genes_no_sinkhorn)

# Genes with at least two metrics showing "No"
genes_at_least_two_no = (
    (genes_no_tv & genes_no_kl) | (genes_no_tv & genes_no_sinkhorn) | (genes_no_kl & genes_no_sinkhorn)
)

# Genes with all three metrics showing "No"
genes_all_three_no = genes_no_tv & genes_no_kl & genes_no_sinkhorn

# Prepare dataframes for easy viewing and saving
summary_df = pd.DataFrame({
    "Criteria": [
        "At least one metric (TV, KL, or Sinkhorn) showing No",
        "At least two metrics (TV, KL, or Sinkhorn) showing No",
        "All three metrics (TV, KL, and Sinkhorn) showing No"
    ],
    "Genes": [
        ", ".join(sorted(genes_at_least_one_no)),
        ", ".join(sorted(genes_at_least_two_no)),
        ", ".join(sorted(genes_all_three_no))
    ]
})

# Print results
print("\nSummary of Genes by Null Hypothesis Rejection:")
print(summary_df.to_string(index=False))

# Save to CSV
summary_csv_path = os.path.join(output_dir, "genes_null_hypothesis_summary.csv")
summary_df.to_csv(summary_csv_path, index=False)
print(f"\nSummary saved to {summary_csv_path}")



# Venn diagram
plt.figure(figsize=(12, 8))
venn = venn3(
    [genes_no_tv, genes_no_kl, genes_no_sinkhorn],
    set_labels=('TV', 'KL', 'SINKHORN')
)

plt.title("Genes NOT Rejecting Null Hypothesis", fontsize=16)

# Additional summaries (positions adjusted)
x_pos = 0.6  # Move further to the right
y_pos = 0.6
step = 0.07

plt.text(x_pos, y_pos, f"Genes in at least 1 metric: {len(genes_no_tv | genes_no_kl | genes_no_sinkhorn)}", fontsize=12)
plt.text(x_pos, y_pos - step, f"Genes in at least 2 metrics: {len((genes_no_tv & genes_no_kl) | (genes_no_tv & genes_no_sinkhorn) | (genes_no_kl & genes_no_sinkhorn))}", fontsize=12)
plt.text(x_pos, y_pos - 2*step, f"Genes in all 3 metrics: {len(genes_no_tv & genes_no_kl & genes_no_sinkhorn)}", fontsize=12)

plt.tight_layout()

# Save figure
venn_path = os.path.join(output_dir, "genes_venn_diagram.png")
plt.savefig(venn_path, dpi=300)
plt.show()

print(f"✅ Venn diagram saved to {venn_path}")



In [None]:
## Showing distribution plots based on the gene sets which have similar distributions

## PDF combination for comparison distributions 
## save the gene expression dynamics png as pdf


import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from math import ceil
import matplotlib.lines as mlines


def create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=25, grid_size=(5, 5)):
    """
    Create a PDF with gene expression PNG images arranged in a grid layout while preserving original resolution.

    Parameters:
        output_dir (str): Directory containing the PNG files.
        exp_memo (str): Base name used in the PNG filenames.
        gene_list (list): List of genes corresponding to the PNG files.
        pdf_path (str): Path to save the output PDF file.
        images_per_page (int): Number of images per page (default: 25).
        grid_size (tuple): Grid size (rows, cols) for each page (default: 5x5).
    """

    # Generate list of PNG file paths
    png_files = [
        f"{output_dir}/KDE_Intermediate_Only_{gene}.png" for gene in gene_list
    ]

    # Check if all PNG files exist
    missing_files = [file for file in png_files if not os.path.exists(file)]
    if missing_files:
        print(f"Warning: The following files are missing and will be skipped:\n{missing_files}")

    # Filter out missing files
    png_files = [file for file in png_files if os.path.exists(file)]

    # Calculate the total number of pages
    total_pages = ceil(len(png_files) / images_per_page)

    # Create the PDF
    with PdfPages(pdf_path) as pdf:
        for page in range(total_pages):
            # Create a figure with dynamically sized subplots
            fig, axes = plt.subplots(*grid_size, figsize=(15, 15))  # Increased size for better resolution
            axes = axes.flatten()

            # Plot images for the current page
            start_idx = page * images_per_page
            end_idx = start_idx + images_per_page

            for i, ax in enumerate(axes):
                img_idx = start_idx + i
                if img_idx < len(png_files):
                    img = plt.imread(png_files[img_idx])
                    ax.imshow(img, aspect='auto')  # Preserve aspect ratio
                    ax.axis('off')  # Remove axes
                    # Add filename as the title
                    gene_name = gene_list[img_idx]
                    ax.set_title('', fontsize=8)
                else:
                    ax.axis('off')  # Hide empty axes

            # Save the page to the PDF with high resolution
            pdf.savefig(fig, dpi=300, bbox_inches='tight')
            plt.close(fig)  # Close the figure to free memory

    print(f"✅ PDF saved to {pdf_path} with original image resolution.")

# Example usage

exp_memo = "72GS_dim8-f_Lip=5e-2-t_size=50-network=64_64_64"
gene_list = ['DSP', 'ENPP5', 'EPB41L5', 'KRTCAP3', 'MMP2', 'RAB25', 'SERINC2', 'TMEM45B']  # List of genes 
## Selected genes (no difference by TV)  ['AXL', 'HNMT', 'TMEM45B', 'SSH3', 'SHROOM3', 'PRSS22', 'SERINC2', 'EVPL', 'GALNT3', 'DSP', 'ELMO3', 'KRTCAP3', 'KRT19', 'C1orf116', 'CDS1', 'INADL']
## Selected genes (no difference by TV and KL) ['HNMT', 'TMEM45B', 'SHROOM3', 'PRSS22', 'SERINC2', 'KRTCAP3', 'C1orf116', 'CDS1']  
pdf_path = f"{output_dir}/selected_gene_expression_KDE_Intermediate_Only(two metrics).pdf"  # Output PDF path

create_pdf_from_gene_images(output_dir, exp_memo, gene_list, pdf_path, images_per_page=6, grid_size=(3, 2))
