In [1]:
import hail as hl
from scipy.optimize import minimize
import numpy as np

cpus = 128
memory = int(3600*cpus/256)


tmpdir = '/local/tmp'

config = {
    'spark.driver.memory': f'{memory}g',  #Set to total memory
    'spark.executor.memory': f'{memory}g',
    'spark.local.dir': tmpdir,
    'spark.ui.enabled': 'false'
}

hl.init(spark_conf=config, master=f'local[{cpus}]', tmp_dir=tmpdir, local_tmpdir=tmpdir)

hl.plot.output_notebook()
%matplotlib inline

my_bucket = '/storage/zoghbi/home/u235147/merged_vars'

rg38 = hl.get_reference('GRCh38')

#!wget -P {my_bucket} https://storage.googleapis.com/hail-common/references/Homo_sapiens_assembly38.fasta.fai
#!wget -P {my_bucket} https://storage.googleapis.com/hail-common/references/Homo_sapiens_assembly38.fasta.gz

rg38.add_sequence(f'{my_bucket}/Homo_sapiens_assembly38.fasta.gz',
                  f'{my_bucket}/Homo_sapiens_assembly38.fasta.fai')

#!wget -P {my_bucket} https://storage.googleapis.com/hail-common/references/grch37_to_grch38.over.chain.gz
# rg37 = hl.get_reference('GRCh37') 
# rg37.add_liftover(f'{my_bucket}/grch37_to_grch38.over.chain.gz', rg38)

Running on Apache Spark version 3.5.4
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.134-952ae203dbbe
LOGGING: writing to /storage/zoghbi/home/u235147/VarPredBrowser/notebooks/hail-20260110-1722-0.2.134-952ae203dbbe.log


In [4]:
#CLINVAR DATA (Add P/LP missense variants to locus-level data, group by locus, count variants and include clinrevstat metadata as a comma separated list)
clinvar_file = 'clinvar_20251013.vcf.gz'
#!wget -P {my_bucket} https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/weekly/{clinvar_file}
recode = {str(i): f"chr{i}" for i in list(range(1, 23)) + ['X', 'Y']}
recode['MT'] = 'chrM'
clinvar_data = hl.import_vcf(f'{my_bucket}/{clinvar_file}', reference_genome='GRCh38', contig_recoding=recode, skip_invalid_loci=True, force_bgz=True).rows()

clinvar_data = clinvar_data.filter(hl.len(clinvar_data.alleles) > 1)

clinvar_by_locus = clinvar_data.group_by('locus').aggregate(
    clinvar_variants = hl.agg.collect(hl.struct(
        ref = clinvar_data.alleles[0],
        alt = clinvar_data.alleles[1],
        significance = hl.delimit(clinvar_data.info.CLNSIG, ','),
        status = hl.delimit(clinvar_data.info.CLNREVSTAT, ','),
        mol_csq = hl.if_else(
            hl.is_defined(clinvar_data.info.MC) & (hl.len(clinvar_data.info.MC) > 0),
            clinvar_data.info.MC[0].split('|')[1],
            hl.missing(hl.tstr)
        ),
        variation_id = clinvar_data.rsid
    ))
  )

clinvar_by_locus = clinvar_by_locus.checkpoint(f'{tmpdir}/clinvar_by_locus.ht', overwrite=True)

[Stage 2:>                                                          (0 + 1) / 1]

In [None]:
# TRAINING LABELS: (Include as locus level information, with a count for each label (unlabelled, labelled, unlabelled_high_qual, labelled_high_qual))
train_data = hl.import_table('/local/Missense_Predictor_copy/_rgc_train_vars.csv', delimiter=',')
train_data = train_data.key_by(
    locus = hl.locus('chr' + train_data.ID_38.split('-')[0], hl.int(train_data.ID_38.split('-')[1]), reference_genome='GRCh38')
)
train_data = train_data.group_by('locus').aggregate(
    train_counts = hl.struct(
        unlabelled = hl.agg.count_where(train_data.label == 'unlabelled'),
        labelled = hl.agg.count_where(train_data.label == 'labelled'),
        unlabelled_high_qual = hl.agg.count_where((train_data.label == 'unlabelled') & (train_data.High_Qual == 'True')),
        labelled_high_qual = hl.agg.count_where((train_data.label == 'labelled') & (train_data.High_Qual == 'True'))
    )
)
train_data = train_data.checkpoint(f'{tmpdir}/train_data.ht', overwrite=True)

[Stage 6:>                                                          (0 + 1) / 1]

In [8]:
# DBNSFP SCORES - Collapse to max value per (locus, transcript)
# Only include core dbNSFP scores; RGC/gnomAD constraint metrics come from their tables
# Percentiles are calculated in Polars post-processing

ht = hl.read_table('/local/Missense_Predictor_copy/Data/dbnsfp/All_missense_with_impute_mane_select_final_with_perc_with_impute_con_perc.ht')

