ResNet Feature Extraction & Saving

Extract latent features from trained ResNet model for all genomic data

In [None]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import h5py
import json
from tqdm import tqdm
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Configuration - GitHub repository structure
DATA_DIR = '../data/processed/all_datasets_images_rgb'
MODELS_DIR = '../data/processed/resnet_experiments/models'
SAVE_DIR = '../data/processed/resnet_latents'

# Create save directory
os.makedirs(SAVE_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")

print(f"Save directory: {SAVE_DIR}")

In [None]:
# Load best model

def find_champion_model():
    """Find the best performing ResNet model automatically"""

    print("Searching for champion ResNet model...")

    if not os.path.exists(MODELS_DIR):
        print(f"Models directory not found: {MODELS_DIR}")
        print("Run ResNet training first!")
        return None, None

    best_acc = 0
    best_path = None
    best_info = None

    # Search all saved models
    for model_dir in os.listdir(MODELS_DIR):
        model_path = os.path.join(MODELS_DIR, model_dir)
        checkpoint_path = os.path.join(model_path, 'best_model.pth')

        if os.path.exists(checkpoint_path):
            try:
                checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

                # Only consider binary models for feature extraction
                if checkpoint['num_classes'] == 2:
                    acc = checkpoint['best_test_acc']
                    if acc > best_acc:
                        best_acc = acc
                        best_path = checkpoint_path
                        best_info = {
                            'name': checkpoint['experiment_name'],
                            'architecture': checkpoint['architecture'],
                            'accuracy': acc,
                            'auc': checkpoint.get('final_test_auc', 0),
                            'path': checkpoint_path
                        }
            except Exception as e:
                print(f"Error loading {model_dir}: {str(e)[:50]}...")
                continue

    if best_path is None:
        print("No binary ResNet models found!")
        return None, None

    print(f"Champion model found:")
    print(f"   Name: {best_info['name']}")
    print(f"   Architecture: {best_info['architecture']}")
    print(f"   Accuracy: {best_info['accuracy']:.2f}%")
    print(f"   AUC: {best_info['auc']:.3f}")
    print(f"   Path: {best_path}")

    return best_path, best_info

def load_champion_resnet():
    """Load the champion ResNet model for feature extraction"""

    champion_path, champion_info = find_champion_model()
    if champion_path is None:
        return None, None

    print(f"Loading champion ResNet...")

    # Load checkpoint
    checkpoint = torch.load(champion_path, map_location='cpu', weights_only=False)

    # Recreate ResNet class (same as training)
    class FineTunedResNet(torch.nn.Module):
        def __init__(self, architecture, num_classes, dropout=0.2):
            super().__init__()
            import torchvision

            # Architecture mapping
            arch_map = {
                'resnet34': torchvision.models.resnet34,
                'resnet50': torchvision.models.resnet50,
                'resnet50x2': torchvision.models.wide_resnet50_2,
                'resnet101x2': torchvision.models.wide_resnet101_2
            }

            feature_dims = {
                'resnet34': 512,
                'resnet50': 2048,
                'resnet50x2': 2048,
                'resnet101x2': 2048
            }

            self.backbone = arch_map[architecture](weights='IMAGENET1K_V1')
            self.backbone.fc = torch.nn.Identity()

            # Freeze backbone
            for param in self.backbone.parameters():
                param.requires_grad = False

            # Classifier
            self.classifier = torch.nn.Sequential(
                torch.nn.Dropout(dropout),
                torch.nn.Linear(feature_dims[architecture], num_classes)
            )

        def forward(self, x):
            with torch.no_grad():
                features = self.backbone(x)
            return self.classifier(features)

    # Recreate and load model
    model = FineTunedResNet(
        checkpoint['architecture'],
        checkpoint['num_classes'],
        checkpoint['model_config']['dropout_rate']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Extract backbone for feature extraction
    feature_extractor = model.backbone
    feature_extractor.to(device)
    feature_extractor.eval()

    print(f"Loaded {checkpoint['architecture']} backbone for feature extraction")

    return feature_extractor, champion_info

In [None]:
class FeatureExtractionDataset(Dataset):
    """Dataset for feature extraction with metadata tracking"""

    def __init__(self, file_info_list, transform=None):
        self.file_info = file_info_list
        self.transform = transform

    def __len__(self):
        return len(self.file_info)

    def __getitem__(self, idx):
        info = self.file_info[idx]

        try:
            # Load image
            data = torch.load(info['filepath'], map_location='cpu')
            if isinstance(data, dict):
                image = data['image']
            else:
                image = data

            # Handle channels (ensure RGB)
            if image.shape[0] != 3:
                if image.shape[0] < 3:
                    padding = torch.zeros(3 - image.shape[0], *image.shape[1:])
                    image = torch.cat([image, padding], dim=0)
                else:
                    image = image[:3]

            # Normalize to [0,1]
            if image.max() > 1:
                image = image.float() / 255.0

            success = True

        except Exception as e:
            # Fallback for corrupted files
            image = torch.zeros(3, 224, 224)
            success = False

        if self.transform:
            image = self.transform(transforms.ToPILImage()(image))

        return image, info, success

In [None]:
# Data loading and feature extraction functions

def load_all_file_info():
    """Load info for ALL genomic data files"""

    print("Scanning all genomic data files...")

    datasets = ['HG002_GRCh37', 'HG002_GRCh38', 'HG005_GRCh38']
    all_file_info = []

    for dataset_name in datasets:
        dataset_path = os.path.join(DATA_DIR, dataset_name)

        if not os.path.exists(dataset_path):
            print(f"Missing dataset: {dataset_path}")
            continue

        print(f"Scanning {dataset_name}...")
        filenames = [f for f in os.listdir(dataset_path) if f.endswith('.pt')]

        for filename in tqdm(filenames, desc=f"Processing {dataset_name}"):
            parts = filename[:-3].split('_')

            if len(parts) >= 8:
                try:
                    label_str = parts[2]  # TP or FP
                    svtype = parts[6]     # INS, DEL, etc.

                    if label_str in ['TP', 'FP']:
                        all_file_info.append({
                            'filename': filename,
                            'filepath': os.path.join(dataset_path, filename),
                            'dataset': dataset_name,
                            'label_str': label_str,
                            'svtype': svtype,
                            'binary_label': 1 if label_str == 'TP' else 0,
                            'multiclass_label': 0 if label_str == 'FP' else (1 if svtype == 'DEL' else 2)
                        })

                except (ValueError, IndexError):
                    continue

    print(f"Found {len(all_file_info)} total files")

    # Show breakdown
    datasets_count = {}
    labels_count = {}

    for info in all_file_info:
        datasets_count[info['dataset']] = datasets_count.get(info['dataset'], 0) + 1
        labels_count[info['label_str']] = labels_count.get(info['label_str'], 0) + 1

    print("Dataset breakdown:")
    for dataset, count in datasets_count.items():
        print(f"   {dataset}: {count:,} files")

    print("Label breakdown:")
    for label, count in labels_count.items():
        print(f"   {label}: {count:,} files")

    return all_file_info


def extract_and_save_resnet_features():
    """Extract and save ResNet features for all data samples"""

    print("Starting comprehensive ResNet feature extraction...")

    # Load champion model
    feature_extractor, champion_info = load_champion_resnet()
    if feature_extractor is None:
        print("Failed to load champion model!")
        return None, None, None

    # Load all file info
    all_file_info = load_all_file_info()

    # Create transform (same as used in training)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.458, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create dataset
    dataset = FeatureExtractionDataset(all_file_info, transform)

    # Custom collate function to handle metadata
    def collate_fn(batch):
        images = torch.stack([item[0] for item in batch])
        infos = [item[1] for item in batch]  # Keep as list of dicts
        successes = torch.tensor([item[2] for item in batch])
        return images, infos, successes

    # Create dataloader
    dataloader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn
    )

    # Determine feature dimension from champion architecture
    feature_dims = {
        'resnet34': 512,
        'resnet50': 2048,
        'resnet50x2': 2048,
        'resnet101x2': 2048
    }
    feature_dim = feature_dims[champion_info['architecture']]

    # Prepare storage
    total_samples = len(all_file_info)
    features_array = np.zeros((total_samples, feature_dim), dtype=np.float32)
    metadata_list = []

    print(f"Extracting features for {total_samples:,} samples...")
    print(f"Feature dimension: {feature_dim}")

    sample_idx = 0
    failed_count = 0

    with torch.no_grad():
        for batch_images, batch_info, batch_success in tqdm(dataloader, desc="Extracting features"):
            batch_images = batch_images.to(device)

            # Extract features
            batch_features = feature_extractor(batch_images)
            batch_features_np = batch_features.cpu().numpy()

            # Handle variable batch sizes
            actual_batch_size = batch_features_np.shape[0]
            features_array[sample_idx:sample_idx + actual_batch_size] = batch_features_np

            # Process metadata
            for i, (info, success) in enumerate(zip(batch_info, batch_success)):
                if i >= actual_batch_size:  # Safety check
                    break

                metadata_list.append({
                    'sample_idx': sample_idx + i,
                    'filename': info['filename'],
                    'dataset': info['dataset'],
                    'label_str': info['label_str'],
                    'svtype': info['svtype'],
                    'binary_label': info['binary_label'],
                    'multiclass_label': info['multiclass_label'],
                    'extraction_success': success.item(),
                    'filepath': info['filepath']
                })

                if not success.item():
                    failed_count += 1

            sample_idx += actual_batch_size

    print(f"Feature extraction complete!")
    print(f"   Successfully processed: {total_samples - failed_count:,}/{total_samples:,} samples")
    print(f"   Failed files: {failed_count:,}")

    # Save features to HDF5
    features_file = os.path.join(SAVE_DIR, f"{champion_info['architecture']}_features.h5")
    print(f"Saving features to: {features_file}")

    with h5py.File(features_file, 'w') as f:
        f.create_dataset('features', data=features_array, compression='gzip')
        f.attrs['champion_model'] = champion_info['name']
        f.attrs['champion_path'] = champion_info['path']
        f.attrs['architecture'] = champion_info['architecture']
        f.attrs['champion_accuracy'] = champion_info['accuracy']
        f.attrs['champion_auc'] = champion_info['auc']
        f.attrs['feature_dim'] = feature_dim
        f.attrs['total_samples'] = total_samples
        f.attrs['extraction_date'] = datetime.now().isoformat()

    # Save metadata to CSV
    metadata_file = os.path.join(SAVE_DIR, f"{champion_info['architecture']}_metadata.csv")
    print(f"Saving metadata to: {metadata_file}")

    metadata_df = pd.DataFrame(metadata_list)
    metadata_df.to_csv(metadata_file, index=False)

    # Create summary
    summary = {
        'extraction_date': datetime.now().isoformat(),
        'champion_model': {
            'name': champion_info['name'],
            'architecture': champion_info['architecture'],
            'accuracy': champion_info['accuracy'],
            'auc': champion_info['auc'],
            'path': champion_info['path']
        },
        'data_summary': {
            'total_samples': int(total_samples),
            'successful_extractions': int(total_samples - failed_count),
            'failed_extractions': int(failed_count),
            'feature_dimension': int(feature_dim)
        },
        'files': {
            'features_file': features_file,
            'metadata_file': metadata_file,
            'features_size_mb': round(features_array.nbytes / 1024**2, 1)
        },
        'data_breakdown': {
            'datasets': {k: int(v) for k, v in metadata_df['dataset'].value_counts().items()},
            'labels': {k: int(v) for k, v in metadata_df['label_str'].value_counts().items()},
            'sv_types': {k: int(v) for k, v in metadata_df['svtype'].value_counts().items()}
        }
    }

    summary_file = os.path.join(SAVE_DIR, f"{champion_info['architecture']}_extraction_summary.json")
    print(f"Saving summary to: {summary_file}")

    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\nFEATURE EXTRACTION COMPLETE!")
    print(f"Files saved in: {SAVE_DIR}")
    print(f"   Features: {os.path.basename(features_file)} ({summary['files']['features_size_mb']} MB)")
    print(f"   Metadata: {os.path.basename(metadata_file)}")
    print(f"   Summary: {os.path.basename(summary_file)}")

    return features_array, metadata_df, summary

In [None]:
# Utility functions

def load_saved_resnet_features(architecture=None):
    """Load previously saved ResNet features"""

    # Auto-detect architecture if not specified
    if architecture is None:
        # Find any feature files
        h5_files = [f for f in os.listdir(SAVE_DIR) if f.endswith('_features.h5')]
        if not h5_files:
            print("No saved feature files found!")
            return None, None, None

        # Use the first one found
        features_file = os.path.join(SAVE_DIR, h5_files[0])
        architecture = h5_files[0].replace('_features.h5', '')
        print(f"Auto-detected architecture: {architecture}")
    else:
        features_file = os.path.join(SAVE_DIR, f"{architecture}_features.h5")

    metadata_file = os.path.join(SAVE_DIR, f"{architecture}_metadata.csv")
    summary_file = os.path.join(SAVE_DIR, f"{architecture}_extraction_summary.json")

    if not os.path.exists(features_file):
        print(f"Features file not found: {features_file}")
        return None, None, None

    print(f"Loading saved features...")

    # Load features
    with h5py.File(features_file, 'r') as f:
        features = f['features'][:]
        attrs = dict(f.attrs)

    # Load metadata
    metadata_df = pd.read_csv(metadata_file)

    # Load summary
    with open(summary_file, 'r') as f:
        summary = json.load(f)

    print(f"Loaded {len(features):,} feature vectors")
    print(f"   Architecture: {attrs.get('architecture', 'unknown')}")
    print(f"   Feature dimension: {features.shape[1]}")
    print(f"   Champion accuracy: {attrs.get('champion_accuracy', 'unknown'):.2f}%")

    return features, metadata_df, summary

def get_features_by_split(features, metadata_df, split_type='holdout_HG005_GRCh38'):
    """Get train/test features for a specific data split"""

    print(f"Creating {split_type} split...")

    if split_type == '80_20':
        # Random 80/20 split
        from sklearn.model_selection import train_test_split
        train_indices, test_indices = train_test_split(
            range(len(metadata_df)),
            test_size=0.2,
            stratify=metadata_df['label_str'],
            random_state=42
        )
    else:
        # Leave-one-genome-out split
        test_genome = split_type.replace('holdout_', '')
        train_indices = metadata_df[metadata_df['dataset'] != test_genome].index.tolist()
        test_indices = metadata_df[metadata_df['dataset'] == test_genome].index.tolist()

    X_train = features[train_indices]
    X_test = features[test_indices]
    y_train = metadata_df.iloc[train_indices]['binary_label'].values
    y_test = metadata_df.iloc[test_indices]['binary_label'].values

    print(f"   Train: {len(X_train)} samples")
    print(f"   Test: {len(X_test)} samples")
    print(f"   Train class balance: {np.bincount(y_train)}")
    print(f"   Test class balance: {np.bincount(y_test)}")

    return X_train, X_test, y_train, y_test

def analyze_saved_features():
    """Analyze saved feature files"""

    print("ANALYZING SAVED RESNET FEATURES")
    print("="*50)

    if not os.path.exists(SAVE_DIR):
        print("Features directory not found!")
        return

    h5_files = [f for f in os.listdir(SAVE_DIR) if f.endswith('_features.h5')]

    if not h5_files:
        print("No saved feature files found!")
        return

    print(f"Found {len(h5_files)} feature file(s):")

    for h5_file in h5_files:
        architecture = h5_file.replace('_features.h5', '')
        features_path = os.path.join(SAVE_DIR, h5_file)
        summary_path = os.path.join(SAVE_DIR, f"{architecture}_extraction_summary.json")

        # Load summary
        if os.path.exists(summary_path):
            with open(summary_path, 'r') as f:
                summary = json.load(f)

            print(f"\n{architecture.upper()} FEATURES:")
            print(f"   Champion: {summary['champion_model']['name']}")
            print(f"   Accuracy: {summary['champion_model']['accuracy']:.2f}%")
            print(f"   Samples: {summary['data_summary']['total_samples']:,}")
            print(f"   Feature dim: {summary['data_summary']['feature_dimension']}")
            print(f"   File size: {summary['files']['features_size_mb']} MB")
            print(f"   Extraction date: {summary['extraction_date'][:10]}")
        else:
            print(f"\n{architecture.upper()}: Summary file missing")

In [None]:
print("MAIN FUNCTION:")
print("   features, metadata_df, summary = extract_and_save_resnet_features()")
print()
print("LOAD SAVED FEATURES:")
print("   features, metadata_df, summary = load_saved_resnet_features()")
print("   X_train, X_test, y_train, y_test = get_features_by_split(features, metadata_df, 'holdout_HG005_GRCh38')")
print()
print("ANALYZE FEATURES:")
print("   analyze_saved_features()")
print()
print("WORKFLOW:")
print("   1. Finds best ResNet model")
print("   2. Extracts features for all ~50K genomic samples")
print("   3. Saves features to HDF5 (efficient storage)")
print("   4. Saves metadata to CSV (sample info)")
print("   5. Creates reusable feature datasets for ML experiments")

# To run feature extraction:
# features, metadata_df, summary = extract_and_save_resnet_features()

# To load saved features:
# features, metadata_df, summary = load_saved_resnet_features()

# To analyze what's available:
# analyze_saved_features()

In [None]:
features, metadata_df, summary = extract_and_save_resnet_features()