In [None]:
import pandas as pd
import os
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors 
import matplotlib.cm as cm
import re
import textwrap

In [None]:
base_dir = Path("/path/to/project")

# Subdirectories
## Input
data_dir = base_dir / "data/h5ad"
csv_dir = base_dir / "data/results/enrichR"

os.makedirs(csv_dir, exist_ok = True)

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
fig_dir = base_dir / 'figures'

In [None]:
filenames = [
    'Sham-GFP_vs_TBI-GFP_CA1-ProS_DOWN_GOBP2025_table.txt',
    'TBI-GFP_vs_TBI-VEGFC_CA1-ProS_UP_GOBP2025_table.txt',
    'TBI-GFP_vs_TBI-VEGFC_DG_UP_GOBP2025_table.txt'
]

# List to collect DataFrames
df_list = []

for fname in filenames:
    file_path = os.path.join(csv_dir, fname)
    df = pd.read_csv(file_path, sep='\t')

    # Convert Odds Ratio to numeric
    df['Odds Ratio'] = pd.to_numeric(df['Odds Ratio'], errors='coerce')

    # Parse filename
    parts = fname.replace('_GOBP2025_table.txt', '').split('_')
    
    comparison_full = '_'.join(parts[0:3])  # TBI-GFP_vs_TBI-VEGFC
    cell_type = parts[3]                    # CA1-ProS

    # Split comparison into reference and comparison groups
    reference_group, comparison_group = comparison_full.split('_vs_')

    # Determine higher_in_group based on filename part (UP/DOWN)
    direction = parts[4]  # still needed to infer higher_in_group
    higher_in_group = comparison_group if direction == 'UP' else reference_group

    # Add columns
    df['comparison'] = comparison_full
    df['cell_type'] = cell_type
    df['higher_in_group'] = higher_in_group

    # (No filtering!)
    # (No sorting!)
    # (No Rank column — since no sorting)

    # Save individual CSV to csv_dir
    out_fname = f"Enrichr_{comparison_full}_{cell_type}_higher_in_{higher_in_group}.csv"
    out_path = os.path.join(csv_dir, out_fname)
    df.to_csv(out_path, index=False)
    print(f"Saved {out_fname} with {df.shape[0]} rows")

    # Append to master list
    df_list.append(df)

# Combine into one master DataFrame
master_df = pd.concat(df_list, ignore_index=True)

# Drop old p-value columns
master_df = master_df.drop(columns=['Old P-value', 'Old Adjusted P-value'])

# Display master_df
display(master_df)

In [None]:
enrichr_files = [
    'Enrichr_Sham-GFP_vs_TBI-GFP_CA1-ProS_higher_in_Sham-GFP.csv',
    'Enrichr_TBI-GFP_vs_TBI-VEGFC_CA1-ProS_higher_in_TBI-VEGFC.csv',
    'Enrichr_TBI-GFP_vs_TBI-VEGFC_DG_higher_in_TBI-VEGFC.csv'
]

df_dict = {}

# Loop and load
for fname in enrichr_files:
    file_path = os.path.join(csv_dir, fname)
    df = pd.read_csv(file_path)
    
    # Make a variable name — e.g. Enrichr_Sham_GFP_vs_TBI_GFP_CA1_ProS_higher_in_Sham_GFP
    var_name = fname.replace('.csv', '').replace('-', '_').replace(' ', '_')
    
    df_dict[var_name] = df
    print(f"Loaded {var_name} → {df.shape[0]} rows")

