# Borzoi Variant Effect Prediction on Gene Expression

This notebook performs in silico mutagenesis to predict the effect of SNPs on gene expression.

**Features:**
- Configurable task filtering (assay, description, name keywords)
- Both TSS-centered and SNP-centered prediction strategies
- Averaging across all 4 Borzoi replicates
- Comprehensive validation checks
- Z-score normalization of effects

## 1. Setup and Imports

In [10]:
# Package dependencies
import numpy as np
import pandas as pd
import torch
import os
from pathlib import Path
from typing import List, Optional, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

# Disable W&B
# os.environ["WANDB_MODE"] = "disabled"
# os.environ["WANDB_DISABLED"] = "true"
# os.environ["WANDB_SILENT"] = "true"

import matplotlib.pyplot as plt
import seaborn as sns

print("Imports complete.")

Imports complete.


## 2. Configuration

Set your parameters here:

In [11]:
# =============================================================================
# USER CONFIGURATION
# =============================================================================

# Gene of interest
GENE_NAME = "CBX8"  # Change this to your gene of interest

# VCF file path (set to your VCF file)
VCF_FILE = "path/to/your/variants.vcf"  # UPDATE THIS PATH
GTF_CACHE_DIR = "/storage/home/mgc5166/work/Annotations/eQTL_annotations_for_susine/data/gtf_cache"

# Task filtering (case-insensitive, uses 'contains')
# Set to None to skip that filter
TASK_FILTER = {
    "assay": "rna",           # e.g., "rna", "cage", "atac"
    "description": "lung",    # e.g., "lung", "brain", "liver"
    "name": "gtex",           # e.g., "gtex", "encode"
}

# Aggregation method for predictions across tasks and bins
TASK_AGGFUNC = "mean"    # "mean" or "sum"
LENGTH_AGGFUNC = "sum"   # "mean" or "sum" (sum is typical for gene expression)

# Model replicates to use (0, 1, 2, 3 available)
MODEL_REPLICATES = [0, 1, 2, 3]

# Set device to GPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

def predict_on_sequence(
    model,
    seq: str,
    device: str = DEVICE
) -> np.ndarray:
    """
    Run model prediction on a single sequence.
    Returns array of shape (1, n_tasks, n_bins).
    """
    preds = model.predict_on_seqs([seq], device=device)  # Fixed: devices -> device
    return preds

print("Function redefined.")

Using device: cuda
Function redefined.


## 3. Download/Load GTF Annotations

In [12]:
def download_gtf_if_needed(cache_dir: Path, genome: str = "hg38") -> Path:
    """
    Download GENCODE GTF file if not present locally.
    Returns path to the GTF file.
    """
    gtf_path = cache_dir / f"{genome}_gencode.gtf.gz"
    
    if gtf_path.exists():
        print(f"GTF file already exists: {gtf_path}")
        return gtf_path
    
    print("Downloading GENCODE GTF file...")
    
    # GENCODE v44 for hg38
    url = "https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_44/gencode.v44.annotation.gtf.gz"
    
    import urllib.request
    urllib.request.urlretrieve(url, gtf_path)
    
    print(f"Downloaded GTF to: {gtf_path}")
    return gtf_path


def load_gene_annotations(gtf_path: Path) -> pd.DataFrame:
    """
    Load gene annotations from GTF file.
    Returns DataFrame with gene info including TSS.
    """
    import gzip
    
    print("Loading GTF file (this may take a moment)...")
    
    genes = []
    exons = []
    
    opener = gzip.open if str(gtf_path).endswith('.gz') else open
    
    with opener(gtf_path, 'rt') as f:
        for line in f:
            if line.startswith('#'):
                continue
            
            fields = line.strip().split('\t')
            if len(fields) < 9:
                continue
            
            chrom, source, feature, start, end, score, strand, frame, attributes = fields
            
            # Parse attributes
            attr_dict = {}
            for attr in attributes.split(';'):
                attr = attr.strip()
                if attr:
                    parts = attr.split(' ', 1)
                    if len(parts) == 2:
                        key, value = parts
                        attr_dict[key] = value.strip('"')
            
            gene_name = attr_dict.get('gene_name', '')
            gene_type = attr_dict.get('gene_type', '')
            
            if feature == 'gene':
                genes.append({
                    'chrom': chrom,
                    'start': int(start) - 1,  # Convert to 0-based
                    'end': int(end),
                    'gene_name': gene_name,
                    'gene_type': gene_type,
                    'strand': strand,
                })
            elif feature == 'exon':
                exons.append({
                    'chrom': chrom,
                    'start': int(start) - 1,
                    'end': int(end),
                    'gene_name': gene_name,
                    'strand': strand,
                })
    
    genes_df = pd.DataFrame(genes)
    exons_df = pd.DataFrame(exons)
    
    # Add TSS (start for + strand, end for - strand)
    genes_df['tss'] = genes_df.apply(
        lambda row: row['start'] if row['strand'] == '+' else row['end'] - 1,
        axis=1
    )
    
    print(f"Loaded {len(genes_df)} genes and {len(exons_df)} exons")
    return genes_df, exons_df


