In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

# Add the parent directory to the Python path
script_dir = os.path.dirname(os.getcwd())  # Get the directory where the script is located
parent_dir = os.path.dirname(script_dir)  # Get the parent directory
parent_dir = os.path.dirname(parent_dir)  # Get the parent directory


print(f"{script_dir = }")
print(f"{parent_dir = }")

sys.path.append(parent_dir)

script_dir = '/lustre06/project/6067616/soroush1/idiosyncrasy/notebooks/experiments'
parent_dir = '/lustre06/project/6067616/soroush1/idiosyncrasy'


In [26]:
import h5py
import numpy as np
from tqdm import tqdm

from lit_modules.datamodule import MuriDataModule
from argparse import Namespace

import lightning as L

In [16]:
import os
from torch.utils.data import Dataset, Subset
from torchvision import transforms
import PIL.Image
import pandas as pd
import torch

class MuriDataset(Dataset):
    def __init__(self, root: str, transforms=None):
        self.root = root
        self.transforms = transforms
        
        # Read metadata
        self.meta_data = pd.read_csv(os.path.join(root, "meta.csv"))
        
        # Ensure img_path is absolute path if not already
        if not os.path.isabs(self.meta_data["img_path"].iloc[0]):
            self.meta_data["img_path"] = self.meta_data["img_path"].apply(
                lambda x: os.path.join(root, x)
            )

    def __len__(self):
        return len(self.meta_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        # Get image path and verify it exists
        img_path = self.meta_data["img_path"].iloc[idx]
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")
            
        # Load and transform image
        try:
            image = PIL.Image.open(img_path).convert("RGB")
            if self.transforms:
                image = self.transforms(image)
        except Exception as e:
            print(f"Error loading image {img_path}: {str(e)}")
            raise
            
        # Get label
        label = self.meta_data["labels"].iloc[idx]
        
        return image, label

def get_transform(input_size: int = 256):
    """Define image transformations"""
    return transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    ])

# Create dataset and dataloader
input_size = 256
batch_size = 32

# Create dataset
dataset = MuriDataset(
    root="/scratch/soroush1/memorability/muri1320",
    transforms=get_transform(input_size)
)
print(f"Dataset size: {len(dataset)}")

# load test indices
df = pd.read_csv('/scratch/soroush1/memorability/muri1320/test.csv')
test_indices = df.values.flatten()
print(f"{test_indices.shape = }")

dataset = Subset(dataset, test_indices)

# Create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# Test dataset
print(f"Dataset size: {len(dataset)}")
sample_image, sample_label = dataset[0]
print(f"Image shape: {sample_image.shape}")
print(f"Label: {sample_label}")

# Test dataloader
for batch_idx, (images, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx}")
    print(f"Batch image shape: {images.shape}")
    print(f"Batch labels shape: {labels.shape}")
    break

Dataset size: 1320
test_indices.shape = (660,)
Dataset size: 660
Image shape: torch.Size([3, 256, 256])
Label: 5
Batch 0
Batch image shape: torch.Size([32, 3, 256, 256])
Batch labels shape: torch.Size([32])


In [28]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
from enum import Enum

class InferioTemporalLayer(Enum):
    ALEXNET = "features.12"
    RESNET50 = "layer3.2.bn1"
    RESNET101 = "layer3.2.bn1"
    VGG16 = "features.30"
    VGG19 = "features.36"
    INCEPTION_V3 = "Mixed_7a.branch3x3_1.bn"
    VIT_B_16 = "encoder.layers.encoder_layer_8.mlp"
    VIT_B_32 = "encoder.layers.encoder_layer_8.mlp"
    EFFICIENTNET_B0 = "features.6.2.stochastic_depth"
    RESNET18 = "layer4.0.relu"

def create_feature_generator(weights_path=None):
    # Load model
    model = resnet50(weights=None)
    
    if weights_path:
        ckpt = torch.load(weights_path, map_location='cpu', weights_only=True)
        model.load_state_dict(ckpt)
    
    # Get the IT layer name for ResNet50
    it_layer = InferioTemporalLayer.RESNET50.value
    
    # Create feature extractor
    feature_extractor = create_feature_extractor(
        model, 
        return_nodes=[it_layer]
    )
    
    # Set to evaluation mode
    feature_extractor.eval()
    
    return feature_extractor

def extract_batch_features(batch, feature_extractor, device):
    images, labels = batch
    images = images.to(device)
    with torch.no_grad():
        features = feature_extractor(images)
        # Flatten the features from the last layer
        features = features[InferioTemporalLayer.RESNET50.value]
        features = features.view(features.size(0), -1)  # Flatten all dimensions except batch
    return features, labels

def extract_and_save_features(dataloader, feature_extractor, output_path, batch_size=32):
    """
    Extract flattened features from all images and save to H5 file
    """
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    feature_extractor = feature_extractor.to(device)  # Move feature extractor to device
    
    # First pass to get feature dimensions
    sample_batch = next(iter(dataloader))
    sample_features, _ = extract_batch_features(sample_batch, feature_extractor, device)
    feature_dim = sample_features.shape[1]  # This will be the flattened dimension
    
    # Calculate total number of samples
    total_samples = len(dataloader.dataset)
    
    with h5py.File(output_path, 'w') as f:
        # Create datasets with flattened features
        features_dataset = f.create_dataset(
            'features', 
            shape=(total_samples, feature_dim),
            dtype=np.float32,
            chunks=(min(1000, total_samples), feature_dim)
        )
        labels_dataset = f.create_dataset(
            'labels',
            shape=(total_samples,),
            dtype=np.int64
        )
        
        # Process batches
        start_idx = 0
        for batch in tqdm(dataloader, desc="Extracting features"):
            # Extract flattened features
            features, labels = extract_batch_features(batch, feature_extractor, device)
            
            # Convert to numpy and move to CPU
            features = features.cpu().numpy()
            labels = labels.cpu().numpy()
            
            # Calculate end index for this batch
            end_idx = start_idx + features.shape[0]
            
            # Save to H5 file
            features_dataset[start_idx:end_idx] = features
            labels_dataset[start_idx:end_idx] = labels
            
            # Update start index
            start_idx = end_idx
            
            # Clear GPU cache if needed
            torch.cuda.empty_cache()

In [29]:
# Usage
output_path = 'resnet50_features.h5'

feature_extractor = create_feature_generator(weights_path="weights/resnet50.pth")

# Extract and save features
extract_and_save_features(dataloader, feature_extractor, output_path)

# Verify saved features
with h5py.File(output_path, 'r') as f:
    print("Saved features shape:", f['features'].shape)
    print("Saved labels shape:", f['labels'].shape)

Extracting features: 100%|██████████| 21/21 [00:03<00:00,  5.45it/s]

Saved features shape: (660, 65536)
Saved labels shape: (660,)



