In [6]:
! pip install pyfaidx



### Download human genome

In [7]:
! wget http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
! gunzip hg38.fa.gz

--2026-02-12 21:01:54--  http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
Resolving hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)... 128.114.119.163
Connecting to hgdownload.cse.ucsc.edu (hgdownload.cse.ucsc.edu)|128.114.119.163|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 983659424 (938M) [application/x-gzip]
Saving to: ‘hg38.fa.gz’


2026-02-12 21:02:09 (63.4 MB/s) - ‘hg38.fa.gz’ saved [983659424/983659424]

gzip: hg38.fa already exists; do you wish to overwrite (y or n)? y
y
^C


### Download clinVar nucleotide variant database

In [8]:
! wget https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar.vcf.gz

--2026-02-12 21:09:08--  https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar.vcf.gz
Resolving ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)... 130.14.250.7, 130.14.250.10, 130.14.250.11, ...
Connecting to ftp.ncbi.nlm.nih.gov (ftp.ncbi.nlm.nih.gov)|130.14.250.7|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 186444347 (178M) [application/x-gzip]
Saving to: ‘clinvar.vcf.gz.1’


2026-02-12 21:09:13 (39.0 MB/s) - ‘clinvar.vcf.gz.1’ saved [186444347/186444347]



In [9]:
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
import pandas as pd
import numpy as np
from sklearn.metrics import roc_curve, auc, roc_auc_score
import matplotlib.pyplot as plt
from pyfaidx import Fasta
import gzip
import subprocess
import os
import pickle
import json

### Predict human variants using PhyloGPN
Aiming for Figure 3a in Albors et al.

In [10]:
# Load PhyloGPN model and tokenizer
model_name = "songlab/PhyloGPN"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Using device: {device}")

# Load human reference genome
genome = Fasta('hg38.fa')  # Update path as needed

def get_sequence_context_fixed(chrom, pos, genome, context_size=240):
    """Get 481bp sequence centered at position with proper coordinate handling"""
    # Convert to 0-based indexing for Fasta
    pos_0based = pos - 1

    start = max(0, pos_0based - context_size)
    end = pos_0based + context_size + 1

    # Extract sequence
    seq = str(genome[chrom][start:end]).upper()

    # Verify the center nucleotide
    expected_center_idx = len(seq) // 2
    actual_ref = seq[expected_center_idx] if expected_center_idx < len(seq) else 'N'

    # Pad if necessary to get exactly 481bp
    if len(seq) < 481:
        pad_needed = 481 - len(seq)
        pad_left = pad_needed // 2
        pad_right = pad_needed - pad_left
        seq = 'N' * pad_left + seq + 'N' * pad_right

        # Update center index after padding
        expected_center_idx += pad_left
    elif len(seq) > 481:
        # Trim to exactly 481bp, keeping center
        excess = len(seq) - 481
        start_trim = excess // 2
        seq = seq[start_trim:start_trim + 481]
        expected_center_idx -= start_trim

    # Final center nucleotide after all processing
    final_center = seq[240]  # Should always be position 240 in 481bp sequence

    return seq, final_center

