In [1]:
# Common imports and constants
import os

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from cvfgaou import hailtools, gctools

BUCKET = os.environ["WORKSPACE_BUCKET"]

In [2]:
# Genes of interest
genes = [
    'APP',
    'BAP1',
    'BARD1',
    'BRCA1',
    'BRCA2',
    'BRIP1',
    'CALM1',
    'CALM2',
    'CALM3',
    'GCK',
    'KCNH2',
    'KCNQ4',
    'MSH2',
    'OTC',
    'PALB2',
    'PRKN',
    'PTEN',
    'RAD51C',
    'RAD51D',
    'SCN5A',
    'SNCA',
    'TARDBP',
    'TP53',
    'VWF'
]

In [3]:
import hail as hl
hl.init()
wgs_mt_path = os.getenv("WGS_EXOME_SPLIT_HAIL_PATH")
wgs_mt_path

In [4]:
# Load wgs
wgs_mt = hl.read_matrix_table(wgs_mt_path)
wgs_mt.describe()

In [5]:
# Bergquist et al. thresholds 10.1016/j.gim.2025.101402

# Benign Moderate: <= 0.199
# Benign Supporting: <= 0.170
# Pathogenic Supporting: >= 0.792
# Pathogenic Moderate: >= 0.906
# Pathogenic Strong: >= 0.990

In [6]:
# Since the AM file is large, we will take a different approach here:
# we will iterate through the file and populate our gene dataframes.

am_df = pd.read_table(
    f'{BUCKET}/precomputed/AlphaMissense_hg38.tsv.gz',
    header=3,
    #engine='pyarrow',
    chunksize=500000
)

# AM columns are:
# CHROM, POS, REF, ALT, genome, uniprot_id, transcript_id, protein_variant, am_pathogenicity, am_class

In [7]:
# We need the interval spec of each gene:
# Obtain genomic coordinates

intervals_df = pd.read_table(
    f'{BUCKET}/aux_data/gene_metadata.txt',
    usecols=['Gene stable ID version', 'Chromosome/scaffold name', 'Gene start (bp)', 'Gene end (bp)', 'Gene name'],
    dtype=str
).drop_duplicates()

# Filter to genes of interest

intervals_df = intervals_df[intervals_df['Gene name'].isin(genes)]

# Cleanup

intervals_df = intervals_df[intervals_df['Chromosome/scaffold name'] != 'HG2334_PATCH']

# Nicer notation
intervals_df['Chrom'] = 'chr' + intervals_df['Chromosome/scaffold name']
intervals_df = intervals_df.astype(
    {
        'Gene start (bp)': int,
        'Gene end (bp)': int
    }
)

# Check
intervals_df

In [8]:
from collections import defaultdict

am_gene_df_lists = defaultdict(list)

for chunk in tqdm(am_df):
    select_intervals_df = intervals_df[
        intervals_df['Chrom'].isin(chunk['#CHROM'].drop_duplicates())
    ][['Gene start (bp)', 'Gene end (bp)', 'Gene name', 'Chrom']]
    
    for start, end, gene, chrom in select_intervals_df.itertuples(index=False):
        am_gene_df_lists[gene].append(chunk[
            (chunk['#CHROM'] == chrom) &
            (chunk['POS'] >= start) &
            (chunk['POS'] <= end)
        ])

am_gene_dfs = {
    gene: pd.concat(df_list, ignore_index=True)
    for gene, df_list in tqdm(am_gene_df_lists.items())
    if df_list
}

In [9]:
# Load gene-specific Calibration table
# This table is indexed by gene
am_gene_thresholds_df = pd.read_csv(
    f'{BUCKET}/calibrations/pillarg_AM_gene_specific_thresh.csv',
    index_col=0
)

# Mapping from our thresholds to thresholds used in the table and the respective comparison direction
gs_threshold_map = {
    f'{direction} {strength}': (f'{direction_label}_{strength.title()}', comparator)
    for direction, direction_label, comparator in (
        ('Pathogenic', 'PP3', pd.Series.ge),
        ('Benign', 'BP4', pd.Series.le)
    )
    for strength in ('very strong', 'strong', 'moderate', 'supporting')
}

In [10]:
clinvar_bins_df = pd.read_csv(f'{BUCKET}/clinvar/clinvar-bins.csv.gz')

In [11]:
from itertools import chain

for gene, per_gene_df in tqdm(am_gene_dfs.items()):
    
    exposures_file = f'{BUCKET}/classes_2025-07-22/exposures/alphamissense_{gene}.parquet'
    clinvar_file = f'{BUCKET}/classes_2025-07-22/clinvar_maps/alphamissense_{gene}.parquet'
    af_file = f'{BUCKET}/classes_2025-07-22/af_maps/alphamissense_{gene}.parquet'
    
    if gctools.blob_exists(exposures_file): continue
     
    gene_result_dfs = []
    clinvar_classes_dfs = []
    joint_af_map = {}
    
    variant_classes = chain(
        ( # Author Reported scores; also, make the label text nicer
            ('Author Reported', c.replace('_', ' ').capitalize(), df)
            for c, df in per_gene_df.groupby('am_class')
        ),
        ( # Calibrated scores
            ('Calibrated (Bergquist et al. 10.1016/j.gim.2025.101402)', c, per_gene_df[selection])
            for c, selection in (
                ('Benign moderate', per_gene_df['am_pathogenicity'] <= 0.199),
                ('Benign supporting', per_gene_df['am_pathogenicity'] <= 0.170),
                ('Pathogenic supporting', per_gene_df['am_pathogenicity'] >= 0.792),
                ('Pathogenic moderate', per_gene_df['am_pathogenicity'] >= 0.906),
                ('Pathogenic strong', per_gene_df['am_pathogenicity'] >= 0.990)
            )
        )
    )
    
    # Gene-specific hresholds
    if gene in am_gene_thresholds_df.index:
        variant_classes = chain(
            variant_classes,
            (
                ('Calibrated (gene-specific)', c, per_gene_df[selection])
                for c, selection in (
                    (classification, compare(per_gene_df['am_pathogenicity'], threshold))
                    for classification, (label, compare) in gs_threshold_map.items()
                    for threshold in (am_gene_thresholds_df.loc[gene, label], )
                    if not np.isnan(threshold)
                )
            )
        )
    
    for classifier, classification, variant_df in tqdm(variant_classes):

        if variant_df.empty: continue
            
        try:
            exposure_df, af_map, clinvar_df = hailtools.get_exposure_package(
                variant_df,
                wgs_mt,
                clinvar_bins_df,
                contig_col='#CHROM',
                pos_col='POS',
                ref_col='REF',
                alt_col='ALT',
                metadata_dict = {
                    'Dataset': 'AlphaMissense',
                    'Gene': gene,
                    'Classifier': classifier,
                    'Classification': classification
                }
            )

            clinvar_classes_dfs.append(clinvar_df)
            joint_af_map.update(af_map)
            gene_result_dfs.append(exposure_df)
            
        except:
            print(f'Failed on {gene}, {classifier}, {classification}:')
            print(variant_df)
            raise
        
    if clinvar_classes_dfs:
        pd.concat(clinvar_classes_dfs, ignore_index=True).to_parquet(clinvar_file, index=False)
    if joint_af_map:
        pd.Series(joint_af_map).to_frame(name='AF').to_parquet(af_file)
    if gene_result_dfs:
        pd.concat(gene_result_dfs, ignore_index=True).to_parquet(exposures_file, index=False)