# Borzoi Variant Effect Prediction on Gene Expression
This notebook performs in silico mutagenesis to predict the effect of SNPs on gene expression.
**Features:**
- Configurable task filtering (assay, description, name keywords)
- Both TSS-centered and SNP-centered prediction strategies
- Averaging across all 4 Borzoi replicates
- Comprehensive validation checks
- Z-score normalization of effects


## 1. Setup and Imports


In [None]:

# Package dependencies
import os
from pathlib import Path
from typing import Dict, List, Optional
import grelu.resources
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import warnings
from utils.data_ingestion import (
    download_gtf_if_needed,
    get_gene_info,
    load_gene_annotations_for_gene,
    load_vcf,
)
from utils.genome import (
    create_centered_window,
    ensure_fasta_shortcut,
    get_output_bins_for_interval,
    validate_bins_in_output,
)
from utils.prediction import (
    filter_tasks,
    run_full_analysis,
)
warnings.filterwarnings('ignore')
print("Imports complete.")

## 2. Configuration
Set your parameters here:


In [None]:

# =============================================================================

# USER CONFIGURATION

# =============================================================================

# Gene of interest
GENE_NAME = "CBX8"  # Change this to your gene of interest

# VCF file path (set to your VCF file)
VCF_FILE = "path/to/your/variants.vcf"  # UPDATE THIS PATH

# Reference genome tag for grelu
REFERENCE_GENOME = "hg38"

# Paths and caching
GTF_CACHE_DIR = Path("/storage/home/mgc5166/work/Annotations/eQTL_annotations_for_susine/data/gtf_cache")
GTF_SHORTCUT_DIR = Path("data/gtf_shortcuts")
FASTA_SHORTCUT_DIR = Path("data/fasta_shortcuts")
FASTA_SHORTCUT_RADIUS = 1_000_000  # +/- bp around TSS to cache
OUTPUT_DIR = Path("output")

# Task filtering (case-insensitive, uses 'contains')

# Set to None to skip that filter
TASK_FILTER = {
    "assay": "rna",           # e.g., "rna", "cage", "atac"
    "description": "lung",    # e.g., "lung", "brain", "liver"
    "name": "gtex",           # e.g., "gtex", "encode"
}

# Aggregation method for predictions across tasks and bins
TASK_AGGFUNC = "mean"    # "mean" or "sum"
LENGTH_AGGFUNC = "sum"   # "mean" or "sum" (sum is typical for gene expression)

# Model replicates to use (0, 1, 2, 3 available)
MODEL_REPLICATES = [0, 1, 2, 3]

# Set device to GPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Device name: {torch.cuda.get_device_name(0)}")


## 3. Download/Load GTF Annotations


In [None]:

# Download and load GTF (cached if already present)
gtf_gene_shortcut = GTF_SHORTCUT_DIR / f"{GENE_NAME}_gene_annotations.csv"
gtf_exon_shortcut = GTF_SHORTCUT_DIR / f"{GENE_NAME}_exon_annotations.csv"
gtf_needed = not (gtf_gene_shortcut.exists() and gtf_exon_shortcut.exists())
if gtf_needed:
    gtf_path = download_gtf_if_needed(GTF_CACHE_DIR, genome=REFERENCE_GENOME)
else:

    # Path is unused when shortcuts already exist but kept for clarity
    gtf_path = GTF_CACHE_DIR / f"{REFERENCE_GENOME}_gencode.gtf.gz"

# Prefer a gene-specific shortcut to avoid reparsing the full file
genes_df, exons_df, used_gtf_shortcut = load_gene_annotations_for_gene(
    gtf_path,
    GENE_NAME,
    GTF_SHORTCUT_DIR,
)

