# AD Gene Feature Enrichment Downstream of Early PTC Variants

This notebook analyzes whether specific protein features occur downstream of
early premature termination codon (PTC)–causing variants. These variants include:

- **Minus1** (frameshift, –1)
- **Plus1** (frameshift, +1)
- **Nonsense** (stopgain)

We focus specifically on **autosomal dominant (AD) genes**, and for each feature
(Pfam domains, SLiMs, MoRFs, PTMs, NLS signals, and LCS regions) we:

1. Load per-gene binary feature flags  
2. Subset to AD genes only  
3. Aggregate if each feature appears downstream of PTC-causing variants  
4. Compare case vs. control groups using a two-sided binomial test  
5. Generate clean, publication-style bar plots showing:  
   - % of genes with the feature  
   - sample sizes per category  
   - statistical significance (case vs. control)  

This notebook serves as a compact workflow for generating the motif enrichment
panels used for downstream figure preparation.


## Import required libraries

We begin by importing the core scientific Python packages used throughout the
analysis. These include:

- **pandas / numpy** for data handling and array operations  
- **matplotlib** for generating publication-style plots  
- **scipy** for performing binomial tests  
- **statsmodels** for multiple-testing correction (FDR)

These imports support the statistical comparisons and visualizations generated in
the notebook.


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pandas as pd
import numpy as np
from scipy.stats import binomtest
from statsmodels.stats.multitest import multipletests

## Load and prepare AD gene feature data

We load the per-variant motif data and LCS data, remove duplicate variants, and
subset both tables to **autosomal dominant (AD) genes** using the AD gene list
from `pli_AD_genes.csv`.  

The notebook then standardizes the variant category labels (Minus1, Plus1,
Nonsense, and their Control groups) and applies a consistent, ordered grouping
so that all downstream analyses and plots use the same category sequence.


In [None]:
# ---------- Load and prepare data ----------
# Main motif data
df = pd.read_csv('motif_max_AD.csv')
df = df.drop_duplicates(subset='Variant_Key').copy()

#subsetting to autosomal dominant genes list
ad_genes_list = ad_genes_list = pd.read_csv('pli_AD_genes.csv')['gene'].unique().tolist()
#print(f"\nLoaded {len(ad_genes_list)} AD genes from pli_AD_genes.csv")
#print(f"\nFirst few genes from ad_genes_list: {ad_genes_list[:5]}")     
df = df[df['hgnc_symbol'].isin(ad_genes_list)].copy()
# LCS data from separate CSV
lcs_df = pd.read_csv('LCS_max_AD.csv')
lcs_df = lcs_df.drop_duplicates(subset='Variant_Key').copy()
lcs_df = lcs_df[lcs_df['hgnc_symbol'].isin(ad_genes_list)].copy()
#print(lcs_df.columns)

ORDER = ['Minus1','Minus1_Control','Plus1','Plus1_Control','Nonsense','Nonsense_Control']

def rename_and_capitalize_groups(df):
    df = df.copy()
    df['group'] = df['group'].replace({
        'SNV':'Nonsense','SNV_Control':'Nonsense_Control',
        'plus1':'Plus1','plus1_Control':'Plus1_Control',
        'minus1':'Minus1','minus1_Control':'Minus1_Control'
    })
    df['group'] = pd.Categorical(df['group'], categories=ORDER, ordered=True)
    return df.sort_values('group').reset_index(drop=True)

def apply_custom_order(df):
    order = ['Minus1','Minus1_Control','Plus1','Plus1_Control','Nonsense','Nonsense_Control']
    df = df.copy()
    df['group'] = pd.Categorical(df['group'], categories=order, ordered=True)
    return df.sort_values('group').reset_index(drop=True)

# Apply to both dataframes
df = apply_custom_order(rename_and_capitalize_groups(df))
lcs_df = apply_custom_order(rename_and_capitalize_groups(lcs_df))

## Plotting style, color scheme, and helper functions

This section sets up the plotting environment used for all figures.  
We define:

- **Matplotlib style settings** for consistent, publication-ready formatting  
  (font sizes, gridlines, DPI, axis appearance).
- **Color mappings** for each variant category (Minus1, Plus1, Nonsense and their
  Control counterparts), along with a lightweight function for generating softened
  colors when needed.
- **Helper functions** used throughout the notebook:
  - `get_colors_for_groups()` to map categories to colors  
  - `prepare_motif_data()` to aggregate TRUE/FALSE feature flags within each group  
  - `calculate_p_values_binomial()` to compute case–control p-values with
    FDR correction  
  - small utilities for significance bracket placement in the bar plots  

