# Beaver Tutorial 4: GWAS Analysis (Data Scientist)

This tutorial demonstrates running a **Genome-Wide Association Study (GWAS)** pipeline using PLINK with privacy-preserving collaboration.

## Overview

- **You (Data Scientist)**: Define the GWAS analysis pipeline
- **Data Owner**: Provides real genomic data and runs PLINK on their machine
- **Results**: You receive Manhattan plots, QQ plots, and significant SNPs (not raw genotypes)

## How to Run This Tutorial

### Option 1: Two Browser Tabs (Solo Testing)
1. Create a session with yourself in BioVault
2. Open two Jupyter tabs from the same session
3. Run the DO notebook in one tab, this notebook in the other

### Option 2: With a Collaborator
1. Create a session with your collaborator
2. They run the DO notebook, you run this DS notebook

---

## Step 1: Setup

In [None]:
!uv pip install pandas numpy matplotlib -q
print("Dependencies installed!")

In [None]:
import beaver
from beaver import Twin
import time

bv = beaver.ctx()
session = bv.active_session()

print(f"You are: {bv.user}")
print(f"Session peer: {session.peer}")

## Step 2: Wait for GWAS Data from DO

**Run DO notebook Steps 1-4 first!**

In [None]:
print("Waiting for DO to publish 'gwas_data'...")

gwas_data = None
for _ in range(120):  # Wait up to 2 minutes
    peer_vars = session.peer_remote_vars
    if "gwas_data" in peer_vars:
        gwas_data = peer_vars["gwas_data"].load(inject=False, auto_accept=True)
        print(f"\nLoaded GWAS data info from {session.peer}!")
        break
    time.sleep(1)
    print(".", end="", flush=True)

if gwas_data is None:
    print("\nTimeout! Make sure DO has run Steps 1-4.")
else:
    display(gwas_data)

In [None]:
# Preview the mock data info (what we can see)
if gwas_data:
    mock = gwas_data.public
    print("Mock Dataset Info (safe to preview):")
    print(f"  Dataset 1: {mock['dataset1_name']}")
    print(f"    Samples: {mock['n_samples_1']}")
    print(f"    Variants: {mock['n_variants_1']}")
    print(f"  Dataset 2: {mock['dataset2_name']}")
    print(f"    Samples: {mock['n_samples_2']}")
    print(f"    Variants: {mock['n_variants_2']}")

## Step 3: Define the GWAS Pipeline Function

