In [1]:
import hail as hl
from scipy.optimize import minimize
import numpy as np

cpus = 96
memory = int(3600*cpus/256)


tmpdir = '/local/tmp'

config = {
    'spark.driver.memory': f'{memory}g',  #Set to total memory
    'spark.executor.memory': f'{memory}g',
    'spark.local.dir': tmpdir,
    'spark.ui.enabled': 'false'
}

hl.init(spark_conf=config, master=f'local[{cpus}]', tmp_dir=tmpdir, local_tmpdir=tmpdir)

hl.plot.output_notebook()
%matplotlib inline

rg38 = hl.get_reference('GRCh38')

#!wget https://storage.googleapis.com/hail-common/references/Homo_sapiens_assembly38.fasta.fai
#!wget https://storage.googleapis.com/hail-common/references/Homo_sapiens_assembly38.fasta.gz

rg38.add_sequence('Homo_sapiens_assembly38.fasta.gz',
                            'Homo_sapiens_assembly38.fasta.fai')

#!wget https://storage.googleapis.com/hail-common/references/grch37_to_grch38.over.chain.gz
# rg37 = hl.get_reference('GRCh37') 
# rg37.add_liftover('grch37_to_grch38.over.chain.gz', rg38)

my_bucket = '/storage/zoghbi/home/u235147/merged_vars'

Running on Apache Spark version 3.5.4
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.134-952ae203dbbe
LOGGING: writing to /storage/zoghbi/home/u235147/merged_vars/hail-20251219-1435-0.2.134-952ae203dbbe.log


2025-12-19 14:40:58.940 Hail: INFO: Ordering unsorted dataset with network shuffle
2025-12-19 14:41:25.972 Hail: INFO: Coerced sorted dataset
2025-12-19 14:41:45.885 Hail: INFO: wrote table with 27778551 rows in 96 partitions to /storage/zoghbi/home/u235147/merged_vars/tmp/preds.ht
2025-12-19 14:42:17.503 Hail: INFO: Coerced sorted dataset
2025-12-19 14:42:28.445 Hail: INFO: Ordering unsorted dataset with network shuffle
2025-12-19 14:43:23.631 Hail: INFO: Coerced sorted dataset
2025-12-19 14:43:36.936 Hail: INFO: Ordering unsorted dataset with network shuffle
2025-12-19 14:45:12.650 Hail: INFO: wrote table with 33387608 rows in 40 partitions to /storage/zoghbi/home/u235147/merged_vars/tmp/merged_browser_data.ht


In [2]:
#CLINVAR DATA (Add P/LP missense variants to locus-level data, group by locus, count variants and include clinrevstat metadata as a comma separated list)
clinvar_file = 'clinvar_20251013.vcf.gz'
#!wget https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/weekly/{clinvar_file}
#!gsutil cp {clinvar_file} {my_bucket}/tmp/
recode = {str(i): f"chr{i}" for i in list(range(1, 23)) + ['X', 'Y']}
recode['MT'] = 'chrM'
clinvar_data = hl.import_vcf(f'{my_bucket}/{clinvar_file}', reference_genome='GRCh38', contig_recoding=recode, skip_invalid_loci=True, force_bgz=True).rows()
clinvar_data = clinvar_data.key_by('locus', 'alleles')
clinvar_data = clinvar_data.select(
    clinvar_status = clinvar_data.info.CLNREVSTAT,
    clinvar_label = clinvar_data.info.CLNSIG,
    clinvar_var_type = clinvar_data.info.MC.map(lambda x: x.split('|')[1])
)

# Group by locus - aggregate to count variants and combine metadata as comma-separated lists
clinvar_by_locus = clinvar_data.group_by('locus').aggregate(
    clinvar_count = hl.agg.count(),
    clinvar_status_list = hl.agg.collect(hl.delimit(clinvar_data.clinvar_status, ',')),
    clinvar_label_list = hl.agg.collect(hl.delimit(clinvar_data.clinvar_label, ',')),
    clinvar_var_type_list = hl.agg.collect(hl.delimit(clinvar_data.clinvar_var_type, ','))
)

clinvar_by_locus = clinvar_by_locus.checkpoint(f'{my_bucket}/tmp/clinvar_by_locus.ht', overwrite=True)

[Stage 1:>                                                          (0 + 1) / 1]

In [3]:
# TRAINING LABELS: (Include as locus level information, with a count for each label (unlabelled, labelled, unlabelled_high_qual, labelled_high_qual))
train_data = hl.import_table('/local/Missense_Predictor_copy/_rgc_train_vars.csv', delimiter=',')
train_data = train_data.key_by(
    locus = hl.locus('chr' + train_data.ID_38.split('-')[0], hl.int(train_data.ID_38.split('-')[1]), reference_genome='GRCh38')
)
train_data = train_data.group_by('locus').aggregate(
    train_counts = hl.struct(
        unlabelled = hl.agg.count_where(train_data.label == 'unlabelled'),
        labelled = hl.agg.count_where(train_data.label == 'labelled'),
        unlabelled_high_qual = hl.agg.count_where((train_data.label == 'unlabelled') & (train_data.High_Qual == 'True')),
        labelled_high_qual = hl.agg.count_where((train_data.label == 'labelled') & (train_data.High_Qual == 'True'))
    )
)
train_data = train_data.checkpoint(f'{my_bucket}/tmp/train_data.ht', overwrite=True)