These components provide the standardized visual and statistical framework used
for all downstream motif enrichment plots.


In [None]:
# ---------------- Style ----------------
plt.rcParams.update({
    'font.family': 'Arial', 'font.size': 8,
    'axes.linewidth': 0.5, 'axes.labelsize': 8, 'axes.titlesize': 9,
    'xtick.labelsize': 7, 'ytick.labelsize': 7,
    'axes.spines.top': False, 'axes.spines.right': False,
    'axes.grid': True, 'grid.linewidth': 0.3, 'grid.alpha': 0.3,
    'figure.dpi': 300, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1
})

# ---------------- Colors ----------------
COLOR_SCHEME = {
    'Nonsense': '#70AD47', 'Nonsense_Control': '#C5E0B4',
    'Plus1': '#C65911', 'Plus1_Control': '#F4B183',
    'Minus1': '#4472C4', 'Minus1_Control': '#B4C7E7',
}

def lighten_color(color, amount=0.35):
    try: r,g,b = mcolors.to_rgb(color)
    except ValueError: r,g,b = (0.5,0.5,0.5)
    return (1-amount)+amount*r, (1-amount)+amount*g, (1-amount)+amount*b

# ---------------- Helpers ----------------
def get_colors_for_groups(labels):
    out=[]
    for lab in labels:
        if 'Control' in lab: out.append(lighten_color(CATEGORY_BASE_COLOR[lab]))
        else: out.append(COLOR_SCHEME[lab])
    return out

def prepare_motif_data(df, motif_col):
    """Aggregate motif data by group, excluding NA values"""
    summary_data = []
    
    for group in df['group'].unique():
        group_df = df[df['group'] == group]
        
        # Exclude NA values - only count TRUE and FALSE
        non_na_df = group_df[group_df[motif_col].notna()]
        
        # Count TRUE values (matched)
        matched = (non_na_df[motif_col] == True).sum()
        
        # Total is count of non-NA values (TRUE + FALSE)
        total = len(non_na_df)
        
        summary_data.append({
            'group': group,
            'matched': matched,
            'total': total
        })
    
    summary = pd.DataFrame(summary_data)
    return summary

# ---------- Stats with Binomial Test ----------
def calculate_p_values_binomial(df):
    """Calculate p-values using binomial test with FDR correction"""
    results = []
    base_groups = df['group'].str.replace('_Control','',regex=False).unique()
    
    for base in base_groups:
        if base not in df['group'].values or f'{base}_Control' not in df['group'].values:
            continue
            
        exp  = df.loc[df['group']==base].iloc[0]
        ctrl = df.loc[df['group']==f'{base}_Control'].iloc[0]
        
        # Skip if control has zero total
        if ctrl['total'] == 0:
            continue
        
        # Control proportion as expected rate
        p0 = ctrl['matched'] / ctrl['total']
        
        # Binomial test
        res = binomtest(int(exp['matched']), int(exp['total']),
                       p=p0, alternative='two-sided')
        
        results.append({
            'experimental_group': base,
            'control_group': f'{base}_Control',
            'exp_matched': int(exp['matched']),
            'exp_total': int(exp['total']),
            'ctrl_matched': int(ctrl['matched']),
            'ctrl_total': int(ctrl['total']),
            'control_proportion': p0,
            'p_value': res.pvalue,
            'significant': res.pvalue < 0.05
        })
    
    results_df = pd.DataFrame(results)
    
    # Add FDR correction if we have results
    if not results_df.empty:
        results_df['p_value_fdr'] = multipletests(results_df['p_value'], method='fdr_bh')[1]
        results_df['significant_fdr'] = results_df['p_value_fdr'] < 0.05
    
    return results_df

# ---------- Bracket height control ----------
BRACKET_RAISE = 0.12
PTEXT_RAISE   = 0.03

def bracket_top_for_pair(bar_heights, i, j, max_pct):
    top_pair  = max(bar_heights[i], bar_heights[j])
    lbl_off   = 0.022 * max_pct
    return top_pair + lbl_off + BRACKET_RAISE * max_pct

## Create motif enrichment figure (percentages and counts)

The `create_visualization` function generates a two-panel figure for a single
motif or feature (e.g., SLiM, MoRF, PTM, NLS, LCS, Pfam domains):

