In [1]:
# @title Imports

import functools
import os
from typing import Callable

import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from sklearn import metrics
from sklearn.metrics import average_precision_score, roc_auc_score


In [2]:
# @title Helper functions.

def calculate_tissue_weighted_metric(
    predictions_df: pd.DataFrame,
    metric_fn: Callable[[np.ndarray, np.ndarray], float],
    prediction_col: str = 'prediction',
    target_col: str = 'target',
    tissue_col: str = 'tissue'
) -> float:
    """
    Calculates a tissue-weighted mean for a given metric function from a DataFrame.

    Args:
        predictions_df: DataFrame containing predictions, targets, and tissue info.
        metric_fn: A function that accepts two np.ndarrays (y_true, y_pred)
                   and returns a single float metric.
                   Note: For scipy functions that return tuples (e.g., spearmanr),
                   you must wrap them in a lambda function (see example).
        prediction_col: Name of the column with prediction scores.
        target_col: Name of the column with target labels/values.
        tissue_col: Name of the column indicating the tissue/group.

    Returns:
        The calculated tissue-weighted mean metric, or np.nan if unable to calculate.
    """
    if predictions_df is None or predictions_df.empty:
        print("Warning: predictions_df is empty, returning NaN.")
        return float('nan')

    required_cols = [prediction_col, target_col, tissue_col]
    if not all(col in predictions_df.columns for col in required_cols):
        raise ValueError(f"DataFrame must contain columns: {required_cols}")

    tissue_results = []

    # Group by tissue and calculate the metric for each one.
    for tissue_name, group_df in predictions_df.groupby(tissue_col):

        # Drop rows where target or prediction is NaN.
        clean_group = group_df.dropna(subset=[target_col, prediction_col])

        num_variants = len(clean_group)

        # Skip tissues with < 2 variants (can't calculate metrics).
        if num_variants < 2:
            # print(f"Skipping tissue {tissue_name} (variants={num_variants})")
            continue

        try:
            # Calculate the metric for this one tissue using the passed function
            metric_value = metric_fn(
                clean_group[target_col], clean_group[prediction_col]
            )

            # Ensure metric is a valid number.
            if not np.isfinite(metric_value):
                print(f"Skipping tissue {tissue_name} (metric_fn returned non-finite value)")
                continue

            tissue_results.append({
                'tissue': tissue_name,
                'metric_value': metric_value,
                'count': num_variants  # This is the "weight".
            })
        except ValueError as e:
            # This will catch single-class errors for AUROC/AUPRC
            # and other potential issues from metric_fn.
            print(f"Could not calculate metric for tissue {tissue_name}: {e}")
            continue # Skip tissue if metric calculation fails.

    if not tissue_results:
        print("Warning: No tissues had scorable metric values, returning NaN.")
        return float('nan')

    # Create a DataFrame of the per-tissue metrics.
    metrics_df = pd.DataFrame(tissue_results)

    # Calculate the final weighted mean.
    weighted_sum = (metrics_df['metric_value'] * metrics_df['count']).sum()
    total_count = metrics_df['count'].sum()

    if total_count == 0:
         print("Warning: Total count for weighting is zero, returning NaN.")
         return float('nan')

    weighted_mean_metric = weighted_sum / total_count
    return weighted_mean_metric


def paqtl_auprc(df):
  SEED = 0

  # 1. Separate positives and negatives.
  pos = df[df['target'] == 1]
  neg = df[df['target'] == 0]

  # 2. Merge on 'PI' to find all valid matched pairs.
  matched = pd.merge(
      pos,
      neg,
      on='PI',
      how='inner',
      suffixes=('_pos', '_neg')
  )

  # 3. Group by 'PI' and sample exactly one pair per group.
  # This ensures 1:1 matching controlled by PI.
  sampled_pairs = matched.groupby('PI').sample(n=1, random_state=SEED)

  # 4. Reconstruct a single dataframe for AUPRC calculation
  # We stack the positive and negative parts of the sampled pairs back together.
  df_sampled = pd.concat([
      sampled_pairs[['prediction_pos', 'target_pos']].rename(
          columns={'prediction_pos': 'prediction', 'target_pos': 'target'}
      ),
      sampled_pairs[['prediction_neg', 'target_neg']].rename(
          columns={'prediction_neg': 'prediction', 'target_neg': 'target'}
      )
  ])
  return metrics.average_precision_score(
      df_sampled['target'], df_sampled['prediction'])