[Stage 6:>                                                          (0 + 1) / 1]

In [3]:
#Locus level data (Collapse to max value for that locus and transcript, rename the columns to reflect this)
ht = hl.read_table('/local/Missense_Predictor_copy/Data/dbnsfp/All_missense_with_impute_mane_select_final_with_perc_with_impute_con_perc.ht')

# Find OE/VIR exome_perc columns
all_cols = list(ht.row)
oe_vir_cols = [col for col in all_cols if ('_oe_' in col or '_vir_' in col) and '_exome_perc' in col]

# Score columns to include
score_cols = ['AlphaMissense_am_pathogenicity', 'RGC_MTR.MTR', 'RGC_MTR.MTRpercentile_exome', 
              'Non_Neuro_CCR.resid_pctile', 'ESM1b_score', 'AlphaSync.plddt', 'AlphaSync.plddt10',
              'AlphaSync.relasa', 'AlphaSync.relasa10', 'AlphaMissense_am_pathogenicity_exome_perc', 
              'ESM1b_score_exome_perc', 'AlphaSync.plddt_exome_perc', 'AlphaSync.plddt10_exome_perc',
              'AlphaSync.relasa_exome_perc', 'AlphaSync.relasa10_exome_perc']

# Create locus column
ht = ht.annotate(
    locus = hl.locus(ht['locus.contig'], ht['locus.position'], reference_genome='GRCh38')
)

# Build selection dictionary with sanitized column names (replace . with _)
select_dict = {}
for col in score_cols:
    if col in all_cols:
        select_dict[col.replace('.', '_')] = ht[col]
for col in oe_vir_cols:
    select_dict[col] = ht[col]

ht_selected = ht.select(
    'locus',
    'alleles', 
    'Ensembl_transcriptid',
    **select_dict
)

# Group by locus and transcript, take max of each numeric column
agg_dict = {}
for col in select_dict.keys():
    agg_dict[f'max_{col}'] = hl.agg.max(ht_selected[col])

dbnsfp_ht = ht_selected.group_by('locus', 'Ensembl_transcriptid').aggregate(**agg_dict)

dbnsfp_ht = dbnsfp_ht.checkpoint(f'{my_bucket}/tmp/dbnsfp_scores.ht', overwrite=True)



In [None]:
code_block = """
#!/usr/bin/env python3
"""
MANE Select Domain Annotation Track Pipeline - Version 2
=========================================================

Simplified approach using UniProt REST API to fetch InterPro domains.

Output schema:
- transcript_id_ensembl: MANE Select transcript (ENST...)
- protein_id_uniprot: UniProt accession
- domain_id_interpro: InterPro ID
- domain_name: Domain name
- domain_start_aa: Amino acid start (1-based)
- domain_end_aa: Amino acid end (1-based)
- source_db: Source database (Pfam, etc.)

