# ðŸ§¬ Paralog Detection and Summary (Ensembl API)

This notebook retrieves **human gene paralogs** using the [Ensembl REST API](https://rest.ensembl.org/), caches results locally,  
and generates both **gene-level** and **variant-type-level** summaries of paralog counts and sequence identity.  

The results are used in the paper to quantify how often genes in different variant categories (e.g., *Minus1*, *Plus1*, *Nonsense*, and their respective controls) have annotated paralogs.

---

## ðŸ“˜ Overview of this cell

**Purpose:**  
Fetch paralog information from Ensembl for each gene in a variant-annotated dataset (`pli_AD_genes.csv`),  
store the results in a local cache, and compute per-variant-type statistics.

**Main steps:**
1. **Define helper functions**
   - `get_ensembl_paralogs()` â€” calls Ensemblâ€™s `/homology/symbol/homo_sapiens/{gene}?type=paralogues` endpoint  
     and extracts each paralogâ€™s Ensembl ID, `dn_ds`, `% identity`, and `type`.
   - `save_paralog_cache()` / `load_paralog_cache()` â€” manage a local `paralog_cache.pkl` file to avoid re-fetching existing results.
   - `count_paralogs_by_variant_type()` â€” orchestrates the workflow:
     - Loads or creates the paralog cache  
     - Queries Ensembl for any uncached genes  
     - Counts the number of paralogs per gene  
     - Aggregates and summarizes counts and sequence identity by variant category  
     - Outputs two CSVs:
       - `paralog_summary_by_category.csv`  
       - `paralog_detailed_by_gene.csv`

2. **Pre-process input gene list**
   - Loads `pli_AD_genes.csv`
   - Normalizes variant categories (renames `SNV` â†’ `Nonsense`)

3. **Run analysis**
   - Optionally deletes any old cache to force fresh API retrieval  
   - Calls `count_paralogs_by_variant_type()`  
   - Saves and prints summary tables and top genes with the most paralogs.

---

### ðŸ’¾ Outputs

| File | Description |
|------|--------------|
| `paralog_cache.pkl` | Serialized dictionary `{gene â†’ list of paralog dicts}` for fast reuse |
| `paralog_summary_by_category.csv` | Per-category summary of paralog frequencies and averages |
| `paralog_detailed_by_gene.csv` | Per-gene paralog counts and average sequence identity |

### ðŸ§  Notes

- The API is rate-limited to one call every 0.2 s (â‰ˆ5 req/s).  
- Intermediate cache saves every 500 genes to prevent data loss.  
- Running this on the full gene list may take several minutes depending on connectivity.
- Figures in the published manuscript were generated using a cached paralog
dataset at the time of analysis. Because Ensembl REST API results can change over time
(or fail transiently), re-running this notebook may yield slightly different numerical values.

---


In [None]:
import requests
import pandas as pd
from tqdm import tqdm
import time
import pickle

def get_ensembl_paralogs(gene_symbol):
    """
    Get official paralogs from Ensembl Compara API
    """
    try:
        url = f"https://rest.ensembl.org/homology/symbol/homo_sapiens/{gene_symbol}"
        params = {'type': 'paralogues'}  # Removed format parameter
        headers = {"Content-Type": "application/json"}
        
        response = requests.get(url, params=params, headers=headers, timeout=15)
        if response.status_code == 200:
            data = response.json()
            paralogs = []
            if 'data' in data and len(data['data']) > 0:
                for homology in data['data'][0].get('homologies', []):
                    target = homology.get('target', {})
                    paralog_symbol = target.get('id', '')
                    if paralog_symbol:
                        paralogs.append({
                            'paralog_gene': paralog_symbol,
                            'dn_ds': homology.get('dn_ds'),
                            'perc_id': homology.get('perc_id'),
                            'type': homology.get('type', '')
                        })
            return paralogs
        return []
    except Exception as e:
        return []

def save_paralog_cache(cache, filename='paralog_cache.pkl'):
    """Save paralog cache to disk"""
    with open(filename, 'wb') as f:
        pickle.dump(cache, f)
    print(f"Paralog cache saved to {filename}")

def load_paralog_cache(filename='paralog_cache.pkl'):
    """Load paralog cache from disk"""
    try:
        with open(filename, 'rb') as f:
            cache = pickle.load(f)
            print(f"Paralog cache loaded: {len(cache)} genes")
            return cache
    except FileNotFoundError:
        print("No existing paralog cache found - will create new one")
        return {}
    except Exception as e:
        print(f"Error loading paralog cache: {e}")
        return {}

def count_paralogs_by_variant_type(df, gene_column='gene', variant_type_column='variant_type', 
                                   cache_filename='paralog_cache.pkl'):
    """
    Count paralogs for genes grouped by variant type
    
    Parameters:
    -----------
    df : DataFrame
        DataFrame containing genes and their variant types
    gene_column : str
        Name of column containing gene symbols
    variant_type_column : str
        Name of column containing variant type categories
    cache_filename : str
        Path to cache file for storing API results
    
    Returns:
    --------
    summary_df : DataFrame
        Summary statistics by variant type
    detailed_df : DataFrame
        Detailed information for each gene
    """
    print("Counting paralogs by variant type...")
    
    # Load existing cache
    gene_paralog_cache = load_paralog_cache(cache_filename)
    
    # Get unique genes
    unique_genes = df[gene_column].unique()
    print(f"Total unique genes: {len(unique_genes)}")
    
    # Find genes not in cache
    genes_to_fetch = [gene for gene in unique_genes if gene not in gene_paralog_cache]
    print(f"Genes already cached: {len(unique_genes) - len(genes_to_fetch)}")
    print(f"Genes to fetch: {len(genes_to_fetch)}")
    
    # Fetch paralog data for missing genes
    if genes_to_fetch:
        print("Fetching paralog data for missing genes...")
        failed_genes = 0
        
        for i, gene in enumerate(tqdm(genes_to_fetch, desc="Fetching paralogs")):
            try:
                paralogs = get_ensembl_paralogs(gene)
                gene_paralog_cache[gene] = paralogs
                time.sleep(0.2)  # Rate limiting
                
                # Save cache every 500 genes
                if (i + 1) % 500 == 0:
                    save_paralog_cache(gene_paralog_cache, cache_filename)
                    
            except Exception as e:
                failed_genes += 1
                gene_paralog_cache[gene] = []
        
        print(f"API failures: {failed_genes}")
        save_paralog_cache(gene_paralog_cache, cache_filename)
    
    # Count paralogs for each gene
    print("Counting paralogs for each gene...")
    gene_results = []
    
    for gene in tqdm(unique_genes, desc="Processing genes"):
        paralogs = gene_paralog_cache.get(gene, [])
        num_paralogs = len(paralogs)
        
        # Get variant types for this gene (a gene might appear in multiple rows)
        gene_rows = df[df[gene_column] == gene]
        variant_types = gene_rows[variant_type_column].unique()
        
        # Calculate average sequence identity if paralogs exist
        avg_seq_identity = 0
        if num_paralogs > 0:
            seq_ids = [p['perc_id'] for p in paralogs if p.get('perc_id') is not None]
            if seq_ids:
                avg_seq_identity = sum(seq_ids) / len(seq_ids)
        
        gene_results.append({
            'gene': gene,
            'variant_type': ', '.join(variant_types),  # In case gene has multiple types
            'num_paralogs': num_paralogs,
            'avg_sequence_identity': avg_seq_identity,
            'has_paralogs': num_paralogs > 0
        })
    
    detailed_df = pd.DataFrame(gene_results)
    
    # Create summary by variant type
    print("\nCreating summary statistics by variant type...")
    summary_results = []
    
    for variant_type in df[variant_type_column].unique():
        # Get genes of this variant type
        genes_of_type = df[df[variant_type_column] == variant_type][gene_column].unique()
        
        # Get their paralog counts
        type_data = detailed_df[detailed_df['gene'].isin(genes_of_type)]
        
        total_genes = len(genes_of_type)
        genes_with_paralogs = type_data['has_paralogs'].sum()
        total_paralogs = type_data['num_paralogs'].sum()
        avg_paralogs_per_gene = type_data['num_paralogs'].mean()
        
        # Calculate stats only for genes with paralogs
        genes_with_p = type_data[type_data['has_paralogs']]
        avg_paralogs_when_present = genes_with_p['num_paralogs'].mean() if len(genes_with_p) > 0 else 0
        avg_seq_identity = genes_with_p['avg_sequence_identity'].mean() if len(genes_with_p) > 0 else 0
        
        summary_results.append({
            'variant_type': variant_type,
            'total_genes': total_genes,
            'genes_with_paralogs': genes_with_paralogs,
            'percent_with_paralogs': (genes_with_paralogs / total_genes * 100) if total_genes > 0 else 0,
            'total_paralogs': total_paralogs,
            'avg_paralogs_per_gene': avg_paralogs_per_gene,
            'avg_paralogs_when_present': avg_paralogs_when_present,
            'avg_sequence_identity': avg_seq_identity
        })
    
    summary_df = pd.DataFrame(summary_results)
    summary_df = summary_df.sort_values('percent_with_paralogs', ascending=False)
    
    # Print summary
    print("\n" + "="*80)
    print("PARALOG ANALYSIS SUMMARY BY VARIANT TYPE")
    print("="*80)
    print(summary_df.to_string(index=False))
    print("="*80)
    
    return summary_df, detailed_df

# Load and prepare the data
print("Loading pli_all_gene_list.csv...")
df = pd.read_csv('pli_AD_genes.csv')

print(f"Loaded {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")
print(f"\nOriginal category values:")
print(df['category'].value_counts())

# Rename SNV to Nonsense (both with and without _Control)
df['category'] = df['category'].replace({
    'SNV': 'Nonsense',
    'SNV_Control': 'Nonsense_Control'
})
print(f"\nUpdated category values:")
print(df['category'].value_counts())

# IMPORTANT: Delete old cache to get correct data
import os
cache_file = 'paralog_cache.pkl'
if os.path.exists(cache_file):
    os.remove(cache_file)
    print(f"\n*** Deleted old cache file: {cache_file} ***")
    print("Will fetch fresh data from Ensembl API\n")

# Run the paralog analysis
summary_df, detailed_df = count_paralogs_by_variant_type(
    df, 
    gene_column='gene',
    variant_type_column='category'
)

# Save results
summary_df.to_csv('paralog_summary_by_category.csv', index=False)
detailed_df.to_csv('paralog_detailed_by_gene.csv', index=False)
print("\nResults saved to:")
print("  - paralog_summary_by_category.csv")
print("  - paralog_detailed_by_gene.csv")

# Show top genes with most paralogs
print("\nTop 20 genes with most paralogs:")
print(detailed_df.sort_values('num_paralogs', ascending=False).head(20).to_string(index=False))

## ðŸ“Š Paralog Enrichment by Variant Category

This section takes the paralog summary tables produced above and generates the
two-panel figure used in the paper:

- **Panel A â€“ Paralog Presence:**  
  Percentage of genes with â‰¥1 annotated paralog in each variant category, with
  binomial test p-values comparing each experimental group to its matched control.

- **Panel B â€“ Counts:**  
  Raw counts of genes with and without paralogs in each category.

---

### 1. Load summary tables

We load the two CSVs generated by the Ensembl paralog query step:

- `paralog_summary_by_category.csv`  
  One row per **variant category** with columns such as:
  - `variant_type`
  - `total_genes`
  - `genes_with_paralogs`
  - `percent_with_paralogs`
- `paralog_detailed_by_gene.csv`  
  One row per **gene** with the number of paralogs and mean sequence identity.

The code then:

1. Sets consistent **Matplotlib styling** (fonts, DPI, grid, etc.) so the
   exported PNG matches the figure aesthetics used in the paper.
2. Defines a color palette mapping each category and its control to paired colors
   (e.g., Minus1 vs. Minus1_Control).

---

### 2. Normalize category labels and ordering

To keep the plot readable and consistent:

- Raw `variant_type` labels are mapped to unified names:
  - `SNV â†’ Nonsense`
  - `plus1 â†’ Plus1`
  - `minus1 â†’ Minus1`
  - and their `_Control` counterparts.
- A `group_label` column is created and ordered as:  
  **Minus1, Minus1_Control, Plus1, Plus1_Control, Nonsense, Nonsense_Control**

The summary table is then renamed to a generic format:

- `genes_with_paralogs â†’ matched`
- `total_genes â†’ total`

so downstream plotting and statistics code can treat all groups in a uniform way.

---

### 3. Binomial test for enrichment vs. control

For each base category (`Minus1`, `Plus1`, `Nonsense`), we compare its fraction
of genes with paralogs to the fraction in its matching control:

\[
\text{p0} = \frac{\text{matched}_\text{control}}{\text{total}_\text{control}}
\]

Then we run a **binomial test** for the experimental group:

- **Null hypothesis:** experimental group has paralog frequency `p0`
- **Alternative:** two-sided (enriched or depleted vs. control)

Results:

- Raw p-values (`p_value`)
- FDR-corrected p-values (`p_value_fdr`) via Benjaminiâ€“Hochberg
- Flags for significance with and without FDR correction

These p-values are later rendered as text labels on the brackets in Panel A.

---

### 4. Figure construction

The helper function `create_paralog_visualization()` builds a **two-panel figure**:

1. **Panel A â€” Paralog Presence (left):**
   - Bar plot of `% with paralogs = matched / total Ã— 100` per category.
   - Bars are colored by category/control pairing.
   - Each bar is labeled with:
     - The percentage (above the bar).
     - The number of genes (`n=`) centered on the bar.
   - If statistics are available:
     - Brackets connect each experimental group to its control.
     - p-values (optionally FDR-adjusted) are written above the bracket.
     - Asterisk (`*`) marks significant comparisons.

2. **Panel B â€” Counts (right):**
   - Stacked bar plot:
     - **Top segment:** genes with paralogs.
     - **Bottom segment:** genes without paralogs.
   - Bars are labeled with:
     - Number of genes with paralogs (inside the colored segment, when large enough).
     - Total genes in the category (above each bar).

Stylistic details (spine removal, tick settings, grid lines, and tight layout)
are tuned to match the figure in the manuscript and export cleanly at 300 DPI.

The final figure is shown inline and, if `save_prefix` is provided, saved to:

- `AD_genes_paralog_enrichment.png`


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from scipy.stats import binomtest
from statsmodels.stats.multitest import multipletests


summary_df = pd.read_csv('paralog_summary_by_category.csv')
detailed_df = pd.read_csv('paralog_detailed_by_gene.csv')

# ---------- Matplotlib style ----------
plt.rcParams.update({
    'font.family': 'Arial',
    'font.size': 8,
    'axes.linewidth': 0.5,
    'axes.labelsize': 8,
    'axes.titlesize': 9,
    'xtick.labelsize': 7,
    'ytick.labelsize': 7,
    'legend.fontsize': 7,
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.linewidth': 0.3,
    'grid.alpha': 0.3
})

# ---------- Colors ----------
COLOR_SCHEME = {
    'Nonsense': '#70AD47',
    'Nonsense_Control': '#C5E0B4',
    'Plus1': '#C65911',
    'Plus1_Control': '#F4B183',
    'Minus1': '#4472C4',
    'Minus1_Control': '#B4C7E7',
}
CATEGORY_BASE_COLOR = {
    'Minus1': COLOR_SCHEME['Minus1'],
    'Minus1_Control': COLOR_SCHEME['Minus1'],
    'Plus1': COLOR_SCHEME['Plus1'],
    'Plus1_Control': COLOR_SCHEME['Plus1'],
    'Nonsense': COLOR_SCHEME['Nonsense'],
    'Nonsense_Control': COLOR_SCHEME['Nonsense'],
}

def lighten_color(color, amount=0.35):
    try:
        r, g, b = mcolors.to_rgb(color)
    except ValueError:
        r, g, b = (0.5, 0.5, 0.5)
    return (1 - amount) + amount * r, (1 - amount) + amount * g, (1 - amount) + amount * b

# ---------- Load and prepare paralog data ----------
# Assuming your summary_df is already created in your notebook
# If loading from CSV, use: summary_df = pd.read_csv('paralog_summary.csv')


# ---------- Label normalization ----------
def rename_and_capitalize_groups(df):
    df = df.copy()
    df['group_label'] = df['variant_type'].replace({
        'SNV': 'Nonsense', 'Nonsense_Control': 'Nonsense_Control',
        'plus1': 'Plus1', 'plus1_Control': 'Plus1_Control',
        'minus1': 'Minus1', 'minus1_Control': 'Minus1_Control'
    })
    return df

def apply_custom_order(df):
    order = ['Minus1','Minus1_Control','Plus1','Plus1_Control','Nonsense','Nonsense_Control']
    df = df.copy()
    df['group_label'] = pd.Categorical(df['group_label'], categories=order, ordered=True)
    return df.sort_values('group_label').reset_index(drop=True)

# Apply transformations
paralog_df = apply_custom_order(rename_and_capitalize_groups(summary_df))

# Rename columns to match expected format
paralog_df = paralog_df.rename(columns={
    'genes_with_paralogs': 'matched',
    'total_genes': 'total'
})

print("\nParalog data summary:")
print(paralog_df[['group_label', 'matched', 'total', 'percent_with_paralogs']])

# ---------- Stats with Binomial Test ----------
def calculate_p_values_binomial(df):
    """Calculate p-values using binomial test with FDR correction"""
    results = []
    base_groups = df['group_label'].str.replace('_Control','',regex=False).unique()
    
    for base in base_groups:
        if base not in df['group_label'].values or f'{base}_Control' not in df['group_label'].values:
            continue
            
        exp  = df.loc[df['group_label']==base].iloc[0]
        ctrl = df.loc[df['group_label']==f'{base}_Control'].iloc[0]
        
        # Skip if control has zero total
        if ctrl['total'] == 0:
            continue
        
        # Control proportion as expected rate
        p0 = ctrl['matched'] / ctrl['total']
        
        # Binomial test
        res = binomtest(int(exp['matched']), int(exp['total']),
                       p=p0, alternative='two-sided')
        
        results.append({
            'experimental_group': base,
            'control_group': f'{base}_Control',
            'exp_matched': int(exp['matched']),
            'exp_total': int(exp['total']),
            'ctrl_matched': int(ctrl['matched']),
            'ctrl_total': int(ctrl['total']),
            'control_proportion': p0,
            'p_value': res.pvalue,
            'significant': res.pvalue < 0.05
        })
    
    results_df = pd.DataFrame(results)
    
    # Add FDR correction if we have results
    if not results_df.empty:
        results_df['p_value_fdr'] = multipletests(results_df['p_value'], method='fdr_bh')[1]
        results_df['significant_fdr'] = results_df['p_value_fdr'] < 0.05
    
    return results_df

# Calculate statistics
paralog_stats = calculate_p_values_binomial(paralog_df)
if not paralog_stats.empty:
    print("\nParalog statistics:")
    print(paralog_stats[['experimental_group', 'p_value', 'significant']])

def get_colors_for_groups(group_labels):
    out = []
    for label in group_labels:
        if 'Control' in label:
            out.append(lighten_color(CATEGORY_BASE_COLOR[label]))
        else:
            out.append(COLOR_SCHEME[label])
    return out

# ---------- Bracket height control ----------
BRACKET_RAISE = 0.12
PTEXT_RAISE   = 0.03

def bracket_top_for_pair(bar_heights, i, j, max_pct):
    top_pair  = max(bar_heights[i], bar_heights[j])
    lbl_off   = 0.022 * max_pct
    return top_pair + lbl_off + BRACKET_RAISE * max_pct

# ---------- Figure creation ----------
def create_paralog_visualization(df, stats_df=None, save_prefix="", use_fdr=False):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6.69, 3.0))

    # Calculate percentage if not already present
    if 'percent' not in df.columns:
        df = df.copy()
        df['percent'] = df['matched'] / df['total']

    groups         = df['group_label'].tolist()
    percentages    = (df['percent'] * 100).tolist()
    matched_counts = df['matched'].tolist()
    total_counts   = df['total'].tolist()

    colors = get_colors_for_groups(groups)
    x      = np.arange(len(groups))
    bar_w  = 0.6

    # --- Labels ---
    y_lab     = 'Percentage with Paralogs (%)'
    title_pct = 'Paralog Presence'
    title_cnt = 'Counts'

    # ---- Panel A: Percentages + p-value brackets ----
    bars1 = ax1.bar(x, percentages, width=bar_w, color=colors, alpha=0.8,
                    edgecolor='black', linewidth=0.3)
    ax1.set_xlabel('Gene Category', fontsize=8, fontweight='bold')
    ax1.set_ylabel(y_lab, fontsize=8, fontweight='bold')
    ax1.set_title(title_pct, fontsize=9, fontweight='bold', pad=10)
    ax1.set_xticks(x)
    ax1.set_xticklabels([g.replace('_','\n') for g in groups], fontsize=7, ha='center')

    ymax_pct = max(percentages) if percentages else 1.0
    ax1.set_ylim(0, ymax_pct * 1.35)

    # % and n labels
    for bar, pct, tot in zip(bars1, percentages, total_counts):
        h = bar.get_height()
        ax1.text(bar.get_x()+bar.get_width()/2., h + 0.022*ymax_pct, f'{pct:.1f}%',
                 ha='center', va='bottom', fontsize=6, fontweight='bold')
        ax1.text(bar.get_x()+bar.get_width()/2., (h/2 if h>0 else 0.2), f'n={tot}',
                 ha='center', va='center', fontsize=6, fontweight='bold',
                 color='black', bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8, edgecolor='none'))

    # p-value brackets
    if stats_df is not None and not stats_df.empty and len(groups) > 1:
        local_top = ymax_pct
        
        # Choose which p-value column to use
        p_col = 'p_value_fdr' if use_fdr and 'p_value_fdr' in stats_df.columns else 'p_value'
        sig_col = 'significant_fdr' if use_fdr and 'significant_fdr' in stats_df.columns else 'significant'
        
        for _, row in stats_df.iterrows():
            eg, cg = row['experimental_group'], row['control_group']
            if eg in groups and cg in groups:
                ei, ci = groups.index(eg), groups.index(cg)
                by = bracket_top_for_pair(percentages, ei, ci, ymax_pct)
                ax1.plot([ei, ei, ci, ci],
                         [by - 0.01*ymax_pct, by, by, by - 0.01*ymax_pct],
                         color='black', linewidth=0.6)
                
                p_val = row[p_col]
                ptxt = ("p < 0.001" if p_val < 0.001
                        else (f"p = {p_val:.3f}" if p_val < 0.01
                              else f"p = {p_val:.2f}"))
                if row[sig_col]: ptxt += "*"
                
                ax1.text((ei+ci)/2, by + PTEXT_RAISE*ymax_pct, ptxt,
                         ha='center', va='bottom', fontsize=6, fontweight='bold')
                local_top = max(local_top, by + PTEXT_RAISE*ymax_pct)
        ax1.set_ylim(0, max(local_top*1.05, ymax_pct*1.35))

    # ---- Panel B: Raw counts ----
    ax2.bar(x, matched_counts, width=bar_w, color=colors, alpha=0.8,
            edgecolor='black', linewidth=0.3, label='With Paralogs')
    ax2.bar(x, [t-m for t, m in zip(total_counts, matched_counts)],
            bottom=matched_counts, width=bar_w, color='#E5E5E5', alpha=0.8,
            edgecolor='black', linewidth=0.3, label='Without Paralogs')
    ax2.set_xlabel('Gene Category', fontsize=8, fontweight='bold')
    ax2.set_ylabel('Count', fontsize=8, fontweight='bold')
    ax2.set_title(title_cnt, fontsize=9, fontweight='bold', pad=10)
    ax2.set_xticks(x)
    ax2.set_xticklabels([g.replace('_','\n') for g in groups], fontsize=7, ha='center')

    ymax_ct = max(total_counts) if total_counts else 1
    ax2.set_ylim(0, ymax_ct * 1.10)

    # Count labels
    for i, (m, t) in enumerate(zip(matched_counts, total_counts)):
        if m > t * 0.1:
            ax2.text(i, m/2, f'{m}', ha='center', va='center',
                     fontsize=6, fontweight='bold', color='white')
        ax2.text(i, t + 0.012*ymax_ct, f'{t}', ha='center', va='bottom',
                 fontsize=6, fontweight='bold')

    for ax in (ax1, ax2):
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.tick_params(length=2, width=0.5)

    plt.tight_layout()

    if save_prefix:
        fname = f"{save_prefix}_paralog_enrichment.png"
        plt.savefig(fname, dpi=300, bbox_inches='tight', pad_inches=0.1,
                    facecolor='white', edgecolor='none')
    return fig