# Core dbNSFP score columns (no percentiles - those are calculated in Polars)
score_cols = [
    'AlphaMissense_am_pathogenicity',
    'ESM1b_score',
    'RGC_MTR.MTR',
    'Non_Neuro_CCR.resid_pctile',
    'AlphaSync.plddt',
    'AlphaSync.plddt10',
    'AlphaSync.relasa',
    'AlphaSync.relasa10',
]

# Create locus column
ht = ht.annotate(
    locus = hl.locus(ht['locus.contig'], ht['locus.position'], reference_genome='GRCh38')
)

# Build selection dictionary with sanitized column names (replace . with _)
all_cols = list(ht.row)
select_dict = {}
for col in score_cols:
    if col in all_cols:
        select_dict[col.replace('.', '_')] = ht[col]

ht_selected = ht.select(
    'locus',
    'alleles', 
    'Ensembl_transcriptid',
    **select_dict
)

# Group by locus and transcript, take max of each numeric column
agg_dict = {f'max_{col}': hl.agg.max(ht_selected[col]) for col in select_dict.keys()}
dbnsfp_ht = ht_selected.group_by('locus', 'Ensembl_transcriptid').aggregate(**agg_dict)

dbnsfp_ht = dbnsfp_ht.checkpoint(f'{tmpdir}/dbnsfp_scores.ht', overwrite=True)
print(f"dbNSFP scores table: {dbnsfp_ht.count()} rows, {len(list(dbnsfp_ht.row))} columns")



dbNSFP scores table: 27778551 rows, 10 columns


In [None]:
# DBNSFP STACKED SCORES - Collect variant-level scores with named struct fields
# This complements the max aggregation above by preserving individual variant scores
# REFACTORED: Using hl.struct() instead of hl.tuple() for named field access

# Load the raw dbNSFP data
ht_stacked = hl.read_table('/local/Missense_Predictor_copy/Data/dbnsfp/All_missense_with_impute_mane_select_final_with_perc_with_impute_con_perc.ht')

# Create locus column
ht_stacked = ht_stacked.annotate(
    locus = hl.locus(ht_stacked['locus.contig'], ht_stacked['locus.position'], reference_genome='GRCh38')
)

# Select only the columns needed for stacked display
ht_stacked = ht_stacked.select(
    'locus',
    'alleles',
    'Ensembl_transcriptid',
    'AlphaMissense_am_pathogenicity',
    'AlphaMissense_am_pathogenicity_exome_perc',
    'ESM1b_score',
    'ESM1b_score_exome_perc'
)

# Group by (locus, transcript) and collect structs with named fields: {alt, score, percentile}
# REFACTORED: Changed from hl.tuple() to hl.struct() for better schema and access patterns
dbnsfp_stacked_ht = ht_stacked.group_by('locus', 'Ensembl_transcriptid').aggregate(
    dbnsfp_stacked = hl.struct(
        AlphaMissense = hl.agg.collect(hl.struct(
            alt=ht_stacked.alleles[1],  # Alt allele
            score=ht_stacked.AlphaMissense_am_pathogenicity,
            percentile=ht_stacked.AlphaMissense_am_pathogenicity_exome_perc
        )),
        ESM1b = hl.agg.collect(hl.struct(
            alt=ht_stacked.alleles[1],  # Alt allele
            score=ht_stacked.ESM1b_score,
            percentile=ht_stacked.ESM1b_score_exome_perc
        ))
    )
)

dbnsfp_stacked_ht = dbnsfp_stacked_ht.checkpoint(f'{tmpdir}/dbnsfp_stacked.ht', overwrite=True)
print(f"dbNSFP stacked table row count: {dbnsfp_stacked_ht.count()}")



dbNSFP stacked table row count: 27778551


In [6]:
import polars
import pandas as pd
from pathlib import Path

# Load MANE Select data
mane_path = Path(f'{my_bucket}/data/raw/MANE.GRCh38.summary.txt.gz')
df_mane = pd.read_csv(mane_path, sep='\t', compression='gzip')
df_mane = df_mane[df_mane['MANE_status'] == 'MANE Select'].copy()

# Load output with domains
output_path = Path(f'{my_bucket}/output/mane_domain_track_v2.parquet')
df_output = pd.read_parquet(output_path)

# Get all genes in MANE
all_mane_genes = set(df_mane['symbol'].dropna().unique())

# Get genes with domains
genes_with_domains = set(df_output['gene_symbol'].dropna().unique())

# Find genes without domains
genes_without_domains = all_mane_genes - genes_with_domains

print(f"Total MANE Select genes: {len(all_mane_genes)}")
print(f"Genes with domains: {len(genes_with_domains)}")
print(f"Genes without domains: {len(genes_without_domains)}")
print("\nFirst 10 examples of genes without domains:")
for i, gene in enumerate(sorted(genes_without_domains)[:10], 1):
    # Get transcript info for this gene
    gene_info = df_mane[df_mane['symbol'] == gene].iloc[0]
    print(f"{i}. {gene}")
    print(f"   Transcript: {gene_info.get('Ensembl_nuc', 'N/A')}")
    print(f"   RefSeq: {gene_info.get('RefSeq_nuc', 'N/A')}")
    print()