Author: Bioinformatics Pipeline Assistant
Date: December 2024
"""

import os
import sys
import gzip
import logging
import argparse
import time
import subprocess
import shutil
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Dict, List, Tuple, Set
from collections import defaultdict
import json

import requests
import pandas as pd
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('mane_domain_pipeline_v2.log')
    ]
)
logger = logging.getLogger(__name__)


# =============================================================================
# Configuration
# =============================================================================

@dataclass
class PipelineConfig:
    """Pipeline configuration."""
    
    data_dir: Path = field(default_factory=lambda: Path("data"))
    raw_dir: Path = field(default_factory=lambda: Path("data/raw"))
    cache_dir: Path = field(default_factory=lambda: Path("data/cache"))
    output_dir: Path = field(default_factory=lambda: Path("output"))
    
    # MANE Select
    mane_summary_url: str = "https://ftp.ncbi.nlm.nih.gov/refseq/MANE/MANE_human/current/MANE.GRCh38.v1.4.summary.txt.gz"
    
    # UniProt API
    uniprot_batch_size: int = 100
    uniprot_api_delay: float = 0.1  # Seconds between API calls
    
    # Output
    output_filename: str = "mane_domain_track_v2"
    
    def __post_init__(self):
        for d in [self.data_dir, self.raw_dir, self.cache_dir, self.output_dir]:
            d.mkdir(parents=True, exist_ok=True)


# =============================================================================
# Download Utilities
# =============================================================================

def download_file(url: str, output_path: Path, description: str = None) -> Path:
    """Download a file using aria2c or requests."""
    if output_path.exists():
        logger.info(f"File already exists: {output_path}")
        return output_path
    
    logger.info(f"Downloading: {url}")
    
    aria2c_path = shutil.which('aria2c')
    if aria2c_path:
        try:
            cmd = [
                'aria2c', '--max-connection-per-server=8', '--split=8',
                '--min-split-size=1M', '--file-allocation=none',
                '--continue=true', '--auto-file-renaming=false',
                '-d', str(output_path.parent), '-o', output_path.name, url
            ]
            result = subprocess.run(cmd, capture_output=True, text=True)
            if result.returncode == 0 and output_path.exists():
                logger.info(f"Downloaded: {output_path}")
                return output_path
        except Exception as e:
            logger.warning(f"aria2c failed: {e}, falling back to requests")
    
    response = requests.get(url, stream=True)
    response.raise_for_status()
    total_size = int(response.headers.get('content-length', 0))
    
    with open(output_path, 'wb') as f:
        with tqdm(total=total_size, unit='iB', unit_scale=True, desc=description) as pbar:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                pbar.update(len(chunk))
    
    return output_path


# =============================================================================
# MANE Select Parsing
# =============================================================================

def parse_mane_summary(summary_path: Path) -> pd.DataFrame:
    """Parse MANE summary to get transcript-protein mappings."""
    logger.info(f"Parsing MANE summary: {summary_path}")
    
    df = pd.read_csv(summary_path, sep='\t', compression='gzip')
    
    # Filter to MANE Select only
    if 'MANE_status' in df.columns:
        df = df[df['MANE_status'] == 'MANE Select'].copy()
    
    logger.info(f"Found {len(df)} MANE Select entries")
    return df


# =============================================================================
# UniProt Mapping
# =============================================================================

def map_ensembl_to_uniprot(
    ensembl_protein_ids: List[str],
    id_mapping_path: Path
) -> Dict[str, str]:
    """Map Ensembl protein IDs to UniProt accessions."""
    logger.info("Mapping Ensembl proteins to UniProt")
    
    # Build lookup set (strip versions)
    ensp_set = set()
    for ensp in ensembl_protein_ids:
        if pd.notna(ensp):
            ensp_set.add(str(ensp))
            ensp_set.add(str(ensp).split('.')[0])
    
    logger.info(f"Looking up {len(ensp_set)} Ensembl protein IDs")
    
    # Parse ID mapping file
    ensp_to_uniprot = {}
    reviewed_accs = set()
    
    with gzip.open(id_mapping_path, 'rt') as f:
        for line in tqdm(f, desc="Parsing ID mapping"):
            parts = line.strip().split('\t')
            if len(parts) != 3:
                continue
            
            uniprot_acc, id_type, id_val = parts
            
            if id_type == 'Ensembl_PRO':
                id_base = id_val.split('.')[0]
                if id_val in ensp_set or id_base in ensp_set:
                    # Store both versioned and unversioned
                    if id_val not in ensp_to_uniprot:
                        ensp_to_uniprot[id_val] = []
                    ensp_to_uniprot[id_val].append(uniprot_acc)
                    
                    if id_base not in ensp_to_uniprot:
                        ensp_to_uniprot[id_base] = []
                    if uniprot_acc not in ensp_to_uniprot[id_base]:
                        ensp_to_uniprot[id_base].append(uniprot_acc)
            
            elif id_type == 'UniProtKB-ID' and '_HUMAN' in id_val:
                reviewed_accs.add(uniprot_acc)
    
    # Select best UniProt accession for each ENSP
    final_mapping = {}
    for ensp, accs in ensp_to_uniprot.items():
        if len(accs) == 1:
            final_mapping[ensp] = accs[0]
        else:
            # Prefer reviewed
            reviewed = [a for a in accs if a in reviewed_accs]
            if reviewed:
                final_mapping[ensp] = reviewed[0]
            else:
                final_mapping[ensp] = accs[0]
    
    logger.info(f"Mapped {len(final_mapping)} Ensembl proteins to UniProt")
    return final_mapping


# =============================================================================
# UniProt API Domain Fetching
# =============================================================================

def is_family_domain(domain: Dict) -> bool:
    """
    Check if a domain is Family-level (should be filtered out).
    
    Filters out:
    - Domains with domain_type == 'Family' or 'family' (case-insensitive)
    """
    domain_type = str(domain.get('domain_type', '')).lower()
    
    # Filter Family-level domains (case-insensitive)
    if domain_type == 'family':
        return True
    
    return False


def fetch_uniprot_domains_batch(
    uniprot_accessions: List[str],
    config: PipelineConfig
) -> Dict[str, List[Dict]]:
    """
    Fetch InterPro domain annotations from UniProt REST API.
    
    Uses the UniProt API to get domain features for each protein.
    """
    logger.info(f"Fetching domains for {len(uniprot_accessions)} proteins via UniProt API")
    
    domains = {}
    
    # Process in batches
    for i in tqdm(range(0, len(uniprot_accessions), config.uniprot_batch_size), 
                  desc="Fetching from UniProt"):
        batch = uniprot_accessions[i:i + config.uniprot_batch_size]
        
        # Build query for batch
        acc_query = ' OR '.join([f'accession:{acc}' for acc in batch])
        
        url = "https://rest.uniprot.org/uniprotkb/search"
        params = {
            'query': acc_query,
            'format': 'json',
            'fields': 'accession,xref_interpro',
            'size': len(batch)
        }
        
        try:
            response = requests.get(url, params=params)
            response.raise_for_status()
            data = response.json()
            
            for entry in data.get('results', []):
                acc = entry.get('primaryAccession', '')
                
                # Extract InterPro cross-references
                xrefs = entry.get('uniProtKBCrossReferences', [])
                interpro_refs = [x for x in xrefs if x.get('database') == 'InterPro']
                
                if interpro_refs:
                    domains[acc] = []
                    for ref in interpro_refs:
                        ipr_id = ref.get('id', '')
                        props = {p.get('key'): p.get('value') for p in ref.get('properties', [])}
                        
                        domains[acc].append({
                            'interpro_id': ipr_id,
                            'domain_name': props.get('EntryName', ''),
                        })
            
            time.sleep(config.uniprot_api_delay)
            
        except requests.exceptions.RequestException as e:
            logger.warning(f"API error for batch starting at {i}: {e}")
    
    logger.info(f"Found domains for {len(domains)} proteins")
    return domains


def fetch_interpro_domains_direct(
    uniprot_accessions: List[str],
    config: PipelineConfig,
    cache_path: Optional[Path] = None
) -> Dict[str, List[Dict]]:
    """
    Fetch InterPro domains using InterPro API directly.
    
    This gives us the actual domain coordinates.
    """
    logger.info(f"Fetching domains for {len(uniprot_accessions)} proteins via InterPro API")
    
    # Check cache
    if cache_path and cache_path.exists():
        logger.info(f"Loading cached domains from {cache_path}")
        with open(cache_path, 'r') as f:
            cached = json.load(f)
        # Drop Family-level domains (Family type and PANTHER)
        cached = {
            acc: [d for d in feats if not is_family_domain(d)]
            for acc, feats in cached.items()
            if any(not is_family_domain(d) for d in feats)
        }
        logger.info(f"Filtered out Family-level domains from cache")
        return cached
    
    domains = {}
    
    for acc in tqdm(uniprot_accessions, desc="Fetching from InterPro"):
        # Use InterPro API to get protein matches
        url = f"https://www.ebi.ac.uk/interpro/api/entry/interpro/protein/uniprot/{acc}"
        
        try:
            response = requests.get(url, headers={'Accept': 'application/json'})
            
            if response.status_code == 200:
                data = response.json()
                
                domains[acc] = []
                for result in data.get('results', []):
                    metadata = result.get('metadata', {})
                    ipr_id = metadata.get('accession', '')
                    ipr_name = metadata.get('name', '')
                    ipr_type = metadata.get('type', '')
                    source_db = metadata.get('source_database', '')
                    
                    # Get protein locations
                    proteins = result.get('proteins', [])
                    for protein in proteins:
                        for location in protein.get('entry_protein_locations', []):
                            for fragment in location.get('fragments', []):
                                domains[acc].append({
                                    'interpro_id': ipr_id,
                                    'domain_name': ipr_name,
                                    'domain_type': ipr_type,
                                    'source_db': source_db,
                                    'start_aa': fragment.get('start', 0),
                                    'end_aa': fragment.get('end', 0)
                                })
            
            time.sleep(config.uniprot_api_delay)
            
        except requests.exceptions.RequestException as e:
            logger.debug(f"Error fetching {acc}: {e}")
    
    logger.info(f"Found domains for {len(domains)} proteins")
    
    # Filter out Family-level domains (Family type and PANTHER)
    filtered_domains = {
        acc: [d for d in feats if not is_family_domain(d)]
        for acc, feats in domains.items()
        if any(not is_family_domain(d) for d in feats)
    }
    
    # Cache results
    if cache_path:
        cache_path.parent.mkdir(parents=True, exist_ok=True)
        with open(cache_path, 'w') as f:
            json.dump(filtered_domains, f)
        logger.info(f"Cached filtered domains to {cache_path} (Family-level domains excluded)")
    
    logger.info(f"Filtered out Family-level domains (Family type and PANTHER)")
    return filtered_domains


# =============================================================================
# Main Pipeline
# =============================================================================

def run_pipeline(config: PipelineConfig):
    """Run the pipeline."""
    
    logger.info("=" * 70)
    logger.info("MANE Select Domain Annotation Track Pipeline v2")
    logger.info("=" * 70)
    
    # Step 1: Download MANE summary
    logger.info("\n[Step 1] Downloading MANE Select data...")
    mane_summary_path = download_file(
        config.mane_summary_url,
        config.raw_dir / "MANE.GRCh38.summary.txt.gz",
        "MANE Summary"
    )
    
    # Step 2: Parse MANE summary
    logger.info("\n[Step 2] Parsing MANE Select...")
    mane_df = parse_mane_summary(mane_summary_path)
    
    # Step 3: Map to UniProt
    logger.info("\n[Step 3] Mapping to UniProt...")
    
    # Download ID mapping if needed
    id_mapping_url = "https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/idmapping/by_organism/HUMAN_9606_idmapping.dat.gz"
    id_mapping_path = download_file(
        id_mapping_url,
        config.raw_dir / "HUMAN_9606_idmapping.dat.gz",
        "UniProt ID Mapping"
    )
    
    # Get Ensembl protein IDs from MANE
    ensembl_proteins = mane_df['Ensembl_prot'].dropna().tolist()
    
    # Map to UniProt
    ensp_to_uniprot = map_ensembl_to_uniprot(ensembl_proteins, id_mapping_path)
    
    # Add UniProt accession to MANE dataframe
    mane_df['UniProt_acc'] = mane_df['Ensembl_prot'].apply(
        lambda x: ensp_to_uniprot.get(str(x), ensp_to_uniprot.get(str(x).split('.')[0], '')) 
        if pd.notna(x) else ''
    )
    
    # Get unique UniProt accessions
    uniprot_accessions = mane_df['UniProt_acc'].dropna()
    uniprot_accessions = [a for a in uniprot_accessions if a != '']
    uniprot_accessions = list(set(uniprot_accessions))
    
    logger.info(f"Found {len(uniprot_accessions)} unique UniProt accessions")
    
    # Step 4: Fetch InterPro domains
    logger.info("\n[Step 4] Fetching InterPro domain annotations...")
    
    cache_path = config.cache_dir / "interpro_domains_cache.json"
    interpro_domains = fetch_interpro_domains_direct(
        uniprot_accessions, config, cache_path
    )
    
    # Step 5: Build output table
    logger.info("\n[Step 5] Building output table...")
    
    records = []
    
    for _, row in tqdm(mane_df.iterrows(), total=len(mane_df), desc="Building records"):
        enst = row.get('Ensembl_nuc', '')
        ensp = row.get('Ensembl_prot', '')
        uniprot = row.get('UniProt_acc', '')
        gene_symbol = row.get('symbol', '')
        
        if not uniprot or uniprot not in interpro_domains:
            continue
        
        for domain in interpro_domains[uniprot]:
            records.append({
                'transcript_id_ensembl': enst,
                'protein_id_ensembl': ensp,
                'protein_id_uniprot': uniprot,
                'gene_symbol': gene_symbol,
                'domain_id_interpro': domain.get('interpro_id', ''),
                'domain_name': domain.get('domain_name', ''),
                'domain_type': domain.get('domain_type', ''),
                'source_db': domain.get('source_db', ''),
                'domain_start_aa': domain.get('start_aa', 0),
                'domain_end_aa': domain.get('end_aa', 0)
            })
    
    df = pd.DataFrame(records)
    logger.info(f"Created {len(df)} domain annotation records")
    
    # Step 6: Write output
    logger.info("\n[Step 6] Writing Parquet output...")
    
    if len(df) > 0:
        # Sort by transcript
        df = df.sort_values(['transcript_id_ensembl', 'domain_start_aa'])
        
        # Write Parquet
        output_path = config.output_dir / f"{config.output_filename}.parquet"
        df.to_parquet(output_path, index=False, compression='snappy')
        logger.info(f"Wrote {len(df)} records to: {output_path}")
        
        # Also write TSV for easy inspection
        tsv_path = config.output_dir / f"{config.output_filename}.tsv"
        df.to_csv(tsv_path, sep='\t', index=False)
        logger.info(f"Wrote TSV to: {tsv_path}")
    else:
        logger.warning("No domain records to write!")
    
    # Summary
    logger.info("\n" + "=" * 70)
    logger.info("Pipeline Complete - Summary")
    logger.info("=" * 70)
    logger.info(f"MANE Select transcripts: {len(mane_df)}")
    logger.info(f"Transcripts with UniProt: {len(mane_df[mane_df['UniProt_acc'] != ''])}")
    logger.info(f"UniProt proteins with domains: {len(interpro_domains)}")
    logger.info(f"Total domain annotations: {len(df)}")
    if len(df) > 0:
        logger.info(f"Unique InterPro entries: {df['domain_id_interpro'].nunique()}")
        logger.info(f"Genes with domains: {df['gene_symbol'].nunique()}")
    
    return df


def main():
    parser = argparse.ArgumentParser(description="MANE Domain Track Pipeline v2")
    parser.add_argument('--data-dir', type=Path, default=Path('data'))
    parser.add_argument('--output-dir', type=Path, default=Path('output'))
    
    args = parser.parse_args()
    
    config = PipelineConfig(
        data_dir=args.data_dir,
        raw_dir=args.data_dir / 'raw',
        cache_dir=args.data_dir / 'cache',
        output_dir=args.output_dir
    )
    
    run_pipeline(config)


if __name__ == '__main__':
    main()

"""

