Self-supervised VICReg pre-trained ResNet50x2 for genomic structural variant classification.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import json
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


# Configuration - GitHub repository structure
DATA_DIR = '../data/processed/all_datasets_images_rgb'
SAVE_DIR = '../data/processed/vicreg_experiments'
MODELS_DIR = '../data/processed/vicreg_experiments/models'
FIGURES_DIR = '../figures'

# Create directories
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(FIGURES_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

# Standard hyperparameters
VICREG_CONFIG = {
    'learning_rate': 1e-4,
    'batch_size': 32,
    'weight_decay': 1e-4,
    'dropout_rate': 0.2,
    'epochs': 20,
    'patience': 5,
}

In [None]:
# Dataset class

class GenomicDataset(Dataset):
    """Dataset for genomic structural variant images"""

    def __init__(self, data_list, transform=None):
        self.data = data_list
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        try:
            # Load image
            data = torch.load(item['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels (ensure RGB)
            if image.shape[0] != 3:
                if image.shape[0] < 3:
                    padding = torch.zeros(3 - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:3]

            # Normalize to [0,1]
            if image.max() > 1:
                image = image.float() / 255.0

        except Exception as e:
            print(f"Error loading {item['filepath']}: {e}")
            image = torch.zeros(3, 224, 224)

        if self.transform:
            image = self.transform(transforms.ToPILImage()(image))

        return image, torch.tensor(item['label'], dtype=torch.long)

In [None]:
# Data loading

def load_all_genomic_data():
    """Load all genomic data files"""

    print("Loading all genomic data...")

    all_data = []
    datasets = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']

    for dataset_name in datasets:
        dataset_path = os.path.join(DATA_DIR, dataset_name)

        if not os.path.exists(dataset_path):
            print(f"Missing dataset: {dataset_path}")
            continue

        print(f"Loading {dataset_name}...")

        filenames = [f for f in os.listdir(dataset_path) if f.endswith('.pt')]

        tp_count = 0
        fp_count = 0

        for filename in tqdm(filenames, desc=f"Processing {dataset_name}"):
            parts = filename[:-3].split('_')

            if len(parts) >= 8:
                try:
                    label = parts[2]  # TP or FP
                    svtype = parts[6]  # INS, DEL, etc.

                    if label in ['TP', 'FP']:
                        filepath = os.path.join(dataset_path, filename)

                        all_data.append({
                            'dataset': dataset_name,
                            'filepath': filepath,
                            'label_str': label,
                            'svtype': svtype,
                            'binary_label': 1 if label == 'TP' else 0,
                            'multiclass_label': 0 if label == 'FP' else (1 if svtype == 'DEL' else 2)
                        })

                        if label == 'TP':
                            tp_count += 1
                        else:
                            fp_count += 1

                except (ValueError, IndexError):
                    continue

        print(f"   {tp_count} TP, {fp_count} FP = {tp_count + fp_count} total")

    print(f"\nTotal dataset:")
    print(f"   Samples: {len(all_data)}")

    # Count labels
    tp_total = sum(1 for x in all_data if x['label_str'] == 'TP')
    fp_total = sum(1 for x in all_data if x['label_str'] == 'FP')

    print(f"   TP: {tp_total} ({100*tp_total/len(all_data):.1f}%)")
    print(f"   FP: {fp_total} ({100*fp_total/len(all_data):.1f}%)")

    # Count SV types in TP
    from collections import Counter
    svtype_counts = Counter(x['svtype'] for x in all_data if x['label_str'] == 'TP')
    print(f"   SV types: {dict(svtype_counts)}")

    return all_data

def create_data_splits(all_data):
    """Create both 80/20 and leave-one-genome-out splits"""

    splits = {}

    # 1. Random 80/20 split
    print("\nCreating 80/20 split...")
    train_80, test_20 = train_test_split(
        all_data,
        test_size=0.2,
        stratify=[x['label_str'] for x in all_data],
        random_state=42
    )
    splits['80_20'] = {'train': train_80, 'test': test_20}
    print(f"   Train: {len(train_80)}, Test: {len(test_20)}")

    # 2. Leave-one-genome-out splits
    print("\nCreating leave-one-genome-out splits...")
    genomes = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']

    for test_genome in genomes:
        train_data = [x for x in all_data if x['dataset'] != test_genome]
        test_data = [x for x in all_data if x['dataset'] == test_genome]

        splits[f'holdout_{test_genome}'] = {'train': train_data, 'test': test_data}
        print(f"   {test_genome}: Train={len(train_data)}, Test={len(test_data)}")

    return splits

def create_transforms():
    """Standard ImageNet transforms"""

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.458, 0.406], [0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.458, 0.406], [0.229, 0.224, 0.225])
    ])

    return train_transform, test_transform


