In [1]:
import h5py
import numpy as np
import os
import pandas as pd
import json
from pycocotools.coco import COCO
import matplotlib.pyplot as plt

DATA_DIR="../data/NSD-raw"
SUBJECT=[1,2,5,7]
FILTERED_SUBJECT=[3,4,6,8]
NUM_SESSION=40
PREP_TYPE="betas_fithrf_GLMdenoise_RR"
PREP_NAME="betas"

OUTPUT_DIR="../data/NSD-processed"

stimuli_img = os.path.join(DATA_DIR, "stimuli", "nsd_stimuli.hdf5")
stimuli_info = os.path.join(DATA_DIR, "stimuli", "nsd_stim_info_merged.csv")
train_coco_annotations = os.path.join(DATA_DIR, "stimuli", "annotations", "instances_train2017.json")
val_coco_annotations = os.path.join(DATA_DIR, "stimuli", "annotations", "instances_val2017.json")

Load stimuli information

In [2]:
# load csv
stimuli_info = pd.read_csv(stimuli_info)
print(f"total number of stimuli: {len(stimuli_info)}")
# drop the columns that contain "subject"+id where id not in SUBJECT
stimuli_info = stimuli_info.drop(columns=[col for col in stimuli_info.columns if any(f"subject{subj}" in col for subj in FILTERED_SUBJECT)])
print(stimuli_info.columns)


total number of stimuli: 73000
Index(['Unnamed: 0', 'cocoId', 'cocoSplit', 'cropBox', 'loss', 'nsdId',
       'flagged', 'BOLD5000', 'shared1000', 'subject1', 'subject2', 'subject5',
       'subject7', 'subject1_rep0', 'subject1_rep1', 'subject1_rep2',
       'subject2_rep0', 'subject2_rep1', 'subject2_rep2', 'subject5_rep0',
       'subject5_rep1', 'subject5_rep2', 'subject7_rep0', 'subject7_rep1',
       'subject7_rep2'],
      dtype='object')


Load stimuli image

In [3]:
# load h5py
with h5py.File(stimuli_img, 'r') as f:
    print(f.keys())
    images = f['imgBrick'][:]
print(images.shape)
print(images.dtype)
print(f"min: {images.min()}, max: {images.max()}")

<KeysViewHDF5 ['imgBrick']>
(73000, 425, 425, 3)
uint8
min: 0, max: 255


Load COCO dataset annotations

In [4]:
# load train_coco_annotations
train_coco_annotations = COCO(train_coco_annotations)
val_coco_annotations = COCO(val_coco_annotations)

categories = train_coco_annotations.loadCats(train_coco_annotations.getCatIds())
category_id_to_name = {cat['id']: cat['name'] for cat in categories}
print(category_id_to_name)