with open("mane_domain_pipeline_v2.py", "w") as f:
    f.write(code_block)

!python mane_domain_pipeline_v2.py

In [7]:
import polars as pl

df = pl.read_parquet('/storage/zoghbi/home/u235147/merged_vars/output/mane_domain_track_v2.parquet')
df['domain_type'].value_counts()

domain_type,count
str,u32
"""Coiled_coil""",85
"""Disordered""",30
"""Homologous_superfamily""",12747
"""Domain""",13902
"""Conserved_site""",1849
"""Repeat""",901


In [8]:
import pandas as pd
from pathlib import Path

# Load MANE Select data
mane_path = Path('data/raw/MANE.GRCh38.summary.txt.gz')
df_mane = pd.read_csv(mane_path, sep='\t', compression='gzip')
df_mane = df_mane[df_mane['MANE_status'] == 'MANE Select'].copy()

# Load output with domains
output_path = Path('output/mane_domain_track_v2.parquet')
df_output = pd.read_parquet(output_path)

# Get all genes in MANE
all_mane_genes = set(df_mane['symbol'].dropna().unique())

# Get genes with domains
genes_with_domains = set(df_output['gene_symbol'].dropna().unique())

# Find genes without domains
genes_without_domains = all_mane_genes - genes_with_domains