# Download and load GTF
gtf_path = download_gtf_if_needed(GTF_CACHE_DIR)
genes_df, exons_df = load_gene_annotations(gtf_path)

TypeError: unsupported operand type(s) for /: 'str' and 'str'

In [None]:
def get_gene_info(gene_name: str, genes_df: pd.DataFrame, exons_df: pd.DataFrame) -> Dict:
    """
    Get gene information including TSS and exon coordinates.
    """
    # Find gene (case-insensitive)
    gene_mask = genes_df['gene_name'].str.upper() == gene_name.upper()
    
    if not gene_mask.any():
        raise ValueError(f"Gene '{gene_name}' not found in annotations")
    
    gene_info = genes_df[gene_mask].iloc[0]
    
    # Get exons for this gene
    gene_exons = exons_df[exons_df['gene_name'].str.upper() == gene_name.upper()].copy()
    gene_exons = gene_exons.drop_duplicates(subset=['chrom', 'start', 'end'])
    
    result = {
        'gene_name': gene_info['gene_name'],
        'chrom': gene_info['chrom'],
        'start': gene_info['start'],
        'end': gene_info['end'],
        'strand': gene_info['strand'],
        'tss': gene_info['tss'],
        'exons': gene_exons[['chrom', 'start', 'end']].reset_index(drop=True),
    }
    
    print(f"\nGene: {result['gene_name']}")
    print(f"  Location: {result['chrom']}:{result['start']}-{result['end']}")
    print(f"  Strand: {result['strand']}")
    print(f"  TSS: {result['tss']}")
    print(f"  Number of exons: {len(result['exons'])}")
    
    return result


# Get info for our gene of interest
gene_info = get_gene_info(GENE_NAME, genes_df, exons_df)
gene_info['exons'].head()


Gene: CBX8
  Location: chr17:79792131-79801683
  Strand: -
  TSS: 79801682
  Number of exons: 11


Unnamed: 0,chrom,start,end
0,chr17,79796929,79797077
1,chr17,79796496,79796540
2,chr17,79796249,79796315
3,chr17,79796056,79796123
4,chr17,79792131,79795558


## 4. Load Borzoi Models

In [None]:
import grelu.resources

def load_borzoi_replicate(rep: int):
    """Load a single Borzoi replicate model."""
    print(f"Loading Borzoi human_rep{rep}...")
    model = grelu.resources.load_model(
        project="borzoi",
        model_name=f"human_rep{rep}",
    )
    return model


# Load all replicates
models = {}
for rep in MODEL_REPLICATES:
    models[rep] = load_borzoi_replicate(rep)

# Use first model for parameters
model = models[MODEL_REPLICATES[0]]

print(f"\nLoaded {len(models)} model replicates")



Loading Borzoi human_rep0...