In [None]:
def plot_barplot(df, x_col, y_col, hue_col=None, title=None, top_n=None, figsize=(8, 6)):
    """
    Reusable barplot function.
    Args:
        df: input dataframe
        x_col: column name for x-axis (categorical)
        y_col: column name for y-axis (numeric)
        hue_col: optional, column name for coloring bars
        title: optional, plot title
        top_n: optional, only plot top_n rows (sorted by y_col descending)
        figsize: figure size
    """
    # Sort by y_col (descending) if top_n specified
    if top_n is not None:
        df = df.sort_values(by=y_col, ascending=False).head(top_n)

    plt.figure(figsize=figsize)
    
    sns.barplot(
        data=df,
        x=x_col,
        y=y_col,
        hue=hue_col
    )
    
    plt.xticks(rotation=45, ha='right')
    plt.ylabel(y_col)
    plt.xlabel(x_col)
    
    if title:
        plt.title(title)
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_horizontal_barplot(
    df, 
    y_col='Term', 
    x_col='Combined Score', 
    color_by=None, 
    title=None, 
    top_n=None, 
    row_indices=None,  
    figsize=(6, 8), 
    cmap_name='plasma', 
    wrap_width=25, 
    remove_go_id=True,
    force_gradient=True,
    gradient_min=None,
    gradient_max=None,
    save_path=None  
):
    """
    Reusable horizontal barplot for GO terms with line-wrapped y-axis labels.
    """

    sns.set_style("white")
    sns.set_context("talk", font_scale=0.95)

    if row_indices is not None:
        row_indices = [int(i) for i in row_indices]
        df = df.iloc[row_indices]
    elif top_n is not None:
        df = df.sort_values(by=x_col, ascending=False).head(top_n)

    if color_by is not None:
        if force_gradient:
            vmin = gradient_min if gradient_min is not None else df[color_by].min()
            vmax = gradient_max if gradient_max is not None else df[color_by].max()
            if vmin == vmax:
                vmin = vmax * 0.95
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
        else:
            norm = mcolors.Normalize(vmin=df[color_by].min(), vmax=df[color_by].max())

        cmap = cm.get_cmap(cmap_name)
        colors = cmap(norm(df[color_by].values))
    else:
        colors = 'steelblue'

    plt.figure(figsize=figsize)

    ax = sns.barplot(
        data=df,
        y=y_col,
        x=x_col,
        palette=colors if color_by is not None else None
    )

    labels = df[y_col].tolist()
    if remove_go_id:
        labels = [re.sub(r"\s*\(GO:\d+\)", "", l) for l in labels]
    labels_wrapped = ['\n'.join(textwrap.wrap(l, wrap_width)) for l in labels]
    ax.set_yticklabels(labels_wrapped)

    if color_by is not None:
        sm = cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, fraction=0.03, pad=0.04, aspect=18)
        cbar.set_label(color_by)

    ax.tick_params(axis='x', colors='black')
    ax.tick_params(axis='y', colors='black')
    ax.xaxis.label.set_color('black')
    ax.set_ylabel('')
    if title:
        ax.set_title(title, fontsize=14, color='black')

    sns.despine(ax=ax, left=False, bottom=False)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight')

    plt.show()

In [None]:
plot_horizontal_barplot(
    df=df_dict['Enrichr_Sham_GFP_vs_TBI_GFP_CA1_ProS_higher_in_Sham_GFP'],
    x_col='Combined Score',
    y_col='Term',
    color_by='Odds Ratio',
    title='Down in TBI-GFP vs Sham-GFP (CA1-ProS Glut)',
    row_indices=['2', '0', '4', '6', '9'],
    figsize=(2, 5),
    cmap_name='Blues',
    force_gradient=True,
    gradient_min=5,
    gradient_max=200,
    save_path=fig_dir / "go_enrich_sham_vs_tbi_ca1pros_down.pdf"
)

In [None]:
plot_horizontal_barplot(
    df=df_dict['Enrichr_TBI_GFP_vs_TBI_VEGFC_CA1_ProS_higher_in_TBI_VEGFC'],
    x_col='Combined Score',
    y_col='Term',
    color_by='Odds Ratio',
    title='Up in TBI-VEGFC vs. TBI-GFP (CA1-ProS Glut)',
    row_indices = ['0', '5', '6', '7', '1'],
    figsize=(2, 5),
    cmap_name='Reds',
    force_gradient=True,      # ensures nice gradient
    gradient_min=5,           # make consistent across plots!
    gradient_max=200,
    save_path=fig_dir / "go_enrich_tbi-gfp_vs_tbi-vegfc_ca1pros_up.pdf"
)

In [None]:
plot_horizontal_barplot(
    df=df_dict['Enrichr_TBI_GFP_vs_TBI_VEGFC_DG_higher_in_TBI_VEGFC'],
    x_col='Combined Score',
    y_col='Term',
    color_by='Odds Ratio',
    title='Up in TBI-VEGFC vs. TBI-GFP (DG Glut)',
    row_indices = ['0', '1', '3', '4', '5'],
    figsize=(2, 5),
    cmap_name='Reds',
    force_gradient=True,
    gradient_min=5,     
    gradient_max=200,
    save_path=fig_dir / "go_enrich_tbi-gfp_vs_tbi-vegfc_DG_down.pdf"
)