In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
from pprint import pprint
from pathlib import Path
from random import randint

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact
from tqdm.notebook import tqdm
import nibabel as nib
import glmsingle
from glmsingle.glmsingle import GLM_single
import bids
from bids import BIDSLayout
from scipy.ndimage import zoom, binary_dilation
import h5py
import nibabel as nib
from einops import rearrange

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [3]:
# Enter the path to the Kamitani dataset
derivatives_path = Path('X:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\')
derivatives_path_ssd = Path('D:\\Datasets\\Deep-Image-Reconstruction\\derivatives\\')
dataset_path = derivatives_path / 'fmriprep-20.2.4'

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
stimulus_images = h5py.File(derivatives_path / 'stimulus_images.hdf5', "r")

In [7]:
# Load a CLIP model
import clip

print(clip.available_models())
model_name = 'ViT-B/32'
model, preprocess = clip.load(model_name, device=device)
model = model.visual

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [9]:
# Define features to save

save_modules = {
    '': 'embedding'
}

In [10]:
# Feature extraction

from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=')}-features.hdf5", "a") as f:
    images = list(enumerate(stimulus_images.items()))
    N = len(images)
    for i, (stimulus_id, stimulus_image) in tqdm(images):
        image_data = stimulus_image['data'][:]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0).to(device).to(torch.float16)

        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            if x_out.shape[0] == 1:
                x_out = x_out[0]
            features[module_name] = x_out.clone().cpu().numpy()
        
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        
        with torch.no_grad():
            model(x)
            
        for hook_handle in hook_handles:
            hook_handle.remove()
        
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][i] = feature
            

  0%|          | 0/1300 [00:00<?, ?it/s]