Analysis of SAE atoms that specialize in detecting Alu-mediated deletions.
This notebook identifies, validates, and characterizes Alu specialist atoms using:
- Statistical bias analysis
- UCSC RepeatMasker validation
- Genomic pattern analysis
- Biological interpretation

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import time
import requests
import json
from scipy import stats
from scipy.stats import chi2_contingency, fisher_exact, mannwhitneyu
import warnings
warnings.filterwarnings('ignore')

# Configuration
processed_dir = "../data/processed"
models_dir = "../data/models"
figures_dir = "../figures"

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
# Load SAE Results and Statistics

# Load the  data from previous notebooks
print("Loading SAE analysis results...")

# Load SAE embeddings and metadata
sae_data = torch.load(f"{processed_dir}/sae_sv_embeddings.pt", map_location=device)
test_sv_info = sae_data["sv_info"]
print(f"Loaded {len(test_sv_info)} SV samples")

# Load atom statistics from interpretability analysis
try:
    stats_df = pd.read_csv(f"{models_dir}/sae_atom_statistics.csv")
    print(f"Loaded statistics for {len(stats_df)} atoms")
except FileNotFoundError:
    print("SAE statistics not found - run interpretability analysis first")
    stats_df = None

# Load SAE activations (assuming they exist from training)
try:
    acts = torch.load(f"{models_dir}/sae_activations.pt", map_location=device)
    labels = torch.load(f"{models_dir}/sae_labels.pt", map_location=device)
    print(f"Loaded activations: {acts.shape}")
except FileNotFoundError:
    print("SAE activations not found")

In [None]:
## Identify Alu Specialist Candidates

def identify_alu_specialist_candidates(stats_df, test_sv_info, acts, labels, min_activations=15):
    """
    Identify atoms that are potential Alu deletion specialists based on:
    1. Strong deletion bias (>70% deletions)
    2. Alu-appropriate size clustering (250-350bp)
    3. Sufficient statistical evidence
    """

    print("IDENTIFYING ALU SPECIALIST CANDIDATES")
    print("=" * 50)

    if stats_df is None or acts is None:
        print("Missing required data")
        return {}

    candidates = {}

    # Get atoms with sufficient support
    valid_atoms = stats_df[stats_df['support'] >= min_activations]['atom'].values
    print(f"Analyzing {len(valid_atoms)} atoms with ≥{min_activations} activations")

    for atom_id in valid_atoms:
        # Get variants that activate this atom
        firing_mask = (acts == atom_id).any(dim=1)
        firing_indices = torch.where(firing_mask)[0].cpu().numpy()

        if len(firing_indices) == 0:
            continue

        # Extract variant information
        atom_variants = [test_sv_info[idx] for idx in firing_indices]

        # Analyze characteristics
        analysis = analyze_alu_potential(atom_id, atom_variants, stats_df)

        # Check Alu specialist criteria
        if meets_alu_criteria(analysis):
            candidates[atom_id] = analysis
            print(f"✓ Atom {atom_id}: {analysis['alu_score']:.1f} Alu score")

    print(f"\n Found {len(candidates)} Alu specialist candidates")
    return candidates