This function will be sent to the DO for execution on their private data.
It uses PLINK (which must be installed on the DO's machine) to:
1. Merge datasets
2. Run PCA for population stratification
3. Perform association testing
4. Generate Manhattan and QQ plots

In [None]:
@bv
def run_gwas_pipeline(data: dict) -> dict:
    """
    Run a complete GWAS pipeline using PLINK.
    
    Steps:
    1. Merge two datasets
    2. Run PCA for population stratification
    3. Perform logistic regression with PCA covariates
    4. Generate Manhattan and QQ plots
    
    Returns summary statistics (not raw genotypes).
    """
    import os
    import sys
    import subprocess
    import tempfile
    from pathlib import Path
    import pandas as pd
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator
    
    print("="*60)
    print("GWAS ANALYSIS PIPELINE")
    print("="*60)
    
    # Extract dataset info
    DATASET1 = data['dataset1_prefix']
    DATASET2 = data['dataset2_prefix']
    
    print(f"Dataset 1: {data['dataset1_name']}")
    print(f"Dataset 2: {data['dataset2_name']}")
    
    # Create work directory
    WORKDIR = Path(tempfile.mkdtemp(prefix="gwas_"))
    COMBINED = WORKDIR / "combined_qc"
    LOGS_DIR = WORKDIR / "logs"
    LOGS_DIR.mkdir(exist_ok=True)
    
    # Parameters
    N_PCS = 10
    THREADS = 4
    PLINK = "plink"
    ANNOTATION_PVAL = 1e-5
    GW_SIG = 5e-8
    
    def run_plink(args, log_name):
        """Run PLINK with logging."""
        cmd = [PLINK] + args
        log_path = LOGS_DIR / f"{log_name}.log"
        print(f"\n>> {' '.join(cmd[:6])}...")
        
        with log_path.open("w") as log_file:
            proc = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,
            )
            for line in proc.stdout:
                log_file.write(line)
            ret = proc.wait()
        
        if ret != 0:
            raise subprocess.CalledProcessError(ret, cmd)
        return ret
    
    def count_samples_and_cases(fam_path):
        """Count samples, cases, controls from FAM file."""
        total, cases, controls = 0, 0, 0
        with open(fam_path) as f:
            for line in f:
                if not line.strip():
                    continue
                total += 1
                fields = line.split()
                if len(fields) >= 6:
                    pheno = fields[5]
                    if pheno == "2":
                        cases += 1
                    elif pheno == "1":
                        controls += 1
        return total, cases, controls
    
    def count_snps(bim_path):
        """Count SNPs from BIM file."""
        count = 0
        with open(bim_path) as f:
            for _ in f:
                count += 1
        return count
    
    # ================================================================
    # STEP 1: MERGE DATASETS
    # ================================================================
    print("\n" + "="*60)
    print("STEP 1: Merging datasets")
    print("="*60)
    
    try:
        run_plink(
            ["--bfile", DATASET1, "--bmerge", DATASET2, "--out", str(COMBINED)],
            log_name="merge",
        )
    except subprocess.CalledProcessError:
        # Handle strand issues
        missnp_path = Path(f"{COMBINED}-merge.missnp")
        if missnp_path.exists():
            print("Found strand issues, flipping and re-merging...")
            flipped = f"{DATASET2}_flipped"
            run_plink(
                ["--bfile", DATASET2, "--flip", str(missnp_path), "--make-bed", "--out", flipped],
                log_name="flip",
            )
            run_plink(
                ["--bfile", DATASET1, "--bmerge", flipped, "--out", str(COMBINED)],
                log_name="merge_retry",
            )
    
    print("Datasets merged successfully!")
    
    # Get sample counts
    fam_path = Path(f"{COMBINED}.fam")
    total_samples, cases, controls = count_samples_and_cases(fam_path)
    total_snps = count_snps(Path(f"{COMBINED}.bim"))
    
    print(f"Combined: {total_samples} samples, {total_snps} SNPs")
    print(f"Cases: {cases}, Controls: {controls}")
    
    # ================================================================
    # STEP 2: PCA FOR POPULATION STRATIFICATION
    # ================================================================
    print("\n" + "="*60)
    print("STEP 2: PCA for population stratification")
    print("="*60)
    
    # LD pruning
    print("LD pruning...")
    run_plink(
        ["--bfile", str(COMBINED), "--indep-pairwise", "50", "5", "0.2", "--out", f"{COMBINED}_pruned"],
        log_name="ld_prune",
    )
    
    # Extract pruned SNPs
    run_plink(
        ["--bfile", str(COMBINED), "--extract", f"{COMBINED}_pruned.prune.in", "--make-bed", "--out", f"{COMBINED}_pruned_data"],
        log_name="extract_pruned",
    )
    
    # Remove AT/GC SNPs
    atgc_snps = WORKDIR / "atgc_snps.txt"
    with open(f"{COMBINED}_pruned_data.bim") as infile, open(atgc_snps, "w") as outfile:
        for line in infile:
            fields = line.strip().split()
            if len(fields) >= 6:
                rsid, a1, a2 = fields[1], fields[4], fields[5]
                if (a1, a2) in [("A", "T"), ("T", "A"), ("C", "G"), ("G", "C")]:
                    outfile.write(rsid + "\n")
    
    run_plink(
        ["--bfile", f"{COMBINED}_pruned_data", "--exclude", str(atgc_snps), "--make-bed", "--out", f"{COMBINED}_pruned_noambig"],
        log_name="remove_ambig",
    )
    
    # Run PCA
    print(f"Computing {N_PCS} principal components...")
    run_plink(
        ["--bfile", f"{COMBINED}_pruned_noambig", "--pca", str(N_PCS), "--threads", str(THREADS), "--out", f"{COMBINED}_pca"],
        log_name="pca",
    )
    
    # Add header to eigenvec
    eigenvec_path = Path(f"{COMBINED}_pca.eigenvec")
    header = " ".join(["FID", "IID"] + [f"PC{i}" for i in range(1, N_PCS + 1)]) + "\n"
    content = eigenvec_path.read_text()
    eigenvec_path.write_text(header + content)
    
    print("PCA complete!")
    
    # ================================================================
    # STEP 3: ASSOCIATION TESTING
    # ================================================================
    print("\n" + "="*60)
    print("STEP 3: Association testing")
    print("="*60)
    
    covar_names = ",".join([f"PC{i}" for i in range(1, N_PCS + 1)])
    
    print("Running logistic regression...")
    run_plink(
        ["--bfile", str(COMBINED), "--logistic", "hide-covar", "--covar", str(eigenvec_path),
         "--covar-name", covar_names, "--ci", "0.95", "--threads", str(THREADS), "--out", f"{COMBINED}_gwas"],
        log_name="gwas",
    )
    print("Association testing complete!")
    
    # ================================================================
    # STEP 4: GENERATE PLOTS
    # ================================================================
    print("\n" + "="*60)
    print("STEP 4: Generating plots")
    print("="*60)
    
    # Load results
    gwas_results = f"{COMBINED}_gwas.assoc.logistic"
    df = pd.read_csv(gwas_results, sep=r"\s+")
    
    # Clean P values
    df = df[df['P'].notna()]
    df = df[df['P'] != 'NA']
    df['P'] = pd.to_numeric(df['P'], errors='coerce')
    df = df.dropna(subset=['P'])
    df = df[(df['P'] > 0) & (df['P'] <= 1)]
    
    # Calculate -log10(p)
    df['NEGLOG10P'] = -np.log10(df['P'])
    df = df[np.isfinite(df['NEGLOG10P'])]
    
    # Convert CHR to numeric
    df['CHR'] = df['CHR'].replace({'X': 23, 'Y': 24, 'MT': 25, 'M': 25})
    df['CHR'] = pd.to_numeric(df['CHR'], errors='coerce')
    df = df.dropna(subset=['CHR', 'BP'])
    df['CHR'] = df['CHR'].astype(int)
    df['BP'] = df['BP'].astype(int)
    
    print(f"Loaded {len(df):,} SNPs")
    
    # Count significant
    gw_significant = df[df['P'] < GW_SIG]
    suggestive = df[df['P'] < ANNOTATION_PVAL]
    
    print(f"Genome-wide significant (P < {GW_SIG}): {len(gw_significant)}")
    print(f"Suggestive (P < {ANNOTATION_PVAL}): {len(suggestive)}")
    
    # --- Manhattan Plot ---
    print("\nGenerating Manhattan plot...")
    df = df.sort_values(['CHR', 'BP']).reset_index(drop=True)
    
    df['cumulative_pos'] = 0
    chr_centers = []
    last_pos = 0
    
    for chrom in sorted(df['CHR'].unique()):
        chr_df = df[df['CHR'] == chrom]
        chr_len = chr_df['BP'].max()
        df.loc[df['CHR'] == chrom, 'cumulative_pos'] = chr_df['BP'] + last_pos
        chr_centers.append(last_pos + chr_len / 2)
        last_pos += chr_len
    
    fig, ax = plt.subplots(figsize=(16, 6))
    colors = ['#3182bd', '#9ecae1']
    chr_list = sorted(df['CHR'].unique())
    
    for idx, chrom in enumerate(chr_list):
        chr_df = df[df['CHR'] == chrom]
        ax.scatter(chr_df['cumulative_pos'], chr_df['NEGLOG10P'],
                  c=colors[idx % 2], s=5, alpha=0.7, linewidths=0)
    
    gw_line = -np.log10(GW_SIG)
    sugg_line = -np.log10(ANNOTATION_PVAL)
    
    ax.axhline(y=gw_line, color='red', linestyle='--', linewidth=1.5,
               label=f'Genome-wide sig. (P={GW_SIG:.0e})', alpha=0.7)
    ax.axhline(y=sugg_line, color='blue', linestyle='--', linewidth=1,
               label=f'Suggestive (P={ANNOTATION_PVAL:.0e})', alpha=0.7)
    
    # Annotate top SNPs
    top_snps = df[df['P'] < ANNOTATION_PVAL].copy()
    if len(top_snps) > 20:
        top_snps = top_snps.nsmallest(20, 'P')
    
    for _, snp in top_snps.iterrows():
        ax.annotate(snp['SNP'],
                   xy=(snp['cumulative_pos'], snp['NEGLOG10P']),
                   xytext=(5, 5), textcoords='offset points',
                   fontsize=7, alpha=0.8)
        ax.scatter([snp['cumulative_pos']], [snp['NEGLOG10P']],
                  c='red', s=30, marker='D', zorder=5)
    
    ax.set_xticks(chr_centers)
    ax.set_xticklabels([str(c) for c in chr_list])
    ax.set_xlabel('Chromosome', fontsize=12, fontweight='bold')
    ax.set_ylabel('-log₁₀(P)', fontsize=12, fontweight='bold')
    ax.set_title('Genome-Wide Association Study Results', fontsize=14, fontweight='bold')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.show()
    
    # --- QQ Plot ---
    print("Generating QQ plot...")
    pvals = df['P'].dropna()
    pvals = pvals[pvals > 0]
    
    observed = -np.log10(sorted(pvals))
    n = len(observed)
    expected = -np.log10(np.arange(1, n + 1) / (n + 1))
    
    # Genomic inflation factor
    chisq_values = -2 * np.log(pvals)
    lambda_gc = np.median(chisq_values) / 0.456
    
    fig, ax = plt.subplots(figsize=(7, 7))
    ax.scatter(expected, observed, s=10, alpha=0.6, c='#3182bd')
    max_val = max(max(expected), max(observed))
    ax.plot([0, max_val], [0, max_val], 'r--', linewidth=2, alpha=0.7, label='Expected')
    ax.text(0.05, 0.95, f'λ = {lambda_gc:.3f}', transform=ax.transAxes, fontsize=11,
           verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    ax.set_xlabel('Expected -log₁₀(P)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Observed -log₁₀(P)', fontsize=12, fontweight='bold')
    ax.set_title('QQ Plot', fontsize=14, fontweight='bold')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.2)
    plt.tight_layout()
    plt.show()
    
    # ================================================================
    # PREPARE RESULTS
    # ================================================================
    print("\n" + "="*60)
    print("GWAS ANALYSIS COMPLETE!")
    print("="*60)
    
    # Get top SNPs list
    top_snps_list = []
    if len(suggestive) > 0:
        for _, row in suggestive.nsmallest(20, 'P').iterrows():
            top_snps_list.append({
                'snp': row['SNP'],
                'chr': int(row['CHR']),
                'bp': int(row['BP']),
                'p': float(row['P']),
                'or': float(row['OR']) if 'OR' in row else None,
            })
    
    # Return summary (NOT raw genotypes)
    return {
        'status': 'complete',
        'total_samples': total_samples,
        'cases': cases,
        'controls': controls,
        'total_snps': total_snps,
        'tested_snps': len(df),
        'gw_significant_count': len(gw_significant),
        'suggestive_count': len(suggestive),
        'lambda_gc': float(lambda_gc),
        'top_snps': top_snps_list,
        'n_pcs': N_PCS,
    }

