VICReg latent extraction and analysis

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import h5py
import json
from tqdm import tqdm
from datetime import datetime
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import warnings
warnings.filterwarnings('ignore')

# Configuration - GitHub repository structure
DATA_DIR = '../data/processed/all_datasets_images_rgb'
MODELS_DIR = '../data/processed/vicreg_experiments/models'
SAVE_DIR = '../data/processed/vicreg_latents'

# Create save directory
os.makedirs(SAVE_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

print(f"Save directory: {SAVE_DIR}")

In [None]:
# Load best model

def find_champion_vicreg_model():
    """Find the best performing VICReg model automatically"""

    print("Searching for champion VICReg model...")

    if not os.path.exists(MODELS_DIR):
        print(f"Models directory not found: {MODELS_DIR}")
        print("Run VICReg training first!")
        return None, None

    best_acc = 0
    best_path = None
    best_info = None

    # Search all saved models
    for model_dir in os.listdir(MODELS_DIR):
        model_path = os.path.join(MODELS_DIR, model_dir)
        checkpoint_path = os.path.join(model_path, 'best_vicreg_model.pth')

        if os.path.exists(checkpoint_path):
            try:
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

                # Only consider binary models for feature extraction
                if checkpoint['num_classes'] == 2:
                    acc = checkpoint['best_test_acc']
                    if acc > best_acc:
                        best_acc = acc
                        best_path = checkpoint_path
                        best_info = {
                            'name': checkpoint['experiment_name'],
                            'accuracy': acc,
                            'auc': checkpoint.get('final_test_auc', 0),
                            'path': checkpoint_path
                        }
            except Exception as e:
                print(f"Error loading {model_dir}: {str(e)[:50]}...")
                continue

    if best_path is None:
        print("No binary VICReg models found!")
        return None, None

    print(f"Champion VICReg model found:")
    print(f"   Name: {best_info['name']}")
    print(f"   Accuracy: {best_info['accuracy']:.2f}%")
    print(f"   AUC: {best_info['auc']:.3f}")
    print(f"   Path: {best_path}")

    return best_path, best_info

def load_champion_vicreg():
    """Load the champion VICReg model for feature extraction"""

    champion_path, champion_info = find_champion_vicreg_model()
    if champion_path is None:
        return None, None

    print(f"Loading champion VICReg...")

    # Load checkpoint
    checkpoint = torch.load(champion_path, map_location='cpu', weights_only=False)

    # Recreate VICReg class (same as training)
    class VICRegResNet(nn.Module):
        def __init__(self, num_classes, dropout=0.2):
            super().__init__()
            self.num_classes = num_classes

            print("Loading Facebook's VICReg ResNet50x2...")
            try:
                self.backbone = torch.hub.load('facebookresearch/vicreg:main', 'resnet50x2', pretrained=True)
                print("   Successfully loaded VICReg ResNet50x2")
            except Exception as e:
                print(f"   Failed to load VICReg model: {e}")
                print("   Falling back to standard Wide ResNet50-2...")
                self.backbone = torchvision.models.wide_resnet50_2(weights='IMAGENET1K_V1')

            feature_dim = 4096  # VICReg ResNet50x2 feature dimension

            # Remove original classifier
            if hasattr(self.backbone, 'fc'):
                self.backbone.fc = nn.Identity()
            elif hasattr(self.backbone, 'head'):
                self.backbone.head = nn.Identity()

            # Freeze backbone
            for param in self.backbone.parameters():
                param.requires_grad = False

            # Classifier
            self.classifier = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(feature_dim, num_classes)
            )

        def forward(self, x):
            with torch.no_grad():
                features = self.backbone(x)
            return self.classifier(features)

    # Recreate and load model
    model = VICRegResNet(
        checkpoint['num_classes'],
        checkpoint['model_config']['dropout_rate']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Extract backbone for feature extraction
    feature_extractor = model.backbone
    feature_extractor.to(device)
    feature_extractor.eval()

    print(f"Loaded VICReg backbone for feature extraction")

    return feature_extractor, champion_info


In [None]:
# Dataset class and data loading / feature extraction
class FeatureExtractionDataset(Dataset):
    """Dataset for feature extraction with metadata tracking"""

    def __init__(self, file_info_list, transform=None):
        self.file_info = file_info_list
        self.transform = transform

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        info = self.file_info[idx]

        try:
            # Load image
            data = torch.load(info['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels (ensure RGB)
            if image.shape[0] != 3:
                if image.shape[0] < 3:
                    padding = torch.zeros(3 - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:3]

            # Normalize to [0,1]
            if image.max() > 1:
                image = image.float() / 255.0

            success = True

        except Exception as e:
            # Fallback for corrupted files
            image = torch.zeros(3, 224, 224)
            success = False

        if self.transform:
            image = self.transform(transforms.ToPILImage()(image))

        return image, info, success

def load_all_file_info():
    """Load info for ALL genomic data files"""

    print("Scanning all genomic data files...")

    datasets = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']
    all_file_info = []

    for dataset_name in datasets:
        dataset_path = os.path.join(DATA_DIR, dataset_name)

        if not os.path.exists(dataset_path):
            print(f"Missing dataset: {dataset_path}")
            continue

        print(f"Scanning {dataset_name}...")
        filenames = [f for f in os.listdir(dataset_path) if f.endswith('.pt')]

        for filename in tqdm(filenames, desc=f"Processing {dataset_name}"):
            parts = filename[:-3].split('_')

            if len(parts) >= 8:
                try:
                    label_str = parts[2]  # TP or FP
                    svtype = parts[6]     # INS, DEL, etc.

                    if label_str in ['TP', 'FP']:
                        all_file_info.append({
                            'filename': filename,
                            'filepath': os.path.join(dataset_path, filename),
                            'dataset': dataset_name,
                            'label_str': label_str,
                            'svtype': svtype,
                            'binary_label': 1 if label_str == 'TP' else 0,
                            'multiclass_label': 0 if label_str == 'FP' else (1 if svtype == 'DEL' else 2)
                        })

                except (ValueError, IndexError):
                    continue

    print(f"Found {len(all_file_info)} total files")

    # Show breakdown
    datasets_count = {}
    labels_count = {}

    for info in all_file_info:
        datasets_count[info['dataset']] = datasets_count.get(info['dataset'], 0) + 1
        labels_count[info['label_str']] = labels_count.get(info['label_str'], 0) + 1

    print("Dataset breakdown:")
    for dataset, count in datasets_count.items():
        print(f"   {dataset}: {count:,} files")

    print("Label breakdown:")
    for label, count in labels_count.items():
        print(f"   {label}: {count:,} files")

    return all_file_info


def extract_and_save_vicreg_features():
    """Extract and save VICReg features for all data samples"""

    print("Starting comprehensive VICReg feature extraction...")

    # Load champion model
    feature_extractor, champion_info = load_champion_vicreg()
    if feature_extractor is None:
        print("Failed to load champion model!")
        return None, None, None

    # Load all file info
    all_file_info = load_all_file_info()

    # Create transform (same as used in training)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.458, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create dataset
    dataset = FeatureExtractionDataset(all_file_info, transform)

    # Custom collate function to handle metadata
    def collate_fn(batch):
        images = torch.stack([item[0] for item in batch])
        infos = [item[1] for item in batch]  # Keep as list of dicts
        successes = torch.tensor([item[2] for item in batch])
        return images, infos, successes

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn
    )

    # VICReg ResNet50x2 outputs 4096-dimensional features
    feature_dim = 4096

    # Prepare storage
    total_samples = len(all_file_info)
    features_array = np.zeros((total_samples, feature_dim), dtype=np.float32)
    metadata_list = []

    print(f"Extracting features for {total_samples:,} samples...")
    print(f"Feature dimension: {feature_dim}")

    sample_idx = 0
    failed_count = 0

    with torch.no_grad():
        for batch_images, batch_info, batch_success in tqdm(dataloader, desc="Extracting features"):
            batch_images = batch_images.to(device)

            # Extract features
            batch_features = feature_extractor(batch_images)
            batch_features_np = batch_features.cpu().numpy()

            # Handle variable batch sizes
            actual_batch_size = batch_features_np.shape[0]
            features_array[sample_idx:sample_idx + actual_batch_size] = batch_features_np

            # Process metadata
            for i, (info, success) in enumerate(zip(batch_info, batch_success)):
                if i >= actual_batch_size:  # Safety check
                    break

                metadata_list.append({
                    'sample_idx': sample_idx + i,
                    'filename': info['filename'],
                    'dataset': info['dataset'],
                    'label_str': info['label_str'],
                    'svtype': info['svtype'],
                    'binary_label': info['binary_label'],
                    'multiclass_label': info['multiclass_label'],
                    'extraction_success': success.item(),
                    'filepath': info['filepath']
                })

                if not success.item():
                    failed_count += 1

            sample_idx += actual_batch_size

    print(f"Feature extraction complete!")
    print(f"   Successfully processed: {total_samples - failed_count:,}/{total_samples:,} samples")
    print(f"   Failed files: {failed_count:,}")

    # Save features to HDF5
    features_file = os.path.join(SAVE_DIR, 'vicreg_features.h5')
    print(f"Saving features to: {features_file}")

    with h5py.File(features_file, 'w') as f:
        f.create_dataset('features', data=features_array, compression='gzip')
        f.attrs['champion_model'] = champion_info['name']
        f.attrs['champion_path'] = champion_info['path']
        f.attrs['champion_accuracy'] = champion_info['accuracy']
        f.attrs['champion_auc'] = champion_info['auc']
        f.attrs['feature_dim'] = feature_dim
        f.attrs['total_samples'] = total_samples
        f.attrs['extraction_date'] = datetime.now().isoformat()

    # Save metadata to CSV
    metadata_file = os.path.join(SAVE_DIR, 'vicreg_metadata.csv')
    print(f"Saving metadata to: {metadata_file}")

    metadata_df = pd.DataFrame(metadata_list)
    metadata_df.to_csv(metadata_file, index=False)

    # Create summary
    summary = {
        'extraction_date': datetime.now().isoformat(),
        'champion_model': {
            'name': champion_info['name'],
            'accuracy': champion_info['accuracy'],
            'auc': champion_info['auc'],
            'path': champion_info['path']
        },
        'data_summary': {
            'total_samples': int(total_samples),
            'successful_extractions': int(total_samples - failed_count),
            'failed_extractions': int(failed_count),
            'feature_dimension': int(feature_dim)
        },
        'files': {
            'features_file': features_file,
            'metadata_file': metadata_file,
            'features_size_mb': round(features_array.nbytes / 1024**2, 1)
        },
        'data_breakdown': {
            'datasets': {k: int(v) for k, v in metadata_df['dataset'].value_counts().items()},
            'labels': {k: int(v) for k, v in metadata_df['label_str'].value_counts().items()},
            'sv_types': {k: int(v) for k, v in metadata_df['svtype'].value_counts().items()}
        }
    }

    summary_file = os.path.join(SAVE_DIR, 'vicreg_extraction_summary.json')
    print(f"Saving summary to: {summary_file}")

    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\nVICREG FEATURE EXTRACTION COMPLETE!")
    print(f"Files saved in: {SAVE_DIR}")
    print(f"   Features: vicreg_features.h5 ({summary['files']['features_size_mb']} MB)")
    print(f"   Metadata: vicreg_metadata.csv")
    print(f"   Summary: vicreg_extraction_summary.json")

    return features_array, metadata_df, summary

In [None]:
# Utility functions for loading saved latents

def load_saved_vicreg_features():
    """Load previously saved VICReg features"""

    features_file = os.path.join(SAVE_DIR, 'vicreg_features.h5')
    metadata_file = os.path.join(SAVE_DIR, 'vicreg_metadata.csv')
    summary_file = os.path.join(SAVE_DIR, 'vicreg_extraction_summary.json')

    if not os.path.exists(features_file):
        print(f"Features file not found: {features_file}")
        return None, None, None

    print(f"Loading saved VICReg features...")

    # Load features
    with h5py.File(features_file, 'r') as f:
        features = f['features'][:]
        attrs = dict(f.attrs)

    # Load metadata
    metadata_df = pd.read_csv(metadata_file)

    # Load summary
    with open(summary_file, 'r') as f:
        summary = json.load(f)

    print(f"Loaded {len(features):,} VICReg feature vectors")
    print(f"   Feature dimension: {features.shape[1]}")
    print(f"   Champion accuracy: {attrs.get('champion_accuracy', 'unknown'):.2f}%")

    return features, metadata_df, summary

def get_features_by_split(features, metadata_df, split_type='holdout_HG005_GRCh38'):
    """Get train/test features for a specific data split"""

    print(f"Creating {split_type} split...")

    if split_type == '80_20':
        # Random 80/20 split
        from sklearn.model_selection import train_test_split
        train_indices, test_indices = train_test_split(
            range(len(metadata_df)),
            test_size=0.2,
            stratify=metadata_df['label_str'],
            random_state=42
        )
    else:
        # Leave-one-genome-out split
        test_genome = split_type.replace('holdout_', '')
        train_indices = metadata_df[metadata_df['dataset'] != test_genome].index.tolist()
        test_indices = metadata_df[metadata_df['dataset'] == test_genome].index.tolist()

    X_train = features[train_indices]
    X_test = features[test_indices]
    y_train = metadata_df.iloc[train_indices]['binary_label'].values
    y_test = metadata_df.iloc[test_indices]['binary_label'].values

    print(f"   Train: {len(X_train)} samples")
    print(f"   Test: {len(X_test)} samples")
    print(f"   Train class balance: {np.bincount(y_train)}")
    print(f"   Test class balance: {np.bincount(y_test)}")

    return X_train, X_test, y_train, y_test

In [None]:
# Analysis functions

def analyze_vicreg_learned_features():
    """Analyze what VICReg learned through ML experiments"""

    print("ANALYZING VICREG LEARNED FEATURES")
    print("="*50)

    # Load features
    features, metadata_df, summary = load_saved_vicreg_features()
    if features is None:
        print("No saved features found. Run feature extraction first.")
        return None

    # Get labels
    labels = metadata_df['binary_label'].values

    print(f"Loaded VICReg features:")
    print(f"   Shape: {features.shape}")
    print(f"   Champion accuracy: {summary['champion_model']['accuracy']:.2f}%")

    print(f"Label distribution:")
    print(f"   TP: {labels.sum()} ({100*labels.mean():.1f}%)")
    print(f"   FP: {len(labels) - labels.sum()} ({100*(1-labels.mean()):.1f}%)")

    # Split data for analysis - 80/20 with stratification
    X_train, X_test, y_train, y_test = train_test_split(
        features, labels, test_size=0.2, random_state=42, stratify=labels
    )

    print(f"\nData split (80/20 stratified):")
    print(f"   Train: {len(X_train):,} samples ({y_train.sum():,} TP, {len(y_train) - y_train.sum():,} FP)")
    print(f"   Test:  {len(X_test):,} samples ({y_test.sum():,} TP, {len(y_test) - y_test.sum():,} FP)")

    # Scale features
    print(f"\nScaling VICReg features...")
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Test different models on VICReg features
    models = {
        'Logistic Regression': LogisticRegression(random_state=42, max_iter=1000),
        'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    }

    results = {}

    print(f"\nTesting models on VICReg learned features:")
    print("-" * 50)

    for model_name, model in models.items():
        print(f"\nTraining {model_name}...")

        # Train on VICReg features
        model.fit(X_train_scaled, y_train)

        # Predictions
        y_pred = model.predict(X_test_scaled)
        y_pred_proba = model.predict_proba(X_test_scaled)[:, 1]

        # Metrics
        accuracy = accuracy_score(y_test, y_pred)
        auc = roc_auc_score(y_test, y_pred_proba)

        # Detailed classification report
        report = classification_report(y_test, y_pred, target_names=['FP', 'TP'], output_dict=True)

        results[model_name] = {
            'accuracy': accuracy,
            'auc': auc,
            'y_true': y_test,
            'y_pred': y_pred,
            'y_pred_proba': y_pred_proba,
            'classification_report': report
        }

        print(f"   Results:")
        print(f"      Accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)")
        print(f"      AUC: {auc:.3f}")
        print(f"      FP Precision: {report['FP']['precision']:.3f}")
        print(f"      FP Recall: {report['FP']['recall']:.3f}")
        print(f"      TP Precision: {report['TP']['precision']:.3f}")
        print(f"      TP Recall: {report['TP']['recall']:.3f}")

        # Compare to original VICReg performance
        original_acc = summary['champion_model']['accuracy'] / 100  # Convert to decimal
        if accuracy >= original_acc - 0.01:  # Allow 1% tolerance
            print(f"      MATCHES VICReg! ({accuracy:.3f} ≈ {original_acc:.3f})")
        elif accuracy >= original_acc - 0.05:  # Within 5%
            gap = original_acc - accuracy
            print(f"      Close to VICReg: {gap:.3f} gap ({gap*100:.1f}% points)")
        else:
            gap = original_acc - accuracy
            print(f"      Gap to VICReg: {gap:.3f} ({gap*100:.1f}% points)")

    # Summary
    print(f"\nVICREG LEARNED FEATURES ANALYSIS SUMMARY")
    print("=" * 60)
    print(f"{'Model':<20} | {'Accuracy':<10} | {'AUC':<10} | {'vs VICReg':<15}")
    print("-" * 60)

    original_acc = summary['champion_model']['accuracy'] / 100
    for model_name, result in results.items():
        acc = result['accuracy']
        auc = result['auc']

        if acc >= original_acc - 0.01:
            comparison = "MATCHES!"
        elif acc >= original_acc - 0.05:
            comparison = "Close"
        else:
            comparison = f"-{(original_acc - acc)*100:.1f}%"

        print(f"{model_name:<20} | {acc:.3f}     | {auc:.3f}     | {comparison:<15}")

    # Key insight
    best_acc = max(result['accuracy'] for result in results.values())
    if best_acc >= 0.90:
        print(f"\nKEY INSIGHT: VICReg backbone learned EXCELLENT representations ({best_acc:.1%})")
        print(f"   Most of the {summary['champion_model']['accuracy']:.2f}% performance comes from learned features!")
    elif best_acc >= 0.80:
        print(f"\nKEY INSIGHT: VICReg backbone learned GOOD representations ({best_acc:.1%})")
        print(f"   Combination of learned features + final classifier = {summary['champion_model']['accuracy']:.2f}%")
    else:
        print(f"\nKEY INSIGHT: VICReg performance mainly from final classifier")
        print(f"   Learned features contribute {best_acc:.1%}, classifier adds the rest")

    return results, features, metadata_df

def analyze_saved_features():
    """Analyze saved feature files"""

    print("ANALYZING SAVED VICREG FEATURES")
    print("="*50)

    if not os.path.exists(SAVE_DIR):
        print("Features directory not found!")
        return

    h5_files = [f for f in os.listdir(SAVE_DIR) if f.endswith('_features.h5')]

    if not h5_files:
        print("No saved feature files found!")
        return

    print(f"Found {len(h5_files)} VICReg feature file(s):")

    for h5_file in h5_files:
        features_path = os.path.join(SAVE_DIR, h5_file)
        summary_path = os.path.join(SAVE_DIR, 'vicreg_extraction_summary.json')

        # Load summary
        if os.path.exists(summary_path):
            with open(summary_path, 'r') as f:
                summary = json.load(f)

            print(f"\nVICREG FEATURES:")
            print(f"   Champion: {summary['champion_model']['name']}")
            print(f"   Accuracy: {summary['champion_model']['accuracy']:.2f}%")
            print(f"   Samples: {summary['data_summary']['total_samples']:,}")
            print(f"   Feature dim: {summary['data_summary']['feature_dimension']}")
            print(f"   File size: {summary['files']['features_size_mb']} MB")
            print(f"   Extraction date: {summary['extraction_date'][:10]}")
        else:
            print(f"\nVICREG: Summary file missing")

In [None]:
# Usage

print("MAIN FUNCTIONS:")
print("   features, metadata_df, summary = extract_and_save_vicreg_features()")
print("   features, metadata_df, summary = load_saved_vicreg_features()")
print("   results, features, metadata_df = analyze_vicreg_learned_features()")
print()
print("DATA SPLITS:")
print("   X_train, X_test, y_train, y_test = get_features_by_split(features, metadata_df, 'holdout_HG005_GRCh38')")
print()
print("ANALYZE FEATURES:")
print("   analyze_saved_features()")
print()
print("WORKFLOW:")
print("   1. Automatically finds champion VICReg model")
print("   2. Extracts 4096-dim features for all ~50K genomic samples")
print("   3. Saves features to HDF5 (efficient storage)")
print("   4. Analyzes what VICReg learned through ML experiments")
print("   5. Compares latent features to original VICReg performance")

# To run feature extraction:
# features, metadata_df, summary = extract_and_save_vicreg_features()

# To analyze learned features:
# results, features, metadata_df = analyze_vicreg_learned_features()

# To load saved features:
# features, metadata_df, summary = load_saved_vicreg_features()

# To analyze what's available:
# analyze_saved_features()

In [None]:
features, metadata_df, summary = extract_and_save_vicreg_features()