def analyze_alu_potential(atom_id, variants, stats_df):
    """Analyze potential for Alu-mediated deletion detection"""

    if not variants:
        return {'atom_id': atom_id, 'alu_score': 0}

    # Extract characteristics
    sizes = [abs(v.get('svlen', 0)) for v in variants]
    types = [v.get('svtype', 'UNK') for v in variants]
    truvari_classes = [v.get('truvari_class', 'UNK') for v in variants]

    # Get atom statistics
    atom_stats = stats_df[stats_df['atom'] == atom_id].iloc[0]

    analysis = {
        'atom_id': atom_id,
        'total_variants': len(variants),
        'sizes': sizes,
        'types': types,
        'truvari_classes': truvari_classes,
        'odds_ratio': atom_stats['odds'],
        'p_value': atom_stats['p'],
        'support': atom_stats['support']
    }

    # Calculate Alu-specific metrics
    deletions = sum(1 for t in types if t == 'DEL')
    deletion_rate = deletions / len(variants)

    alu_sized = sum(1 for s in sizes if 250 <= s <= 350)
    alu_sized_rate = alu_sized / len(variants)

    optimal_alu = sum(1 for s in sizes if 290 <= s <= 310)
    optimal_alu_rate = optimal_alu / len(variants)

    median_size = np.median(sizes) if sizes else 0
    size_cv = np.std(sizes) / np.mean(sizes) if sizes and np.mean(sizes) > 0 else float('inf')

    tp_count = sum(1 for tc in truvari_classes if tc in ['TP', 'tp_comp_vcf'])
    tp_rate = tp_count / len(variants)

    # Calculate Alu score
    alu_score = 0

    # Deletion bias (max 3 points)
    if deletion_rate > 0.8:
        alu_score += 3
    elif deletion_rate > 0.6:
        alu_score += 2
    elif deletion_rate > 0.4:
        alu_score += 1

    # Size clustering (max 4 points)
    if alu_sized_rate > 0.7:
        alu_score += 4
    elif alu_sized_rate > 0.5:
        alu_score += 3
    elif alu_sized_rate > 0.3:
        alu_score += 2

    # Optimal size enrichment (max 2 points)
    if optimal_alu_rate > 0.3:
        alu_score += 2
    elif optimal_alu_rate > 0.15:
        alu_score += 1

    # Size precision (max 1 point)
    if 280 <= median_size <= 320 and size_cv < 0.4:
        alu_score += 1

    analysis.update({
        'deletion_count': deletions,
        'deletion_rate': deletion_rate,
        'alu_sized_count': alu_sized,
        'alu_sized_rate': alu_sized_rate,
        'optimal_alu_count': optimal_alu,
        'optimal_alu_rate': optimal_alu_rate,
        'median_size': median_size,
        'size_cv': size_cv,
        'tp_count': tp_count,
        'tp_rate': tp_rate,
        'alu_score': alu_score
    })

    return analysis

def meets_alu_criteria(analysis, min_score=5):
    """Check if atom meets Alu specialist criteria"""
    return (
        analysis['alu_score'] >= min_score and
        analysis['deletion_rate'] > 0.5 and
        analysis['alu_sized_rate'] > 0.3
    )

In [None]:
# Run candidate identification
if stats_df is not None and acts is not None:
    alu_candidates = identify_alu_specialist_candidates(stats_df, test_sv_info, acts, labels)

In [None]:
# UCSC RepeatMasker Validation