- **Panel A (left):**  
  - Bar plot of the **percentage** of variants whose downstream region contains
    the feature, across all variant categories.  
  - Bars are colored by variant category, labeled with both the percentage and
    sample size (`n`).  
  - Case vs. control pairs are annotated with binomial test p-values, using
    FDR-adjusted values if requested.

- **Panel B (right):**  
  - Stacked bar plot of **raw counts**: matched variants (feature present) vs.
    non-matched variants (feature absent) for each category.  
  - Bars are labeled with the number of matched variants and total variants.

The function also handles dynamic titles and axis labels based on the motif
name, standardizes axis limits, and optionally saves each figure as a
high-resolution PNG file suitable for inclusion in figures.


In [None]:
# ---------- Figure creation ----------
def create_visualization(df, motif_name, stats_df=None, save_prefix="", use_fdr=True):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.69, 3.0))

    if 'percent' not in df.columns:
        df = df.copy()
        df['percent'] = df['matched'] / df['total']

    groups         = df['group'].tolist()
    percentages    = (df['percent'] * 100).tolist()
    matched_counts = df['matched'].tolist()
    total_counts   = df['total'].tolist()

    colors = get_colors_for_groups(groups)
    x      = np.arange(len(groups))
    bar_w  = 0.6

    # --- Dynamic labels with full names in titles ---
    motif_full_names = {
        'SLiM': 'Short Linear Motif',
        'MoRF': 'Molecular Recognition Feature',
        'PTM': 'Post-Translational Modification',
        'NLS': 'Nuclear Localization Signal',
        'LCS': 'Low Complexity Sequence',
        'Domains': 'Pfam Domain'
    }
    
    full_name = motif_full_names.get(motif_name, motif_name)
    y_lab     = f'Percentage with {motif_name} (%)'  # Keep acronym for y-axis
    title_pct = f'{full_name} Enrichment'  # Full name, no "by Variant Type"
    title_cnt = f'{full_name} Counts'  # Full name, no "by Variant Type"

    # ---- Panel A: Percentages + p-value brackets ----
    bars1 = ax1.bar(x, percentages, width=bar_w, color=colors, alpha=0.8,
                    edgecolor='black', linewidth=0.3)
    ax1.set_xlabel('Variant Category', fontsize=8, fontweight='bold')
    ax1.set_ylabel(y_lab, fontsize=8, fontweight='bold')
    ax1.set_title(title_pct, fontsize=9, fontweight='bold', pad=10)
    ax1.set_xticks(x)
    ax1.set_xticklabels([g.replace('_','\n') for g in groups], fontsize=7, ha='center')

    ymax_pct = max(percentages) if percentages else 1.0
    ax1.set_ylim(0, ymax_pct * 1.35)

    # % and n labels
    for bar, pct, tot in zip(bars1, percentages, total_counts):
        h = bar.get_height()
        ax1.text(bar.get_x()+bar.get_width()/2., h + 0.022*ymax_pct, f'{pct:.1f}%',
                 ha='center', va='bottom', fontsize=6, fontweight='bold')
        ax1.text(bar.get_x()+bar.get_width()/2., (h/2 if h>0 else 0.2), f'n={tot}',
                 ha='center', va='center', fontsize=6, fontweight='bold',
                 color='black', bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8, edgecolor='none'))

    # p-value brackets (using FDR-corrected p-values if requested)
    if stats_df is not None and not stats_df.empty and len(groups) > 1:
        local_top = ymax_pct
        
        # Choose which p-value column to use
        p_col = 'p_value_fdr' if use_fdr and 'p_value_fdr' in stats_df.columns else 'p_value'
        sig_col = 'significant_fdr' if use_fdr and 'significant_fdr' in stats_df.columns else 'significant'
        
        for _, row in stats_df.iterrows():
            eg, cg = row['experimental_group'], row['control_group']
            if eg in groups and cg in groups:
                ei, ci = groups.index(eg), groups.index(cg)
                by = bracket_top_for_pair(percentages, ei, ci, ymax_pct)
                ax1.plot([ei, ei, ci, ci],
                         [by - 0.01*ymax_pct, by, by, by - 0.01*ymax_pct],
                         color='black', linewidth=0.6)
                
                p_val = row[p_col]
                ptxt = ("p < 0.001" if p_val < 0.001
                        else (f"p = {p_val:.3f}" if p_val < 0.01
                              else f"p = {p_val:.2f}"))
                if row[sig_col]: ptxt += "*"
                
                ax1.text((ei+ci)/2, by + PTEXT_RAISE*ymax_pct, ptxt,
                         ha='center', va='bottom', fontsize=6, fontweight='bold')
                local_top = max(local_top, by + PTEXT_RAISE*ymax_pct)
        ax1.set_ylim(0, max(local_top*1.05, ymax_pct*1.35))

    # ---- Panel B: Raw counts ----
    ax2.bar(x, matched_counts, width=bar_w, color=colors, alpha=0.8,
            edgecolor='black', linewidth=0.3)
    ax2.bar(x, [t-m for t, m in zip(total_counts, matched_counts)],
            bottom=matched_counts, width=bar_w, color='#E5E5E5', alpha=0.8,
            edgecolor='black', linewidth=0.3)
    ax2.set_xlabel('Variant Type', fontsize=8, fontweight='bold')
    ax2.set_ylabel('Count', fontsize=8, fontweight='bold')
    ax2.set_title(title_cnt, fontsize=9, fontweight='bold', pad=10)
    ax2.set_xticks(x)
    ax2.set_xticklabels([g.replace('_','\n') for g in groups], fontsize=7, ha='center')

    ymax_ct = max(total_counts) if total_counts else 1
    ax2.set_ylim(0, ymax_ct * 1.10)

    # Count labels
    for i, (m, t) in enumerate(zip(matched_counts, total_counts)):
        if m > t * 0.1:
            ax2.text(i, m/2, f'{m}', ha='center', va='center',
                     fontsize=6, fontweight='bold', color='white')
        ax2.text(i, t + 0.012*ymax_ct, f'{t}', ha='center', va='bottom',
                 fontsize=6, fontweight='bold')

    for ax in (ax1, ax2):
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.tick_params(length=2, width=0.5)

    plt.tight_layout()

    if save_prefix:
        fname = f"{save_prefix}_{motif_name.lower()}_motif_enrichment.png"
        plt.savefig(fname, dpi=300, bbox_inches='tight', pad_inches=0.1,
                    facecolor='white', edgecolor='none')
    return fig

