In [None]:
!pip install --no-index h5py

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

# Add the parent directory to the Python path
script_dir = os.path.dirname(os.getcwd())  # Get the directory where the script is located
parent_dir = os.path.dirname(script_dir)  # Get the parent directory

print(f"{script_dir = }")
print(f"{parent_dir = }")

sys.path.append(parent_dir)

script_dir = '/lustre06/project/6067616/soroush1/idiosyncrasy/notebooks'
parent_dir = '/lustre06/project/6067616/soroush1/idiosyncrasy'


In [7]:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names

ModuleNotFoundError: No module named 'torchvision.models.feature_extraction'

In [6]:
from lit_modules.datamodule import MuriDataModule
from argparse import Namespace
from torchvision import models
from torchvision.models import resnet50
import torch as ch
import numpy as np

import torch
import re
from tqdm import tqdm
import h5py
from lit_modules.modules.utils import InferioTemporalLayer

ModuleNotFoundError: No module named 'torchvision.models.feature_extraction'

In [None]:
# Define hyperparameters
hparams = Namespace(
    data_dir="/scratch/soroush1/memorability/muri1320",
    image_size=224,
    batch_size=128,
    num_workers=4,
    change_labels = False,
    pin_memories=[False, False, False],  # [train, val, test]
    return_paths = True
)

# Create the DataModule
data_module = MuriDataModule(hparams)

# Prepare data and setup
data_module.prepare_data()
data_module.setup("test")

# Print dataset sizes
print(f"test dataset size: {len(data_module.test_dataset)}")

test_dl = data_module.test_dataloader()

In [None]:
for img, img_path in tqdm(test_dl):
    with torch.no_grad():  # Disable gradient calculation
        # logger.info(f"Using input size: {img.size()}")
        output = model(img)[getattr(InferioTemporalLayer, arch.upper()).value]  # extract IT layer
        output, img_path = sort_batch_by_filename(output, img_path)

In [None]:
class BlurPoolConv2d(ch.nn.Module):

    # Purpose: This class creates a convolutional layer that first applies a blurring filter to the input before performing the convolution operation.
    # Condition: The function apply_blurpool iterates over all layers of the model and replaces convolution layers (ch.nn.Conv2d) with BlurPoolConv2d if they have a stride greater than 1 and at least 16 input channels.
    # Preventing Aliasing: Blurring the output of convolution layers (especially those with strides greater than 1) helps to reduce aliasing effects. Aliasing occurs when high-frequency signals are sampled too sparsely, leading to incorrect representations.
    # Smooth Transitions: Applying a blur before downsampling ensures that transitions between pixels are smooth, preserving important information in the feature maps.
    # Stabilizing Training: Blurring can help stabilize training by reducing high-frequency noise, making the model less sensitive to small changes in the input data.
    def __init__(self, conv):
        super().__init__()
        default_filter = ch.tensor([[[[1, 2, 1], [2, 4, 2], [1, 2, 1]]]]) / 16.0
        filt = default_filter.repeat(conv.in_channels, 1, 1, 1)
        self.conv = conv
        self.register_buffer("blur_filter", filt)

    def forward(self, x):
        blurred = F.conv2d(
            x,
            self.blur_filter,
            stride=1,
            padding=(1, 1),
            groups=self.conv.in_channels,
            bias=None,
        )
        return self.conv.forward(blurred)


def remove_prefix(state_dict, prefix):
    """
    Remove a prefix from the state_dict keys.

    Args:
    state_dict (dict): State dictionary from which the prefix will be removed.
    prefix (str): Prefix to be removed.

    Returns:
    dict: State dictionary with prefix removed from keys.
    """
    return {
        key[len(prefix) :]: value
        for key, value in state_dict.items()
        if key.startswith(prefix)
    }


def match_and_load_weights(checkpoint_state_dict, model, prefix="module."):
    """
    Match weights from checkpoint_state_dict with model's state_dict and load them into the model.

    Args:
    checkpoint_state_dict (dict): State dictionary from checkpoint.
    model (torch.nn.Module): The model instance.
    prefix (str): Prefix to be removed from checkpoint keys.

    Returns:
    None
    """
    # Remove the prefix from checkpoint state dict keys
    cleaned_checkpoint_state_dict = remove_prefix(checkpoint_state_dict, prefix)

    model_state_dict = model.state_dict()
    matched_weights = {}

    # Iterate over the cleaned checkpoint state dict
    for ckpt_key, ckpt_weight in cleaned_checkpoint_state_dict.items():
        if ckpt_key in model_state_dict:
            # If the layer name matches, add to the matched_weights dict
            matched_weights[ckpt_key] = ckpt_weight
        else:
            print(
                f"Layer {ckpt_key} from checkpoint not found in the model state dict."
            )

    return matched_weights

def create_model_and_scaler(config):
    arch = config["arch"]
    weights = config["weights"]
    use_blurpool = config["use_blurpool"]
    device = ch.device("cuda" if ch.cuda.is_available() else "cpu")
    print(f"{device = }")

    model = getattr(models, arch)(pretrained=None)
    

    def apply_blurpool(mod: ch.nn.Module):
        for name, child in mod.named_children():
            if isinstance(child, ch.nn.Conv2d) and (
                np.max(child.stride) > 1 and child.in_channels >= 16
            ):
                setattr(mod, name, BlurPoolConv2d(child))
            else:
                apply_blurpool(child)

    if use_blurpool:
        apply_blurpool(model)

    ckpt = ch.load(weights,weights_only=True, map_location=device)
    print(f"{list(ckpt.keys())[:10] = }")
    ckpt = match_and_load_weights(ckpt, model)
    print(f"{list(ckpt.keys())[:10] = }")
    
    
    model = model.to(device)    
    model.load_state_dict(ckpt)
    
    return model

