In [None]:
import matplotlib
import seaborn as sns
from joypy import joyplot
from matplotlib import cm
from matplotlib.colors import Normalize
from tqdm import tqdm

import matplotlib.pyplot as plt

# Ensure PDF backend for saving figures
matplotlib.use('pdf')

# --- GLOBAL DISTRIBUTION RIDGELINE PLOTS ---
# These plots show the distribution of each feature across binned time (all cell types together)
for feature_of_interest in tqdm(plotting_columns, desc="Global distribution plots"):
    subset = full_df2[['time_bin', feature_of_interest, 'cell_type']]
    mean_values = subset.groupby('time_bin')[feature_of_interest].mean()
    custom_vmin = -1
    custom_vmax = 1
    norm = Normalize(vmin=custom_vmin, vmax=custom_vmax)
    colormap = cm.viridis
    colors = [colormap(norm(value)) for value in mean_values]

    fig, axes = joyplot(
        data=subset,
        by="time_bin",
        column=feature_of_interest,
        figsize=(5, 6),
        kind="kde",
        overlap=0.7,
        color=colors,
        x_range=[-3, 3]
    )

    # Labeling
    if len(axes) > 6:
        axes[6].set_ylabel('Binned time (hrs)', fontsize=12, loc='center')
    plt.xlabel('Standardized Feature (σ)', fontsize=12)
    plt.title(fixed_columns.get(feature_of_interest, feature_of_interest), fontsize=14)

    # Colorbar
    sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=axes, orientation='vertical', pad=0.1)
    cbar.set_label('Mean Value of Feature (σ)', fontsize=10)

    plt.savefig(f"full_distribution_plots/{feature_of_interest}_distribution_plot.pdf", format="pdf", bbox_inches="tight")
    plt.close()

# --- CELL TYPE-WISE RIDGELINE PLOTS ---
# These plots show the distribution of each feature across binned time, separated by cell type
color_mapping = {
    'basal': sns.color_palette("colorblind")[0],
    'goblet': sns.color_palette("colorblind")[1],
    'mcc': sns.color_palette("colorblind")[2],
    'ic': sns.color_palette("colorblind")[4],
    'ssc': sns.color_palette("colorblind")[8]
}
lower_bound = -3
upper_bound = 3

for feature_of_interest in tqdm(plotting_columns, desc="Cell type-wise distribution plots"):
    subset = full_df3[['time_bin', feature_of_interest, 'cell_type']].copy()
    # Create a column for each cell type
    for cell_type in color_mapping.keys():
        subset[cell_type] = subset.apply(
            lambda row: row[feature_of_interest] if row['cell_type'] == cell_type else np.nan,
            axis=1
        )
    columns_to_plot = ['time_bin'] + list(color_mapping.keys())
    data_to_plot = subset[columns_to_plot]
    colors = [color_mapping[cell_type] for cell_type in color_mapping.keys()]

    fig, axes = joyplot(
        data=data_to_plot,
        by="time_bin",
        figsize=(5, 5),
        kind="kde",
        overlap=0.6,
        color=colors,
        fill=False,
        x_range=[lower_bound, upper_bound]
    )

    # Adjust transparency
    for ax in axes:
        for line in ax.get_lines():
            line.set_alpha(0.7)

    plt.xlabel('Standardized Feature (σ)', fontsize=12)
    if len(axes) > 2:
        axes[2].set_ylabel('Binned time (hrs)', fontsize=12, loc='center')
    plt.suptitle(fixed_columns.get(feature_of_interest, feature_of_interest), fontsize=14)
    plt.tight_layout()

    # Legend
    legend_patches = [
        plt.Line2D([0], [0], color=color_mapping[cell_type], lw=2, label=cell_type)
        for cell_type in color_mapping.keys()
    ]
    plt.legend(
        handles=legend_patches,
        title="Cell type",
        loc="center left",
        bbox_to_anchor=(1.05, 0.5),
        borderaxespad=0.0
    )

    plt.savefig(f"celltype_distribution_plots/{feature_of_interest}_celltype_distribution_plot.pdf", format="pdf", bbox_inches="tight")
    plt.close()