## Aggregate motif data, run statistics, and generate figures

In the final step, we loop over each motif/feature and:

1. **Aggregate data by variant category**  
   - For each flag column (Pfam domains, SLiMs, MoRFs, PTMs, NLS) in the main
     dataframe, and the LCS flag from the separate LCS table, we build a
     per-group summary of:
       - number of variants with the feature (`matched`)
       - total number of informative variants (`total`).

2. **Run binomial tests for case vs. control**  
   - For each motif, we call `calculate_p_values_binomial()` to compare every
     case category (Minus1, Plus1, Nonsense) to its matched control group and
     store the resulting p-values and FDR-adjusted significance.

3. **Create and display plots**  
   - We pass the aggregated data and statistics into `create_visualization()` to
     generate and display the paired percentage + count plots for each motif.
   - Figures are also saved to disk with the prefix
     `"AD_variant_downstream"` for downstream use in the manuscript figures.


In [None]:
# ---------- Prepare data for each motif ----------
motif_columns_downstream = ['domains_flag', 'slim_flag', 'morf_flag', 'ptm_flag', 'nls_flag']
motif_names = ['Domains', 'SLiM', 'MoRF', 'PTM', 'NLS']

# Create dataframes for main motifs
motif_data = {}
for col, name in zip(motif_columns_downstream, motif_names):
    motif_data[name] = prepare_motif_data(df, col)
    # print(f"\n{name} motif summary:")
    # print(motif_data[name])

# Add LCS from separate file
motif_data['LCS'] = prepare_motif_data(lcs_df, 'LCS_flag')
# print(f"\nLCS motif summary:")
# print(motif_data['LCS'])

# ---------- Calculate statistics for all motifs ----------
motif_stats = {}
for name in motif_data.keys():
    motif_stats[name] = calculate_p_values_binomial(motif_data[name])
    # if not motif_stats[name].empty:
    #     print(f"\n{name} statistics:")
    #     print(motif_stats[name][['experimental_group', 'p_value', 'p_value_fdr', 'significant', 'significant_fdr']])

# ---------- Build figures for each motif ----------
for name in motif_data.keys():
    # print(f"\nGenerating visualization for {name}...")
    fig = create_visualization(
        motif_data[name],
        name,
        motif_stats[name],
        save_prefix="AD_variant_downstream",
        use_fdr=False
    )
    plt.show()

## Summary

This notebook loads AD gene variant data, aggregates downstream motif features
across variant categories, performs case–control binomial tests, and generates
publication-ready enrichment plots for all features.