loading annotations into memory...
Done (t=18.13s)
creating index...
index created!
loading annotations into memory...
Done (t=0.83s)
creating index...
index created!
{1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut', 61: 'cake', 62: 'chair', 63: 'c

Load fMRI responses and save each fMRI trial to npy file.
Only Need to run once.

In [None]:
# count = 0
for subj in SUBJECT:
    print(f"processing subject {subj}")
    session_responses = []
    for session in range(1, NUM_SESSION+1):
        response_file = os.path.join(
            DATA_DIR, 
            f"sub-{subj:02d}",
            PREP_TYPE,
            f"{PREP_NAME}_session{session:02d}.hdf5")
        identifier = f"sub-{subj:02d}_ses-{session:02d}"
        print(identifier)
        # load h5py
        with h5py.File(response_file, 'r') as f:
            responses = f[PREP_NAME][:]
            # print(responses.shape)
            # session_responses.append(responses)
        print(f"loaded {response_file}")
        # save subdir
        save_dir = os.path.dirname(response_file)
        save_dir = os.path.join(save_dir, identifier)
        os.makedirs(save_dir, exist_ok=True)
        for i in range(responses.shape[0]):
            # save the response as npy
            np.save(os.path.join(save_dir, f"{identifier}_res-{i:03d}.npy"), responses[i])
    break # break after first subject

Alignment: [fmri, image, coco_id, nsd_id, label, coco_mode, sub_id]
fMRI: 3 repetition [3, x, y, z], float16
image: RGB format [c, h, w], uint8
coco_id: unique id of natural stimuli in coco dataset, uint8
nsd_id, unique id of natural stimuli in nsd, uint8
coco_mode, train/test mode of coco, str
sub_id, unique id of subject in nsd, str

In [5]:
# iterate over the rows in the csv file
response_dict = {}
# get the min and max cocoId
min_cocoId = min(train_coco_annotations.imgToAnns.keys())
max_cocoId = max(train_coco_annotations.imgToAnns.keys())
print(f"min_cocoId: {min_cocoId}, max_cocoId: {max_cocoId}")

for subject_id in SUBJECT:
    print(f"processing subject {subject_id}")
    response_dict[subject_id] = {}
    for index, row in stimuli_info.iterrows():
        # get the subject id
        # if any of the subject id in the column f"subject{subject_id}" is not 0, then print the row
        # find the nsd_id for the corresponding cocoId
        if row[f"subject{subject_id}"] != 0:
            response_indicies = []
            for rep in range(0, 3):
                response_indicies.append(row[f"subject{subject_id}_rep{rep}"])
            nsd_id = row['nsdId']
            response_dict[subject_id][nsd_id] = {}
            response_dict[subject_id][nsd_id]['response_indicies'] = response_indicies
            response_dict[subject_id][nsd_id]['cocoId'] = row['cocoId']
            response_dict[subject_id][nsd_id]['coco_index'] = index

min_cocoId: 9, max_cocoId: 581929
processing subject 1
processing subject 2
processing subject 5
processing subject 7


Mindeye setup

In [6]:
def get_huggingface_urls(commit='main',subj=1):
    base_url = "https://huggingface.co/datasets/pscotti/naturalscenesdataset/resolve/"
    train_url = base_url + commit + f"/webdataset_avg_split/train/train_subj0{subj}_" + "{0..17}.tar"
    val_url = base_url + commit + f"/webdataset_avg_split/val/val_subj0{subj}_0.tar"
    test_url = base_url + commit + f"/webdataset_avg_split/test/test_subj0{subj}_" + "{0..1}.tar"
    return train_url, val_url, test_url

In [7]:
import braceexpand
from tqdm import tqdm
import requests
import webdataset as wds
import random

Download dataset

In [8]:
train_url, val_url, test_url = get_huggingface_urls("main",1)
train_url = list(braceexpand.braceexpand(train_url))
val_url = list(braceexpand.braceexpand(val_url))
test_url = list(braceexpand.braceexpand(test_url))
print(len(train_url), len(val_url), len(test_url))

MINDEYE_NSD_DIR = os.path.join(DATA_DIR, "mindeye")
os.makedirs(MINDEYE_NSD_DIR, exist_ok=True)
if not os.path.exists(os.path.join(MINDEYE_NSD_DIR, train_url[0].rsplit('/', 1)[-1])):
    print("Downloading train data...")
    for url in tqdm(train_url):
        destination = MINDEYE_NSD_DIR + "/" + url.rsplit('/', 1)[-1]
        print(f"\nDownloading {url} to {destination}...")
        response = requests.get(url)
        response.raise_for_status()
        with open(destination, 'wb') as file:
            file.write(response.content)
            
    for url in tqdm(val_url):
        destination = MINDEYE_NSD_DIR + "/" + url.rsplit('/', 1)[-1]
        print(f"\nDownloading {url} to {destination}...")
        response = requests.get(url)
        response.raise_for_status()
        with open(destination, 'wb') as file:
            file.write(response.content)
            
    for url in tqdm(test_url):
        destination = MINDEYE_NSD_DIR + "/" + url.rsplit('/', 1)[-1]
        print(f"\nDownloading {url} to {destination}...")
        response = requests.get(url)
        response.raise_for_status()
        with open(destination, 'wb') as file:
            file.write(response.content)

18 1 2


In [9]:
# train_url = train_url[local_rank:world_size]
batch_size = 1
num_worker_batches = 4
num_train = 8559 + 300
voxels_key="nsdgeneral.npy"
to_tuple=["voxels", "images", "coco"]
train_data = wds.WebDataset(train_url, resampled=False, cache_dir=MINDEYE_NSD_DIR)\
    .decode("torch")\
    .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
    .to_tuple(*to_tuple)\
    .batched(batch_size, partial=True)

val_data = wds.WebDataset(val_url, resampled=False, cache_dir=MINDEYE_NSD_DIR)\
    .decode("torch")\
    .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
    .to_tuple(*to_tuple)\
    .batched(batch_size, partial=True)

test_data = wds.WebDataset(test_url, resampled=False, cache_dir=MINDEYE_NSD_DIR)\
    .decode("torch")\
    .rename(images="jpg;png", voxels=voxels_key, trial="trial.npy", coco="coco73k.npy", reps="num_uniques.npy")\
    .to_tuple(*to_tuple)\
    .batched(batch_size, partial=True)
    
train_coco_ids, val_coco_ids, test_coco_ids = [], [], []
print(f"start processing train data")
for i, batch in enumerate(train_data):
    voxels_, images_, nsd_id = batch
    # print(voxels.shape)
    # print(images.shape)
    # print(coco_id.shape)
    nsd_id = nsd_id[0][0]
    # find coco_id in stimuli_info
    coco_id = stimuli_info[stimuli_info['nsdId'] == nsd_id]['cocoId'].values[0]
    # print(f"nsd_id: {nsd_id}")
    
    # # get the coco annotation
    # coco_annotations = train_coco_annotations.imgToAnns[coco_id]
    # if len(coco_annotations) == 0:
    #     coco_annotations = val_coco_annotations.imgToAnns[coco_id]
    # print(coco_annotations)
    # # get text of coco_annotations
    # text_annotations = []
    # for ann in coco_annotations:
    #     text_annotations.append(category_id_to_name[ann['category_id']])
    # image = images[0].permute(1, 2, 0)
    # plt.imshow(image)
    # plt.title(text_annotations)
    # plt.show()
    # break
    train_coco_ids.append(coco_id)
    
print(f"start processing val data")
for i, batch in enumerate(val_data):
    voxels_, images_, nsd_id = batch
    nsd_id = nsd_id[0][0]
    coco_id = stimuli_info[stimuli_info['nsdId'] == nsd_id]['cocoId'].values[0]
    val_coco_ids.append(coco_id)

print(f"start processing test data")
for i, batch in enumerate(test_data):
    voxels_, images_, nsd_id = batch
    nsd_id = nsd_id[0][0]
    coco_id = stimuli_info[stimuli_info['nsdId'] == nsd_id]['cocoId'].values[0]
    test_coco_ids.append(coco_id)

print(f"num of train coco ids: {len(train_coco_ids)}, num of val coco ids: {len(val_coco_ids)}, num of test coco ids: {len(test_coco_ids)}")


start processing train data
start processing val data
start processing test data
num of train coco ids: 8559, num of val coco ids: 300, num of test coco ids: 982


In [10]:
stimuli_info

Unnamed: 0.1,Unnamed: 0,cocoId,cocoSplit,cropBox,loss,nsdId,flagged,BOLD5000,shared1000,subject1,...,subject1_rep2,subject2_rep0,subject2_rep1,subject2_rep2,subject5_rep0,subject5_rep1,subject5_rep2,subject7_rep0,subject7_rep1,subject7_rep2
0,0,532481,val2017,"(0, 0, 0.1671875, 0.1671875)",0.100000,0,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
1,1,245764,val2017,"(0, 0, 0.125, 0.125)",0.000000,1,False,False,False,0,...,0,0,0,0,0,0,0,13985,14176,28603
2,2,385029,val2017,"(0, 0, 0.125, 0.125)",0.000000,2,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
3,3,311303,val2017,"(0, 0, 0.16640625, 0.16640625)",0.125000,3,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
4,4,393226,val2017,"(0, 0, 0.125, 0.125)",0.133333,4,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72995,72995,518071,train2017,"(0, 0, 0.125, 0.125)",0.000000,72995,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
72996,72996,255930,train2017,"(0, 0, 0.125, 0.125)",0.125000,72996,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
72997,72997,255934,train2017,"(0, 0, 0.1, 0.1)",0.000000,72997,False,False,False,0,...,0,0,0,0,0,0,0,0,0,0
72998,72998,518080,train2017,"(0.125, 0.125, 0, 0)",0.000000,72998,False,False,False,0,...,0,0,0,0,0,0,0,5585,11846,14495


In [12]:
num_train, num_val, num_test = 0, 0, 0
# concat all response_indicies into a single array
for subject_id in response_dict.keys():
    print(f"processing subject {subject_id}")
    response_array = []
    for nsd_id in response_dict[subject_id].keys():
        response_array.append(response_dict[subject_id][nsd_id]['response_indicies'])
    response_array = np.concatenate(response_array, axis=0)
    assert len(response_array) == 30000, f"response_array length is {len(response_array)}"
    

    for nsd_id in response_dict[subject_id].keys():
        sample_dict = {}
        response_indicies = response_dict[subject_id][nsd_id]['response_indicies']
        # get the fmri response
        fmri = []
        for i in range(len(response_indicies)):
            response_index = response_indicies[i]
            # use divmod
            session_id, response_id = divmod(response_index, 750)
            if response_id == 0:
                response_id = 750
                session_id -= 1
            response_id -= 1
            session_id += 1
            identifier = f"sub-{subject_id:02d}_ses-{session_id:02d}"
            response_file = os.path.join(
                DATA_DIR, 
                f"sub-{subject_id:02d}",
                PREP_TYPE,
                identifier,
                f"{identifier}_res-{response_id:03d}.npy")
            fmri_response = np.load(response_file) # [x, y, z]
            fmri.append(fmri_response)
        # concat fmri to [N, x, y, z]
        fmri = np.stack(fmri, axis=0)
        sample_dict['fmri'] = fmri
        # get coco_id
        coco_id = response_dict[subject_id][nsd_id]['cocoId']
        sample_dict['cocoId'] = coco_id

        if coco_id in train_coco_ids:
            sub_dir = "train"
            num_train += 1
        elif coco_id in val_coco_ids:
            sub_dir = "val"
            num_val += 1
        elif coco_id in test_coco_ids:
            sub_dir = "test"
            num_test += 1
        else:
            print(f"coco_id {coco_id} is not in train_coco_ids, val_coco_ids, or test_coco_ids")
            continue
        # get the coco annotation
        coco_annotations = train_coco_annotations.imgToAnns[coco_id]
        if len(coco_annotations) == 0:
            coco_annotations = val_coco_annotations.imgToAnns[coco_id]
        text_annotations = []
        for ann in coco_annotations:
            text_annotations.append(category_id_to_name[ann['category_id']])
        sample_dict['label'] = text_annotations
        # get the coco image
        coco_image = images[response_dict[subject_id][nsd_id]['coco_index']]
        sample_dict['image'] = coco_image
        # get the fmri response indicies
        sample_dict['response_indicies'] = response_indicies
        # get the nsdId and subjectId
        sample_dict['nsdId'] = nsd_id
        sample_dict['subjectId'] = subject_id

        # save the sample_dict as npy
        save_dir = os.path.join(OUTPUT_DIR, f"sub-{subject_id:02d}", sub_dir)
        os.makedirs(save_dir, exist_ok=True)
        np.save(os.path.join(save_dir, f"sample_dict_{nsd_id}.npy"), sample_dict)
        # break
        
    break # break after first subject
print(f"num_train: {num_train}, num_val: {num_val}, num_test: {num_test}")

processing subject 1


coco_id 190756 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 111036 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 512194 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 50165 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 106389 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 524470 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 524486 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 3348 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 527643 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 528116 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 44592 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 5962 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 531828 is not in train_coco_ids, val_coco_ids, or test_coco_ids
coco_id 8998 is not in train_coco_ids, val_coco_ids, or test_coco_ids
