In [1]:
# @title Install AlphaGenome

# @markdown Run this cell to install AlphaGenome.
from IPython.display import clear_output
! pip install alphagenome
clear_output()

In [2]:
# @title Setup and imports.

from io import StringIO
from alphagenome import colab_utils
from alphagenome.data import genome
from alphagenome.models import dna_client, variant_scorers
from google.colab import data_table, files
import pandas as pd
from tqdm import tqdm

data_table.enable_dataframe_formatter()

# Load the model.
dna_model = dna_client.create(colab_utils.get_api_key())

In [None]:
# @title Score batch of variants with aggregation

# Load VCF file containing variants.
vcf_file = 'brca1.csv'  # @param


vcf = pd.read_csv(vcf_file, sep=',') # Changed separator to comma
vcf = vcf.rename(columns={
    'chr': 'CHROM',
    'pos': 'POS',
    'ref': 'REF',
    'alt': 'ALT'
}) # Rename columns to match expected names
# Create variant_id column
vcf['variant_id'] = vcf['CHROM'].astype(str) + ':' + \
                    vcf['POS'].astype(str) + ':' + \
                    vcf['REF'].astype(str) + '>' + \
                    vcf['ALT'].astype(str)

required_columns = ['variant_id', 'CHROM', 'POS', 'REF', 'ALT']
for column in required_columns:
  if column not in vcf.columns:
    raise ValueError(f'VCF file is missing required column: {column}.')

organism = 'human'  # @param ["human", "mouse"] {type:"string"}

# @markdown Specify length of sequence around variants to predict:
sequence_length = '1MB'  # @param ["16KB", "100KB", "500KB", "1MB"] { type:"string" }
sequence_length = dna_client.SUPPORTED_SEQUENCE_LENGTHS[
    f'SEQUENCE_LENGTH_{sequence_length}'
]

# @markdown Specify which scorers to use to score your variants:
score_rna_seq = True  # @param { type: "boolean"}
score_cage = True  # @param { type: "boolean" }
score_procap = True  # @param { type: "boolean" }
score_atac = True  # @param { type: "boolean" }
score_dnase = True  # @param { type: "boolean" }
score_chip_histone = True  # @param { type: "boolean" }
score_chip_tf = True  # @param { type: "boolean" }
score_polyadenylation = True  # @param { type: "boolean" }
score_splice_sites = True  # @param { type: "boolean" }
score_splice_site_usage = True  # @param { type: "boolean" }
score_splice_junctions = True  # @param { type: "boolean" }

# @markdown Other settings:
download_predictions = True  # @param { type: "boolean" }

# Parse organism specification.
organism_map = {
    'human': dna_client.Organism.HOMO_SAPIENS,
    'mouse': dna_client.Organism.MUS_MUSCULUS,
}
organism = organism_map[organism]

# Parse scorer specification.
scorer_selections = {
    'rna_seq': score_rna_seq,
    'cage': score_cage,
    'procap': score_procap,
    'atac': score_atac,
    'dnase': score_dnase,
    'chip_histone': score_chip_histone,
    'chip_tf': score_chip_tf,
    'polyadenylation': score_polyadenylation,
    'splice_sites': score_splice_sites,
    'splice_site_usage': score_splice_site_usage,
    'splice_junctions': score_splice_junctions,
}

all_scorers = variant_scorers.RECOMMENDED_VARIANT_SCORERS
selected_scorers = [
    all_scorers[key]
    for key in all_scorers
    if scorer_selections.get(key.lower(), False)
]

# Remove any scorers or output types that are not supported for the chosen organism.
unsupported_scorers = [
    scorer
    for scorer in selected_scorers
    if (
        organism.value
        not in variant_scorers.SUPPORTED_ORGANISMS[scorer.base_variant_scorer]
    )
    | (
        (scorer.requested_output == dna_client.OutputType.PROCAP)
        & (organism == dna_client.Organism.MUS_MUSCULUS)
    )
]
if len(unsupported_scorers) > 0:
  print(
      f'Excluding {unsupported_scorers} scorers as they are not supported for'
      f' {organism}.'
  )
  for unsupported_scorer in unsupported_scorers:
    selected_scorers.remove(unsupported_scorer)


