# NMDesc AD Gene Feature Enrichment

This notebook summarizes the enrichment of several protein-level features among autosomal dominant (AD) genes across different variant categories.

Specifically, we:

- Load AD gene lists and feature flags for:
  - Pfam domains
  - Short Linear Motifs (SLiMs)
  - Molecular Recognition Features (MoRFs)
  - Post-Translational Modification (PTM) sites
  - Nuclear Localization Signals (NLS)
  - Low Complexity Sequences (LCS)
- Aggregate, for each gene category, the fraction of genes that carry each feature.
- Perform binomial tests comparing each “case” group to its matched control group.
- Visualize the percentage of genes with each feature across categories, including:
  - Percentage bars
  - Total gene counts per group
  - Pairwise p-values (optionally FDR-adjusted) between case and control.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pandas as pd
import numpy as np
from scipy.stats import binomtest
from statsmodels.stats.multitest import multipletests

# ---------- Matplotlib style ----------
plt.rcParams.update({
    'font.family': 'Arial',
    'font.size': 8,
    'axes.linewidth': 0.5,
    'axes.labelsize': 8,
    'axes.titlesize': 9,
    'xtick.labelsize': 7,
    'ytick.labelsize': 7,
    'legend.fontsize': 7,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.linewidth': 0.3,
    'grid.alpha': 0.3
})

# ---------- Colors ----------
COLOR_SCHEME = {
    'Nonsense': '#70AD47',
    'Nonsense_Control': '#C5E0B4',
    'Plus1': '#C65911',
    'Plus1_Control': '#F4B183',
    'Minus1': '#4472C4',
    'Minus1_Control': '#B4C7E7',
}
CATEGORY_BASE_COLOR = {
    'Minus1': COLOR_SCHEME['Minus1'],
    'Minus1_Control': COLOR_SCHEME['Minus1'],
    'Plus1': COLOR_SCHEME['Plus1'],
    'Plus1_Control': COLOR_SCHEME['Plus1'],
    'Nonsense': COLOR_SCHEME['Nonsense'],
    'Nonsense_Control': COLOR_SCHEME['Nonsense'],
}

def lighten_color(color, amount=0.35):
    """Lighten a color for control bars."""
    try:
        r, g, b = mcolors.to_rgb(color)
    except ValueError:
        r, g, b = (0.5, 0.5, 0.5)
    return (1 - amount) + amount * r, (1 - amount) + amount * g, (1 - amount) + amount * b


# ============================================================
#               LOAD ALL FOUR DATA FILES
# ============================================================

# AD gene list
ad_genes_df = pd.read_csv('pli_AD_genes.csv')
ad_genes_list = ad_genes_df['gene'].unique().tolist()
print(f"\nLoaded {len(ad_genes_list)} AD genes from pli_AD_genes.csv")
print(f"First few genes: {ad_genes_list[:5]}")

# Motif/feature flags (SLiM, MoRF, PTM, NLS, etc.)
motif_df = pd.read_csv('motif_gene_flags.csv')

# LCS flags
lcs_df = pd.read_csv('LCS_gene_flags.csv')

# Pfam flags
pfam_df = pd.read_csv('pfam_gene_data.csv')

# Filter all to AD genes (safety, even if they’re already AD)
motif_df = motif_df[motif_df['hgnc_symbol'].isin(ad_genes_list)].copy()
lcs_df   = lcs_df[lcs_df['hgnc_symbol'].isin(ad_genes_list)].copy()
pfam_df  = pfam_df[pfam_df['hgnc_symbol'].isin(ad_genes_list)].copy()


# ============================================================
#         GROUP LABEL NORMALIZATION / ORDERING
# ============================================================

def rename_and_capitalize_groups(df):
    df = df.copy()
    df['group_label'] = df['group'].replace({
        'SNV': 'Nonsense', 'SNV_Control': 'Nonsense_Control',
        'plus1': 'Plus1', 'plus1_Control': 'Plus1_Control',
        'minus1': 'Minus1', 'minus1_Control': 'Minus1_Control',
        'Minus1': 'Minus1', 'Minus1_Control': 'Minus1_Control',
        'Plus1': 'Plus1', 'Plus1_Control': 'Plus1_Control',
        'Nonsense': 'Nonsense', 'Nonsense_Control': 'Nonsense_Control',
    })
    return df

def apply_custom_order(df):
    order = ['Minus1','Minus1_Control','Plus1','Plus1_Control','Nonsense','Nonsense_Control']
    df = df.copy()
    df['group_label'] = pd.Categorical(df['group_label'], categories=order, ordered=True)
    return df.sort_values('group_label').reset_index(drop=True)