# Extract gene metadata and exon coordinates
gene_info = get_gene_info(GENE_NAME, genes_df, exons_df)
print(f"Gene: {gene_info['gene_name']}")
print(f"  Location: {gene_info['chrom']}:{gene_info['start']}-{gene_info['end']}")
print(f"  Strand: {gene_info['strand']}")
print(f"  TSS: {gene_info['tss']}")
print(f"  Exons loaded: {len(gene_info['exons'])}")
print(f"  Annotations source: {'shortcut' if used_gtf_shortcut else 'full GTF parse'}")


In [None]:

# Cache a reference FASTA window around the TSS for faster reruns
FASTA_SHORTCUT_META = ensure_fasta_shortcut(
    chrom=gene_info["chrom"],
    tss=gene_info["tss"],
    gene_name=GENE_NAME,
    shortcut_dir=FASTA_SHORTCUT_DIR,
    genome=REFERENCE_GENOME,
    flank_radius=FASTA_SHORTCUT_RADIUS,
)
print(
    f"FASTA shortcut span: {FASTA_SHORTCUT_META['chrom']}:{FASTA_SHORTCUT_META['start']}-"
    f"{FASTA_SHORTCUT_META['end']} (radius {FASTA_SHORTCUT_META['radius']} bp)"
)
gene_info['exons'].head()


## 4. Load Borzoi Models


In [None]:
def load_borzoi_replicate(rep: int):
    """Load a single Borzoi replicate model."""
    print(f"Loading Borzoi human_rep{rep}...")
    model = grelu.resources.load_model(
        project="borzoi",
        model_name=f"human_rep{rep}",
    )
    return model

# Load all replicates
models = {}
for rep in MODEL_REPLICATES:
    models[rep] = load_borzoi_replicate(rep)

# Use first model for parameters
model = models[MODEL_REPLICATES[0]]
print(f"\nLoaded {len(models)} model replicates")

In [None]:

# Display model parameters
print("Model data parameters:")
for key in model.data_params['train'].keys():
    if key != "intervals":
        print(f"  {key}: {model.data_params['train'][key]}")

# Key parameters
SEQ_LEN = model.data_params["train"]["seq_len"]  # 524288
LABEL_LEN = model.data_params["train"]["label_len"]  # 196608
BIN_SIZE = model.data_params["train"]["bin_size"]  # 32
CROP_LEN = model.model_params.get("crop_len", 0)  # Cropped bins on each side

# Calculate the editable region (input positions that affect output)
CROP_BP = CROP_LEN * BIN_SIZE
EDITABLE_START = CROP_BP  # First position in input that produces output
EDITABLE_END = SEQ_LEN - CROP_BP  # Last position
print(f"\nKey dimensions:")
print(f"  Input sequence length: {SEQ_LEN:,} bp")
print(f"  Output length: {LABEL_LEN:,} bp")
print(f"  Bin size: {BIN_SIZE} bp")
print(f"  Crop length: {CROP_LEN} bins ({CROP_BP:,} bp)")
print(f"  Editable region: {EDITABLE_START:,} - {EDITABLE_END:,} ({EDITABLE_END - EDITABLE_START:,} bp)")


## 5. Filter Tasks


In [None]:
filtered_tasks, task_indices = filter_tasks(
    model,
    assay=TASK_FILTER.get("assay"),
    description=TASK_FILTER.get("description"),
    name=TASK_FILTER.get("name"),
)
print(f"Filtered to {len(task_indices)} tasks")
print(f"\nSample of filtered tasks:")
filtered_tasks.head(10)


In [None]:

# Validation: ensure we have tasks
if len(task_indices) == 0:
    raise ValueError(
        f"No tasks match the filter criteria: {TASK_FILTER}\n"
        "Please adjust TASK_FILTER in the configuration cell."
    )
print(f"Found {len(task_indices)} matching tasks")


## 6. Load and Validate VCF


In [None]:

# Load VCF