def query_ucsc_repeatmasker(chrom, start, end, max_retries=3):
    """Query UCSC RepeatMasker track for repeat annotations"""

    url = "https://api.genome.ucsc.edu/getData/track"
    params = {
        'genome': 'hg38',
        'track': 'rmsk',
        'chrom': chrom,
        'start': start,
        'end': end
    }

    for attempt in range(max_retries):
        try:
            response = requests.get(url, params=params, timeout=30)
            if response.status_code == 200:
                return response.json().get('rmsk', [])
            else:
                print(f"API error {response.status_code}")
        except requests.RequestException as e:
            print(f"Request failed (attempt {attempt + 1}): {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff

    return []

def analyze_alu_overlap(sv_info, repeatmasker_hits):
    """Analyze if SV overlaps with Alu elements"""

    sv_start = sv_info.get('pos', 0)
    sv_end = sv_start + abs(sv_info.get('svlen', 0))

    alu_hits = []
    for hit in repeatmasker_hits:
        # Check if it's an Alu element
        if (hit.get('repClass') == 'SINE' and
            hit.get('repFamily') == 'Alu'):

            hit_start = hit.get('genoStart', 0)
            hit_end = hit.get('genoEnd', 0)

            # Check for overlap or proximity (within 100bp)
            distance = min(
                abs(sv_start - hit_end),
                abs(sv_end - hit_start)
            )
            overlap = max(0, min(sv_end, hit_end) - max(sv_start, hit_start))

            if overlap > 0 or distance <= 100:
                alu_hits.append({
                    'repName': hit.get('repName', ''),
                    'repFamily': hit.get('repFamily', ''),
                    'strand': hit.get('strand', ''),
                    'genoStart': hit_start,
                    'genoEnd': hit_end,
                    'overlap': overlap,
                    'distance': distance,
                    'swScore': hit.get('swScore', 0)
                })

    return {
        'validated': len(alu_hits) > 0,
        'alu_count': len(alu_hits),
        'primary_subfamily': alu_hits[0]['repName'] if alu_hits else 'None',
        'all_hits': alu_hits
    }

def validate_alu_candidates_with_repeatmasker(candidates, test_sv_info, acts, max_per_atom=50):
    """Validate Alu candidates using UCSC RepeatMasker"""

    print("UCSC REPEATMASKER VALIDATION")
    print("=" * 40)

    validation_results = {}

    for atom_id in candidates.keys():
        print(f"\n Validating Atom {atom_id}")

        # Get variants that activate this atom
        firing_mask = (acts == atom_id).any(dim=1)
        firing_indices = torch.where(firing_mask)[0].cpu().numpy()

        # Limit to reasonable number for API
        if len(firing_indices) > max_per_atom:
            firing_indices = np.random.choice(firing_indices, max_per_atom, replace=False)

        atom_validations = []

        for i, idx in enumerate(firing_indices):
            sv = test_sv_info[idx]

            # Query RepeatMasker with ±200bp window
            window_start = max(0, sv.get('pos', 0) - 200)
            window_end = sv.get('pos', 0) + abs(sv.get('svlen', 0)) + 200

            print(f"\r   Processing {i+1}/{len(firing_indices)}: {sv.get('chrom')}:{sv.get('pos')}", end="")

            try:
                repeatmasker_hits = query_ucsc_repeatmasker(
                    sv.get('chrom', ''), window_start, window_end
                )

                validation = analyze_alu_overlap(sv, repeatmasker_hits)
                validation.update({
                    'variant_idx': int(idx),
                    'chrom': sv.get('chrom'),
                    'pos': sv.get('pos'),
                    'svlen': sv.get('svlen'),
                    'svtype': sv.get('svtype'),
                    'truvari_class': sv.get('truvari_class')
                })

                atom_validations.append(validation)

                # Rate limiting
                time.sleep(0.3)

            except Exception as e:
                print(f" [Error: {e}]", end="")
                continue

        validation_results[atom_id] = atom_validations

        # Calculate validation rate
        validated_count = sum(1 for v in atom_validations if v['validated'])
        validation_rate = validated_count / len(atom_validations) * 100 if atom_validations else 0

        print(f"\n    Validation rate: {validation_rate:.1f}% ({validated_count}/{len(atom_validations)})")

    return validation_results

In [None]:
# Run validation
validation_results = validate_alu_candidates_with_repeatmasker(alu_candidates, test_sv_info, acts)

In [None]:
# Statistical Analysis of Validation Results

def analyze_validation_statistics(validation_results, alu_candidates):
    """Comprehensive statistical analysis of validation results"""

    print("STATISTICAL VALIDATION ANALYSIS")
    print("=" * 40)

    # Overall validation rates
    validation_stats = {}

    for atom_id, validations in validation_results.items():
        if not validations:
            continue

        validated_count = sum(1 for v in validations if v['validated'])
        total_count = len(validations)
        validation_rate = validated_count / total_count

        # Binomial confidence interval
        from statsmodels.stats.proportion import proportion_confint
        ci_low, ci_high = proportion_confint(validated_count, total_count, alpha=0.05)

        validation_stats[atom_id] = {
            'validated': validated_count,
            'total': total_count,
            'rate': validation_rate,
            'ci_low': ci_low,
            'ci_high': ci_high,
            'alu_score': alu_candidates[atom_id]['alu_score']
        }

        print(f"Atom {atom_id}: {validation_rate:.1%} validated ({validated_count}/{total_count})")
        print(f"   95% CI: [{ci_low:.1%}, {ci_high:.1%}]")
        print(f"   Alu Score: {alu_candidates[atom_id]['alu_score']:.1f}")

    # Subfamily analysis
    print(f"\n ALU SUBFAMILY DISTRIBUTION:")

    all_subfamilies = []
    for validations in validation_results.values():
        for v in validations:
            if v['validated'] and v['primary_subfamily'] != 'None':
                all_subfamilies.append(v['primary_subfamily'])

    subfamily_counts = Counter(all_subfamilies)
    for subfamily, count in subfamily_counts.most_common():
        percentage = count / len(all_subfamilies) * 100 if all_subfamilies else 0
        print(f"   {subfamily}: {count} ({percentage:.1f}%)")

    return validation_stats, subfamily_counts

# Run statistical analysis
validation_stats, subfamily_counts = analyze_validation_statistics(validation_results, alu_candidates)


In [None]:
# Visualization of Results

def create_alu_analysis_visualizations(alu_candidates, validation_results, validation_stats):
    """Create comprehensive visualizations of Alu analysis"""

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('SAE Alu Specialist Analysis Results', fontsize=16, fontweight='bold')

    # 1. Alu scores by atom
    ax = axes[0, 0]
    atom_ids = list(alu_candidates.keys())
    alu_scores = [alu_candidates[aid]['alu_score'] for aid in atom_ids]

    bars = ax.bar([f'A{aid}' for aid in atom_ids], alu_scores,
                  color='lightcoral', alpha=0.7, edgecolor='black')
    ax.set_ylabel('Alu Specialist Score')
    ax.set_title('Alu Specialist Scores by Atom')
    ax.set_ylim(0, 10)

    # Add score labels
    for bar, score in zip(bars, alu_scores):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
               f'{score:.1f}', ha='center', va='bottom')

    # 2. Validation rates
    ax = axes[0, 1]
    val_atoms = list(validation_stats.keys())
    val_rates = [validation_stats[aid]['rate'] for aid in val_atoms]

    bars = ax.bar([f'A{aid}' for aid in val_atoms], val_rates,
                  color='lightblue', alpha=0.7, edgecolor='black')
    ax.set_ylabel('Validation Rate')
    ax.set_title('RepeatMasker Validation Rates')
    ax.set_ylim(0, 1)

    # Add confidence intervals
    for i, (bar, aid) in enumerate(zip(bars, val_atoms)):
        ci_low = validation_stats[aid]['ci_low']
        ci_high = validation_stats[aid]['ci_high']
        ax.errorbar(bar.get_x() + bar.get_width()/2, bar.get_height(),
                   yerr=[[bar.get_height() - ci_low], [ci_high - bar.get_height()]],
                   fmt='none', color='black', capsize=3)

    # 3. Score vs validation correlation
    ax = axes[0, 2]
    scores_for_corr = [alu_candidates[aid]['alu_score'] for aid in val_atoms]
    ax.scatter(scores_for_corr, val_rates, s=100, alpha=0.7, color='green')

    # Add trend line
    if len(scores_for_corr) > 1:
        z = np.polyfit(scores_for_corr, val_rates, 1)
        p = np.poly1d(z)
        ax.plot(scores_for_corr, p(scores_for_corr), "r--", alpha=0.8)

    ax.set_xlabel('Alu Specialist Score')
    ax.set_ylabel('Validation Rate')
    ax.set_title('Score vs Validation Correlation')

    # Add atom labels
    for score, rate, aid in zip(scores_for_corr, val_rates, val_atoms):
        ax.annotate(f'A{aid}', (score, rate), xytext=(5, 5),
                   textcoords='offset points', fontsize=8)

    # 4. Size distributions for top candidate
    ax = axes[1, 0]
    if alu_candidates:
        top_atom = max(alu_candidates.keys(), key=lambda x: alu_candidates[x]['alu_score'])

        if 'sizes' in alu_candidates[top_atom]:
            sizes = alu_candidates[top_atom]['sizes']
            ax.hist(sizes, bins=30, alpha=0.7, color='orange', edgecolor='black')
            ax.axvspan(250, 350, alpha=0.3, color='red', label='Alu Range')
            ax.axvspan(290, 310, alpha=0.4, color='darkred', label='Optimal Alu')
            ax.set_xlabel('SV Size (bp)')
            ax.set_ylabel('Count')
            ax.set_title(f'Size Distribution - Atom {top_atom}')
            ax.legend()

    # 5. Subfamily distribution
    ax = axes[1, 1]
    if subfamily_counts:
        subfamilies, counts = zip(*subfamily_counts.most_common()[:8])
        ax.pie(counts, labels=subfamilies, autopct='%1.1f%%', startangle=90)
        ax.set_title('Alu Subfamily Distribution')

    # 6. Validation summary table
    ax = axes[1, 2]
    ax.axis('off')

    # Create summary table
    table_data = []
    for aid in validation_stats.keys():
        stats = validation_stats[aid]
        table_data.append([
            f'Atom {aid}',
            f'{stats["rate"]:.1%}',
            f'{stats["validated"]}/{stats["total"]}',
            f'{alu_candidates[aid]["alu_score"]:.1f}'
        ])

    if table_data:
        table = ax.table(cellText=table_data,
                        colLabels=['Atom', 'Val Rate', 'Count', 'Score'],
                        cellLoc='center',
                        loc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.5)
        ax.set_title('Validation Summary', pad=20)

    plt.tight_layout()
    plt.savefig(f'{figures_dir}/alu_specialist_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

# Create visualizations
create_alu_analysis_visualizations(alu_candidates, validation_results, validation_stats)