# Score variants in the VCF file.
# Create a list to store ONLY the aggregated summaries (small memory footprint)
aggregated_data = []

print("Processing variants and aggregating on-the-fly...")

# Iterate through variants one by one
for i, vcf_row in tqdm(vcf.iterrows(), total=len(vcf)):
    variant = genome.Variant(
        chromosome=str(vcf_row.CHROM),
        position=int(vcf_row.POS),
        reference_bases=vcf_row.REF,
        alternate_bases=vcf_row.ALT,
        name=vcf_row.variant_id,
    )
    interval = variant.reference_interval.resize(sequence_length)

 # ... (inside your loop) ...

    # 1. Score ONLY this single variant
    variant_scores = dna_model.score_variant(
        interval=interval,
        variant=variant,
        variant_scorers=selected_scorers,
        organism=organism,
    )

    # 2. Convert immediately to a temporary DataFrame
    df_single = variant_scorers.tidy_scores([variant_scores])

    # ---------------------------------------------------------
    # MODIFICATION: Filter for Breast Cancer Contexts ONLY
    # ---------------------------------------------------------
    if not df_single.empty:
        # Define keywords to identify breast cancer relevant samples
        # You can add more cell lines (e.g., 'MDA-MB', 'SK-BR-3') if needed.
        keywords = ['breast', 'mammary', 'mcf', 't47d', 'zr-75', 'hs-578t']

        # Create a mask to check both 'gtex_tissue' and 'biosample_name'
        # We convert to string first to handle any NaN values safely
        mask = (
            df_single['gtex_tissue'].astype(str).str.contains('|'.join(keywords), case=False) |
            df_single['biosample_name'].astype(str).str.contains('|'.join(keywords), case=False)
        )

        # Apply the filter
        df_single = df_single[mask]
    # ---------------------------------------------------------

    # 3. Calculate aggregates immediately (now only on breast cancer data)
    if not df_single.empty:
        stats = {
            'variant_id': str(vcf_row.variant_id),
            # ... (rest of your aggregation logic remains the same) ... # Ensure string ID

            # Calculate Max
            'raw_score_max': df_single['raw_score'].max(),
            'quantile_score_max': df_single['quantile_score'].max(),

            # Calculate Min
            'raw_score_min': df_single['raw_score'].min(),
            'quantile_score_min': df_single['quantile_score'].min(),

            # Calculate Mean
            'raw_score_mean': df_single['raw_score'].mean(),
            'quantile_score_mean': df_single['quantile_score'].mean(),
        }
        aggregated_data.append(stats)

    # 4. Explicitly delete heavy objects to free RAM immediately
    del variant_scores
    del df_single

# ---------------------------------------------------------
# After the loop, create the final summary DataFrames
# ---------------------------------------------------------

# Convert the lightweight list of dicts to a DataFrame
df_agg_all = pd.DataFrame(aggregated_data)

# Split into the separate DataFrames you needed
agg_max = df_agg_all[['variant_id', 'raw_score_max', 'quantile_score_max']]
agg_min = df_agg_all[['variant_id', 'raw_score_min', 'quantile_score_min']]
agg_mean = df_agg_all[['variant_id', 'raw_score_mean', 'quantile_score_mean']]

# Save aggregated scores (Same logic as before)
if download_predictions:
    agg_max.to_csv('brca1_variant_scores_max.csv', index=False)
    files.download('brca1_variant_scores_max.csv')

    agg_min.to_csv('brca1_variant_scores_min.csv', index=False)
    files.download('brca1_variant_scores_min.csv')

    agg_mean.to_csv('brca1_variant_scores_mean.csv', index=False)
    files.download('brca1_variant_scores_mean.csv')