print(f"Total MANE Select genes: {len(all_mane_genes)}")
print(f"Genes with domains: {len(genes_with_domains)}")
print(f"Genes without domains: {len(genes_without_domains)}")
print("\nFirst 10 examples of genes without domains:")
for i, gene in enumerate(sorted(genes_without_domains)[:10], 1):
    # Get transcript info for this gene
    gene_info = df_mane[df_mane['symbol'] == gene].iloc[0]
    print(f"{i}. {gene}")
    print(f"   Transcript: {gene_info.get('Ensembl_nuc', 'N/A')}")
    print(f"   RefSeq: {gene_info.get('RefSeq_nuc', 'N/A')}")
    print()

Total MANE Select genes: 19338
Genes with domains: 7410
Genes without domains: 11928

First 10 examples of genes without domains:
1. A1BG
   Transcript: ENST00000263100.8
   RefSeq: NM_130786.4

2. A1CF
   Transcript: ENST00000373997.8
   RefSeq: NM_014576.4

3. A2ML1
   Transcript: ENST00000299698.12
   RefSeq: NM_144670.6

4. AAAS
   Transcript: ENST00000209873.9
   RefSeq: NM_015665.6

5. AACS
   Transcript: ENST00000316519.11
   RefSeq: NM_023928.5

6. AADACL2
   Transcript: ENST00000356517.4
   RefSeq: NM_207365.4

