SAE Embedding Extraction

Extract embeddings from structural variant calls across multiple datasets for SAE training.
This notebook processes multiple GIAB datasets, extracts sequences, and computes Evo2 embeddings.

In [None]:
# Install Evo2
# !pip install evo2

import sys
import os
import torch
import json
import pathlib
from pathlib import Path
import cyvcf2
import pysam
import random
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)

print(f"PyTorch version: {torch..__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# %%
# Configure paths
processed_dir = Path("../data/processed")
processed_dir.mkdir(parents=True, exist_ok=True)

# Environment variables for Evo2
os.environ['NVTE_DISABLE_FP8'] = '1'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

In [None]:
# Define datasets with appropriate reference genomes
DATASETS = {
    'HG002_GRCh37': {
        'tp_comp_vcf': '../data/raw/bench-HG002_GRCh37_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG002_GRCh37_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh37.fna'
    },
    'HG002_GRCh38': {
        'tp_comp_vcf': '../data/raw/bench-HG002_GRCh38-GIABv3_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG002_GRCh38-GIABv3_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh38.fa'
    },
    'HG005_GRCh38': {
        'tp_comp_vcf': '../data/raw/bench-HG005_GRCh38_noqc/tp-comp.vcf.gz',
        'fp_vcf': '../data/raw/bench-HG005_GRCh38_noqc/fp.vcf.gz',
        'ref': '../data/raw/GRCh38.fa'
    }
}

print("Environment configured")
print("Datasets:", list(DATASETS.keys()))

In [None]:
# Check if sequence files already exist
sequences_path = processed_dir / "sae_sequences.json"

if sequences_path.exists():
    print("Found existing sequence file, loading...")

    with open(sequences_path, 'r') as f:
        data = json.load(f)
        sequences = data['sequences']
        sv_info = data['sv_info']

    sequences_loaded = True
    print(f"Loaded {len(sequences)} sequences")

else:
    sequences_loaded = False
    print("No existing sequence file found - will extract sequences")

In [None]:
if not sequences_loaded:
    print("Categorizing structural variants across all datasets...")

    all_sv_categories = defaultdict(list)
    all_variant_labels = {}
    global_index = 0
    total_scanned = 0

    valid_chroms = {f'chr{i}' for i in range(1, 23)} | {'chrX', 'chrY'}

    for dataset_name, files in DATASETS.items():
        print(f"Processing dataset: {dataset_name}")

        vcf_files = {
            'TP': files['tp_comp_vcf'],
            'FP': files['fp_vcf']
        }

        for label, vcf_path in vcf_files.items():
            try:
                print(f"  Processing {label} file...")
                vcf = cyvcf2.VCF(vcf_path)

                for i, variant in enumerate(vcf):
                    total_scanned += 1

                    svlen = abs(variant.INFO.get('SVLEN', 0))
                    if svlen < 50:
                        global_index += 1
                        continue

                    if variant.CHROM not in valid_chroms:
                        global_index += 1
                        continue

                    svtype = variant.INFO.get('SVTYPE', 'UNK')

                    # Categorize by type and size
                    if svtype == 'INS':
                        category = 'INS_large' if svlen >= 1000 else 'INS_small'
                    elif svtype == 'DEL':
                        category = 'DEL_large' if svlen >= 1000 else 'DEL_small'
                    elif svtype in ['DUP', 'INV', 'TRA']:
                        category = svtype
                    else:
                        category = 'OTHER'

                    all_sv_categories[category].append((global_index, dataset_name))
                    all_variant_labels[global_index] = (label, vcf_path, i, dataset_name)
                    global_index += 1

                vcf.close()

            except FileNotFoundError:
                print(f"  Warning: {vcf_path} not found for {dataset_name}, skipping...")
                continue

    print(f"Scanned {total_scanned} total variants across all datasets")
    print(f"Valid variants: {sum(len(indices) for indices in all_sv_categories.values())}")

    print("\nBy SV type:")
    for category, indices in all_sv_categories.items():
        print(f"  {category}: {len(indices)} variants")

    # Count by class
    tp_count = sum(1 for idx in all_variant_labels.keys() if all_variant_labels[idx][0] == 'TP')
    fp_count = sum(1 for idx in all_variant_labels.keys() if all_variant_labels[idx][0] == 'FP')
    print(f"\nBy class: TP={tp_count}, FP={fp_count}")

In [None]:
def extract_sequences(indices, set_name):
    """Extract genomic sequences for given variant indices"""
    sequences = []
    sv_info = []
    window_size = 1000

    # Group indices by dataset and file
    file_indices = defaultdict(list)
    for idx in indices:
        label, vcf_path, file_idx, dataset_name = all_variant_labels[idx]
        file_indices[(vcf_path, dataset_name)].append((idx, file_idx))

    for (vcf_path, dataset_name), indices_list in file_indices.items():
        label = 'TP' if 'tp-comp' in vcf_path else 'FP'
        print(f"  Processing {len(indices_list)} variants from {label} in {dataset_name}")

        # Load reference genome for this dataset
        fasta = pysam.FastaFile(DATASETS[dataset_name]['ref'])
        vcf = cyvcf2.VCF(vcf_path)
        file_indices_dict = {file_idx: global_idx for global_idx, file_idx in indices_list}

        for i, variant in enumerate(vcf):
            if i not in file_indices_dict:
                continue

            svlen = abs(variant.INFO.get('SVLEN', 0))
            chrom = variant.CHROM
            pos = variant.POS
            start = max(0, pos - window_size // 2)
            end = pos + window_size // 2

            try:
                sequence = fasta.fetch(chrom, start, end).upper()

                # Quality filters
                if sequence.count('N') > len(sequence) * 0.1:
                    continue
                if len(sequence) < window_size * 0.8:
                    continue

                sequences.append(sequence)
                sv_info.append({
                    'chrom': chrom,
                    'pos': pos,
                    'end': pos + svlen,
                    'svlen': svlen,
                    'svtype': variant.INFO.get('SVTYPE', 'UNK'),
                    'id': variant.ID or f"{chrom}_{pos}",
                    'qual': variant.QUAL if variant.QUAL is not None else 60.0,
                    'truvari_class': label,
                    'dataset': dataset_name
                })

                if len(sequences) % 500 == 0:
                    print(f"    Processed {len(sequences)} sequences...")

            except Exception as e:
                continue

        vcf.close()
        fasta.close()

    return sequences, sv_info

# %%
if not sequences_loaded:
    print("Extracting all sequences from datasets...")

    # Get all variant indices
    all_indices = list(all_variant_labels.keys())
    sequences, sv_info = extract_sequences(all_indices, "all datasets")

    # Save sequences and metadata
    print("Saving sequences and metadata...")
    with open(sequences_path, 'w') as f:
        json.dump({'sequences': sequences, 'sv_info': sv_info}, f)


In [None]:
print(f"Total sequences: {len(sequences)}")

# Count by class and dataset
class_counts = defaultdict(int)
dataset_counts = defaultdict(lambda: defaultdict(int))

for sv in sv_info:
    class_counts[sv['truvari_class']] += 1
    dataset_counts[sv['dataset']][sv['truvari_class']] += 1

print(f"\nBy class:")
for class_name, count in class_counts.items():
    print(f"  {class_name}: {count}")

print(f"\nBy dataset and class:")
for dataset, counts in dataset_counts.items():
    total = sum(counts.values())
    print(f"  {dataset}: {total} total")
    for class_name, count in counts.items():
        print(f"    {class_name}: {count}")

# Count by SV type
type_counts = defaultdict(int)
for sv in sv_info:
    svtype, svlen = sv['svtype'], sv['svlen']
    category = 'INS_large' if svtype == 'INS' and svlen >= 1000 else 'INS_small' if svtype == 'INS' else \
               'DEL_large' if svtype == 'DEL' and svlen >= 1000 else 'DEL_small' if svtype == 'DEL' else svtype
    type_counts[category] += 1

print(f"\nBy SV type:")
for category, count in type_counts.items():
    print(f"  {category}: {count}")


In [None]:
# Load Evo2 model
evo2_available = False

try:
    from evo2 import Evo2
    print("Evo2 imported successfully")

    print("Loading Evo2 model...")
    evo2_model = Evo2('evo2_7b_base')

    # Test model with a short sequence
    test_sequence = 'ATCGATCGATCGATCG'
    input_ids = torch.tensor(
        evo2_model.tokenizer.tokenize(test_sequence),
        dtype=torch.int
    ).unsqueeze(0)

    if torch.cuda.is_available():
        input_ids = input_ids.cuda()

    with torch.no_grad():
        outputs, _ = evo2_model(input_ids)

    print(f"Evo2 working! Output shape: {outputs[0].shape}")
    evo2_available = True

except Exception as e:
    print(f"Could not load Evo2: {e}")
    evo2_available = False

In [None]:
class SV_Evo2_Encoder:
    """Extract Evo2 embeddings from layer 26 for SV sequences"""

    def __init__(self, evo2_wrapper):
        self.evo2 = evo2_wrapper
        self.evo2.model.eval()
        self.layer_26_activations = None
        self.hook_registered = False
        print("SV_Evo2_Encoder initialized")

    def _register_hook(self):
        """Register hook to capture layer 26 activations"""
        if self.hook_registered:
            return

        def hook_fn(module, input, output):
            # Handle tuple output - take the first element (hidden states)
            if isinstance(output, tuple):
                self.layer_26_activations = output[0].detach()
            else:
                self.layer_26_activations = output.detach()

        # Register hook on layer 26
        if hasattr(self.evo2.model, 'blocks') and len(self.evo2.model.blocks) > 26:
            self.evo2.model.blocks[26].register_forward_hook(hook_fn)
            self.hook_registered = True
            print("Registered hook for layer 26")
        else:
            print("Could not find layer 26, using output embeddings instead")

    def extract_embeddings_for_svs(self, sequences, batch_size=8):
        """Extract layer 26 embeddings for SV sequences"""
        self._register_hook()
        embeddings = []

        print(f"Extracting Evo2 layer 26 embeddings for {len(sequences)} sequences...")

        for i in range(0, len(sequences), batch_size):
            batch_sequences = sequences[i:i+batch_size]
            batch_embeddings = []

            for j, sequence in enumerate(batch_sequences):
                if (i + j) % 100 == 0:
                    print(f"  Processing {i + j + 1}/{len(sequences)}")

                # Tokenize sequence
                input_ids = torch.tensor(
                    self.evo2.tokenizer.tokenize(sequence),
                    dtype=torch.int
                ).unsqueeze(0)

                if torch.cuda.is_available():
                    input_ids = input_ids.cuda()

                # Reset hook storage
                self.layer_26_activations = None

                # Forward pass - hook will capture layer 26
                with torch.no_grad():
                    outputs, _ = self.evo2(input_ids)

                    # Use layer 26 activations if available, otherwise use outputs
                    if self.layer_26_activations is not None:
                        embeddings_tensor = self.layer_26_activations
                    else:
                        embeddings_tensor = outputs[0]  # Fallback to output

                    # Global average pooling to get fixed-size representation
                    pooled = embeddings_tensor.mean(dim=1).squeeze(0)
                    batch_embeddings.append(pooled.cpu())

            embeddings.extend(batch_embeddings)

        final_embeddings = torch.stack(embeddings)
        print(f"Extracted embeddings shape: {final_embeddings.shape}")
        return final_embeddings

In [None]:
if evo2_available:
    # Initialize encoder
    sv_encoder = SV_Evo2_Encoder(evo2_model)

    # Extract embeddings from all sequences
    print("Extracting Evo2 layer 26 embeddings for all sequences...")
    sv_embeddings = sv_encoder.extract_embeddings_for_svs(sequences, batch_size=4)

    # Save embeddings
    torch.save({
        'embeddings': sv_embeddings,
        'sv_info': sv_info,
        'layer': 26,
        'model': 'evo2_7b_base'
    }, processed_dir / "sae_sv_embeddings.pt")
    print(f"Saved embeddings: {sv_embeddings.shape}")

In [None]:
print("Embedding extraction complete!")
print(f"\nDataset summary:")
print(f"  Total sequences: {len(sequences)}")
print(f"  Total embeddings: {sv_embeddings.shape}")

print(f"\nFiles saved to {processed_dir}:")
print("  - sae_sequences.json (sequences and metadata)")
print("  - sae_sv_embeddings.pt (embeddings and metadata)")