def compute_phylogpn_score_reliable(model, tokenizer, sequence, ref_nucleotide, alt_nucleotide, device):
    """Compute PhyloGPN score with reliable F81 interpretation"""

    # Verify the center nucleotide matches what we expect
    center_idx = len(sequence) // 2
    actual_center = sequence[center_idx]

    if actual_center != ref_nucleotide:
        print(f"Warning: Center nucleotide is {actual_center}, expected {ref_nucleotide}")

    # Get model prediction for the reference sequence
    inputs = tokenizer(sequence, return_tensors="pt", max_length=512, truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model(**inputs)

        # Extract θ parameters for each nucleotide
        theta_A = output['A'][0, 0].item()
        theta_C = output['C'][0, 0].item()
        theta_G = output['G'][0, 0].item()
        theta_T = output['T'][0, 0].item()

    # Create mapping
    nucleotide_theta = {
        'A': theta_A,
        'C': theta_C,
        'G': theta_G,
        'T': theta_T
    }

    # Method that showed best performance: max comparison
    # Higher score = more pathogenic
    all_thetas = [theta_A, theta_C, theta_G, theta_T]
    max_theta = max(all_thetas)
    alt_theta = nucleotide_theta[alt_nucleotide]

    pathogenicity_score = max_theta - alt_theta

    # Also compute F81 stationary probability ratio for comparison
    theta_tensor = torch.tensor(all_thetas, device=device)
    probs = F.softmax(theta_tensor, dim=0)
    prob_dict = dict(zip(['A', 'C', 'G', 'T'], probs.tolist()))

    # Add small epsilon to avoid log(0)
    epsilon = 1e-10
    ref_prob = max(prob_dict[ref_nucleotide], epsilon)
    alt_prob = max(prob_dict[alt_nucleotide], epsilon)

    # Log likelihood ratio
    llr = np.log(alt_prob) - np.log(ref_prob)
    # Convert to pathogenicity score (negative because lower prob = more pathogenic)
    prob_pathogenicity_score = -llr

    return {
        'max_comparison': pathogenicity_score,
        'prob_ratio': prob_pathogenicity_score,
        'theta_values': nucleotide_theta,
        'probabilities': prob_dict,
        'center_match': actual_center == ref_nucleotide
    }

# Load ClinVar data with better filtering
def load_clinvar_variants_filtered(vcf_path, max_variants=None):
    """Load ClinVar variants with better filtering"""
    variants = []

    open_func = gzip.open if vcf_path.endswith('.gz') else open

    with open_func(vcf_path, 'rt') as f:
        for line_num, line in enumerate(f):
            if line.startswith('#'):
                continue

            fields = line.strip().split('\t')
            if len(fields) < 8:
                continue

            chrom = fields[0]
            pos = int(fields[1])
            ref = fields[3]
            alt = fields[4]
            info = fields[7]

            # Skip multi-allelic variants and non-SNVs
            if ',' in alt or len(ref) != 1 or len(alt) != 1 or ref == alt:
                continue

            # Parse CLNSIG
            clnsig = None
            for item in info.split(';'):
                if item.startswith('CLNSIG='):
                    clnsig = item.split('=')[1]
                    break

            if not clnsig:
                continue

            # More strict classification
            clnsig_lower = clnsig.lower()

            # Only include high-confidence classifications
            if ('pathogenic' in clnsig_lower and 'likely_pathogenic' not in clnsig_lower and
                'benign' not in clnsig_lower and 'uncertain' not in clnsig_lower):
                label = 1  # Pathogenic
            elif ('benign' in clnsig_lower and 'likely_benign' not in clnsig_lower and
                  'pathogenic' not in clnsig_lower and 'uncertain' not in clnsig_lower):
                label = 0  # Benign
            else:
                continue

            variants.append({
                'chrom': chrom,
                'pos': pos,
                'ref': ref,
                'alt': alt,
                'label': label,
                'clnsig': clnsig
            })

            if max_variants and len(variants) >= max_variants:
                break

    return pd.DataFrame(variants)

# Compute scores for all variants
def compute_scores_for_variants_improved(variants_df, model, tokenizer, genome, device):
    """Compute pathogenicity scores with improved coordinate handling"""
    max_comp_scores = []
    prob_ratio_scores = []
    labels = []
    coordinate_issues = 0

    for idx, row in variants_df.iterrows():
        try:
            # Get sequence context with coordinate verification
            chrom = row['chrom'] if row['chrom'].startswith('chr') else f"chr{row['chrom']}"
            sequence, actual_center = get_sequence_context_fixed(chrom, row['pos'], genome)

            # Skip if coordinate doesn't match
            if actual_center != row['ref']:
                coordinate_issues += 1
                print(f"Coordinate issue at {chrom}:{row['pos']} - expected {row['ref']}, got {actual_center}")
                continue

            # Compute scores
            results = compute_phylogpn_score_reliable(model, tokenizer, sequence, row['ref'], row['alt'], device)

            max_comp_scores.append(results['max_comparison'])
            prob_ratio_scores.append(results['prob_ratio'])
            labels.append(row['label'])

        except Exception as e:
            print(f"Error processing variant at {row['chrom']}:{row['pos']}: {e}")
            continue

        # Print progress
        if (idx + 1) % 100 == 0:
            print(f"Processed {idx + 1}/{len(variants_df)} variants")

    print(f"Coordinate issues encountered: {coordinate_issues}/{len(variants_df)}")

    return {
        'max_comp_scores': np.array(max_comp_scores),
        'prob_ratio_scores': np.array(prob_ratio_scores),
        'labels': np.array(labels)
    }

# Plot ROC curve
def plot_comparison_roc(results):
    """Plot ROC curves for both methods"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    # Max comparison method
    fpr1, tpr1, _ = roc_curve(results['labels'], results['max_comp_scores'])
    auc1 = auc(fpr1, tpr1)

    ax1.plot(fpr1, tpr1, color='blue', lw=2, label=f'Max Comparison (AUC = {auc1:.3f})')
    ax1.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title('Max Comparison Method')
    ax1.legend(loc="lower right")
    ax1.grid(True, alpha=0.3)

    # Probability ratio method
    fpr2, tpr2, _ = roc_curve(results['labels'], results['prob_ratio_scores'])
    auc2 = auc(fpr2, tpr2)

    ax2.plot(fpr2, tpr2, color='red', lw=2, label=f'Probability Ratio (AUC = {auc2:.3f})')
    ax2.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('Probability Ratio Method')
    ax2.legend(loc="lower right")
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(f"Max Comparison AUC: {auc1:.3f}")
    print(f"Probability Ratio AUC: {auc2:.3f}")

    return auc1, auc2

# Main execution
if __name__ == "__main__":
    # Load ClinVar variants
    clinvar_vcf_path = "clinvar.vcf.gz"
    print("Loading ClinVar variants...")

    # Use more variants with stricter filtering
    # variants_df = load_clinvar_variants_filtered(clinvar_vcf_path, max_variants=5000)
    variants_df = load_clinvar_variants_paper_protocol(clinvar_vcf_path, max_variants=20000)[10000:20000]

    print(f"Loaded {len(variants_df)} ClinVar variants")
    print(f"Pathogenic: {sum(variants_df['label'])}, Benign: {len(variants_df) - sum(variants_df['label'])}")

    # Compute scores with improved coordinate handling
    print("Computing scores with coordinate verification...")
    results = compute_scores_for_variants_improved(variants_df, model, tokenizer, genome, device)

    print(f"Successfully computed scores for {len(results['labels'])} variants")

    # Plot comparison
    auc1, auc2 = plot_comparison_roc(results)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/376 [00:00<?, ?B/s]

configuration_phylogpn.py:   0%|          | 0.00/472 [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/songlab/PhyloGPN:
- configuration_phylogpn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json:   0%|          | 0.00/873 [00:00<?, ?B/s]

tokenization_phylogpn.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/songlab/PhyloGPN:
- tokenization_phylogpn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


special_tokens_map.json:   0%|          | 0.00/43.0 [00:00<?, ?B/s]

modeling_phylogpn.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/songlab/PhyloGPN:
- modeling_phylogpn.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/333M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/485 [00:00<?, ?it/s]

AttributeError: 'PhyloGPNModel' object has no attribute 'all_tied_weights_keys'

## Transfer the model to the mouse genome

Download the liftOver tool.

In [None]:
! wget http://hgdownload.soe.ucsc.edu/admin/exe/linux.x86_64/liftOver

# Make it executable
! chmod +x liftOver

! gunzip hg38ToMm10.over.chain.gz

Download the mouse genome

In [None]:
! wget http://hgdownload.soe.ucsc.edu/goldenPath/mm10/bigZips/mm10.fa.gz

In [None]:
! gunzip mm10.fa.gz

Download the human-mouse genome mapping



In [None]:
! wget http://hgdownload.soe.ucsc.edu/goldenPath/hg38/liftOver/hg38ToMm10.over.chain.gz

Predict variants on the mouse genome.

In [None]:
# Load PhyloGPN model
model_name = "songlab/PhyloGPN"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"Using device: {device}")

# Load genomes - only need mouse genome
mouse_genome = Fasta('mm10.fa')

def create_bed_file(variants_df, filename):
    """Create BED file from variants for liftOver"""
    with open(filename, 'w') as f:
        for _, row in variants_df.iterrows():
            chrom = row['chrom'] if row['chrom'].startswith('chr') else f"chr{row['chrom']}"
            # BED format: chrom start end name
            f.write(f"{chrom}\t{row['pos']-1}\t{row['pos']}\t{row['chrom']}:{row['pos']}:{row['ref']}>{row['alt']}\n")

def liftover_coordinates(bed_file, chain_file, output_file):
    """Use UCSC liftOver to convert coordinates"""
    # Try different possible locations for liftOver
    liftover_paths = ['liftOver', './liftOver', '/usr/local/bin/liftOver']

    liftover_cmd = None
    for path in liftover_paths:
        try:
            subprocess.run([path], capture_output=True)
            liftover_cmd = path
            break
        except FileNotFoundError:
            continue

    if liftover_cmd is None:
        print("liftOver not found. Please download it:")
        print("wget http://hgdownload.soe.ucsc.edu/admin/exe/linux.x86_64/liftOver")
        print("chmod +x liftOver")
        return False

    cmd = [liftover_cmd, bed_file, chain_file, output_file, 'unmapped.bed']
    result = subprocess.run(cmd, capture_output=True, text=True)
    return result.returncode == 0

def parse_liftover_results(lifted_file):
    """Parse liftOver results and create mapping dictionary"""
    mapping = {}

    if not os.path.exists(lifted_file):
        return mapping

    with open(lifted_file, 'r') as f:
        for line in f:
            fields = line.strip().split('\t')
            if len(fields) >= 4:
                mouse_chrom = fields[0]
                mouse_pos = int(fields[2])  # End position in BED
                variant_id = fields[3]

                # Extract human info from variant ID
                parts = variant_id.split(':')
                human_chrom = parts[0]
                human_pos = int(parts[1])
                ref_alt = parts[2]
                ref, alt = ref_alt.split('>')

                mapping[(human_chrom, human_pos)] = (mouse_chrom, mouse_pos, ref, alt)

    return mapping

def get_sequence_context(genome, chrom, pos, context_size=240):
    """Get 481bp sequence centered at position"""
    # Convert to 0-based indexing
    pos_0based = pos - 1

    start = max(0, pos_0based - context_size)
    end = pos_0based + context_size + 1

    # Extract sequence
    seq = str(genome[chrom][start:end]).upper()

    # Pad if necessary to get exactly 481bp
    if len(seq) < 481:
        pad_needed = 481 - len(seq)
        pad_left = pad_needed // 2
        pad_right = pad_needed - pad_left
        seq = 'N' * pad_left + seq + 'N' * pad_right
    elif len(seq) > 481:
        # Trim to exactly 481bp, keeping center
        excess = len(seq) - 481
        start_trim = excess // 2
        seq = seq[start_trim:start_trim + 481]

    return seq

def compute_phylogpn_llr(model, tokenizer, sequence, ref_nucleotide, alt_nucleotide, device):
    """Compute LLR following PhyloGPN paper protocol"""
    inputs = tokenizer(sequence, return_tensors="pt", max_length=512, truncation=True, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model(**inputs)

        # Extract θ parameters
        theta_A = output['A'][0, 0].item()
        theta_C = output['C'][0, 0].item()
        theta_G = output['G'][0, 0].item()
        theta_T = output['T'][0, 0].item()

    # Compute stationary probabilities
    theta_tensor = torch.tensor([theta_A, theta_C, theta_G, theta_T], device=device)
    pi = F.softmax(theta_tensor, dim=0)

    # Get probabilities for ref and alt
    nucleotide_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
    ref_idx = nucleotide_to_idx[ref_nucleotide]
    alt_idx = nucleotide_to_idx[alt_nucleotide]

    ref_prob = pi[ref_idx].item()
    alt_prob = pi[alt_idx].item()

    # Add epsilon to avoid log(0)
    epsilon = 1e-10
    ref_prob = max(ref_prob, epsilon)
    alt_prob = max(alt_prob, epsilon)

    # LLR = log(P(alt)) - log(P(ref))
    llr = np.log(alt_prob) - np.log(ref_prob)

    return llr

def load_clinvar_variants(vcf_path, max_variants=1000):
    """Load ClinVar variants"""
    variants = []

    open_func = gzip.open if vcf_path.endswith('.gz') else open

    with open_func(vcf_path, 'rt') as f:
        for line_num, line in enumerate(f):
            if line.startswith('#'):
                continue

            fields = line.strip().split('\t')
            if len(fields) < 8:
                continue

            chrom = fields[0]
            pos = int(fields[1])
            ref = fields[3]
            alt = fields[4]
            info = fields[7]

            # Skip multi-allelic variants and non-SNVs
            if ',' in alt or len(ref) != 1 or len(alt) != 1 or ref == alt:
                continue

            # Parse CLNSIG
            clnsig = None
            for item in info.split(';'):
                if item.startswith('CLNSIG='):
                    clnsig = item.split('=')[1]
                    break

            if not clnsig:
                continue

            # Parse star rating
            review_status = "no_assertion"
            for item in info.split(';'):
                if item.startswith('CLNREVSTAT='):
                    review_status = item.split('=')[1]
                    break

            # Require at least one star (skip "no assertion criteria")
            if "no_assertion" in review_status:
                continue

            # Simple classification
            clnsig_lower = clnsig.lower()
            if ('pathogenic' in clnsig_lower and 'benign' not in clnsig_lower) or \
               ('likely_pathogenic' in clnsig_lower):
                label = 1
            elif ('benign' in clnsig_lower and 'pathogenic' not in clnsig_lower) or \
                 ('likely_benign' in clnsig_lower):
                label = 0
            else:
                continue

            variants.append({
                'chrom': chrom,
                'pos': pos,
                'ref': ref,
                'alt': alt,
                'label': label,
                'review': review_status
            })

            if max_variants and len(variants) >= max_variants:
                break

    return pd.DataFrame(variants)

def test_mouse_approaches(variants_df, model, tokenizer, mouse_genome, device, save_file=None):
    """Test both approaches on mouse sequences only"""

    print("Using liftOver for coordinate mapping...")

    # Create BED file for liftOver
    create_bed_file(variants_df, 'variants.bed')

    # Run liftOver
    if liftover_coordinates('variants.bed', 'hg38ToMm10.over.chain', 'variants_mm10.bed'):
        mapping = parse_liftover_results('variants_mm10.bed')
        print(f"Successfully mapped {len(mapping)} out of {len(variants_df)} variants using liftOver")
    else:
        print("liftOver failed!")
        return None, None, None

    mouse_approach1_llrs = []  # Human ref>alt on mouse sequence
    mouse_approach2_llrs = []  # Mouse center>human alt
    labels = []
    variant_info = []

    mapped_variants = 0

    for idx, row in variants_df.iterrows():
        try:
            # Skip if not mapped
            key = (row['chrom'], row['pos'])
            if key not in mapping:
                continue

            # Get mouse coordinates and sequence
            mouse_chrom, mouse_pos, human_ref, human_alt = mapping[key]
            mouse_seq = get_sequence_context(mouse_genome, mouse_chrom, mouse_pos)
            mouse_center = mouse_seq[240]

            # Only proceed if mouse center is a valid nucleotide
            if mouse_center not in ['A', 'C', 'G', 'T']:
                continue

            # Approach 1: Human ref>alt on mouse sequence
            mouse_approach1_llr = compute_phylogpn_llr(model, tokenizer, mouse_seq, human_ref, human_alt, device)

            # Approach 2: Mouse center>human alt
            mouse_approach2_llr = compute_phylogpn_llr(model, tokenizer, mouse_seq, mouse_center, human_alt, device)

            mouse_approach1_llrs.append(mouse_approach1_llr)
            mouse_approach2_llrs.append(mouse_approach2_llr)
            labels.append(row['label'])

            # Store additional info for analysis
            variant_info.append({
                'human_chrom': row['chrom'],
                'human_pos': row['pos'],
                'human_ref': human_ref,
                'human_alt': human_alt,
                'mouse_chrom': mouse_chrom,
                'mouse_pos': mouse_pos,
                'mouse_center': mouse_center,
                'label': row['label'],
                'mouse_approach1_llr': mouse_approach1_llr,
                'mouse_approach2_llr': mouse_approach2_llr
            })

            mapped_variants += 1

        except Exception as e:
            if 'not found' not in str(e):
                print(f"Error processing variant {row['chrom']}:{row['pos']}: {e}")
            continue

        if (idx + 1) % 500 == 0:
            print(f"Processed {idx + 1}/{len(variants_df)} variants, mapped {mapped_variants}")

    print(f"Final stats: {mapped_variants} mapped variants")

    # Convert to numpy arrays
    mouse_approach1_llrs = np.array(mouse_approach1_llrs)
    mouse_approach2_llrs = np.array(mouse_approach2_llrs)
    labels = np.array(labels)

    # Save results if requested
    if save_file:
        # Save numpy arrays
        np.savez(save_file,
                 mouse_approach1_llrs=mouse_approach1_llrs,
                 mouse_approach2_llrs=mouse_approach2_llrs,
                 labels=labels)

        # Save detailed info as pickle
        with open(f"{save_file}_details.pkl", 'wb') as f:
            pickle.dump(variant_info, f)

        # Also save a JSON version for easier sharing
        with open(f"{save_file}_details.json", 'w') as f:
            json.dump(variant_info, f, indent=2)

        print(f"Saved results to {save_file}")

    return mouse_approach1_llrs, mouse_approach2_llrs, labels, variant_info

def plot_mouse_comparison_roc(mouse_approach1_llrs, mouse_approach2_llrs, labels, save_plot=None):
    """Plot ROC curves comparing only mouse approaches"""

    # Calculate ROC curves
    fpr_mouse1, tpr_mouse1, _ = roc_curve(labels, -mouse_approach1_llrs)
    auc_mouse1 = auc(fpr_mouse1, tpr_mouse1)

    fpr_mouse2, tpr_mouse2, _ = roc_curve(labels, -mouse_approach2_llrs)
    auc_mouse2 = auc(fpr_mouse2, tpr_mouse2)

    # Create plot
    plt.figure(figsize=(10, 8))

    plt.plot(fpr_mouse1, tpr_mouse1, color='red', lw=2,
             label=f'Approach 1: Human ref>alt (AUC = {auc_mouse1:.3f})')
    plt.plot(fpr_mouse2, tpr_mouse2, color='green', lw=2,
             label=f'Approach 2: Mouse center>alt (AUC = {auc_mouse2:.3f})')

    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('PhyloGPN Performance on Mouse Sequences: Comparison of Approaches')
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)

    # Add explanatory text
    plt.text(0.05, 0.3,
             "Approach 1: Human ref>alt on mouse sequence\n"
             "Approach 2: Mouse center>human alt",
             bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})

    plt.tight_layout()

    # Save plot if requested
    if save_plot:
        plt.savefig(save_plot, dpi=300, bbox_inches='tight')
        print(f"Saved plot to {save_plot}")

    plt.show()

    # Print detailed statistics
    print("\n=== MOUSE APPROACHES COMPARISON ===")
    print(f"Approach 1 (Human ref>alt) AUC: {auc_mouse1:.3f}")
    print(f"Approach 2 (Mouse center>alt) AUC: {auc_mouse2:.3f}")

    # Calculate correlation between approaches
    corr_approaches = np.corrcoef(mouse_approach1_llrs, mouse_approach2_llrs)[0, 1]
    print(f"Correlation between approaches: {corr_approaches:.3f}")

    print(f"\nPathogenic variants (n={sum(labels == 1)}):")
    print(f"  Approach 1: {mouse_approach1_llrs[labels == 1].mean():.3f} ± {mouse_approach1_llrs[labels == 1].std():.3f}")
    print(f"  Approach 2: {mouse_approach2_llrs[labels == 1].mean():.3f} ± {mouse_approach2_llrs[labels == 1].std():.3f}")

    print(f"Benign variants (n={sum(labels == 0)}):")
    print(f"  Approach 1: {mouse_approach1_llrs[labels == 0].mean():.3f} ± {mouse_approach1_llrs[labels == 0].std():.3f}")
    print(f"  Approach 2: {mouse_approach2_llrs[labels == 0].mean():.3f} ± {mouse_approach2_llrs[labels == 0].std():.3f}")

    return auc_mouse1, auc_mouse2, corr_approaches

# Main execution
if __name__ == "__main__":
    print("Comparing PhyloGPN approaches on mouse sequences...")

    # Load ClinVar variants
    clinvar_vcf_path = "clinvar.vcf.gz"
    variants_df = load_clinvar_variants(clinvar_vcf_path, max_variants=20000)[10000:20000]

    print(f"Loaded {len(variants_df)} variants")
    print(f"Pathogenic: {sum(variants_df['label'])}, Benign: {len(variants_df) - sum(variants_df['label'])}")

    # Test both approaches on mouse sequences only
    results = test_mouse_approaches(
        variants_df, model, tokenizer, mouse_genome, device,
        save_file="phylogpn_mouse_results"  # Save results to this file
    )

    if results[0] is not None:
        mouse_approach1_llrs, mouse_approach2_llrs, labels, variant_info = results

        # Plot ROC comparison
        auc_mouse1, auc_mouse2, corr = plot_mouse_comparison_roc(
            mouse_approach1_llrs, mouse_approach2_llrs, labels,
            save_plot="phylogpn_mouse_roc.png"  # Save plot to this file
        )

        # Interpretation
        print("\n=== INTERPRETATION ===")
        if auc_mouse2 > auc_mouse1:
            print(f"Approach 2 (Mouse center>alt, AUC={auc_mouse2:.3f}) outperforms")
            print(f"Approach 1 (Human ref>alt, AUC={auc_mouse1:.3f}) by {(auc_mouse2-auc_mouse1):.3f}")
            print("\nThis suggests that:")
            print("1. The model is sensitive to the actual nucleotide at the center position")
            print("2. Evolutionary divergence between species matters for prediction")
        else:
            print(f"Approach 1 (Human ref>alt, AUC={auc_mouse1:.3f}) outperforms")
            print(f"Approach 2 (Mouse center>alt, AUC={auc_mouse2:.3f}) by {(auc_mouse1-auc_mouse2):.3f}")
            print("\nThis suggests that:")
            print("1. The model captures conservation patterns that transfer across species")
            print("2. PhyloGPN has learned evolutionary constraints independent of local context")
    else:
        print("Could not complete the analysis")