# ---------- Build figure ----------
print("\nGenerating visualization for Paralogs...")
fig = create_paralog_visualization(
    paralog_df, 
    paralog_stats, 
    save_prefix="AD_genes",
    use_fdr=False
)
plt.show()

## ðŸ“¦ Paralog Count Distribution & Statistical Comparison

This section generates the **second figure** used in the manuscript:  
a boxplot showing the distribution of paralog counts per gene across variant categories
(Minus1, Plus1, Nonsense, and their matched control groups).

### ðŸ”§ What this cell does

1. **Rebuilds gene-level dataframe (`detailed_expanded`)**  
   From the earlier Ensembl paralog query results (`detailed_df`), this step:
   - extracts each gene's number of annotated paralogs,
   - assigns each gene to the correct variant category,
   - normalizes category names (e.g., `plus1` â†’ `Plus1`),
   - enforces a consistent category order for plotting.

2. **Runs Mannâ€“Whitney U tests**  
   For each experimental group vs. its matched control, we compare:
   - median paralog count  
   - mean paralog count  
   - distribution of paralog counts  
   
   This non-parametric test is appropriate because paralog counts are discrete and often skewed.
   FDR correction (Benjaminiâ€“Hochberg) is applied to adjust for multiple comparisons.

3. **Creates the paralog count distribution boxplot**  
   The function:
   - draws a grouped boxplot across the six categories,
   - colors each category/control pair consistently with Figure 1,
   - annotates each box with *n* (number of genes per group),
   - adds significance brackets and p-values above relevant comparisons,
   - saves the output figure as  
     **`AD_genes_paralog_distribution_boxplot.png`**.