print("\nAggregated scores (max):")
display(agg_max.head()) # Use head() to avoid printing huge tables
print("\nAggregated scores (min):")
display(agg_min.head())
print("\nAggregated scores (mean):")
display(agg_mean.head())

Processing variants and aggregating on-the-fly...


100%|██████████| 8/8 [00:13<00:00,  1.68s/it]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Aggregated scores (max):


Unnamed: 0,variant_id,raw_score_max,quantile_score_max
0,chr17:41276135:T>G,0.024837,0.999302
1,chr17:41276135:T>C,0.03144,0.99884
2,chr17:41276135:T>A,0.032645,0.999483
3,chr17:41276134:T>G,0.023589,0.997402
4,chr17:41276134:T>C,0.038402,0.998813



Aggregated scores (min):


Unnamed: 0,variant_id,raw_score_min,quantile_score_min
0,chr17:41276135:T>G,-0.051805,-0.998361
1,chr17:41276135:T>C,-0.034535,-0.998323
2,chr17:41276135:T>A,-0.037059,-0.996728
3,chr17:41276134:T>G,-0.042626,-0.997936
4,chr17:41276134:T>C,-0.021362,-0.996876



Aggregated scores (mean):


Unnamed: 0,variant_id,raw_score_mean,quantile_score_mean
0,chr17:41276135:T>G,-0.002094,-0.069305
1,chr17:41276135:T>C,-0.001044,-0.117726
2,chr17:41276135:T>A,0.000473,0.048312
3,chr17:41276134:T>G,-0.000474,-0.004724
4,chr17:41276134:T>C,0.000343,0.033802