motif_df = apply_custom_order(rename_and_capitalize_groups(motif_df))
lcs_df   = apply_custom_order(rename_and_capitalize_groups(lcs_df))
pfam_df  = apply_custom_order(rename_and_capitalize_groups(pfam_df))


# ============================================================
#         AGGREGATE DATA BY GROUP FOR EACH FEATURE
# ============================================================

def prepare_motif_data(df, motif_col):
    """
    Aggregate motif data by group, excluding NA values.
    Returns columns: group_label, matched, total
    """
    summary_data = []
    
    for group in df['group_label'].dropna().unique():
        group_df   = df[df['group_label'] == group]
        non_na_df  = group_df[group_df[motif_col].notna()]
        matched    = (non_na_df[motif_col] == True).sum()
        total      = len(non_na_df)
        
        summary_data.append({
            'group_label': group,
            'matched': matched,
            'total': total
        })
    
    return pd.DataFrame(summary_data)

# Which flags map to which motif names
# Pfam from pfam_df, others from motif_df, LCS from lcs_df
motif_sources = {
    'Pfam': (pfam_df, 'pfam_flag2'),
    'SLiM': (motif_df, 'slim_flag2'),
    'MoRF': (motif_df, 'morf_flag2'),
    'PTM':  (motif_df, 'ptm_flag2'),
    'NLS':  (motif_df, 'nls_flag2'),
    'LCS':  (lcs_df,  'LCS_flag2'),
}

motif_names = ['Pfam', 'SLiM', 'MoRF', 'PTM', 'NLS', 'LCS']

motif_data = {}
for name in motif_names:
    df_src, col = motif_sources[name]
    motif_data[name] = prepare_motif_data(df_src, col)
    print(f"\n{name} motif summary:")
    print(motif_data[name])


# ============================================================
#                 BINOMIAL TEST STATS
# ============================================================

def calculate_p_values_binomial(df):
    """
    Two-sided binomial test comparing each experimental group
    to its corresponding Control group, with FDR correction.
    """
    results = []
    base_groups = df['group_label'].str.replace('_Control', '', regex=False).unique()
    
    for base in base_groups:
        exp_mask  = (df['group_label'] == base)
        ctrl_mask = (df['group_label'] == f'{base}_Control')
        if not exp_mask.any() or not ctrl_mask.any():
            continue
        
        exp  = df.loc[exp_mask].iloc[0]
        ctrl = df.loc[ctrl_mask].iloc[0]
        
        if ctrl['total'] == 0:
            continue
        
        p0 = ctrl['matched'] / ctrl['total']
        res = binomtest(
            int(exp['matched']), 
            int(exp['total']),
            p=p0, 
            alternative='two-sided'
        )
        
        results.append({
            'experimental_group': base,
            'control_group': f'{base}_Control',
            'exp_matched': int(exp['matched']),
            'exp_total': int(exp['total']),
            'ctrl_matched': int(ctrl['matched']),
            'ctrl_total': int(ctrl['total']),
            'control_proportion': p0,
            'p_value': res.pvalue,
            'significant': res.pvalue < 0.05
        })
    
    results_df = pd.DataFrame(results)
    
    if not results_df.empty:
        results_df['p_value_fdr'] = multipletests(results_df['p_value'], method='fdr_bh')[1]
        results_df['significant_fdr'] = results_df['p_value_fdr'] < 0.05
    
    return results_df

motif_stats = {}
for name in motif_names:
    motif_stats[name] = calculate_p_values_binomial(motif_data[name])
    if not motif_stats[name].empty:
        print(f"\n{name} statistics:")
        print(motif_stats[name][['experimental_group', 'p_value', 'p_value_fdr', 'significant_fdr']])


# ============================================================
#                     PLOTTING HELPERS
# ============================================================

def get_colors_for_groups(group_labels):
    out = []
    for label in group_labels:
        if 'Control' in label:
            out.append(lighten_color(CATEGORY_BASE_COLOR[label]))
        else:
            out.append(COLOR_SCHEME[label])
    return out

BRACKET_RAISE = 0.12
PTEXT_RAISE   = 0.03

def bracket_top_for_pair(bar_heights, i, j, max_pct):
    top_pair  = max(bar_heights[i], bar_heights[j])
    lbl_off   = 0.022 * max_pct
    return top_pair + lbl_off + BRACKET_RAISE * max_pct