### ðŸ“˜ Interpretation

This plot visualizes whether genes harboring different variant types tend to have:
- more paralogs (redundancy),
- fewer paralogs (uniqueness / singletons),
- or no significant difference.

It directly complements the paralog-presence barplot by capturing the **shape of the paralog count distribution** rather than just presence/absence.

---


In [None]:
from scipy.stats import mannwhitneyu

# ---------- Rebuild gene-level DF with group labels ----------
# detailed_df comes from the first paralog-analysis cell
# columns: gene, variant_type, num_paralogs, avg_sequence_identity, has_paralogs

detailed_expanded = detailed_df.copy()

# Normalize category names to match plotting order
detailed_expanded["group_label"] = detailed_expanded["variant_type"].replace({
    "SNV": "Nonsense",
    "SNV_Control": "Nonsense_Control",
    "plus1": "Plus1",
    "plus1_Control": "Plus1_Control",
    "minus1": "Minus1",
    "minus1_Control": "Minus1_Control",
})

# Ensure consistent category order
group_order = ["Minus1", "Minus1_Control",
               "Plus1", "Plus1_Control",
               "Nonsense", "Nonsense_Control"]

detailed_expanded["group_label"] = pd.Categorical(
    detailed_expanded["group_label"],
    categories=group_order,
    ordered=True
)
detailed_expanded = detailed_expanded.sort_values("group_label").reset_index(drop=True)