Total MANE Select genes: 19338
Genes with domains: 7411
Genes without domains: 11927

First 10 examples of genes without domains:
1. A1BG
   Transcript: ENST00000263100.8
   RefSeq: NM_130786.4

2. A1CF
   Transcript: ENST00000373997.8
   RefSeq: NM_014576.4

3. A2ML1
   Transcript: ENST00000299698.12
   RefSeq: NM_144670.6

4. AAAS
   Transcript: ENST00000209873.9
   RefSeq: NM_015665.6

5. AACS
   Transcript: ENST00000316519.11
   RefSeq: NM_023928.5

6. AADACL2
   Transcript: ENST00000356517.4
   RefSeq: NM_207365.4

7. AADAT
   Transcript: ENST00000337664.9
   RefSeq: NM_016228.4

8. AAGAB
   Transcript: ENST00000261880.10
   RefSeq: NM_024666.5

9. AAK1
   Transcript: ENST00000409085.9
   RefSeq: NM_014911.5

10. AAMDC
   Transcript: ENST00000393427.7
   RefSeq: NM_024684.4



In [None]:
# DOMAIN INFORMATION (Load from interpro_representative_domains.json)
# Uses API-generated representative domains from fetch_interpro_domains.py
import json
import polars as pl

# Get Spark session from Hail
spark_session = hl.utils.java.Env.spark_session()

# Load transcript -> UniProt mapping from existing parquet
mapping_df = pl.read_parquet(f'{my_bucket}/output/mane_domain_track_v2.parquet').select([
    'transcript_id_ensembl', 'protein_id_uniprot'
]).unique()
print(f"Loaded {len(mapping_df)} transcript-to-UniProt mappings")

# Load representative domains from JSON cache
with open('/storage/zoghbi/home/u235147/VarPredBrowser/data/cache/interpro_representative_domains.json', 'r') as f:
    rep_domains = json.load(f)
print(f"Loaded representative domains for {len(rep_domains)} proteins")

# Build domain records with transcript_id
domain_records = []
for row in mapping_df.iter_rows(named=True):
    transcript_id = row['transcript_id_ensembl'].split('.')[0]  # Strip version
    uniprot_id = row['protein_id_uniprot']
    
    if uniprot_id in rep_domains:
        for d in rep_domains[uniprot_id]:
            domain_records.append({
                'transcript_id': transcript_id,
                'domain_id': d.get('interpro_id') or d.get('member_db_id', ''),
                'domain_name': d.get('domain_name', ''),
                'domain_type': d.get('domain_type', ''),
                'source_db': d.get('source_db', ''),
                'start_aa': d.get('start_aa', 0),
                'end_aa': d.get('end_aa', 0),
            })

print(f"Created {len(domain_records)} domain records")

# Convert to Polars DataFrame and then to Spark DataFrame for Hail
domain_df = pl.DataFrame(domain_records)
spark_df = spark_session.createDataFrame(domain_df.to_pandas())

# Convert to Hail table
domain_ht = hl.Table.from_spark(spark_df)

# Key by transcript_id and aggregate domains into an array
domain_ht = domain_ht.key_by('transcript_id')
domain_agg = domain_ht.group_by('transcript_id').aggregate(
    domains = hl.agg.collect(hl.struct(
        domain_id = domain_ht.domain_id,
        domain_name = domain_ht.domain_name,
        domain_type = domain_ht.domain_type,
        source_db = domain_ht.source_db,
        start_aa = domain_ht.start_aa,
        end_aa = domain_ht.end_aa
    ))
)

domain_agg = domain_agg.checkpoint(f'{tmpdir}/domains.ht', overwrite=True)
print(f"Domain aggregation table row count: {domain_agg.count()}")

Loaded 7411 transcript-to-UniProt mappings
Loaded representative domains for 16255 proteins
Created 20913 domain records




Domain aggregation table row count: 7217


In [None]:
# CONSTRAINT PREDICTIONS (Load predictions and group by locus into structs with named fields)
# Source: /local/Missense_Predictor_copy/Results/Inference/Predictions/AOU_RGC_All_preds.ht
# REFACTORED: Using hl.struct() instead of hl.tuple() for named field access

preds_ht = hl.read_table('/local/Missense_Predictor_copy/Results/Inference/Predictions/AOU_RGC_All_preds.ht')

# Create locus from ID_38 (format: chr-pos-ref-alt)
preds_ht = preds_ht.annotate(
    locus = hl.locus('chr' + preds_ht.ID_38.split('-')[0], hl.int(preds_ht.ID_38.split('-')[1]), reference_genome='GRCh38')
)