7. AADAT
   Transcript: ENST00000337664.9
   RefSeq: NM_016228.4

8. AAGAB
   Transcript: ENST00000261880.10
   RefSeq: NM_024666.5

9. AAK1
   Transcript: ENST00000409085.9
   RefSeq: NM_014911.5

10. AAMDC
   Transcript: ENST00000393427.7
   RefSeq: NM_024684.4



In [2]:
#Domain information (Load from pre-generated parquet and prepare for merging by transcript)
# The domain parquet was generated by the MANE Domain Track Pipeline
spark_session = hl.utils.java.Env.spark_session()

# Read domain parquet
domain_df = spark_session.read.parquet(f'{my_bucket}/output/mane_domain_track_v2.parquet')

# Convert to Hail table
domain_ht = hl.Table.from_spark(domain_df)

# Strip version from transcript_id for joining
domain_ht = domain_ht.annotate(
    transcript_id = domain_ht.transcript_id_ensembl.split('\\.')[0]
)

# Key by transcript_id and aggregate domains into an array (multiple domains per transcript)
domain_ht = domain_ht.key_by('transcript_id')
domain_agg = domain_ht.group_by('transcript_id').aggregate(
    domains = hl.agg.collect(hl.struct(
        domain_id = domain_ht.domain_id_interpro,
        domain_name = domain_ht.domain_name,
        domain_type = domain_ht.domain_type,
        source_db = domain_ht.source_db,
        start_aa = domain_ht.domain_start_aa,
        end_aa = domain_ht.domain_end_aa
    ))
)

domain_agg = domain_agg.checkpoint(f'{my_bucket}/tmp/domains.ht', overwrite=True)

[Stage 2:>                                                          (0 + 1) / 1]

In [2]:
preds_ht = hl.read_table('/local/Missense_Predictor_copy/Results/Inference/Predictions/AOU_RGC_All_preds.ht')
preds_ht.describe()

----------------------------------------
Global fields:
    None