Unnamed: 0,variant_id,scored_interval,gene_id,gene_name,gene_type,gene_strand,junction_Start,junction_End,output_type,variant_scorer,...,biosample_type,biosample_life_stage,data_source,endedness,genetically_modified,transcription_factor,histone_mark,gtex_tissue,raw_score,quantile_score
0,chr3:58394738:A>T,chr3:57870450-58919026:.,,,,,,,ATAC,"CenterMaskScorer(requested_output=ATAC, width=...",...,primary_cell,adult,encode,paired,False,,,,-0.004951,-0.259272
168,chr3:58394738:A>T,chr3:57870450-58919026:.,,,,,,,DNASE,"CenterMaskScorer(requested_output=DNASE, width...",...,primary_cell,adult,encode,paired,False,,,,-0.030459,-0.680793
2109,chr3:58394738:A>T,chr3:57870450-58919026:.,,,,,,,CHIP_HISTONE,CenterMaskScorer(requested_output=CHIP_HISTONE...,...,primary_cell,adult,encode,single,False,,H3K27ac,,-0.000978,-0.301775
2110,chr3:58394738:A>T,chr3:57870450-58919026:.,,,,,,,CHIP_HISTONE,CenterMaskScorer(requested_output=CHIP_HISTONE...,...,primary_cell,adult,encode,single,False,,H3K27me3,,0.000441,0.080577
2111,chr3:58394738:A>T,chr3:57870450-58919026:.,,,,,,,CHIP_HISTONE,CenterMaskScorer(requested_output=CHIP_HISTONE...,...,primary_cell,adult,encode,single,False,,H3K36me3,,0.000000,-0.092027
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
91363,chr16:1135446:G>T,chr16:611158-1659734:.,ENSG00000290756,ENSG00000290756,lncRNA,-,,,RNA_SEQ,GeneMaskLFCScorer(requested_output=RNA_SEQ),...,primary_cell,adult,encode,single,False,,,,-0.000391,-0.402670
91364,chr16:1135446:G>T,chr16:611158-1659734:.,ENSG00000292400,ENSG00000292400,lncRNA,-,,,RNA_SEQ,GeneMaskLFCScorer(requested_output=RNA_SEQ),...,primary_cell,adult,encode,single,False,,,,0.000173,0.280661
91365,chr16:1135446:G>T,chr16:611158-1659734:.,ENSG00000292423,ENSG00000292423,lncRNA,-,,,RNA_SEQ,GeneMaskLFCScorer(requested_output=RNA_SEQ),...,primary_cell,adult,encode,single,False,,,,-0.001031,-0.641886
91366,chr16:1135446:G>T,chr16:611158-1659734:.,ENSG00000292431,ENSG00000292431,lncRNA,-,,,RNA_SEQ,GeneMaskLFCScorer(requested_output=RNA_SEQ),...,primary_cell,adult,encode,single,False,,,,0.000020,0.103452


In [4]:
# @title Score batch of variants with aggregation and Performance Evaluation
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, roc_curve

# Load VCF file containing variants.
vcf_file = 'clinvar_train.csv'  # @param

vcf = pd.read_csv(vcf_file, sep=',') # Changed separator to comma
vcf = vcf.rename(columns={
    'chr': 'CHROM',
    'pos': 'POS',
    'ref': 'REF',
    'alt': 'ALT'
}) # Rename columns to match expected names

# ---------------------------------------------------------
# [新增/修改] 修复染色体前缀 (Fix Chromosome Prefix)
# ---------------------------------------------------------
# 确保 CHROM 列是字符串格式
vcf['CHROM'] = vcf['CHROM'].astype(str)
# 如果没有 'chr' 前缀，则手动添加
vcf['CHROM'] = vcf['CHROM'].apply(lambda x: x if x.startswith('chr') else 'chr' + x)

# Create variant_id column (现在的 variant_id 也会自动带上 chr 前缀)
vcf['variant_id'] = vcf['CHROM'].astype(str) + ':' + \
                    vcf['POS'].astype(str) + ':' + \
                    vcf['REF'].astype(str) + '>' + \
                    vcf['ALT'].astype(str)

# ---------------------------------------------------------
# Required columns (Including 'label')
# ---------------------------------------------------------
required_columns = ['variant_id', 'CHROM', 'POS', 'REF', 'ALT', 'label']
for column in required_columns:
  if column not in vcf.columns:
    raise ValueError(f'VCF file is missing required column: {column}.')

organism = 'human'  # @param ["human", "mouse"] {type:"string"}

# @markdown Specify length of sequence around variants to predict:
sequence_length = '1MB'  # @param ["16KB", "100KB", "500KB", "1MB"] { type:"string" }
sequence_length = dna_client.SUPPORTED_SEQUENCE_LENGTHS[
    f'SEQUENCE_LENGTH_{sequence_length}'
]

# @markdown Specify which scorers to use to score your variants:
score_rna_seq = True  # @param { type: "boolean"}
score_cage = True  # @param { type: "boolean" }
score_procap = True  # @param { type: "boolean" }
score_atac = True  # @param { type: "boolean" }
score_dnase = True  # @param { type: "boolean" }
score_chip_histone = True  # @param { type: "boolean" }
score_chip_tf = True  # @param { type: "boolean" }
score_polyadenylation = True  # @param { type: "boolean" }
score_splice_sites = True  # @param { type: "boolean" }
score_splice_site_usage = True  # @param { type: "boolean" }
score_splice_junctions = True  # @param { type: "boolean" }

# @markdown Other settings:
download_predictions = True  # @param { type: "boolean" }

# Parse organism specification.
organism_map = {
    'human': dna_client.Organism.HOMO_SAPIENS,
    'mouse': dna_client.Organism.MUS_MUSCULUS,
}
organism = organism_map[organism]

# Parse scorer specification.
scorer_selections = {
    'rna_seq': score_rna_seq,
    'cage': score_cage,
    'procap': score_procap,
    'atac': score_atac,
    'dnase': score_dnase,
    'chip_histone': score_chip_histone,
    'chip_tf': score_chip_tf,
    'polyadenylation': score_polyadenylation,
    'splice_sites': score_splice_sites,
    'splice_site_usage': score_splice_site_usage,
    'splice_junctions': score_splice_junctions,
}

all_scorers = variant_scorers.RECOMMENDED_VARIANT_SCORERS
selected_scorers = [
    all_scorers[key]
    for key in all_scorers
    if scorer_selections.get(key.lower(), False)
]

# Remove any scorers or output types that are not supported for the chosen organism.
unsupported_scorers = [
    scorer
    for scorer in selected_scorers
    if (
        organism.value
        not in variant_scorers.SUPPORTED_ORGANISMS[scorer.base_variant_scorer]
    )
    | (
        (scorer.requested_output == dna_client.OutputType.PROCAP)
        & (organism == dna_client.Organism.MUS_MUSCULUS)
    )
]
if len(unsupported_scorers) > 0:
  print(
      f'Excluding {unsupported_scorers} scorers as they are not supported for'
      f' {organism}.'
  )
  for unsupported_scorer in unsupported_scorers:
    selected_scorers.remove(unsupported_scorer)


# Score variants in the VCF file.
aggregated_data = []

print("Processing variants and aggregating on-the-fly...")

# Iterate through variants one by one
for i, vcf_row in tqdm(vcf.iterrows(), total=len(vcf)):
    variant = genome.Variant(
        chromosome=str(vcf_row.CHROM),
        position=int(vcf_row.POS),
        reference_bases=vcf_row.REF,
        alternate_bases=vcf_row.ALT,
        name=vcf_row.variant_id,
    )
    interval = variant.reference_interval.resize(sequence_length)

    # 1. Score ONLY this single variant
    try:
        variant_scores = dna_model.score_variant(
            interval=interval,
            variant=variant,
            variant_scorers=selected_scorers,
            organism=organism,
        )

        # 2. Convert immediately to a temporary DataFrame
        df_single = variant_scorers.tidy_scores([variant_scores])

        # ---------------------------------------------------------
        # REMOVED FILTER: Now calculating based on ALL tissues/contexts
        # ---------------------------------------------------------

        # 3. Calculate aggregates immediately
        if not df_single.empty:
            stats = {
                'variant_id': str(vcf_row.variant_id),
                # Store the label
                'label': vcf_row['label'],

                # Calculate Max
                'raw_score_max': df_single['raw_score'].max(),
                'quantile_score_max': df_single['quantile_score'].max(),

                # Calculate Min
                'raw_score_min': df_single['raw_score'].min(),
                'quantile_score_min': df_single['quantile_score'].min(),

                # Calculate Mean
                'raw_score_mean': df_single['raw_score'].mean(),
                'quantile_score_mean': df_single['quantile_score'].mean(),
            }
            aggregated_data.append(stats)

        # 4. Explicitly delete heavy objects
        del variant_scores
        del df_single

    except Exception as e:
        print(f"Error processing variant {vcf_row.variant_id}: {e}")
        continue

# ---------------------------------------------------------
# After the loop, create the final summary DataFrames
# ---------------------------------------------------------

# Convert the lightweight list of dicts to a DataFrame
df_agg_all = pd.DataFrame(aggregated_data)

# Check if we have data to evaluate
if df_agg_all.empty:
    print("No data was aggregated (check filters or input). Exiting.")
else:
    # ---------------------------------------------------------
    # Performance Evaluation
    # ---------------------------------------------------------
    print("\n" + "="*40)
    print("PERFORMANCE EVALUATION (vs Label)")
    print("="*40)

    # Ensure labels are numeric.
    # Assumes labels are already 0/1 or can be coerced to numbers.
    y_true = pd.to_numeric(df_agg_all['label'], errors='coerce')

    # Drop NaNs if label conversion failed for some rows
    valid_mask = ~y_true.isna()
    if valid_mask.sum() < len(y_true):
        print(f"Warning: {len(y_true) - valid_mask.sum()} rows had non-numeric labels and were skipped.")

    y_true = y_true[valid_mask]
    df_eval = df_agg_all[valid_mask]

    def evaluate_metric(y_true, y_pred, metric_name):
        """Helper to print metrics"""
        # Handle cases where score might be NaN
        clean_mask = ~np.isnan(y_pred)
        if clean_mask.sum() == 0:
            print(f"  {metric_name}: No valid predictions.")
            return

        y_t = y_true[clean_mask]
        y_p = y_pred[clean_mask]

        # Check if we have both classes to calculate AUC
        if len(np.unique(y_t)) < 2:
            print(f"  {metric_name}: Only one class present in labels. Cannot compute AUC.")
            return

        try:
            auroc = roc_auc_score(y_t, y_p)
            auprc = average_precision_score(y_t, y_p)
            print(f"  [{metric_name}]")
            print(f"    AUROC: {auroc:.4f}")
            print(f"    AUPRC: {auprc:.4f}")
        except ValueError as e:
            print(f"  {metric_name}: Error calculating metrics ({e})")

    # Evaluate Max Scores
    print("\n--- Strategy: MAX Score ---")
    evaluate_metric(y_true, df_eval['raw_score_max'], 'Raw Score Max')
    evaluate_metric(y_true, df_eval['quantile_score_max'], 'Quantile Score Max')

    # Evaluate Min Scores
    print("\n--- Strategy: MIN Score ---")
    evaluate_metric(y_true, df_eval['raw_score_min'], 'Raw Score Min')
    evaluate_metric(y_true, df_eval['quantile_score_min'], 'Quantile Score Min')

    # Evaluate Mean Scores
    print("\n--- Strategy: MEAN Score ---")
    evaluate_metric(y_true, df_eval['raw_score_mean'], 'Raw Score Mean')
    evaluate_metric(y_true, df_eval['quantile_score_mean'], 'Quantile Score Mean')
    print("="*40 + "\n")

    # Save aggregated scores
    agg_max = df_agg_all[['variant_id', 'label', 'raw_score_max', 'quantile_score_max']]
    agg_min = df_agg_all[['variant_id', 'label', 'raw_score_min', 'quantile_score_min']]
    agg_mean = df_agg_all[['variant_id', 'label', 'raw_score_mean', 'quantile_score_mean']]

    if download_predictions:
        agg_max.to_csv('brca1_variant_scores_max.csv', index=False)
        files.download('brca1_variant_scores_max.csv')

        agg_min.to_csv('brca1_variant_scores_min.csv', index=False)
        files.download('brca1_variant_scores_min.csv')

        agg_mean.to_csv('brca1_variant_scores_mean.csv', index=False)
        files.download('brca1_variant_scores_mean.csv')

    print("\nAggregated scores (max) head:")
    display(agg_max.head())

Processing variants and aggregating on-the-fly...


  0%|          | 0/3850 [00:00<?, ?it/s]

Error processing variant chrMT:12706:T>C: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.INVALID_ARGUMENT
	details = "Unrecognized chromosome 'chrMT' for organism."
	debug_error_string = "UNKNOWN:Error received from peer ipv4:108.177.12.95:443 {grpc_status:3, grpc_message:"Unrecognized chromosome \'chrMT\' for organism."}"
>

PERFORMANCE EVALUATION (vs Label)

--- Strategy: MAX Score ---
  [Raw Score Max]
    AUROC: 0.5435
    AUPRC: 0.7304
  [Quantile Score Max]
    AUROC: 0.5528
    AUPRC: 0.7287

--- Strategy: MIN Score ---
  [Raw Score Min]
    AUROC: 0.4941
    AUPRC: 0.6682
  [Quantile Score Min]
    AUROC: 0.4506
    AUPRC: 0.6402

--- Strategy: MEAN Score ---
  [Raw Score Mean]
    AUROC: 0.5656
    AUPRC: 0.7441
  [Quantile Score Mean]
    AUROC: 0.5002
    AUPRC: 0.6685



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


Aggregated scores (max) head:


Unnamed: 0,variant_id,label,raw_score_max,quantile_score_max
0,chrX:38297331:T>C,0,0.371268,0.997839
1,chr13:32337870:C>T,0,0.283936,0.997684
2,chr11:2847899:G>A,0,0.110233,0.997936
3,chr13:32339667:G>A,0,0.048962,0.989902
4,chr2:38071026:G>C,0,0.832705,0.999703