auroc_fn = roc_auc_score
auprc_fn = average_precision_score

spearman_fn = lambda y_true, y_pred: spearmanr(y_true, y_pred)[0]


In [3]:
# @title Eval configs.
PREDS_PATH = 'https://storage.googleapis.com/alphagenome/evals'

evals = {
    'clinvar_splice_site_region': {
        'output_type':      'SPLICE_SITE_USAGE;RNA_SEQ;SPLICE_SITE_POSITIONS;SPLICE_JUNCTIONS;SPLICE_SITES',
        'metric_name':      'auprc_max_abs_track_aggregation',
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.5699,
    },
    'clinvar_noncoding': {
        'output_type':      'RNA_SEQ;SPLICE_SITE_POSITIONS;SPLICE_SITE_USAGE;SPLICE_JUNCTIONS;SPLICE_SITES',
        'metric_name':      'auprc_max_abs_track_aggregation',
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.6588,
    },
    'clinvar_missense': {
        'output_type':      'SPLICE_SITE_USAGE;SPLICE_JUNCTIONS;SPLICE_SITES;RNA_SEQ;SPLICE_SITE_POSITIONS',
        'metric_name':      'auprc_max_abs_track_aggregation',
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.1792,
    },
    'sqtl_variant_causality_gene_human': {
        'output_type':      'RNA_SEQ;SPLICE_JUNCTIONS;SPLICE_SITES;SPLICE_SITE_USAGE;SPLICE_SITE_POSITIONS',
        'metric_name':      'tissue_weighted_mean_auprc',
        'metric_fn':        functools.partial(calculate_tissue_weighted_metric, metric_fn=auprc_fn),
        'reported_metric':  0.7644,
    },
    'mfass_splicing': {
        'output_type':      'SPLICE_SITES_LOGITS;SPLICE_SITE_USAGE;SPLICE_JUNCTIONS',
        'metric_name':      'all_tissues_auprc',
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.5120,
    },
    'eqtl_variant_borzoi_sign_human': {
        'output_type':      'RNA_SEQ',
        'metric_name':      'tissue_weighted_mean_auroc',
        'metric_fn':        functools.partial(calculate_tissue_weighted_metric, metric_fn=auroc_fn),
        'reported_metric':  0.810077,
    },
    'eqtl_variant_catalogue_causality_gene_balanced_human': {
        'output_type':      'RNA_SEQ',
        'metric_name':      'tissue_weighted_mean_auroc',
        'metric_fn':        functools.partial(calculate_tissue_weighted_metric, metric_fn=auroc_fn),
        'reported_metric':  0.713255,
    },
    'eqtl_variant_borzoi_coefficient_human': {
        'output_type':      'RNA_SEQ',
        'metric_name':      'tissue_weighted_mean_spearmanr',
        'metric_fn':        functools.partial(calculate_tissue_weighted_metric, metric_fn=spearman_fn),
        'reported_metric':  0.500588,
    },
    'enhancer_gene_linking_e2g': {
        'output_type':      'RNA_SEQ',
        'metric_name':      None,
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.7490,
    },
    'paqtl_variant_causality_human': {
        'output_type':      'RNA_SEQ',
        'metric_name':      'PAS_10000_threshold_average_auprc',
        'metric_fn':        paqtl_auprc,
        'reported_metric':  0.6294,
        'notes':            'The reported metric is from many permutations, here we just do 1 resample.',
    },
    'caqtl_african_variant_causality_human': {
        'output_type':      'DNASE',
        'metric_name':      "auprc_mean_abs_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.5643,
    },
    'caqtl_european_variant_causality_human': {
        'output_type':      'DNASE',
        'metric_name':      "auprc_mean_abs_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.3638,
    },
    'dsqtl_yoruba_variant_causality_human': {
        'output_type':      'DNASE',
        'metric_name':      "auprc_mean_abs_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.6308,
    },
    'caqtl_african_variant_coefficient_human': {
        'output_type':      'DNASE',
        'metric_name':      "pearsonr_mean_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.7368,
    },
    'caqtl_european_variant_coefficient_human': {
        'output_type':      'DNASE',
        'metric_name':      "pearsonr_mean_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.5916,
    },
    'dsqtl_yoruba_variant_coefficient_human': {
        'output_type':      'DNASE',
        'metric_name':      "pearsonr_mean_track_aggregation_ontology_curie:['EFO:0002784']",
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.8323,
        'notes':            'The sign of the correlation flips depending on which allele is labelled REF vs. ALT.',
    },
    'caqtl_microglia_variant_coefficient_human': {
        'output_type':      'DNASE',
        'metric_name':      "pearsonr_mean_track_aggregation_ontology_curie:['CL:0000862']",
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.6357,
        'notes':            'The sign of the correlation flips depending on which allele is labelled REF vs. ALT.',
    },
    'caqtl_smc_variant_coefficient_human': {
        'output_type':      'ATAC',
        'metric_name':      "pearsonr_mean_track_aggregation_ontology_curie:['UBERON:0002079']",
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.687,
    },
    'bqtl_spi1_variant_coefficient_human': {
        'output_type':      'CHIP_TF',
        'metric_name':      'pearsonr_mean_track_aggregation_EFO:0002784_SPI1',
        'metric_fn':        lambda x: pearsonr(x['target'], x['prediction']).statistic,
        'reported_metric':  0.549967,
    },
    'bqtl_spi1_variant_causality_human': {
        'output_type':      'CHIP_TF',
        'metric_name':      'auprc_mean_abs_track_aggregation_EFO:0002784_SPI1',
        'metric_fn':        lambda x: metrics.average_precision_score(x['target'], x['prediction']),
        'reported_metric':  0.4952,
    },
}