----------------------------------------
Row fields:
    'ID_38': str 
    'genename': str 
    'alleles': array<str> 
    'Ensembl_transcriptid': str 
    'aapos': int32 
    'aaref': str 
    'Uniprot_acc': str 
    'aaalt': str 
    'plddt_lt_60': bool 
    'Constraint_1000_General_pred': float32 
    'Constraint_1000_General_pred_std': float32 
    'Constraint_1000_General_label': float32 
    'Constraint_1000_General_n_pred': int32 
    'Core_1000_General_pred': float32 
    'Core_1000_General_pred_std': float32 
    'Core_1000_General_label': float32 
    'Core_1000_General_n_pred': int32 
    'Complete_1000_General_pred': float32 
    'Complete_1000_General_pred_std': float32 
    'Complete_1000_General_label': float32 
    'Complete_1000_General_n_pred': int32 
    'Constraint_200_rgc_zero_General_pred': float32 
    'Constraint_200_rgc_zero_General_pred_std': float32 
    'Constraint_200_rgc_zero_General_label': 

In [3]:
# CONSTRAINT PREDICTIONS (Load predictions and group by locus into tuples of (pred, n_pred))
# Source: /local/Missense_Predictor_copy/Results/Inference/Predictions/AOU_RGC_All_preds.ht
preds_ht = hl.read_table('/local/Missense_Predictor_copy/Results/Inference/Predictions/AOU_RGC_All_preds.ht')

# Create locus from ID_38 (format: chr-pos-ref-alt)
preds_ht = preds_ht.annotate(
    locus = hl.locus('chr' + preds_ht.ID_38.split('-')[0], hl.int(preds_ht.ID_38.split('-')[1]), reference_genome='GRCh38')
)

# Select the prediction columns we need
preds_ht = preds_ht.select(
    'locus',
    'alleles',
    'Ensembl_transcriptid',
    'Constraint_1000_General_pred',
    'Constraint_1000_General_n_pred',
    'Core_1000_General_pred',
    'Core_1000_General_n_pred',
    'Complete_1000_General_pred',
    'Complete_1000_General_n_pred'
)

preds_ht = preds_ht.annotate(
    Const_Core_diff_1000_General_pred = preds_ht.Constraint_1000_General_pred - preds_ht.Core_1000_General_pred
)

# Group by locus and collect tuples of (pred, n_pred) for each model
# Format: ((Constraint_pred, Constraint_n), (Core_pred, Core_n), (Complete_pred, Complete_n))
preds_ht = preds_ht.order_by('Ensembl_transcriptid', 'locus', 'alleles')
preds_agg = preds_ht.group_by('Ensembl_transcriptid', 'locus').aggregate(
    preds = hl.struct(
        Constraint = hl.agg.collect(hl.tuple([preds_ht.alleles[1], preds_ht.Constraint_1000_General_pred, preds_ht.Constraint_1000_General_n_pred])),
        Core = hl.agg.collect(hl.tuple([preds_ht.alleles[1], preds_ht.Core_1000_General_pred, preds_ht.Core_1000_General_n_pred])),
        Complete = hl.agg.collect(hl.tuple([preds_ht.alleles[1], preds_ht.Complete_1000_General_pred, preds_ht.Complete_1000_General_n_pred]))
    )
)

preds_agg = preds_agg.checkpoint(f'{my_bucket}/tmp/preds.ht', overwrite=True)



In [None]:
#This is the base table to which everything else is merged
# Load base table and rekey by (locus, transcript_id)
base_ht = hl.read_table('/storage/zoghbi/home/u235147/merged_vars/tmp/constraint_metrics_by_locus_rgc_glaf.ht')
base_ht = base_ht.key_by('locus', 'transcript_id')


gnomadV4_exomes_coverage = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV4_exomes_coverage_struct.ht')
gnomadV4_genomes_coverage = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV3_coverage_struct.ht') # Note: Path mentions v3
all_sites = hl.read_table(f'{my_bucket}/rgc_scaled.ht')

base_ht = base_ht.annotate(
    gnomad_exomes_over_20 = gnomadV4_exomes_coverage[base_ht.locus].gnomADV4_coverage.over_20,
    gnomad_genomes_over_20 = gnomadV4_genomes_coverage[base_ht.locus].gnomADV3_coverage.over_20
)

# Load preprocessed tables
clinvar_ht = hl.read_table(f'{my_bucket}/tmp/clinvar_by_locus.ht')
train_ht = hl.read_table(f'{my_bucket}/tmp/train_data.ht')
dbnsfp_ht = hl.read_table(f'{my_bucket}/tmp/dbnsfp_scores.ht')
domain_ht = hl.read_table(f'{my_bucket}/tmp/domains.ht')
preds_ht = hl.read_table(f'{my_bucket}/tmp/preds.ht')

# Merge ClinVar (by locus)
clinvar_ht = clinvar_ht.key_by('locus')
merged = base_ht.annotate(
    clinvar = clinvar_ht[base_ht.locus]
)

# Merge Training Labels (by locus)
train_ht = train_ht.key_by('locus')
merged = merged.annotate(
    training = train_ht[merged.locus]
)