## Step 4: Test on Mock Data

Run the GWAS pipeline on mock data first to verify it works.

In [None]:
if gwas_data:
    print("Testing GWAS pipeline on mock data...\n")
    print("(Note: Mock data is small, so results will be limited)\n")
    
    result = run_gwas_pipeline(gwas_data)
    
    print("\n=== Mock Result ===")
    if result.public:
        r = result.public
        print(f"Status: {r.get('status', 'unknown')}")
        print(f"Samples: {r.get('total_samples', 'N/A')}")
        print(f"SNPs tested: {r.get('tested_snps', 'N/A')}")

In [None]:
# View captured figures
if result and result.public_figures:
    print(f"Captured {len(result.public_figures)} figure(s):")
    result.show_figures("public")

## Step 5: Request GWAS on Real Data

In [None]:
if result:
    result.request_private()
    print("\nGWAS request sent to DO!")
    print("Run DO notebook Steps 5-9 now...")

## Step 6: Monitor GWAS Progress

Watch the DO's progress via the live variable.

In [None]:
print("Waiting for progress variable...")

progress = None
for _ in range(60):
    peer_vars = session.peer_remote_vars
    if "gwas_progress" in peer_vars:
        progress = peer_vars["gwas_progress"].load(inject=False, auto_accept=True)
        print("\nFound gwas_progress!")
        break
    time.sleep(1)
    print(".", end="", flush=True)