In [4]:
for eval_name, c in evals.items():
  print(f"\nEval: {eval_name}")

  filepath = os.path.join(PREDS_PATH, eval_name + '_predictions' + '.feather')
  predictions = pd.read_feather(filepath)

  recomputed_metric = c['metric_fn'](predictions)
  print(f"  Reported:   {c['reported_metric']}")
  print(f"  Recomputed: {recomputed_metric:.4f}")

  if c.get('notes'):
      print(f"  Notes:      {c['notes']}")


Eval: clinvar_splice_site_region
  Reported:   0.5699
  Recomputed: 0.5699

Eval: clinvar_noncoding
  Reported:   0.6588
  Recomputed: 0.6588

Eval: clinvar_missense
  Reported:   0.1792
  Recomputed: 0.1792

Eval: sqtl_variant_causality_gene_human
  Reported:   0.7644
  Recomputed: 0.7645

Eval: mfass_splicing
  Reported:   0.512
  Recomputed: 0.5124

Eval: eqtl_variant_borzoi_sign_human
  Reported:   0.810077
  Recomputed: 0.8101

Eval: eqtl_variant_catalogue_causality_gene_balanced_human
  Reported:   0.713255
  Recomputed: 0.7133

Eval: eqtl_variant_borzoi_coefficient_human
  Reported:   0.500588
  Recomputed: 0.5006

Eval: enhancer_gene_linking_e2g
  Reported:   0.749
  Recomputed: 0.7488

Eval: paqtl_variant_causality_human
  Reported:   0.6294
  Recomputed: 0.6288
  Notes:      The reported metric is from many permutations, here we just do 1 resample.

Eval: caqtl_african_variant_causality_human
  Reported:   0.5643
  Recomputed: 0.5641

Eval: caqtl_european_variant_causality_h