def create_visualization(df, motif_name, stats_df=None, save_prefix="", use_fdr=True):

    # --- Prepare data ---
    if 'percent' not in df.columns:
        df = df.copy()
        df['percent'] = df['matched'] / df['total']

    groups         = df['group_label'].tolist()
    percentages    = (df['percent'] * 100).tolist()
    matched_counts = df['matched'].tolist()
    total_counts   = df['total'].tolist()
    colors         = get_colors_for_groups(groups)
    x              = np.arange(len(groups))
    bar_w          = 0.6

    # --- Single-panel figure ---
    fig, ax1 = plt.subplots(1, 1, figsize=(3.4, 3.0))

    motif_full_names = {
        'Pfam': 'Pfam Domain',
        'SLiM': 'Short Linear Motif',
        'MoRF': 'Molecular Recognition Feature',
        'PTM': 'Post-Translational Modification',
        'NLS': 'Nuclear Localization Signal',
        'LCS': 'Low Complexity Sequence',
    }

    full_name = motif_full_names.get(motif_name, motif_name)
    y_lab     = f'Percentage with {motif_name} (%)'
    title_pct = f'{full_name} Enrichment'

    # --- Percentage bars ---
    bars1 = ax1.bar(
        x, percentages, width=bar_w, color=colors, alpha=0.85,
        edgecolor='black', linewidth=0.3
    )

    ax1.set_xlabel('Gene Category', fontsize=8, fontweight='bold')
    ax1.set_ylabel(y_lab, fontsize=8, fontweight='bold')
    ax1.set_title(title_pct, fontsize=9, fontweight='bold', pad=10)

    ax1.set_xticks(x)
    ax1.set_xticklabels([g.replace('_','\n') for g in groups], fontsize=7, ha='center')

    ymax_pct = max(percentages) if percentages else 1.0
    ax1.set_ylim(0, ymax_pct * 1.35)

    # --- % labels + n labels ---
    for bar, pct, tot in zip(bars1, percentages, total_counts):
        h = bar.get_height()
        ax1.text(bar.get_x()+bar.get_width()/2., h + 0.022*ymax_pct,
                 f'{pct:.1f}%', ha='center', va='bottom', fontsize=6, fontweight='bold')
        ax1.text(bar.get_x()+bar.get_width()/2., h/2,
                 f'n={tot}', ha='center', va='center', fontsize=6, fontweight='bold',
                 color='black', bbox=dict(boxstyle='round,pad=0.15',
                 facecolor='white', alpha=0.8, edgecolor='none'))

    # --- p-value brackets ---
    if stats_df is not None and not stats_df.empty:
        p_col   = 'p_value_fdr'      if use_fdr and 'p_value_fdr' in stats_df.columns else 'p_value'
        sig_col = 'significant_fdr' if use_fdr and 'significant_fdr' in stats_df.columns else 'significant'

        local_top = ymax_pct
        for _, row in stats_df.iterrows():
            eg, cg = row['experimental_group'], row['control_group']
            if eg in groups and cg in groups:
                ei, ci = groups.index(eg), groups.index(cg)
                by = bracket_top_for_pair(percentages, ei, ci, ymax_pct)

                ax1.plot([ei, ei, ci, ci],
                         [by - 0.01*ymax_pct, by, by, by - 0.01*ymax_pct],
                         color='black', linewidth=0.6)

                p_val = row[p_col]
                ptxt = ("p < 0.001" if p_val < 0.001
                        else (f"p = {p_val:.3f}" if p_val < 0.01
                              else f"p = {p_val:.2f}"))
                if row[sig_col]:
                    ptxt += "*"

                ax1.text((ei+ci)/2, by + PTEXT_RAISE*ymax_pct,
                         ptxt, ha='center', va='bottom', fontsize=6, fontweight='bold')
                local_top = max(local_top, by + PTEXT_RAISE*ymax_pct)

        ax1.set_ylim(0, max(local_top * 1.05, ymax_pct * 1.35))

    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.tick_params(length=2, width=0.5)

    plt.tight_layout()

    if save_prefix:
        fname = f"{save_prefix}_{motif_name.lower()}_singlepanel.png"
        plt.savefig(
            fname, dpi=300, bbox_inches='tight', pad_inches=0.1,
            facecolor='white', edgecolor='none'
        )

    return fig


# ============================================================
#                BUILD FIGURES FOR ALL MOTIFS
# ============================================================

for name in motif_names:
    print(f"\nGenerating visualization for {name}...")
    fig = create_visualization(
        motif_data[name],
        name,
        motif_stats[name],
        save_prefix="AD_NMDesc_reg",  # "" if you don't want files saved
        use_fdr=False                 # True to label with FDR-adjusted p-values
    )
    plt.show()