# ---------- Calculate statistics for paralog counts ----------
def calculate_paralog_count_stats(df):
    """Calculate Mann-Whitney U test for paralog counts between experimental and control groups"""
    results = []
    base_groups = df["group_label"].astype(str).str.replace("_Control", "", regex=False).unique()
    
    for base in base_groups:
        exp_mask  = df["group_label"] == base
        ctrl_mask = df["group_label"] == f"{base}_Control"
        if not exp_mask.any() or not ctrl_mask.any():
            continue
            
        exp_data  = df.loc[exp_mask,  "num_paralogs"].values
        ctrl_data = df.loc[ctrl_mask, "num_paralogs"].values
        
        # Skip if either group is empty
        if len(exp_data) == 0 or len(ctrl_data) == 0:
            continue
        
        # Mann-Whitney U test (non-parametric, good for count data)
        statistic, p_value = mannwhitneyu(exp_data, ctrl_data, alternative="two-sided")
        
        results.append({
            "experimental_group": base,
            "control_group": f"{base}_Control",
            "exp_n": len(exp_data),
            "exp_median": np.median(exp_data),
            "exp_mean": np.mean(exp_data),
            "ctrl_n": len(ctrl_data),
            "ctrl_median": np.median(ctrl_data),
            "ctrl_mean": np.mean(ctrl_data),
            "U_statistic": statistic,
            "p_value": p_value,
            "significant": p_value < 0.05,
        })
    
    results_df = pd.DataFrame(results)
    
    # Add FDR correction if we have results
    if not results_df.empty:
        results_df["p_value_fdr"] = multipletests(results_df["p_value"], method="fdr_bh")[1]
        results_df["significant_fdr"] = results_df["p_value_fdr"] < 0.05
    
    return results_df