In [None]:
# Monitor GWAS progress
if progress:
    print("\nMonitoring GWAS progress...\n")
    print("{:^10} {:^20} {:^15}".format("Step", "Task", "Status"))
    print("-" * 50)
    
    last_step = -1
    for _ in range(600):  # Monitor for up to 10 minutes
        try:
            progress = session.peer_remote_vars["gwas_progress"].load(
                inject=False, auto_accept=True
            )
            p = progress.public
            
            step = p.get("step", 0)
            status = p.get("status", "unknown")
            
            if step != last_step or status in ["complete", "starting"]:
                last_step = step
                task = p.get("current_task", "")
                total = p.get("total_steps", "?")
                
                print("{:^10} {:^20} {:^15}".format(
                    f"{step}/{total}",
                    task[:20],
                    status
                ))
                
                if status == "complete":
                    print("\nGWAS analysis complete!")
                    break
        except Exception:
            pass
        
        time.sleep(2)

## Step 7: Receive GWAS Results

In [None]:
# Wait for approved results
final_result = bv.wait_for_response(result, timeout=900)  # 15 min timeout

if final_result:
    print("\n" + "="*60)
    print("GWAS RESULTS RECEIVED!")
    print("="*60)
    
    private = final_result.private
    print(f"\nSummary Statistics:")
    print(f"  Total Samples: {private.get('total_samples', 'N/A')}")
    print(f"  Cases: {private.get('cases', 'N/A')}")
    print(f"  Controls: {private.get('controls', 'N/A')}")
    print(f"  Total SNPs: {private.get('total_snps', 'N/A')}")
    print(f"  Tested SNPs: {private.get('tested_snps', 'N/A')}")
    print(f"\nResults:")
    print(f"  Genome-wide significant: {private.get('gw_significant_count', 0)}")
    print(f"  Suggestive: {private.get('suggestive_count', 0)}")
    print(f"  Genomic inflation (λ): {private.get('lambda_gc', 'N/A'):.4f}")