# Select the prediction columns we need
preds_ht = preds_ht.select(
    'locus',
    'alleles',
    'Ensembl_transcriptid',
    'Constraint_200_otu_No_RGC_aapos_AON4_pred',
    'Constraint_200_otu_No_RGC_aapos_AON4_n_pred',
    'Core_200_otu_No_RGC_aapos_AON4_pred',
    'Core_200_otu_No_RGC_aapos_AON4_n_pred',
    'Complete_200_otu_No_RGC_aapos_AON4_pred',
    'Complete_200_otu_No_RGC_aapos_AON4_n_pred'
)

preds_ht = preds_ht.annotate(
    Const_Core_diff_200_otu_No_RGC_aapos_AON4_pred = preds_ht.Constraint_200_otu_No_RGC_aapos_AON4_pred - preds_ht.Core_200_otu_No_RGC_aapos_AON4_pred
)

# Group by locus and collect structs with named fields: {alt, pred, n_pred} for each model
# REFACTORED: Changed from hl.tuple() to hl.struct() for better schema and access patterns
preds_ht = preds_ht.order_by('Ensembl_transcriptid', 'locus', 'alleles')
preds_agg = preds_ht.group_by('Ensembl_transcriptid', 'locus').aggregate(
    preds = hl.struct(
        Constraint = hl.agg.collect(hl.struct(
            alt=preds_ht.alleles[1],
            pred=preds_ht.Constraint_200_otu_No_RGC_aapos_AON4_pred,
            n_pred=preds_ht.Constraint_200_otu_No_RGC_aapos_AON4_n_pred
        )),
        Core = hl.agg.collect(hl.struct(
            alt=preds_ht.alleles[1],
            pred=preds_ht.Core_200_otu_No_RGC_aapos_AON4_pred,
            n_pred=preds_ht.Core_200_otu_No_RGC_aapos_AON4_n_pred
        )),
        Complete = hl.agg.collect(hl.struct(
            alt=preds_ht.alleles[1],
            pred=preds_ht.Complete_200_otu_No_RGC_aapos_AON4_pred,
            n_pred=preds_ht.Complete_200_otu_No_RGC_aapos_AON4_n_pred
        ))
    )
)

preds_agg = preds_agg.checkpoint(f'{tmpdir}/preds.ht', overwrite=True)



In [None]:
# VARIANT CONSEQUENCES (Extract from rgc_scaled.ht and group by locus)
# Source: /storage/zoghbi/home/u235147/merged_vars/rgc_scaled.ht
# Field: most_deleterious_consequence_cds
# REFACTORED: Using hl.struct() instead of hl.tuple() for named field access

# Define consequence categorization function
def categorize_consequence(csq):
    """
    Map VEP consequence terms to simplified categories for filtering.
    Categories: plof, missense, synonymous, other
    """
    plof = {'frameshift_variant', 'stop_gained', 'splice_donor_variant',
            'splice_acceptor_variant', 'start_lost', 'stop_lost'}
    missense = {'missense_variant', 'inframe_insertion', 'inframe_deletion',
                'protein_altering_variant'}
    synonymous = {'synonymous_variant'}
    
    return (hl.case()
        .when(hl.set(plof).contains(csq), 'plof')
        .when(hl.set(missense).contains(csq), 'missense')
        .when(hl.set(synonymous).contains(csq), 'synonymous')
        .default('other'))

# Load rgc_scaled.ht which contains consequence annotations
print("Loading rgc_scaled.ht for variant consequences...")
rgc_scaled = hl.read_table(f'{my_bucket}/rgc_scaled.ht')

# Select only the fields we need: locus, alleles, region (for transcript), and consequence
# Filter to coding variants (where consequence is defined)
csq_ht = rgc_scaled.filter(hl.is_defined(rgc_scaled.most_deleterious_consequence_cds))
csq_ht = csq_ht.select(
    'region',  # Contains transcript-position info
    csq_raw = csq_ht.most_deleterious_consequence_cds
)

# Add categorized consequence
csq_ht = csq_ht.annotate(
    csq_category = categorize_consequence(csq_ht.csq_raw)
)

# Extract transcript ID from region (format: ENST00000123456-100)
csq_ht = csq_ht.annotate(
    transcript_id = csq_ht.region.split('-')[0]
)

# Group by locus and transcript, collect structs with named fields: {alt, csq}
# REFACTORED: Changed from hl.tuple() to hl.struct() for better schema and access patterns
csq_agg = csq_ht.group_by('transcript_id', 'locus').aggregate(
    variant_consequences = hl.agg.collect(hl.struct(
        alt=csq_ht.alleles[1],  # Alternate allele
        csq=csq_ht.csq_category
    ))
)

# Checkpoint for efficiency
csq_agg = csq_agg.checkpoint(f'{tmpdir}/tmp/variant_consequences.ht', overwrite=True)
print(f"Consequence table row count: {csq_agg.count()}")

# Show distribution of consequence categories
print("\nConsequence category distribution:")
csq_ht.group_by('csq_category').aggregate(n=hl.agg.count()).show()

Loading rgc_scaled.ht for variant consequences...




Consequence table row count: 33271388

Consequence category distribution:




csq_category,n
str,int64
"""missense""",72752726
"""plof""",4119703
"""synonymous""",22941731