# Merge dbNSFP scores (by locus, transcript)
dbnsfp_ht = dbnsfp_ht.key_by('locus', 'Ensembl_transcriptid')
merged = merged.annotate(
    dbnsfp = dbnsfp_ht[merged.locus, merged.transcript_id]
)

# Merge Domain information (by transcript)
domain_ht = domain_ht.key_by('transcript_id')
merged = merged.annotate(
    domains = domain_ht[merged.transcript_id].domains
)

# Merge Constraint Predictions (by locus)
# Format: array of ((Constraint_pred, Constraint_n), (Core_pred, Core_n), (Complete_pred, Complete_n))
preds_ht = preds_ht.key_by('Ensembl_transcriptid', 'locus')
merged = merged.annotate(
    preds = preds_ht[merged.transcript_id, merged.locus].preds
)

merged = merged.annotate(
   **{col: merged.preds[col] for col in list(merged.row.preds)}
)

# Checkpoint the merged table
merged = merged.checkpoint(f'{my_bucket}/tmp/merged_browser_data.ht', overwrite=True)
print(f"Merged table row count: {merged.count()}")



Merged table row count: 33387608


In [None]:
# Export merged table to parquet - CHR2 ONLY for faster local iteration
merged_ht = hl.read_table(f'{my_bucket}/tmp/merged_browser_data.ht')

# Filter to chr2 only
CHR_FILTER = 'chr2'  # Set to None for all chromosomes
if CHR_FILTER:
    print(f"Filtering to {CHR_FILTER}...")
    merged_ht = merged_ht.filter(merged_ht.locus.contig == CHR_FILTER)
    print(f"Rows after filter: {merged_ht.count()}")
    output_file = f'{my_bucket}/rgc_browser_data_{CHR_FILTER}.parquet'
else:
    output_file = f'{my_bucket}/rgc_browser_data_merged.parquet'

# Repartition to single partition and export to parquet
merged_repartitioned = merged_ht.repartition(1)

# Convert to Spark DataFrame and export
spark_df = merged_repartitioned.to_spark()
spark_df.write.mode('overwrite').parquet(output_file)

print(f"Exported to: {output_file}")

In [8]:
base_ht = hl.read_table('/storage/zoghbi/home/u235147/merged_vars/tmp/constraint_metrics_by_locus_rgc_glaf.ht')
base_ht.describe()


----------------------------------------
Global fields:
    None
----------------------------------------
Row fields:
    'region': str 
    'locus': locus<GRCh38> 
    'rgc_any_prob_mu_exomes_XX_XY': float64 
    'rgc_any_obs_exomes_XX_XY': int64 
    'rgc_any_count': int64 
    'rgc_any_max_af': float64 
    'rgc_syn_prob_mu_exomes_XX_XY': float64 
    'rgc_syn_obs_exomes_XX_XY': int64 
    'rgc_syn_max_af': float64 
    'rgc_syn_count': int64 
    'rgc_mis_prob_mu_exomes_XX_XY': float64 
    'rgc_mis_obs_exomes_XX_XY': int64 
    'rgc_mis_max_af': float64 
    'rgc_mis_count': int64 
    'rgc_stop_gained_prob_mu_exomes_XX_XY': float64 
    'rgc_stop_gained_obs_exomes_XX_XY': int64 
    'rgc_stop_gained_max_af': float64 
    'rgc_stop_gained_count': int64 
    'aa_pos': float64 
    'rgc_mis_exomes_XX_XY_vir_length_af0epos00': int32 
    'rgc_mis_exomes_XX_XY_vir_mu_exp_af0epos00': float64 
    'rgc_mis_exomes_XX_XY_vir_depth_af0epos00': float64 
    'rgc_mis_exomes_XX_XY_mean_vir_ex

In [11]:
!cd gosling_mvp && python preprocess_mis_all.py 

COMPRESSED GENOME VIEWER - DATA PREPROCESSING
Generating axis tables for all filter modes

Loading data from ../rgc_browser_data_merged.parquet...
✓ Loaded 33,387,608 total positions
  Columns: 278

Detected columns:
  Chromosome: chrom
  Position: pos
  Gene: HGNC

Chromosome distribution (top 10):
  chr1: 3,414,807
  chr2: 2,477,732
  chr19: 2,149,560
  chr11: 1,980,678
  chr3: 1,921,014
  chr17: 1,918,143
  chr12: 1,732,842
  chr6: 1,690,167
  chr7: 1,565,442
  chr5: 1,562,490

Processing filter: any_count_gt0
  All positions where any variant is possible

Applying filter...
✓ Kept 33,387,608 positions
Sorting by chromosome and position...
Generating compressed coordinates...

Column breakdown:
  - Core (idx, chrom, pos, gene): 4
  - RGC raw metrics: 142
  - ClinVar: 4
  - Training labels: 4
  - dbNSFP scores: 9
  - Constraint predictions: 3
  - Percentiles: 105
  - Domains: 1
  Total columns: 273

Saving axis table...
✓ Saved: data/any_count_gt0.parquet
  Size: 15434.78 MB

Generat