else:
    print("Timeout waiting for results.")

In [None]:
# Display top SNPs
if final_result and final_result.private.get('top_snps'):
    print("\n=== Top Significant SNPs ===")
    print("{:<15} {:>5} {:>12} {:>12} {:>8}".format("SNP", "CHR", "BP", "P-value", "OR"))
    print("-" * 55)
    
    for snp in final_result.private['top_snps'][:10]:
        print("{:<15} {:>5} {:>12,} {:>12.2e} {:>8.2f}".format(
            snp['snp'][:15],
            snp['chr'],
            snp['bp'],
            snp['p'],
            snp['or'] if snp['or'] else 0
        ))

In [None]:
# Display GWAS figures from real data
if final_result and final_result.private_figures:
    print(f"\n=== GWAS Plots from REAL Data ({len(final_result.private_figures)} figures) ===")
    final_result.show_figures("private")

In [None]:
# Compare mock vs real
if final_result:
    print("\n=== Mock vs Real Comparison ===")
    print("{:<25} {:>15} {:>15}".format("", "Mock Data", "Real Data"))
    print("-" * 55)
    
    mock = result.public
    real = final_result.private
    
    print("{:<25} {:>15} {:>15}".format(
        "Samples",
        mock.get('total_samples', 'N/A'),
        real.get('total_samples', 'N/A')
    ))
    print("{:<25} {:>15} {:>15}".format(
        "SNPs Tested",
        mock.get('tested_snps', 'N/A'),
        real.get('tested_snps', 'N/A')
    ))
    print("{:<25} {:>15} {:>15}".format(
        "GW Significant",
        mock.get('gw_significant_count', 0),
        real.get('gw_significant_count', 0)
    ))
    print("{:<25} {:>15} {:>15}".format(
        "Suggestive",
        mock.get('suggestive_count', 0),
        real.get('suggestive_count', 0)
    ))

## Summary

Congratulations! You've completed a full privacy-preserving GWAS analysis:

1. **Received mock data info** to understand the dataset structure
2. **Defined a GWAS pipeline** using PLINK commands
3. **Tested locally** on mock data
4. **Sent analysis request** to the data owner
5. **Monitored progress** via live variables
6. **Received results** including plots and significant SNPs

### What You Got

- Manhattan plot showing association signals
- QQ plot with genomic inflation factor
- Summary statistics (sample counts, SNP counts)
- List of significant/suggestive SNPs with p-values

### What You Didn't Get

- Raw genotype data (.bed files)
- Individual-level genotypes
- Full association results for all SNPs

### Privacy Preserved!

The data owner maintained control over their genomic data while you were able to run a complete GWAS analysis and receive aggregate results.