In [None]:
# VEP AA_POS (Add accurate amino acid position from VEP annotations)
# VEP protein_start encodes correct position accounting for strand direction

# Load VEP table
vep_ht = hl.read_table('/storage/zoghbi/home/u235147/merged_vars/vep.ht')

vep_ht = vep_ht.annotate(
    transcript_consequences = vep_ht.vep.transcript_consequences,
    transcript_id = vep_ht.regions.split('-')[0]
)

vep_ht = vep_ht.annotate(
    transcript_consequences = vep_ht.transcript_consequences.filter(lambda x: x.transcript_id.split('\\.')[0] == vep_ht.transcript_id)
)


# Explode transcript_consequences to get one row per transcript
vep_exploded = vep_ht.explode(vep_ht.transcript_consequences)

# Key by (locus, transcript_id) for joining wth base table
vep_exploded = vep_exploded.key_by(
    locus = hl.locus(vep_exploded.locus.contig, vep_exploded.locus.position, reference_genome='GRCh38'),
    transcript_id = vep_exploded.transcript_id
)


# Select relevant fields and key by (locus, alleles, transcript_id)
vep_aa_pos = vep_exploded.select(
    aa_pos_vep = vep_exploded.transcript_consequences.protein_start
)


# Checkpoint for performance
vep_aa_pos = vep_aa_pos.checkpoint(f'{tmpdir}/vep_aa_pos.ht', overwrite=True)
print(f"VEP aa_pos table row count: {vep_aa_pos.count()}")

[Stage 0:>                                                        (0 + 96) / 96]

In [2]:
# ALLELE FREQUENCIES (Section 3 of browser-data-refactor.md)
# Collect AF/AC/AN per variant for each cohort with filter status
# Data sources:
# - RGC: /storage/zoghbi/data/sharing/hail_tables/no_anno/rgc.ht
# - gnomAD Exomes: gnomadV4exomes_snvs_sex.ht
# - gnomAD Genomes: gnomadV4genomes_snvs.ht

# Load data sources
print("Loading allele frequency tables...")
rgc_ht = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/no_anno/rgc.ht')
gnomad_exomes_ht = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV4_exomes/gnomadV4exomes_snvs_sex.ht')
gnomad_genomes_ht = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV4_genomes/gnomadV4genomes_snvs.ht')

# Key tables by locus, alleles for joining
rgc_ht = rgc_ht.key_by('locus', 'alleles')
gnomad_exomes_ht = gnomad_exomes_ht.key_by('locus', 'alleles')
gnomad_genomes_ht = gnomad_genomes_ht.key_by('locus', 'alleles')

print("Tables loaded and keyed by (locus, alleles)")

# Describe tables for reference
print("\n=== RGC Table Schema ===")
rgc_ht.describe()
print("\n=== gnomAD Exomes Table Schema ===")
gnomad_exomes_ht.describe()
print("\n=== gnomAD Genomes Table Schema ===")
gnomad_genomes_ht.describe()

Loading allele frequency tables...
Tables loaded and keyed by (locus, alleles)

=== RGC Table Schema ===
----------------------------------------
Global fields:
    None