# For demo, create a synthetic VCF if file doesn't exist
if not os.path.exists(VCF_FILE):
    print(f"VCF file not found: {VCF_FILE}")
    print("Creating synthetic demo variants around the gene TSS...")
    np.random.seed(42)
    n_variants = 20
    demo_vcf = pd.DataFrame({
        'chrom': [gene_info['chrom']] * n_variants,
        'pos': gene_info['tss'] + np.random.randint(-100000, 100000, n_variants),
        'rsID': [f'rs{i}' for i in range(n_variants)],
        'ref': np.random.choice(['A', 'C', 'G', 'T'], n_variants),
        'alt': np.random.choice(['A', 'C', 'G', 'T'], n_variants),
    })
    for i in range(len(demo_vcf)):
        while demo_vcf.loc[i, 'ref'] == demo_vcf.loc[i, 'alt']:
            demo_vcf.loc[i, 'alt'] = np.random.choice(['A', 'C', 'G', 'T'])
    demo_vcf['pos_0based'] = demo_vcf['pos'] - 1
    variants_df = demo_vcf
else:
    variants_df = load_vcf(VCF_FILE)
print(f"\nLoaded {len(variants_df)} variants")
variants_df.head()


## 7. Validation Functions and Window Calculations


In [None]:

# Calculate number of output bins
N_OUTPUT_BINS = LABEL_LEN // BIN_SIZE
print(f"Number of output bins: {N_OUTPUT_BINS}")


In [None]:

# Create TSS-centered window for gene bin calculations
tss_window_start, _ = create_centered_window(gene_info['tss'], SEQ_LEN)
print(f"TSS-centered window start: {tss_window_start:,}")

In [None]:

# Get gene bins for TSS-centered window
gene_exon_bins_tss = get_output_bins_for_interval(
    model,
    gene_info['exons'],
    tss_window_start
)
print("Gene Exon Bins (TSS-centered window):")
print(gene_exon_bins_tss.head())

# Validate bins are in output
is_valid, msg = validate_bins_in_output(gene_exon_bins_tss, N_OUTPUT_BINS)
print(f"\nValidation: {msg}")
if is_valid:
    print("All gene bins are within output field")
else:
    print(f"WARNING: {msg}")


## 8. Main Analysis: Variant Effect Prediction


In [None]:
print(f"Running analysis on {len(variants_df)} variants...")

print(f"Using {len(models)} model replicates and 2 centering strategies")

print(f"(Total of {len(models) * 2} predictions per allele per variant)\n")

SNP_annotations = run_full_analysis(
    variants_df,
    models,
    gene_info,
    task_indices,
    seq_len=SEQ_LEN,
    editable_start=EDITABLE_START,
    editable_end=EDITABLE_END,
    n_output_bins=N_OUTPUT_BINS,
    task_agg=TASK_AGGFUNC,
    length_agg=LENGTH_AGGFUNC,
    device=DEVICE,
    genome=REFERENCE_GENOME,
    fasta_meta=FASTA_SHORTCUT_META,
)

## 9. Visualization and Normalization

In [None]:

# Z-score normalization
valid_diffs = SNP_annotations['diff'].dropna()
diff_mean = float('nan')
diff_std = float('nan')
if len(valid_diffs) > 1:
    diff_mean = valid_diffs.mean()
    diff_std = valid_diffs.std()
    if diff_std > 0:
        SNP_annotations['normalized_diff'] = (
            (SNP_annotations['diff'] - diff_mean) / diff_std
        )
    else:
        SNP_annotations['normalized_diff'] = 0.0
        print("Warning: Zero standard deviation in differences")
else:
    SNP_annotations['normalized_diff'] = np.nan
    print("Warning: Not enough valid differences for normalization")