In [None]:
# VICReg model class

class VICRegResNet(nn.Module):
    """VICReg pre-trained ResNet50x2 for genomic classification"""

    def __init__(self, num_classes, dropout=0.2):
        super().__init__()

        self.num_classes = num_classes

        print("Loading Facebook's VICReg ResNet50x2...")

        try:
            # Load Facebook's pre-trained VICReg model
            self.backbone = torch.hub.load('facebookresearch/vicreg:main', 'resnet50x2', pretrained=True)
            print("   Successfully loaded VICReg ResNet50x2")

        except Exception as e:
            print(f"   Failed to load VICReg model: {e}")
            print("   Falling back to standard Wide ResNet50-2...")
            self.backbone = torchvision.models.wide_resnet50_2(weights='IMAGENET1K_V1')

        # Get feature dimension - VICReg ResNet50x2 outputs 4096 features
        feature_dim = 4096

        # Remove original classifier
        if hasattr(self.backbone, 'fc'):
            self.backbone.fc = nn.Identity()
        elif hasattr(self.backbone, 'head'):
            self.backbone.head = nn.Identity()

        # Freeze backbone
        print("Freezing VICReg backbone...")
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Trainable classifier
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(feature_dim, num_classes)
        )

        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        print(f"   VICReg Model: {total_params:,} total, {trainable_params:,} trainable")

    def forward(self, x):
        with torch.no_grad():
            features = self.backbone(x)
        return self.classifier(features)

In [None]:
# Training function