----------------------------------------
Row fields:
    'locus': locus<GRCh38> 
    'alleles': array<str> 
    'rsid': str 
    'qual': float64 
    'filters': set<str> 
    'info': struct {
        AFR_AC: float64, 
        AFR_AF: float64, 
        AFR_AN: float64, 
        ALL_AC: int32, 
        ALL_AF: float64, 
        ALL_AN: int32, 
        AMI_AC: int32, 
        AMI_AF: float64, 
        AMI_AN: int32, 
        ASH_AC: float64, 
        ASH_AF: float64, 
        ASH_AN: float64, 
        BI_AC: float64, 
        BI_AF: float64, 
        BI_AN: float64, 
        C_EUR_AC: float64, 
        C_EUR_AF: float64, 
        C_EUR_AN: float64, 
        EAS_AC: float64, 
        EAS_AF: float64, 
        EAS_AN: float64, 
        EUR_AC: float64, 
        EUR_AF: float64, 
        EUR_AN: float64, 
        E_AFR_AC: f

In [None]:
#This is the base table to which everything else is merged
# Load base table - key is ['locus', 'region']
base_ht = hl.read_table('/storage/zoghbi/home/u235147/merged_vars/tmp/constraint_metrics_by_locus_rgc_glaf.ht')

# Check key fields - these are automatically included and CANNOT be in select()
key_fields = list(base_ht.key)
print(f"Base table key fields (auto-included): {key_fields}")

# Core row columns - EXCLUDING key fields (locus, region)
# transcript_id is a row field, not a key field
core_cols = ['HGNC', 'chrom', 'pos', 'aa_pos', 'transcript_id']

# RGC oe/vir columns (excluding any key fields)
rgc_oe_vir_cols = [col for col in base_ht.row if ('oe' in col or 'vir' in col) and col not in key_fields]

# RGC count/obs columns (excluding any key fields) - EXCLUDING max_af columns (deprecated)
rgc_count_cols = [col for col in base_ht.row
                  if col.startswith('rgc_')
                  and ('_count' in col or '_obs_' in col or '_prob_mu' in col)
                  and '_max_af' not in col  # Exclude deprecated max_af columns
                  and col not in key_fields]

cols_to_keep = core_cols + rgc_oe_vir_cols + rgc_count_cols
print(f"Keeping {len(cols_to_keep)} columns from base table (+ key fields):")
print(f"  - Core: {len(core_cols)}")
print(f"  - RGC oe/vir: {len(rgc_oe_vir_cols)}")
print(f"  - RGC count/obs: {len(rgc_count_cols)}")

base_ht = base_ht.select(*cols_to_keep)
base_ht = base_ht.key_by('locus', 'transcript_id')

# Load join target tables
gnomadV4_exomes_coverage = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV4_exomes_coverage_struct.ht')
gnomadV4_genomes_coverage = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/gnomadV3_coverage_struct.ht')
phylop_ht = hl.read_table('/storage/zoghbi/data/sharing/hail_tables/phyloPscores_hg38_final.ht')

# GNOMAD CONSTRAINT METRICS (Section 6 of browser-data-refactor.md)
# NOTE: This table has no key, must key by locus before joining
gnomad_constraint_ht = hl.read_table(f'{my_bucket}/constraint_metrics_by_locus.ht')
gnomad_constraint_ht = gnomad_constraint_ht.key_by('locus')  # Key by locus for joining

# Dynamically get gnomAD constraint columns containing 'oe' or 'vir'
# NOTE: Source table already has gnomad_ prefix, so we copy columns as-is
gnomad_constraint_cols = [col for col in gnomad_constraint_ht.row if 'oe' in col or 'vir' in col]
print(f"Found {len(gnomad_constraint_cols)} gnomAD constraint columns (oe/vir)")

# OPTIMIZED JOINS: Join once to get struct, then extract fields
base_ht = base_ht.annotate(
    _exome_cov = gnomadV4_exomes_coverage[base_ht.locus],
    _genome_cov = gnomadV4_genomes_coverage[base_ht.locus],
    _phylop = phylop_ht[base_ht.locus],
    _gnomad_constraint = gnomad_constraint_ht[base_ht.locus],
)

# GNOMAD COVERAGE - All thresholds (Task 2 of browser-data-refactor.md)
# Extract 12 total coverage columns: 6 thresholds x 2 cohorts
coverage_thresholds = [10, 15, 20, 25, 30, 50]
base_ht = base_ht.annotate(
    # Exome coverage (gnomAD v4) - 6 thresholds
    **{f'gnomad_exomes_over_{t}': base_ht._exome_cov.gnomADV4_coverage[f'over_{t}'] for t in coverage_thresholds},
    # Genome coverage (gnomAD v3) - 6 thresholds  
    **{f'gnomad_genomes_over_{t}': hl.float64(base_ht._genome_cov.gnomADV3_coverage[f'over_{t}']) for t in coverage_thresholds},
    # PhyloP conservation scores (Task 6)
    phylop_scores_447way = base_ht._phylop.phylop_scores['447way'],
    phylop_scores_100way = base_ht._phylop.phylop_scores['100way'],
)

# Dynamically annotate gnomAD constraint columns (oe/vir)
# Source table already has gnomad_ prefix, so copy columns as-is (no additional prefix)
base_ht = base_ht.annotate(**{
    col: base_ht._gnomad_constraint[col] for col in gnomad_constraint_cols
})

# Drop temporary join structs
base_ht = base_ht.drop('_exome_cov', '_genome_cov', '_phylop', '_gnomad_constraint')

# ALLELE FREQUENCIES (Task 4 of browser-data-refactor.md)
# Base table is locus-keyed, AF tables are (locus, alleles)-keyed
# Must aggregate variant-level data by locus into array of structs
print("Aggregating allele frequencies by locus...")

# Group RGC by locus and collect variant-level data
rgc_by_locus = rgc_ht.group_by('locus').aggregate(
    rgc_variants = hl.agg.collect(hl.struct(
        alt = rgc_ht.alleles[1],
        af = rgc_ht.info.ALL_AF,
        ac = rgc_ht.info.ALL_AC,
        an = rgc_ht.info.ALL_AN,
        filters = rgc_ht.filters
    ))
)

# Group gnomAD exomes by locus
gnomad_exomes_by_locus = gnomad_exomes_ht.group_by('locus').aggregate(
    gnomad_exomes_variants = hl.agg.collect(hl.struct(
        alt = gnomad_exomes_ht.alleles[1],
        ac = gnomad_exomes_ht.gnomadV4_exomes.AC,
        an = gnomad_exomes_ht.gnomadV4_exomes.AN,
        af = hl.if_else(
            gnomad_exomes_ht.gnomadV4_exomes.AN > 0,
            gnomad_exomes_ht.gnomadV4_exomes.AC / gnomad_exomes_ht.gnomadV4_exomes.AN,
            hl.missing(hl.tfloat64)
        ),
        filters = gnomad_exomes_ht.gnomadV4_exomes.filters
    ))
)

# Group gnomAD genomes by locus
gnomad_genomes_by_locus = gnomad_genomes_ht.group_by('locus').aggregate(
    gnomad_genomes_variants = hl.agg.collect(hl.struct(
        alt = gnomad_genomes_ht.alleles[1],
        ac = gnomad_genomes_ht.gnomadV4_genomes.AC,
        an = gnomad_genomes_ht.gnomadV4_genomes.AN,
        af = hl.if_else(
            gnomad_genomes_ht.gnomadV4_genomes.AN > 0,
            gnomad_genomes_ht.gnomadV4_genomes.AC / gnomad_genomes_ht.gnomadV4_genomes.AN,
            hl.missing(hl.tfloat64)
        ),
        filters = gnomad_genomes_ht.gnomadV4_genomes.filters
    ))
)

# Join aggregated AF data to base table
base_ht = base_ht.annotate(
    rgc_variants = rgc_by_locus[base_ht.locus].rgc_variants,
    gnomad_exomes_variants = gnomad_exomes_by_locus[base_ht.locus].gnomad_exomes_variants,
    gnomad_genomes_variants = gnomad_genomes_by_locus[base_ht.locus].gnomad_genomes_variants,
)

base_ht = base_ht.checkpoint(f'{tmpdir}/base_with_coverage_and_af.ht', overwrite=True)


# Load preprocessed tables
clinvar_ht = hl.read_table(f'{tmpdir}/clinvar_by_locus.ht')
train_ht = hl.read_table(f'{tmpdir}/train_data.ht')
dbnsfp_ht = hl.read_table(f'{tmpdir}/dbnsfp_scores.ht')
dbnsfp_stacked_ht = hl.read_table(f'{tmpdir}/dbnsfp_stacked.ht')
domain_ht = hl.read_table(f'{tmpdir}/domains.ht')
preds_ht = hl.read_table(f'{tmpdir}/preds.ht')

# Merge ClinVar variants array (Task 3 - new format, replaces old clinvar struct)
clinvar_ht = clinvar_ht.key_by('locus')
merged = base_ht.annotate(
    clinvar_variants = clinvar_ht[base_ht.locus].clinvar_variants
)

# Merge Training Labels (by locus) - FLATTEN to top level
train_ht = train_ht.key_by('locus')
_train = train_ht[merged.locus]
merged = merged.annotate(
    train_unlabelled = _train.train_counts.unlabelled,
    train_labelled = _train.train_counts.labelled,
    train_unlabelled_high_qual = _train.train_counts.unlabelled_high_qual,
    train_labelled_high_qual = _train.train_counts.labelled_high_qual,
)

# Merge dbNSFP scores (by locus, transcript) - FLATTEN to top level
# Get all dbnsfp column names dynamically
dbnsfp_ht = dbnsfp_ht.key_by('locus', 'Ensembl_transcriptid')
dbnsfp_cols = [col for col in dbnsfp_ht.row if col not in ['locus', 'Ensembl_transcriptid']]
_dbnsfp = dbnsfp_ht[merged.locus, merged.transcript_id]
merged = merged.annotate(**{
    col: _dbnsfp[col] for col in dbnsfp_cols
})

# Merge dbNSFP stacked scores (by locus, transcript) - FLATTEN to top level
dbnsfp_stacked_ht = dbnsfp_stacked_ht.key_by('locus', 'Ensembl_transcriptid')
merged = merged.annotate(
    AlphaMissense_stacked = dbnsfp_stacked_ht[merged.locus, merged.transcript_id].dbnsfp_stacked.AlphaMissense,
    ESM1b_stacked = dbnsfp_stacked_ht[merged.locus, merged.transcript_id].dbnsfp_stacked.ESM1b
)

merged = merged.checkpoint(f'{tmpdir}/merged_with_dbnsfp.ht', overwrite=True)

# Merge VEP aa_pos FIRST (needed for domain filtering below)
try:
    vep_aa_pos_ht = hl.read_table(f'{tmpdir}/vep_aa_pos.ht')
    merged = merged.annotate(
        aa_pos_vep = vep_aa_pos_ht[merged.locus, merged.transcript_id].aa_pos_vep
    )
    print("Added VEP aa_pos annotation")
except Exception as e:
    print(f"VEP aa_pos not available yet: {e}")

# Merge Domain information (by transcript) - FILTER BY AMINO ACID POSITION
# Only include domains where this position's aa_pos_vep falls within [start_aa, end_aa]
domain_ht = domain_ht.key_by('transcript_id')
_all_domains = domain_ht[merged.transcript_id].domains
merged = merged.annotate(
    domains = hl.if_else(
        hl.is_defined(_all_domains) & hl.is_defined(merged.aa_pos_vep),
        _all_domains.filter(lambda d: 
            (merged.aa_pos_vep >= d.start_aa) & (merged.aa_pos_vep <= d.end_aa)
        ),
        hl.missing(_all_domains.dtype)
    )
)

# Merge Constraint Predictions (by locus, transcript) - FLATTEN to top level
preds_ht = preds_ht.key_by('Ensembl_transcriptid', 'locus')
_preds = preds_ht[merged.transcript_id, merged.locus].preds
merged = merged.annotate(
    Constraint = _preds.Constraint,
    Core = _preds.Core,
    Complete = _preds.Complete,
)

# Merge Variant Consequences (by locus, transcript)
csq_ht = hl.read_table(f'{tmpdir}/variant_consequences.ht')
csq_ht = csq_ht.key_by('transcript_id', 'locus')
merged = merged.annotate(
    variant_consequences = csq_ht[merged.transcript_id, merged.locus].variant_consequences
)

# VEP aa_pos already merged above (before domain filtering)

# Checkpoint the merged table
merged = merged.checkpoint(f'/storage/zoghbi/home/u235147/VarPredBrowser/notebooks/merged_browser_data.ht', overwrite=True)
print(f"Merged table row count: {merged.count()}")

Base table key fields (auto-included): ['locus', 'region']
Keeping 122 columns from base table (+ key fields):
  - Core: 5
  - RGC oe/vir: 105
  - RGC count/obs: 12
Found 471 gnomAD constraint columns (oe/vir)
Aggregating allele frequencies by locus...




Added VEP aa_pos annotation
Added 6 cross-norm percentile columns




Merged table row count: 33387608


In [None]:
# Export merged table to parquet
merged_ht = hl.read_table(f'/storage/zoghbi/home/u235147/VarPredBrowser/notebooks/merged_browser_data.ht')


output_folder = f'{my_bucket}/rgc_browser_data_merged_tmp'
output_file = f'{my_bucket}/rgc_browser_data_merged.parquet'

# Repartition to single partition and export to parquet folder
merged_repartitioned = merged_ht.repartition(1)
spark_df = merged_repartitioned.to_spark()
spark_df.write.mode('overwrite').parquet(output_folder)

print(f"Exported to folder: {output_folder}")

# Consolidate Spark output folder to single parquet file
import os
import shutil

# Find the part file in the output folder
part_files = [f for f in os.listdir(output_folder) if f.startswith('part-') and f.endswith('.parquet')]

part_file_path = os.path.join(output_folder, part_files[0])
# Remove existing output file if it exists
if os.path.exists(output_file):
    os.remove(output_file)
# Move part file to final location
shutil.move(part_file_path, output_file)
# Clean up the temp folder
shutil.rmtree(output_folder)
print(f"Consolidated to: {output_file}")

[Stage 44:>                                                         (0 + 1) / 1]

In [None]:
# Calculate exome-wide and cross-normalized percentiles using Polars
# This runs AFTER parquet export and BEFORE preprocessing
# Uses rank(method='average') for proper tie handling

input_file = f'{my_bucket}/rgc_browser_data_merged.parquet'
output_file = f'{my_bucket}/rgc_browser_data_merged_with_perc.parquet'

!python /storage/zoghbi/home/u235147/VarPredBrowser/scripts/calculate_percentiles.py \
    --input {input_file} \
    --output {output_file}

# Update the input for preprocessing to use the percentile-augmented file
print(f"\nPercentile calculation complete. Use {output_file} for preprocessing.")

In [8]:
!python /storage/zoghbi/home/u235147/VarPredBrowser/scripts/preprocess_browser_data.py \
      --input /storage/zoghbi/home/u235147/merged_vars/rgc_browser_data_merged_with_perc.parquet \
      --output /storage/zoghbi/home/u235147/VarPredBrowser/data/

VARPRED BROWSER - DATA PREPROCESSING
Generating axis tables for all filter modes

Loading data from /storage/zoghbi/home/u235147/merged_vars/rgc_browser_data_merged.parquet...
Loaded 33,387,608 total positions
  Columns: 287

Detected columns:
  Chromosome: chrom
  Position: pos
  Gene: HGNC

Chromosome distribution (top 10):
  chr1: 3,414,807
  chr2: 2,477,732
  chr19: 2,149,560
  chr11: 1,980,678
  chr3: 1,921,014
  chr17: 1,918,143
  chr12: 1,732,842
  chr6: 1,690,167
  chr7: 1,565,442
  chr5: 1,562,490

Processing filter: all_sites
  All positions where any variant is possible

Applying filter...
✓ Kept 33,387,608 positions
Sorting by chromosome and position...
Generating compressed coordinates...

Column breakdown:
  - Core (idx, chrom, pos, gene): 4
  - RGC raw metrics: 142
  - ClinVar: 4
  - Training labels: 4
  - dbNSFP scores: 9
  - gnomAD coverage: 2
  - phyloP: 0
  - gnomAD constraint: 0
  - Constraint predictions: 3
  - Stacked scores: 0
  - Variant consequences: 0
  - Perc