[34m[1mwandb[0m: Currently logged in as: [33mmgc5166[0m ([33mmgc5166-penn-state[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact 'human_rep0:latest', 711.80MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:01.5 (477.4MB/s)


Loading Borzoi human_rep1...


[34m[1mwandb[0m: Downloading large artifact 'human_rep1:latest', 711.80MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:11.4 (62.4MB/s)


Loading Borzoi human_rep2...


[34m[1mwandb[0m: Downloading large artifact 'human_rep2:latest', 711.80MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:11.0 (64.5MB/s)


Loading Borzoi human_rep3...


[34m[1mwandb[0m: Downloading large artifact 'human_rep3:latest', 711.80MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:10.3 (69.2MB/s)



Loaded 4 model replicates


In [None]:
# Display model parameters
print("Model data parameters:")
for key in model.data_params['train'].keys():
    if key != "intervals":
        print(f"  {key}: {model.data_params['train'][key]}")

# Key parameters
SEQ_LEN = model.data_params["train"]["seq_len"]  # 524288
LABEL_LEN = model.data_params["train"]["label_len"]  # 196608
BIN_SIZE = model.data_params["train"]["bin_size"]  # 32
CROP_LEN = model.model_params.get("crop_len", 0)  # Cropped bins on each side

# Calculate the editable region (input positions that affect output)
CROP_BP = CROP_LEN * BIN_SIZE
EDITABLE_START = CROP_BP  # First position in input that produces output
EDITABLE_END = SEQ_LEN - CROP_BP  # Last position

print(f"\nKey dimensions:")
print(f"  Input sequence length: {SEQ_LEN:,} bp")
print(f"  Output length: {LABEL_LEN:,} bp")
print(f"  Bin size: {BIN_SIZE} bp")
print(f"  Crop length: {CROP_LEN} bins ({CROP_BP:,} bp)")
print(f"  Editable region: {EDITABLE_START:,} - {EDITABLE_END:,} ({EDITABLE_END - EDITABLE_START:,} bp)")

Model data parameters:
  seq_len: 524288
  label_len: 196608
  genome: hg38
  bin_size: 32

Key dimensions:
  Input sequence length: 524,288 bp
  Output length: 196,608 bp
  Bin size: 32 bp
  Crop length: 5120 bins (163,840 bp)
  Editable region: 163,840 - 360,448 (196,608 bp)


## 5. Filter Tasks

In [None]:
def filter_tasks(
    model,
    assay: Optional[str] = None,
    description: Optional[str] = None,
    name: Optional[str] = None,
) -> Tuple[pd.DataFrame, List[int]]:
    """
    Filter model tasks based on keywords (case-insensitive contains).
    Returns filtered tasks dataframe and list of task indices.
    """
    tasks = pd.DataFrame(model.data_params["tasks"])
    
    mask = pd.Series([True] * len(tasks))
    
    if assay is not None:
        mask &= tasks["assay"].astype(str).str.lower().str.contains(assay.lower(), na=False)
    
    if description is not None:
        mask &= tasks["description"].astype(str).str.lower().str.contains(description.lower(), na=False)
    
    if name is not None:
        mask &= tasks["name"].astype(str).str.lower().str.contains(name.lower(), na=False)
    
    filtered_tasks = tasks[mask]
    task_indices = filtered_tasks.index.tolist()
    
    return filtered_tasks, task_indices


# Apply our filter
filtered_tasks, task_indices = filter_tasks(
    model,
    assay=TASK_FILTER.get("assay"),
    description=TASK_FILTER.get("description"),
    name=TASK_FILTER.get("name"),
)

print(f"Filtered to {len(task_indices)} tasks")
print(f"\nSample of filtered tasks:")
filtered_tasks.head(10)

Filtered to 3 tasks

Sample of filtered tasks:


Unnamed: 0,name,file,clip,clip_soft,scale,sum_stat,strand_pair,description,assay,sample
7566,GTEX-1399S-1726-SM-5L3DI.1,/home/drk/tillage/datasets/human/rna/recount3/...,768,384,0.01,sum_sqrt,7566,RNA:lung,RNA,lung
7567,GTEX-14AS3-0926-SM-5TDD6.1,/home/drk/tillage/datasets/human/rna/recount3/...,768,384,0.01,sum_sqrt,7567,RNA:lung,RNA,lung
7568,GTEX-14JG1-0926-SM-5YY8W.1,/home/drk/tillage/datasets/human/rna/recount3/...,768,384,0.01,sum_sqrt,7568,RNA:lung,RNA,lung


In [None]:
# Validation: ensure we have tasks
if len(task_indices) == 0:
    raise ValueError(
        f"No tasks match the filter criteria: {TASK_FILTER}\n"
        "Please adjust TASK_FILTER in the configuration cell."
    )

print(f"✓ Found {len(task_indices)} matching tasks")

✓ Found 3 matching tasks


## 6. Load and Validate VCF

In [None]:
def load_vcf(vcf_path: str) -> pd.DataFrame:
    """
    Load VCF file into a DataFrame.
    Handles standard VCF format.
    """
    if not os.path.exists(vcf_path):
        raise FileNotFoundError(f"VCF file not found: {vcf_path}")
    
    # Determine if gzipped
    import gzip
    opener = gzip.open if vcf_path.endswith('.gz') else open
    
    # Find header line to get column names
    header_line = None
    skip_rows = 0
    
    with opener(vcf_path, 'rt') as f:
        for line in f:
            if line.startswith('##'):
                skip_rows += 1
            elif line.startswith('#CHROM'):
                header_line = line.strip().lstrip('#').split('\t')
                skip_rows += 1
                break
            else:
                break
    
    if header_line is None:
        # Assume standard VCF columns
        header_line = ['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO']
    
    # Load data
    df = pd.read_csv(
        vcf_path,
        sep='\t',
        comment='#',
        header=None,
        names=header_line[:8],  # Only take standard columns
        usecols=range(min(8, len(header_line))),
    )
    
    # Standardize column names
    df = df.rename(columns={
        'CHROM': 'chrom',
        'POS': 'pos',
        'ID': 'rsID',
        'REF': 'ref',
        'ALT': 'alt',
    })
    
    # Ensure chrom has 'chr' prefix
    if not df['chrom'].iloc[0].startswith('chr'):
        df['chrom'] = 'chr' + df['chrom'].astype(str)
    
    # Convert position to 0-based
    df['pos_0based'] = df['pos'] - 1
    
    # Filter to SNPs only (single base ref and alt)
    is_snp = (df['ref'].str.len() == 1) & (df['alt'].str.len() == 1)
    n_before = len(df)
    df = df[is_snp].reset_index(drop=True)
    print(f"Filtered {n_before - len(df)} non-SNP variants, {len(df)} SNPs remaining")
    
    return df


# Load VCF
# For demo, create a synthetic VCF if file doesn't exist
if not os.path.exists(VCF_FILE):
    print(f"VCF file not found: {VCF_FILE}")
    print("Creating synthetic demo variants around the gene TSS...")
    
    # Create demo variants near the gene TSS
    np.random.seed(42)
    n_variants = 20
    
    demo_vcf = pd.DataFrame({
        'chrom': [gene_info['chrom']] * n_variants,
        'pos': gene_info['tss'] + np.random.randint(-100000, 100000, n_variants),
        'rsID': [f'rs{i}' for i in range(n_variants)],
        'ref': np.random.choice(['A', 'C', 'G', 'T'], n_variants),
        'alt': np.random.choice(['A', 'C', 'G', 'T'], n_variants),
    })
    # Ensure ref != alt
    for i in range(len(demo_vcf)):
        while demo_vcf.loc[i, 'ref'] == demo_vcf.loc[i, 'alt']:
            demo_vcf.loc[i, 'alt'] = np.random.choice(['A', 'C', 'G', 'T'])
    
    demo_vcf['pos_0based'] = demo_vcf['pos'] - 1
    variants_df = demo_vcf
else:
    variants_df = load_vcf(VCF_FILE)

print(f"\nLoaded {len(variants_df)} variants")
variants_df.head()

VCF file not found: path/to/your/variants.vcf
Creating synthetic demo variants around the gene TSS...

Loaded 20 variants


Unnamed: 0,chrom,pos,rsID,ref,alt,pos_0based
0,chr17,79823640,rs0,T,G,79823639
1,chr17,79848549,rs1,C,T,79848548
2,chr17,79833614,rs2,C,T,79833613
3,chr17,79805376,rs3,C,G,79805375
4,chr17,79821561,rs4,T,C,79821560


## 7. Validation Functions and Window Calculations

In [None]:
def create_centered_window(center_pos: int, seq_len: int = SEQ_LEN) -> Tuple[int, int]:
    """
    Create a window of length seq_len centered on center_pos.
    Returns (start, end) as 0-based coordinates.
    """
    half_len = seq_len // 2
    start = center_pos - half_len
    end = start + seq_len
    return start, end


def position_in_window(pos: int, window_start: int, window_end: int) -> int:
    """
    Convert genomic position to position within window.
    Returns -1 if position is outside window.
    """
    if window_start <= pos < window_end:
        return pos - window_start
    return -1


def is_in_editable_region(pos_in_window: int) -> bool:
    """
    Check if position within window is in the editable region
    (i.e., will affect model output, not in cropped flanks).
    """
    return EDITABLE_START <= pos_in_window < EDITABLE_END


def get_output_bins_for_interval(
    model, 
    interval_df: pd.DataFrame, 
    window_start: int
) -> pd.DataFrame:
    """
    Get output bin indices for genomic intervals.
    Uses model.input_intervals_to_output_bins.
    """
    return model.input_intervals_to_output_bins(
        intervals=interval_df,
        start_pos=window_start
    )


def validate_bins_in_output(
    bin_df: pd.DataFrame,
    n_output_bins: int
) -> Tuple[bool, str]:
    """
    Validate that all bins are within valid output range.
    Returns (is_valid, message).
    """
    min_bin = bin_df['start'].min()
    max_bin = bin_df['end'].max()
    
    if min_bin < 0:
        return False, f"Bins extend before output start (min_bin={min_bin})"
    if max_bin > n_output_bins:
        return False, f"Bins extend beyond output end (max_bin={max_bin}, n_output={n_output_bins})"
    
    return True, f"All bins valid (range: {min_bin} - {max_bin}, output size: {n_output_bins})"


# Calculate number of output bins
N_OUTPUT_BINS = LABEL_LEN // BIN_SIZE
print(f"Number of output bins: {N_OUTPUT_BINS}")

Number of output bins: 6144


In [None]:
# =============================================================================
# TSS-CENTERED VALIDATION
# =============================================================================

# Create window centered on TSS
tss_window_start, tss_window_end = create_centered_window(gene_info['tss'])

print("TSS-Centered Window:")
print(f"  Window: {gene_info['chrom']}:{tss_window_start:,}-{tss_window_end:,}")
print(f"  TSS position: {gene_info['tss']:,}")
print(f"  TSS position in window: {gene_info['tss'] - tss_window_start:,}")

# Check which SNPs fall within the editable region
variants_df['pos_in_tss_window'] = variants_df['pos_0based'].apply(
    lambda x: position_in_window(x, tss_window_start, tss_window_end)
)
variants_df['in_tss_window'] = variants_df['pos_in_tss_window'] >= 0
variants_df['in_editable_region_tss'] = variants_df['pos_in_tss_window'].apply(
    lambda x: is_in_editable_region(x) if x >= 0 else False
)

n_in_window = variants_df['in_tss_window'].sum()
n_in_editable = variants_df['in_editable_region_tss'].sum()

print(f"\nSNP Validation (TSS-centered):")
print(f"  SNPs in window: {n_in_window}/{len(variants_df)}")
print(f"  SNPs in editable region: {n_in_editable}/{len(variants_df)}")

if n_in_editable == 0:
    print("  ⚠️ WARNING: No SNPs in editable region for TSS-centered analysis")
else:
    print(f"  ✓ {n_in_editable} SNPs can be analyzed with TSS-centered window")

TSS-Centered Window:
  Window: chr17:79,539,538-80,063,826
  TSS position: 79,801,682
  TSS position in window: 262,144

SNP Validation (TSS-centered):
  SNPs in window: 20/20
  SNPs in editable region: 19/20
  ✓ 19 SNPs can be analyzed with TSS-centered window


In [None]:
# Get gene bins for TSS-centered window
gene_exon_bins_tss = get_output_bins_for_interval(
    model, 
    gene_info['exons'],
    tss_window_start
)

print("Gene Exon Bins (TSS-centered window):")
print(gene_exon_bins_tss.head())

# Validate bins are in output
is_valid, msg = validate_bins_in_output(gene_exon_bins_tss, N_OUTPUT_BINS)
print(f"\nValidation: {msg}")
if is_valid:
    print("✓ All gene bins are within output field")
else:
    print(f"⚠️ WARNING: {msg}")

Gene Exon Bins (TSS-centered window):
   start   end
0   2923  2929
1   2909  2912
2   2902  2905
3   2896  2899
4   2773  2881

Validation: All bins valid (range: 2773 - 3073, output size: 6144)
✓ All gene bins are within output field


## 8. Prediction Functions

In [None]:
import grelu.sequence.format

def get_sequence_for_interval(
    chrom: str,
    start: int,
    end: int,
    genome: str = "hg38"
) -> str:
    """
    Get genomic sequence for an interval.
    """
    interval_df = pd.DataFrame({
        'chrom': [chrom],
        'start': [start],
        'end': [end],
        'strand': ['+'],
    })
    
    seqs = grelu.sequence.format.convert_input_type(
        interval_df,
        output_type="strings",
        genome=genome
    )
    return seqs[0]


def mutate_sequence(seq: str, pos: int, new_base: str) -> str:
    """
    Mutate a single position in the sequence.
    """
    return seq[:pos] + new_base + seq[pos+1:]


def predict_on_sequence(
    model,
    seq: str,
    device: str = DEVICE
) -> np.ndarray:
    """
    Run model prediction on a single sequence.
    Returns array of shape (1, n_tasks, n_bins).
    """
    preds = model.predict_on_seqs([seq], devices=device)
    return preds


def aggregate_predictions(
    preds: np.ndarray,
    task_indices: List[int],
    bin_indices: List[int],
    task_aggfunc: str = TASK_AGGFUNC,
    length_aggfunc: str = LENGTH_AGGFUNC,
) -> float:
    """
    Aggregate predictions over specified tasks and bins.
    
    Args:
        preds: Array of shape (1, n_tasks, n_bins)
        task_indices: List of task indices to use
        bin_indices: List of bin indices to use
        task_aggfunc: 'mean' or 'sum' for aggregating across tasks
        length_aggfunc: 'mean' or 'sum' for aggregating across bins
    
    Returns:
        Scalar aggregated prediction value
    """
    # Select tasks and bins
    selected = preds[0, task_indices, :][:, bin_indices]  # (n_selected_tasks, n_selected_bins)
    
    # Aggregate across bins (length axis)
    if length_aggfunc == "mean":
        agg_bins = np.mean(selected, axis=1)
    else:
        agg_bins = np.sum(selected, axis=1)
    
    # Aggregate across tasks
    if task_aggfunc == "mean":
        result = np.mean(agg_bins)
    else:
        result = np.sum(agg_bins)
    
    return float(result)


def get_all_bin_indices(bin_df: pd.DataFrame) -> List[int]:
    """
    Get all unique bin indices from start/end DataFrame.
    Clips to valid output range.
    """
    all_bins = set()
    for _, row in bin_df.iterrows():
        start = max(0, int(row['start']))
        end = min(N_OUTPUT_BINS, int(row['end']))
        if start < end:
            all_bins.update(range(start, end))
    return sorted(all_bins)


print("Prediction functions defined.")

Prediction functions defined.


## 9. Main Analysis: Variant Effect Prediction

In [None]:
def analyze_variant(
    variant_row: pd.Series,
    models: Dict,
    gene_info: Dict,
    task_indices: List[int],
    centering: str = "tss",  # "tss" or "snp"
) -> Dict:
    """
    Analyze a single variant across all model replicates.
    
    Args:
        variant_row: Row from variants DataFrame
        models: Dict of model replicates
        gene_info: Gene information dictionary
        task_indices: Task indices to use
        centering: "tss" or "snp"
    
    Returns:
        Dict with ref/alt predictions for each replicate
    """
    chrom = variant_row['chrom']
    snp_pos = variant_row['pos_0based']
    ref_allele = variant_row['ref'].upper()
    alt_allele = variant_row['alt'].upper()
    
    # Determine window based on centering strategy
    if centering == "tss":
        center = gene_info['tss']
    else:  # snp
        center = snp_pos
    
    window_start, window_end = create_centered_window(center)
    
    # Position of SNP within window
    snp_pos_in_window = snp_pos - window_start
    
    # Check if SNP is in editable region
    if not (0 <= snp_pos_in_window < SEQ_LEN):
        return None  # SNP outside window
    
    if centering == "tss" and not is_in_editable_region(snp_pos_in_window):
        return None  # SNP in cropped region
    
    # Get gene bin indices for this window
    gene_exon_bins = get_output_bins_for_interval(
        list(models.values())[0],
        gene_info['exons'],
        window_start
    )
    bin_indices = get_all_bin_indices(gene_exon_bins)
    
    # Skip if no valid bins
    if len(bin_indices) == 0:
        return None
    
    # Get sequence
    ref_seq = get_sequence_for_interval(chrom, window_start, window_end)
    
    # Verify reference allele matches
    seq_at_pos = ref_seq[snp_pos_in_window].upper()
    ref_match = seq_at_pos == ref_allele
    
    # Create alt sequence
    alt_seq = mutate_sequence(ref_seq, snp_pos_in_window, alt_allele)
    
    # Run predictions for each replicate
    ref_preds = []
    alt_preds = []
    
    for rep, model in models.items():
        ref_pred = predict_on_sequence(model, ref_seq)
        alt_pred = predict_on_sequence(model, alt_seq)
        
        ref_agg = aggregate_predictions(ref_pred, task_indices, bin_indices)
        alt_agg = aggregate_predictions(alt_pred, task_indices, bin_indices)
        
        ref_preds.append(ref_agg)
        alt_preds.append(alt_agg)
    
    return {
        'ref_preds': ref_preds,
        'alt_preds': alt_preds,
        'ref_mean': np.mean(ref_preds),
        'alt_mean': np.mean(alt_preds),
        'ref_match': ref_match,
        'seq_at_pos': seq_at_pos,
        'n_bins': len(bin_indices),
        'centering': centering,
    }


print("Analysis function defined.")

Analysis function defined.


In [None]:
from tqdm import tqdm

def run_full_analysis(
    variants_df: pd.DataFrame,
    models: Dict,
    gene_info: Dict,
    task_indices: List[int],
) -> pd.DataFrame:
    """
    Run analysis on all variants with both centering strategies.
    """
    results = []
    
    for idx, row in tqdm(variants_df.iterrows(), total=len(variants_df), desc="Analyzing variants"):
        # TSS-centered analysis
        tss_result = analyze_variant(
            row, models, gene_info, task_indices, centering="tss"
        )
        
        # SNP-centered analysis
        snp_result = analyze_variant(
            row, models, gene_info, task_indices, centering="snp"
        )
        
        # Combine results
        record = {
            'rsID': row['rsID'],
            'chrom': row['chrom'],
            'position': row['pos'],
            'ref_allele': row['ref'],
            'alt_allele': row['alt'],
        }
        
        # TSS-centered results
        if tss_result is not None:
            record['tss_ref_exp'] = tss_result['ref_mean']
            record['tss_alt_exp'] = tss_result['alt_mean']
            record['tss_ref_match'] = tss_result['ref_match']
        else:
            record['tss_ref_exp'] = np.nan
            record['tss_alt_exp'] = np.nan
            record['tss_ref_match'] = np.nan
        
        # SNP-centered results
        if snp_result is not None:
            record['snp_ref_exp'] = snp_result['ref_mean']
            record['snp_alt_exp'] = snp_result['alt_mean']
            record['snp_ref_match'] = snp_result['ref_match']
        else:
            record['snp_ref_exp'] = np.nan
            record['snp_alt_exp'] = np.nan
            record['snp_ref_match'] = np.nan
        
        results.append(record)
    
    results_df = pd.DataFrame(results)
    
    # Calculate average across centering strategies (8 predictions total: 4 reps x 2 centering)
    results_df['ref_gene_exp_level'] = results_df[['tss_ref_exp', 'snp_ref_exp']].mean(axis=1)
    results_df['alt_gene_exp_level'] = results_df[['tss_alt_exp', 'snp_alt_exp']].mean(axis=1)
    
    # Calculate differences
    results_df['diff'] = results_df['alt_gene_exp_level'] - results_df['ref_gene_exp_level']
    
    return results_df


# Run the analysis
print(f"Running analysis on {len(variants_df)} variants...")
print(f"Using {len(models)} model replicates and 2 centering strategies")
print(f"(Total of {len(models) * 2} predictions per allele per variant)\n")

SNP_annotations = run_full_analysis(
    variants_df,
    models,
    gene_info,
    task_indices,
)

Running analysis on 20 variants...
Using 4 model replicates and 2 centering strategies
(Total of 8 predictions per allele per variant)



Analyzing variants:   0%|          | 0/20 [00:01<?, ?it/s]


AcceleratorError: CUDA error: no kernel image is available for execution on the device
Search for `cudaErrorNoKernelImageForDevice' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
import torch
print(torch.__version__)
print(torch.version.cuda)

2.9.1+cu128
12.8


In [None]:
help(model.predict_on_seqs)

Help on method predict_on_seqs in module grelu.lightning:

predict_on_seqs(x: Union[str, List[str]], device: Union[str, int] = 'cpu') -> numpy.ndarray method of grelu.lightning.LightningModel instance
    A simple function to return model predictions directly
    on a batch of a single batch of sequences in string
    format.
    
    Args:
        x: DNA sequences as a string or list of strings.
        device: Index of the device to use
    
    Returns:
        A numpy array of predictions.



In [None]:
# First, verify PyTorch sees the GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
Device count: 1
Device name: Tesla P100-PCIE-12GB


In [None]:
# Z-score normalization
valid_diffs = SNP_annotations['diff'].dropna()

if len(valid_diffs) > 1:
    diff_mean = valid_diffs.mean()
    diff_std = valid_diffs.std()
    
    if diff_std > 0:
        SNP_annotations['normalized_diff'] = (
            (SNP_annotations['diff'] - diff_mean) / diff_std
        )
    else:
        SNP_annotations['normalized_diff'] = 0.0
        print("Warning: Zero standard deviation in differences")
else:
    SNP_annotations['normalized_diff'] = np.nan
    print("Warning: Not enough valid differences for normalization")

print(f"Normalization complete.")
print(f"  Mean diff: {diff_mean:.6f}")
print(f"  Std diff: {diff_std:.6f}")

In [None]:
# Display results
print(f"\nSNP Annotation Results for {GENE_NAME}:")
print(f"  Total variants analyzed: {len(SNP_annotations)}")
print(f"  Variants with valid predictions: {SNP_annotations['diff'].notna().sum()}")

# Select columns for display
display_cols = [
    'rsID', 'chrom', 'position', 'ref_allele', 'alt_allele',
    'ref_gene_exp_level', 'alt_gene_exp_level', 'diff', 'normalized_diff'
]

SNP_annotations[display_cols].head(20)

## 10. Visualization: Expression Delta Histograms

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw differences
ax1 = axes[0]
valid_diff = SNP_annotations['diff'].dropna()
ax1.hist(valid_diff, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
ax1.axvline(x=0, color='red', linestyle='--', linewidth=2, label='No effect')
ax1.axvline(x=valid_diff.mean(), color='green', linestyle='-', linewidth=2, label=f'Mean: {valid_diff.mean():.4f}')
ax1.set_xlabel('Expression Difference (Alt - Ref)', fontsize=12)
ax1.set_ylabel('Count', fontsize=12)
ax1.set_title(f'Raw Expression Deltas\n{GENE_NAME} ({len(valid_diff)} SNPs)', fontsize=14)
ax1.legend()
ax1.grid(alpha=0.3)

# Normalized differences
ax2 = axes[1]
valid_norm = SNP_annotations['normalized_diff'].dropna()
ax2.hist(valid_norm, bins=30, edgecolor='black', alpha=0.7, color='coral')
ax2.axvline(x=0, color='red', linestyle='--', linewidth=2, label='No effect')

# Add significance thresholds
ax2.axvline(x=-2, color='purple', linestyle=':', linewidth=1.5, label='|Z| = 2')
ax2.axvline(x=2, color='purple', linestyle=':', linewidth=1.5)

ax2.set_xlabel('Z-Score (Normalized Expression Difference)', fontsize=12)
ax2.set_ylabel('Count', fontsize=12)
ax2.set_title(f'Z-Score Normalized Expression Deltas\n{GENE_NAME} ({len(valid_norm)} SNPs)', fontsize=14)
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('expression_delta_histograms.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSummary Statistics:")
print(f"  Raw diff - Mean: {valid_diff.mean():.6f}, Std: {valid_diff.std():.6f}")
print(f"  Raw diff - Min: {valid_diff.min():.6f}, Max: {valid_diff.max():.6f}")
print(f"  |Z| > 2: {(valid_norm.abs() > 2).sum()} variants ({100*(valid_norm.abs() > 2).mean():.1f}%)")

In [None]:
# Compare TSS-centered vs SNP-centered predictions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Reference expression comparison
ax1 = axes[0]
valid_mask = SNP_annotations['tss_ref_exp'].notna() & SNP_annotations['snp_ref_exp'].notna()
ax1.scatter(
    SNP_annotations.loc[valid_mask, 'tss_ref_exp'],
    SNP_annotations.loc[valid_mask, 'snp_ref_exp'],
    alpha=0.6
)
# Add diagonal line
lims = [
    min(ax1.get_xlim()[0], ax1.get_ylim()[0]),
    max(ax1.get_xlim()[1], ax1.get_ylim()[1])
]
ax1.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
ax1.set_xlabel('TSS-Centered Prediction', fontsize=12)
ax1.set_ylabel('SNP-Centered Prediction', fontsize=12)
ax1.set_title('Reference Allele Expression\n(TSS vs SNP centering)', fontsize=14)
ax1.grid(alpha=0.3)

# Difference comparison
ax2 = axes[1]
tss_diff = SNP_annotations['tss_alt_exp'] - SNP_annotations['tss_ref_exp']
snp_diff = SNP_annotations['snp_alt_exp'] - SNP_annotations['snp_ref_exp']
valid_mask = tss_diff.notna() & snp_diff.notna()

ax2.scatter(tss_diff[valid_mask], snp_diff[valid_mask], alpha=0.6)
lims = [
    min(ax2.get_xlim()[0], ax2.get_ylim()[0]),
    max(ax2.get_xlim()[1], ax2.get_ylim()[1])
]
ax2.plot(lims, lims, 'r--', alpha=0.75, zorder=0)
ax2.set_xlabel('TSS-Centered Effect', fontsize=12)
ax2.set_ylabel('SNP-Centered Effect', fontsize=12)
ax2.set_title('Variant Effect (Alt-Ref)\n(TSS vs SNP centering)', fontsize=14)
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('centering_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# Correlation
if valid_mask.sum() > 2:
    corr = np.corrcoef(tss_diff[valid_mask], snp_diff[valid_mask])[0, 1]
    print(f"Correlation between TSS and SNP-centered effects: {corr:.4f}")

## 11. Export Results

In [None]:
# Export full results
output_file = f'{GENE_NAME}_variant_effects.csv'
SNP_annotations.to_csv(output_file, index=False)
print(f"Results saved to: {output_file}")

# Display final summary
print(f"\n" + "="*60)
print(f"ANALYSIS COMPLETE")
print(f"="*60)
print(f"Gene: {GENE_NAME}")
print(f"Chromosome: {gene_info['chrom']}")
print(f"TSS: {gene_info['tss']:,}")
print(f"\nTask Filter: {TASK_FILTER}")
print(f"Number of tasks used: {len(task_indices)}")
print(f"\nModel replicates: {MODEL_REPLICATES}")
print(f"Centering strategies: TSS-centered, SNP-centered")
print(f"\nVariants analyzed: {len(SNP_annotations)}")
print(f"Variants with valid predictions: {SNP_annotations['diff'].notna().sum()}")
print(f"\nTop variants by |Z-score|:")
top_variants = SNP_annotations.nlargest(5, 'normalized_diff', keep='first')[display_cols]
print(top_variants.to_string())

## 12. Validation Summary

In [None]:
print("VALIDATION CHECKS SUMMARY")
print("="*60)

# Check 1: Reference allele matching
tss_ref_match = SNP_annotations['tss_ref_match'].dropna()
snp_ref_match = SNP_annotations['snp_ref_match'].dropna()

print(f"\n1. Reference Allele Matching:")
if len(tss_ref_match) > 0:
    print(f"   TSS-centered: {tss_ref_match.sum()}/{len(tss_ref_match)} matched ({100*tss_ref_match.mean():.1f}%)")
if len(snp_ref_match) > 0:
    print(f"   SNP-centered: {snp_ref_match.sum()}/{len(snp_ref_match)} matched ({100*snp_ref_match.mean():.1f}%)")

# Check 2: SNPs in editable region (TSS-centered)
print(f"\n2. SNPs in Editable Region (TSS-centered):")
n_valid_tss = SNP_annotations['tss_ref_exp'].notna().sum()
print(f"   {n_valid_tss}/{len(SNP_annotations)} variants ({100*n_valid_tss/len(SNP_annotations):.1f}%)")

# Check 3: Gene bins in output (TSS-centered)
print(f"\n3. Gene Bins in Output Field:")
is_valid, msg = validate_bins_in_output(gene_exon_bins_tss, N_OUTPUT_BINS)
status = "✓" if is_valid else "✗"
print(f"   {status} {msg}")

# Check 4: Missing predictions
print(f"\n4. Missing Predictions:")
n_missing_tss = SNP_annotations['tss_ref_exp'].isna().sum()
n_missing_snp = SNP_annotations['snp_ref_exp'].isna().sum()
print(f"   TSS-centered: {n_missing_tss} missing")
print(f"   SNP-centered: {n_missing_snp} missing")

print(f"\n" + "="*60)