def train_vicreg_model(num_classes, train_data, test_data, experiment_name):
    """Train a single VICReg model and save it"""

    print(f"\nTraining VICReg ({num_classes}-class) - {experiment_name}")

    # Create model save directory
    model_save_dir = os.path.join(MODELS_DIR, experiment_name)
    os.makedirs(model_save_dir, exist_ok=True)
    print(f"Model will be saved to: {model_save_dir}")

    # Create transforms and datasets
    train_transform, test_transform = create_transforms()

    # Set label key based on num_classes
    label_key = 'binary_label' if num_classes == 2 else 'multiclass_label'

    # Update data with correct labels
    train_data_labeled = [{**item, 'label': item[label_key]} for item in train_data]
    test_data_labeled = [{**item, 'label': item[label_key]} for item in test_data]

    train_dataset = GenomicDataset(train_data_labeled, train_transform)
    test_dataset = GenomicDataset(test_data_labeled, test_transform)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=VICREG_CONFIG['batch_size'],
                             shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=VICREG_CONFIG['batch_size'],
                            shuffle=False, num_workers=4, pin_memory=True)

    # Model
    model = VICRegResNet(num_classes, VICREG_CONFIG['dropout_rate']).to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=VICREG_CONFIG['learning_rate'],
        weight_decay=VICREG_CONFIG['weight_decay']
    )

    # Scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, verbose=False
    )

    # Training loop
    best_test_acc = 0
    best_model_state = None
    patience_counter = 0
    history = []

    for epoch in range(VICREG_CONFIG['epochs']):
        # Training
        model.train()
        train_correct = 0
        train_total = 0
        train_loss_total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss_total += loss.item()
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        train_acc = 100. * train_correct / train_total
        train_loss = train_loss_total / len(train_loader)

        # Testing
        model.eval()
        test_correct = 0
        test_total = 0
        test_probs = []
        test_targets = []

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)

                _, predicted = outputs.max(1)
                test_total += labels.size(0)
                test_correct += predicted.eq(labels).sum().item()

                # For AUC (binary only)
                if num_classes == 2:
                    probs = F.softmax(outputs, dim=1)
                    test_probs.extend(probs[:, 1].cpu().numpy())
                    test_targets.extend(labels.cpu().numpy())

        test_acc = 100. * test_correct / test_total
        test_auc = roc_auc_score(test_targets, test_probs) if num_classes == 2 and len(set(test_targets)) > 1 else 0

        scheduler.step(test_acc)
        current_lr = optimizer.param_groups[0]['lr']

        # Save best model state
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print(f"   New best: {test_acc:.2f}% - Model state saved!")
        else:
            patience_counter += 1

        history.append({
            'epoch': epoch + 1,
            'train_acc': train_acc,
            'test_acc': test_acc,
            'test_auc': test_auc,
            'train_loss': train_loss,
            'lr': current_lr
        })

        print(f"   Epoch {epoch+1}: Train={train_acc:.1f}%, Test={test_acc:.1f}%, AUC={test_auc:.3f}, LR={current_lr:.1e}")

        # Early stopping
        if patience_counter >= VICREG_CONFIG['patience']:
            print(f"   Early stopping at epoch {epoch+1}")
            break

    # Save the best model to disk
    checkpoint_path = None
    if best_model_state is not None:
        checkpoint = {
            'model_state_dict': best_model_state,
            'model_type': 'VICReg',
            'num_classes': num_classes,
            'experiment_name': experiment_name,
            'best_test_acc': best_test_acc,
            'final_test_auc': test_auc,
            'config': VICREG_CONFIG,
            'history': history,
            'model_config': {
                'feature_dim': 4096,
                'dropout_rate': VICREG_CONFIG['dropout_rate']
            }
        }

        checkpoint_path = os.path.join(model_save_dir, 'best_vicreg_model.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"   Model saved to: {checkpoint_path}")

        # Save model info
        info_path = os.path.join(model_save_dir, 'vicreg_model_info.json')
        model_info = {
            'experiment_name': experiment_name,
            'model_type': 'VICReg',
            'num_classes': num_classes,
            'best_test_acc': best_test_acc,
            'final_test_auc': test_auc,
            'checkpoint_path': checkpoint_path,
            'saved_at': datetime.now().isoformat()
        }
        with open(info_path, 'w') as f:
            json.dump(model_info, f, indent=2)

        print(f"   Model info saved to: {info_path}")

    print(f"   Best test accuracy: {best_test_acc:.2f}%")

    return {
        'model_type': 'VICReg',
        'num_classes': num_classes,
        'experiment': experiment_name,
        'best_test_acc': best_test_acc,
        'final_test_auc': test_auc,
        'history': history,
        'model_path': checkpoint_path,
        'model_save_dir': model_save_dir
    }

In [None]:
# Model loading and analysis

def load_saved_vicreg_model(checkpoint_path):
    """Load a saved VICReg model from checkpoint"""

    print(f"Loading VICReg model from: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

    # Recreate VICReg model
    num_classes = checkpoint['num_classes']
    dropout_rate = checkpoint['model_config']['dropout_rate']

    model = VICRegResNet(num_classes, dropout_rate)

    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    print(f"Loaded VICReg ({num_classes}-class)")
    print(f"   Best accuracy: {checkpoint['best_test_acc']:.2f}%")
    print(f"   Final AUC: {checkpoint['final_test_auc']:.3f}")

    return model, checkpoint


def analyze_saved_vicreg_models():
    """Analyze all saved VICReg models"""

    print("ANALYZING SAVED VICREG MODELS")
    print("="*50)

    if not os.path.exists(MODELS_DIR):
        print("No models directory found")
        return None

    results = []

    for model_dir in os.listdir(MODELS_DIR):
        model_path = os.path.join(MODELS_DIR, model_dir)
        checkpoint_path = os.path.join(model_path, 'best_vicreg_model.pth')

        if os.path.exists(checkpoint_path):
            try:
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
                results.append({
                    'Model': 'VICReg-ResNet50x2',
                    'Classes': f"{checkpoint['num_classes']}-class",
                    'Split': checkpoint['experiment_name'].split('_')[-1] if '_' in checkpoint['experiment_name'] else 'unknown',
                    'Accuracy': checkpoint['best_test_acc'],
                    'AUC': checkpoint.get('final_test_auc', 0),
                    'Full_Name': checkpoint['experiment_name']
                })
            except Exception as e:
                print(f"Error loading {model_dir}: {str(e)[:100]}...")

    if not results:
        print("No saved VICReg models found!")
        return None

    # Create DataFrame
    df = pd.DataFrame(results)

    print(f"Found {len(df)} saved VICReg models")

    # Performance table
    print(f"\nTOP VICREG MODELS BY ACCURACY:")
    top_models = df.nlargest(10, 'Accuracy')[['Model', 'Classes', 'Split', 'Accuracy', 'AUC']]
    print(top_models.to_string(index=False, float_format='%.2f'))

    # Best by category
    print(f"\nBEST BY CATEGORY:")

    # Best binary
    binary_models = df[df['Classes'] == '2-class']
    if len(binary_models) > 0:
        binary_best = binary_models.nlargest(1, 'Accuracy').iloc[0]
        print(f"   Binary: {binary_best['Model']} ({binary_best['Split']}) - {binary_best['Accuracy']:.2f}%")

    # Best 3-class
    multiclass_models = df[df['Classes'] == '3-class']
    if len(multiclass_models) > 0:
        multiclass_best = multiclass_models.nlargest(1, 'Accuracy').iloc[0]
        print(f"   3-class: {multiclass_best['Model']} ({multiclass_best['Split']}) - {multiclass_best['Accuracy']:.2f}%")

    # Overall champion
    overall_best = df.nlargest(1, 'Accuracy').iloc[0]
    print(f"\nOVERALL VICREG CHAMPION:")
    print(f"   {overall_best['Full_Name']}: {overall_best['Accuracy']:.2f}% (AUC: {overall_best['AUC']:.3f})")

    # CSV-Filter comparison
    csv_target = 94.94
    gap = csv_target - overall_best['Accuracy']
    print(f"\nCSV-FILTER COMPARISON:")
    print(f"   Target: {csv_target}%")
    print(f"   VICReg best: {overall_best['Accuracy']:.2f}%")
    print(f"   Gap: {gap:.2f}%")

    if gap <= 0:
        print(f"   VICREG BEATS CSV-FILTER!")
    elif gap <= 2:
        print(f"   Very close! Excellent self-supervised performance.")
    else:
        print(f"   Good self-supervised performance on realistic data!")

    return df


In [None]:
# Main experiments

def run_all_vicreg_experiments():
    """Run all VICReg experiments with model saving"""

    print("RUNNING ALL VICREG EXPERIMENTS")
    print("="*60)

    # Load data
    all_data = load_all_genomic_data()
    splits = create_data_splits(all_data)

    # Experiment configuration
    class_setups = [2, 3]  # Binary and 3-class
    split_names = list(splits.keys())

    total_experiments = len(class_setups) * len(split_names)
    print(f"\nPlanning {total_experiments} VICReg experiments:")
    print(f"   Model: VICReg ResNet50x2")
    print(f"   Class setups: {class_setups}")
    print(f"   Data splits: {split_names}")

    # Run experiments
    all_results = []

    for num_classes in class_setups:
        for split_name in split_names:
            train_data = splits[split_name]['train']
            test_data = splits[split_name]['test']

            experiment_name = f"vicreg_{num_classes}class_{split_name}"

            try:
                result = train_vicreg_model(num_classes, train_data, test_data, experiment_name)
                all_results.append(result)

                # Save intermediate results
                results_df = pd.DataFrame(all_results)
                results_df.to_csv(os.path.join(SAVE_DIR, 'vicreg_results.csv'), index=False)

            except Exception as e:
                print(f"Failed {experiment_name}: {e}")
                continue

            # Clear memory
            torch.cuda.empty_cache()

    # Final summary
    print(f"\nVICREG EXPERIMENT SUMMARY:")
    print(f"   Completed: {len(all_results)}/{total_experiments}")

    if all_results:
        results_df = pd.DataFrame(all_results)

        print(f"\nVICREG RESULTS:")
        for _, result in results_df.iterrows():
            print(f"   {result['experiment']}: {result['best_test_acc']:.2f}% (AUC: {result['final_test_auc']:.3f})")

        overall_best = results_df.loc[results_df['best_test_acc'].idxmax()]
        print(f"\nOVERALL BEST VICREG: {overall_best['best_test_acc']:.2f}%")
        print(f"   Experiment: {overall_best['experiment']}")
        print(f"   Saved at: {overall_best['model_path']}")

        # CSV-Filter comparison
        csv_filter_target = 94.94
        if overall_best['best_test_acc'] >= csv_filter_target:
            print(f"BEAT CSV-FILTER! (+{overall_best['best_test_acc'] - csv_filter_target:.2f}%)")
        else:
            print(f"Gap to CSV-Filter: {csv_filter_target - overall_best['best_test_acc']:.2f}%")

        # List all saved models
        print(f"\nSAVED VICREG MODELS:")
        for _, result in results_df.iterrows():
            if result['model_path']:
                print(f"   {result['experiment']}: {result['model_path']}")

    return all_results

def compare_vicreg_to_resnet():
    """Compare VICReg results to ResNet baselines"""

    print("COMPARING VICREG TO RESNET BASELINES")
    print("="*50)

    # Analyze VICReg results
    vicreg_df = analyze_saved_vicreg_models()

    if vicreg_df is None:
        print("No VICReg results to compare")
        return

    # Get best VICReg performance
    best_vicreg = vicreg_df.nlargest(1, 'Accuracy').iloc[0]

    print(f"\nPERFORMANCE COMPARISON:")
    print(f"   Best VICReg: {best_vicreg['Accuracy']:.2f}% ({best_vicreg['Full_Name']})")

    # Note: ResNet comparison would require loading ResNet results
    print(f"   Self-supervised pre-training vs supervised ImageNet pre-training")
    print(f"   VICReg uses learned representations from self-supervised learning")

    return vicreg_df

In [None]:
# Usage

print("MAIN FUNCTIONS:")
print("   results = run_all_vicreg_experiments()")
print("   analysis_df = analyze_saved_vicreg_models()")
print("   comparison = compare_vicreg_to_resnet()")
print()
print("MODEL LOADING:")
print("   model, checkpoint = load_saved_vicreg_model('/path/to/model.pth')")
print()
print("EXPERIMENT DETAILS:")
print("   VICReg ResNet50x2 (Facebook self-supervised pre-trained)")
print("   2 class setups (binary TP/FP, 3-class FP/DEL/INS)")
print("   4 data splits (80/20 + 3 leave-one-genome-out)")
print("   = 8 total experiments")
print("   All models saved with full checkpoints")

# To run all experiments:
# results = run_all_vicreg_experiments()

# To analyze results:
# analysis_df = analyze_saved_vicreg_models()

# To compare with ResNet:
# comparison = compare_vicreg_to_resnet()

In [None]:
results = run_all_vicreg_experiments()