In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
from pprint import pprint
from pathlib import Path
from random import randint

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact
from tqdm.notebook import tqdm
import nibabel as nib
import glmsingle
from glmsingle.glmsingle import GLM_single
import bids
from bids import BIDSLayout
from scipy.ndimage import zoom, binary_dilation
import h5py
import nibabel as nib
from einops import rearrange

dir2 = os.path.abspath('..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [None]:
dataset_root = Path('E:\\fmri_processing\\results')
tc2see_version = 3 # [1, 2]
# dataset_path = dataset_root / f"TC2See_v{tc2see_version}"
dataset_path = dataset_root
derivatives_path = dataset_path / 'derivatives_TC2See'

In [None]:
# Create h5 file for v1 or v2 stimulus images
from PIL import Image

stimulus_images_path = Path('E:\\Decoding\\bird_data\\docs\\cropped')

with h5py.File(derivatives_path / 'stimulus-images.hdf5', 'w') as f:
    stimulus_names = []
    for image_file_path in stimulus_images_path.iterdir():
        stimulus_name = image_file_path.stem
        stimulus_names.append(stimulus_name)
        
        class_id, image_id = stimulus_name.split('.')
        class_id = int(class_id)
        
        bird_name = image_id[:-2]
        bird_id = int(image_id[-1])
        
        with Image.open(image_file_path) as image:
            data = np.array(image)
        f[f'{stimulus_name}/data'] = data
        f[stimulus_name].attrs['class_id'] = class_id
        f[stimulus_name].attrs['bird_id'] = bird_id
    f.attrs['stimulus_names'] = stimulus_names

In [None]:
# Create h5 file for v3 stimulus images
from PIL import Image

dataset_layout = BIDSLayout(dataset_path / 'TC2See')
derivatives_layout = BIDSLayout(derivatives_path / 'fmriprep')

events_files = dataset_layout.get(
    subject='03',
    task='bird',
    extension='tsv'
)

events_dfs = [
    pd.read_csv(events_file.path, sep='\t')
    for events_file in events_files
]
events_dfs = [df[df['stimulus'] != '+'] for df in events_dfs]
stimulus_paths = [np.array(df['stimulus']) for df in events_dfs]
stimulus_paths = np.unique(np.concatenate(stimulus_paths))
stimulus_names = [Path(p).stem for p in stimulus_paths if 'hash' not in p]

stimulus_images_path = Path('X:\\Datasets\\EEG\\Things-concepts-and-images\\Main\\images')

with h5py.File(derivatives_path / 'stimulus-images.hdf5', 'w') as f:
    for stimulus_name in stimulus_names:
        
        class_name = '_'.join(stimulus_name.split('_')[:-1])
        image_file_path = stimulus_images_path / class_name / f'{stimulus_name}.jpg'
        
        with Image.open(image_file_path) as image:
            data = np.array(image)
        
        f[f'{stimulus_name}/data'] = data
        f[stimulus_name].attrs['class_name'] = class_name
    f.attrs['stimulus_names'] = stimulus_names

In [None]:
len(events_dfs[1])

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
stimulus_images = h5py.File(derivatives_path / 'stimulus-images.hdf5', "r")

In [None]:
# Load a CLIP model
import clip

print(clip.available_models())
model_name = 'ViT-B/32'
model, preprocess = clip.load(model_name, device=device)
model = model.visual

In [None]:
@interact(stimulus_id=stimulus_images.keys())
def select_module(stimulus_id):
    image = stimulus_images[stimulus_id]['data'][:]
    image = Image.fromarray(image)
    plt.imshow(image)
    
    x = preprocess(image).unsqueeze(0).to(device).to(torch.float32)  # Change torch.float16 to torch.float32
    print(x.isnan().sum(), x.isinf().sum())
    
    print(model.conv1)
    x_conv1 = model.conv1(x)
    print(model.conv1.weight.isnan().sum(), model.conv1.weight.isinf().sum())
    print(x_conv1.isnan().sum(), x_conv1.isinf().sum())
    
    
    x_conv1 = torch.nn.functional.conv2d(x, model.conv1.weight, stride=32)
    print(x_conv1.isnan().sum(), x_conv1.isinf().sum())

In [None]:
# Feature visualizer
from PIL import Image
from functools import partial
import math
from einops import rearrange

def vis_features(x):
    if not isinstance(x, torch.Tensor):
        print(type(x))
        return
    x = x.float().cpu()
    print(x.shape, x.dtype)

    if len(x.shape) == 3:
        d = int(math.sqrt(x.shape[0] - 1))
        x = rearrange(x[:-1, 0], '(h w) c -> c h w', h=d, w=d)[None]
        
    if len(x.shape) != 4:
        return
    N, C, W, H = x.shape
    
    print(x.mean(), x.std())

    @interact(i=(0, N-1), c=(0, C-1))
    def plot_feature_map(i, c):
        fig = plt.figure(figsize=(8, 8))
        plt.imshow(x[i, c].cpu(), cmap="gray")
        plt.colorbar()
        plt.show()
        plt.close(fig)


modules = dict(model.named_modules())
#print([(node, modules[node].__class__.__name__) for node in nodes if node in modules])
@interact(module_name=modules.keys(), stimulus_id=stimulus_images.keys())
def select_module(module_name, stimulus_id):
    image = stimulus_images[stimulus_id]['data'][:]
    print(image.min(), image.max())
    image = Image.fromarray(image)
    x = preprocess(image).unsqueeze(0).to(device).to(torch.float32) # Change torch.float16 to torch.float32
    
    #x = preprocess(image).to(device)
    print(x.shape)
    print(x.mean().item(), x.min().item(), x.max().item(), x.std().item())
    
    features = {}
    def forward_hook(module_name, module, x_in, x_out):
        features[module_name] = x_out.clone()
    
    module = modules[module_name]
    hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
    
    with torch.no_grad():
        model(x.float())  # Convert x to float32
    
    vis_features(features[module_name])

In [None]:
# Define features to save
save_modules = {
    '': 'embedding'
}

In [None]:
# Feature extraction

from functools import partial
from tqdm.notebook import tqdm
from PIL import Image
from functools import partial
from typing import Sequence, Dict

modules = dict(model.named_modules())

with h5py.File(derivatives_path / f"{model_name.replace('/', '=')}-features.hdf5", "a") as f:
    images = list(enumerate(stimulus_images.items()))
    N = len(images)
    for i, (stimulus_id, stimulus_image) in tqdm(images):
        image_data = stimulus_image['data'][:]
        image = Image.fromarray(image_data)
        x = preprocess(image).unsqueeze(0).to(device) #.to(torch.float16)

        features = {}
        def forward_hook(module_name, module, x_in, x_out):
            if x_out.shape[0] == 1:
                x_out = x_out[0]
            features[module_name] = x_out.clone().cpu().numpy()
        
        hook_handles = []
        if isinstance(save_modules, Sequence):
            for module_name in save_modules:
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, module_name))
                hook_handles.append(hook_handle)
        elif isinstance(save_modules, Dict):
            for module_name, feature_name in save_modules.items():
                module = modules[module_name]
                hook_handle = module.register_forward_hook(partial(forward_hook, feature_name))
                hook_handles.append(hook_handle)
        
        with torch.no_grad():
            model(x)
            
        for hook_handle in hook_handles:
            hook_handle.remove()
        
        for feature_name, feature in features.items():
            f.require_dataset(feature_name, (N, *feature.shape), feature.dtype)
            f[feature_name][i] = feature
            