config = {"weights": "weights/experiment_different_initialization/resnet50-0/final_weights.pt", "arch": "resnet50", "use_blurpool": True, }
model = create_model_and_scaler(config)
model.eval()

In [None]:
def sort_batch_by_filename(tensor, filenames):
    # Extract indices from filenames
    indices = [int(re.search(r'(\d+)', fname).group()) for fname in filenames]
    
    # Create a list of (index, tensor_slice, filename) tuples
    indexed_data = list(zip(indices, tensor, filenames))
    
    # Sort the list based on the extracted indices
    sorted_data = sorted(indexed_data, key=lambda x: x[0])
    
    # Unzip the sorted list
    _, sorted_tensor, sorted_filenames = zip(*sorted_data)
    
    # Stack the tensor slices back into a single tensor
    sorted_tensor = torch.stack(sorted_tensor)
    
    return sorted_tensor, sorted_filenames

def extract_and_concatenate_features(model, test_dl, arch):
    all_outputs = []
    all_img_paths = []
    for img, img_path in tqdm(test_dl):
        with torch.no_grad():  # Disable gradient calculation
            # logger.info(f"Using input size: {img.size()}")
            output = model(img)[getattr(InferioTemporalLayer, arch.upper()).value]  # extract IT layer
            output, img_path = sort_batch_by_filename(output, img_path)
        
        all_outputs.append(output.cpu())  # Move to CPU to save GPU memory
        all_img_paths.extend(img_path)
        
        # Delete batch data to free up memory
        del img, output
        torch.cuda.empty_cache()  # Clear CUDA cache
    
    # Concatenate all outputs into a single tensor
    concatenated_output = torch.cat(all_outputs, dim=0)
    
    # Delete all_outputs to free up memory
    del all_outputs
    torch.cuda.empty_cache()
    
    # Reshape the output to (samples, -1)
    reshaped_output = concatenated_output.view(concatenated_output.size(0), -1)
    
    # Convert to numpy and delete the torch tensor
    reshaped_output = reshaped_output.detach().numpy()
    del concatenated_output
    torch.cuda.empty_cache()
    
    return reshaped_output, all_img_paths

def save_h5(data, model_name, task_name, dst_path: str, layer_name: str = "it", ):
    """
    Load features from a pickle file and save them to an HDF5 file.

    Args:
    model_name (str): Name of the model used to extract features.
    task_name (str): Name of the task or dataset.
    layer_name (str): Name of the layer from which features were extracted.
    dst_path (str): Directory to save the HDF5 file.

    Returns:
    None
    """

    print(
        f"Model: {model_name}, Task: {task_name}, Layer: {layer_name}, Shape: {data.shape}"
    )

    # Ensure the destination directory exists
    os.makedirs(dst_path, exist_ok=True)

    # Construct the full path for the HDF5 file
    h5_file = os.path.join(dst_path, f"{model_name}_{task_name}_1.h5")

    # Save the NumPy array as a .h5 file
    with h5py.File(h5_file, "w") as hf:
        hf.create_dataset("features", data=data)

    print(f"Features saved to {h5_file}")


def get_config(model_name: str, checkpoint: str):
    num_classes = 1000
    task_type = "classification"
    return {
    "arch": model_name,
    "use_blurpool": True,
    "pretrained": False,
    "num_classes": num_classes,
    "lr": 0.0001,
    "weight_decay": 1e-4,
    "momentum": 0.9,
    "nesterov": True,
    "norm_mean": [0.485, 0.456, 0.406],
    "norm_std": [0.229, 0.224, 0.225],
    "task_type": task_type,
    "experiment": "one",
    "optimizer": "sgd",
    "scheduler": "plateau",
    "step_size": 30,
    "max_epochs": 100,
    "random_training": False,
    "use_ckpt": True,
    "checkpoint": checkpoint,
    }

def get_data(input_size: int):
    # Define hyperparameters
    hparams = Namespace(
        data_dir="/scratch/soroush1/memorability/muri1320",
        image_size=input_size,
        batch_size=128,
        num_workers=4,
        change_labels = False,
        pin_memories=[False, False, False],  # [train, val, test]
        return_paths = True
    )
    
    # Create the DataModule
    data_module = MuriDataModule(hparams)
    
    # Prepare data and setup
    data_module.prepare_data()
    data_module.setup("test")
    
    # Print dataset sizes
    print(f"test dataset size: {len(data_module.test_dataset)}")
    
    return data_module.test_dataloader()

In [None]:
# Extract IT activation
model_name = "alexnet"

concatenated_features, all_image_paths = extract_and_concatenate_features(model, test_dl, model_name)

print(f"Concatenated features shape: {concatenated_features.shape}")
print(f"Total number of images: {len(all_image_paths)}")

task_name = "imagenet_shuffle"
dst_dir = "./features"

# Call the save_h5 function
save_h5(concatenated_features, model_name, task_name, dst_dir)

# Load the saved data and compare
expected_file = os.path.join(dst_dir, f"{model_name}_{task_name}_1.h5")
with h5py.File(expected_file, "r") as hf:
    saved_data = hf["features"][:]

saved_data.shape