print(f"Normalization complete.")
print(f"  Mean diff: {diff_mean:.6f}")
print(f"  Std diff: {diff_std:.6f}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw differences
ax1 = axes[0]
valid_diff = SNP_annotations['diff'].dropna()
ax1.hist(valid_diff, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2, label='No effect')
ax1.axvline(x=valid_diff.mean(), color='green', linestyle='-', linewidth=2, label=f'Mean: {valid_diff.mean():.4f}')
ax1.set_xlabel('Expression Difference (Alt - Ref)', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title(f'Raw Expression Deltas\n{GENE_NAME} ({len(valid_diff)} SNPs)', fontsize=14)
ax1.legend()
ax1.grid(alpha=0.3)

# Normalized differences
ax2 = axes[1]
valid_norm = SNP_annotations['normalized_diff'].dropna()
ax2.hist(valid_norm, bins=30, edgecolor='black', alpha=0.7, color='coral')
ax2.axvline(x=0, color='red', linestyle='--', linewidth=2, label='No effect')

# Add significance thresholds
ax2.axvline(x=-2, color='purple', linestyle=':', linewidth=1.5, label='|Z| = 2')
ax2.axvline(x=2, color='purple', linestyle=':', linewidth=1.5)
ax2.set_xlabel('Z-Score (Normalized Expression Difference)', fontsize=12)
ax2.set_ylabel('Count', fontsize=12)
ax2.set_title(f'Z-Score Normalized Expression Deltas\n{GENE_NAME} ({len(valid_norm)} SNPs)', fontsize=14)
ax2.legend()
ax2.grid(alpha=0.3)
plt.tight_layout()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
plt.savefig(OUTPUT_DIR / f'{GENE_NAME}_expression_delta_histograms.png', dpi=150, bbox_inches='tight')
plt.show()
print(f"\nSummary Statistics:")
print(f"  Raw diff - Mean: {valid_diff.mean():.6f}, Std: {valid_diff.std():.6f}")
print(f"  Raw diff - Min: {valid_diff.min():.6f}, Max: {valid_diff.max():.6f}")
print(f"  |Z| > 2: {(valid_norm.abs() > 2).sum()} variants ({100*(valid_norm.abs() > 2).mean():.1f}%)")


In [None]:

# Compare TSS-centered vs SNP-centered predictions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Reference expression comparison
ax1 = axes[0]
valid_mask = SNP_annotations['tss_ref_exp'].notna() & SNP_annotations['snp_ref_exp'].notna()
ax1.scatter(
    SNP_annotations.loc[valid_mask, 'tss_ref_exp'],
    SNP_annotations.loc[valid_mask, 'snp_ref_exp'],
    alpha=0.6
)

# Add diagonal line
lims = [
    min(ax1.get_xlim()[0], ax1.get_ylim()[0]),
    max(ax1.get_xlim()[1], ax1.get_ylim()[1])
]
ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
ax1.set_xlabel('TSS-Centered Prediction', fontsize=12)
ax1.set_ylabel('SNP-Centered Prediction', fontsize=12)
ax1.set_title('Reference Allele Expression\n(TSS vs SNP centering)', fontsize=14)
ax1.grid(alpha=0.3)

# Difference comparison
ax2 = axes[1]
tss_diff = SNP_annotations['tss_alt_exp'] - SNP_annotations['tss_ref_exp']
snp_diff = SNP_annotations['snp_alt_exp'] - SNP_annotations['snp_ref_exp']
valid_mask = tss_diff.notna() & snp_diff.notna()
ax2.scatter(tss_diff[valid_mask], snp_diff[valid_mask], alpha=0.6)
lims = [
    min(ax2.get_xlim()[0], ax2.get_ylim()[0]),
    max(ax2.get_xlim()[1], ax2.get_ylim()[1])
]
ax2.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
ax2.set_xlabel('TSS-Centered Effect', fontsize=12)
ax2.set_ylabel('SNP-Centered Effect', fontsize=12)
ax2.set_title('Variant Effect (Alt-Ref)\n(TSS vs SNP centering)', fontsize=14)
ax2.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / f'{GENE_NAME}_centering_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Correlation
if valid_mask.sum() > 2:
    corr = np.corrcoef(tss_diff[valid_mask], snp_diff[valid_mask])[0, 1]
    print(f"Correlation between TSS and SNP-centered effects: {corr:.4f}")


## 10. Export Results


In [None]:

# Export full results
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
output_file = OUTPUT_DIR / f"{GENE_NAME}_variant_effects.csv"
SNP_annotations.to_csv(output_file, index=False)
print(f"Results saved to: {output_file}")

# Display final summary
print(f"\n" + "="*60)
print(f"ANALYSIS COMPLETE")
print(f"="*60)
print(f"Gene: {GENE_NAME}")
print(f"Chromosome: {gene_info['chrom']}")
print(f"TSS: {gene_info['tss']:,}")
print(f"\nTask Filter: {TASK_FILTER}")
print(f"Number of tasks used: {len(task_indices)}")
print(f"\nModel replicates: {MODEL_REPLICATES}")
print(f"Centering strategies: TSS-centered, SNP-centered")
print(f"\nVariants analyzed: {len(SNP_annotations)}")
print(f"Variants with valid predictions: {SNP_annotations['diff'].notna().sum()}")
print(f"\nTop variants by |Z-score|:")
display_cols = ['rsID', 'chrom', 'position', 'ref_allele', 'alt_allele', 'diff', 'normalized_diff']
top_variants = SNP_annotations.nlargest(5, 'normalized_diff', keep='first')[display_cols]
print(top_variants.to_string())

## 11. Validation Summary


In [None]:
print("VALIDATION CHECKS SUMMARY")
print("="*60)

# Check 1: Reference allele matching
tss_ref_match = SNP_annotations['tss_ref_match'].dropna()
snp_ref_match = SNP_annotations['snp_ref_match'].dropna()
print(f"\n1. Reference Allele Matching:")
if len(tss_ref_match) > 0:
    print(f"   TSS-centered: {tss_ref_match.sum()}/{len(tss_ref_match)} matched ({100*tss_ref_match.mean():.1f}%)")
if len(snp_ref_match) > 0:
    print(f"   SNP-centered: {snp_ref_match.sum()}/{len(snp_ref_match)} matched ({100*snp_ref_match.mean():.1f}%)")

# Check 2: SNPs in editable region (TSS-centered)
print(f"\n2. SNPs in Editable Region (TSS-centered):")
n_valid_tss = SNP_annotations['tss_ref_exp'].notna().sum()
print(f"   {n_valid_tss}/{len(SNP_annotations)} variants ({100*n_valid_tss/len(SNP_annotations):.1f}%)")

# Check 3: Gene bins in output (TSS-centered)
print(f"\n3. Gene Bins in Output Field:")
is_valid, msg = validate_bins_in_output(gene_exon_bins_tss, N_OUTPUT_BINS)
status = "OK" if is_valid else "FAIL"
print(f"   {status} {msg}")

# Check 4: Missing predictions (detailed)
print(f"\n4. Missing Predictions:")
missing_mask = (SNP_annotations['tss_status'] != 'success') | (SNP_annotations['snp_status'] != 'success')
missing_variants = SNP_annotations[missing_mask]
print(f"   Total variants with any missing score: {len(missing_variants)}")
if len(missing_variants) > 0:
    print("   Detailed missing summaries:")
    for _, row in missing_variants.iterrows():
        missing_parts = []
        if row['tss_status'] != 'success':
            missing_parts.append(f"TSS ({row['tss_missing_reason']})")
        if row['snp_status'] != 'success':
            missing_parts.append(f"SNP ({row['snp_missing_reason']})")
        rel_pos = int(row['relative_to_tss'])
        missing_str = ', '.join(missing_parts)
        print(
            f"     - {row['rsID']} at {row['chrom']}:{row['position']} (offset {rel_pos:+} from TSS): "
            f"missing {missing_str}"
        )
else:
    print("   All variants have both TSS- and SNP-centered scores.")
print(f"\n" + "="*60)