# Calculate stats for paralog counts
paralog_count_stats = calculate_paralog_count_stats(detailed_expanded)
if not paralog_count_stats.empty:
    print("\nParalog count statistics (Mann-Whitney U test):")
    print(paralog_count_stats[["experimental_group", "exp_median", "ctrl_median", "p_value", "significant"]])

def create_paralog_distribution_boxplot_with_stats(df, stats_df=None, save_prefix="", use_fdr=False):
    fig, ax = plt.subplots(1, 1, figsize=(6.69, 4.5))
    
    groups = ["Minus1", "Minus1_Control",
              "Plus1", "Plus1_Control",
              "Nonsense", "Nonsense_Control"]
    colors = get_colors_for_groups(groups)  # defined in the previous plotting cell
    
    # Prepare data for each group
    data_by_group = [df[df["group_label"] == g]["num_paralogs"].values for g in groups]
    
    # Create box plot
    bp = ax.boxplot(
        data_by_group,
        labels=[g.replace("_", "\n") for g in groups],
        patch_artist=True,
        widths=0.6,
        boxprops=dict(linewidth=0.8),
        whiskerprops=dict(linewidth=0.8),
        capprops=dict(linewidth=0.8),
        medianprops=dict(linewidth=1.5, color="red"),
    )
    
    # Color the boxes
    for patch, color in zip(bp["boxes"], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.8)
    
    ax.set_xlabel("Gene Categories", fontsize=8, fontweight="bold")
    ax.set_ylabel("Number of Paralogs per Gene", fontsize=8, fontweight="bold")
    ax.set_title("Distribution of Paralog Counts", fontsize=9, fontweight="bold", pad=10)
    
    # Add n to x-tick labels
    current_labels = [g.replace("_", "\n") for g in groups]
    new_labels = [f"{label}\nn={len(data)}" for label, data in zip(current_labels, data_by_group)]
    ax.set_xticklabels(new_labels, fontsize=7, ha="center")
    
    # Get y-axis limits for bracket placement
    ymin, ymax = ax.get_ylim()
    
    # Add p-value brackets
    if stats_df is not None and not stats_df.empty:
        # Choose which p-value column to use
        p_col = "p_value_fdr" if use_fdr and "p_value_fdr" in stats_df.columns else "p_value"
        sig_col = "significant_fdr" if use_fdr and "significant_fdr" in stats_df.columns else "significant"
        
        max_vals = [data.max() if len(data) > 0 else 0 for data in data_by_group]
        bracket_height_increment = (ymax - ymin) * 0.08
        current_bracket_y = max(max_vals) + (ymax - ymin) * 0.05
        
        for _, row in stats_df.iterrows():
            eg, cg = row["experimental_group"], row["control_group"]
            if eg in groups and cg in groups:
                ei = groups.index(eg) + 1  # boxplot positions are 1-indexed
                ci = groups.index(cg) + 1
                
                pair_max = max(max_vals[ei-1], max_vals[ci-1])
                by = pair_max + (ymax - ymin) * 0.08
                
                # Draw bracket
                ax.plot(
                    [ei, ei, ci, ci],
                    [by - (ymax - ymin) * 0.01, by, by, by - (ymax - ymin) * 0.01],
                    color="black",
                    linewidth=0.6,
                )
                
                # p-value text
                p_val = row[p_col]
                ptxt = (
                    "p < 0.001" if p_val < 0.001
                    else (f"p = {p_val:.3f}" if p_val < 0.01 else f"p = {p_val:.2f}")
                )
                if row[sig_col]:
                    ptxt += "*"
                
                ax.text(
                    (ei + ci) / 2,
                    by + (ymax - ymin) * 0.02,
                    ptxt,
                    ha="center",
                    va="bottom",
                    fontsize=6,
                    fontweight="bold",
                )
                
                current_bracket_y = max(current_bracket_y, by + (ymax - ymin) * 0.06)
        
        # Adjust y-limit to accommodate brackets
        ax.set_ylim(ymin, current_bracket_y + (ymax - ymin) * 0.05)
    
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.tick_params(length=2, width=0.5)
    ax.grid(True, alpha=0.3, linewidth=0.3, axis="y")
    
    plt.tight_layout()
    
    if save_prefix:
        plt.savefig(
            f"{save_prefix}_paralog_distribution_boxplot.png",
            dpi=300,
            bbox_inches="tight",
            pad_inches=0.1,
            facecolor="white",
            edgecolor="none",
        )
    return fig

# ---------- Generate box plot with statistics ----------
print("\nGenerating box plot with statistics...")
fig2 = create_paralog_distribution_boxplot_with_stats(
    detailed_expanded, 
    paralog_count_stats,
    save_prefix="AD_genes",
    use_fdr=False
)
plt.show()


# ðŸŽ‰ Notebook Complete

This notebook performs the full paralog analysis used in the manuscript:

1. **Queries paralogs from Ensembl** for all genes in the dataset  
2. **Generates a cached gene-level paralog table** with sequence identity information  
3. **Computes variant-typeâ€“level paralog summaries**  
4. **Produces both figures used in the paper**:
   - **Paralog Presence Barplot**
   - **Paralog Count Distribution Boxplot**

Because the analysis depends on live Ensembl REST API responses, small differences
may occur if the notebook is re-run at a later time (e.g., updated annotations or
transient API failures).

If using this notebook for reproduction:
- Ensure stable internet access for the API step  
- Do **not** delete the generated `paralog_cache.pkl` unless intentionally re-querying Ensembl  
- Results will be identical only when the same cache and the same gene list are used  

Thank you for using this analysis pipeline!
