In [1]:
import json
import random
import copy

import time 
import numpy as np
import matplotlib.pyplot as plt
import cv2
import h5py
import os
import torch

from tqdm import tqdm
from PIL import Image
from sklearn.preprocessing import StandardScaler

import webdataset as wds
import sys

from utils import seed_everything

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device:",device)

device: cuda


In [2]:
# CLIP
import clip
from torchvision import transforms
from models import *
clip_extractor = Clipper("ViT-L/14", device=device, hidden_state=True, norm_embs=False)
clip_extractor_last = Clipper("ViT-L/14", device=device, hidden_state=False, norm_embs=False)

openclip_extractor = OpenClipper('ViT-H-14', device=device, hidden_state=False, norm_embs=False)
openclip_extractor_last = OpenClipper('ViT-H-14', device=device, hidden_state=True, norm_embs=False)

# ImageBind
import sys
sys.path.insert(0, "/fsx/proj-medarc/fmri/ImageBind")
sys.path.insert(0, "/fsx/proj-medarc/fmri/ImageBind/models")
import data
import imagebind_model
from imagebind_model import ModalityType

imagebind = imagebind_model.imagebind_huge(pretrained=True)
imagebind.eval().requires_grad_(False)
imagebind.to(device)

imagebind_transform = transforms.Compose(
    [transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(224),
    transforms.Normalize(
        mean=(0.48145466, 0.4578275, 0.40821073),
        std=(0.26862954, 0.26130258, 0.27577711))])

imagebind_hidden = copy.deepcopy(imagebind)
imagebind_hidden.modality_heads.vision = nn.Identity()
imagebind_hidden.modality_postprocessors.vision = nn.Identity()
imagebind_hidden.eval().requires_grad_(False)
imagebind_hidden.to(device)

# FCN segmentation model
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
fcn_weights = FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1
fcn_model = create_feature_extractor(fcn_resnet50(weights=fcn_weights), return_nodes=["backbone.maxpool"]).to(device)
fcn_model.eval().requires_grad_(False)
pass

ViT-L/14 cuda
ViT-L/14 cuda
ViT-H-14 cuda
ViT-H-14 cuda
THIS IS NOT WORKING CURRENTLY!




Note that the data have already been session-wise z-score normalized and averaged across same-image repetitions! In fsaverage surface space ("challenge" space), which may be equivalent to nsdgeneral.

In [3]:
for sub in [1]: # [1,2,3,4,5,6,7,8]
    #1 2 5 7 are the subjects who completed all trials
    subj=f'subj0{sub}'
    print(subj)
    
    seed_everything(0)

    samples_per_shard = 300 # samples per tar

    base_path = f"/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/{subj}"
    wds_path = base_path + "/webdataset"
    os.makedirs(wds_path,exist_ok=True)
    print("base_path:",base_path)

    # number of training samples
    files = os.listdir(base_path+"/training_split/training_images")
    train_png_files = [file for file in files if file.endswith('.png')]
    num_train = len(train_png_files)
    print("# training: ", num_train)
    train_imgs = np.array([np.array(Image.open(base_path+"/training_split/training_images/"+p)) for p in train_png_files])

    lh_fmri = np.load(base_path+"/training_split/training_fmri/lh_training_fmri.npy")
    rh_fmri = np.load(base_path+"/training_split/training_fmri/rh_training_fmri.npy")
    train_fmri = np.hstack((lh_fmri,rh_fmri))
    
    # Shuffle both matrices along the first axis
    permutation = np.random.permutation(train_imgs.shape[0])
    train_imgs = train_imgs[permutation]
    train_fmri = StandardScaler().fit_transform(train_fmri[permutation])

    # number of val samples
    files = os.listdir(base_path+"/test_split/test_images")
    test_png_files = [file for file in files if file.endswith('.png')]
    num_test = len(test_png_files)
    print("# test: ", num_test)
    test_imgs = np.array([np.array(Image.open(base_path+"/test_split/test_images/"+p)) for p in test_png_files])

    # train & val
    abs_cnt = -1
    for idx,i in enumerate(tqdm(range(0, num_train, samples_per_shard))):
        cur_imgs = np.moveaxis(train_imgs[i:i+samples_per_shard],-1,1)
        cur_imgs_tensor = torch.Tensor(cur_imgs).to(device)
        cur_samps = train_fmri[i:i+samples_per_shard]
        with torch.no_grad():
            clip_emb_hidden = clip_extractor.embed_image(cur_imgs_tensor)
            clip_emb_last = clip_extractor_last.embed_image(cur_imgs_tensor)

            ib_inputs = {ModalityType.VISION: imagebind_transform(cur_imgs_tensor)}

            ib_emb_last = imagebind(ib_inputs)['vision']
            ib_emb_hidden = imagebind_hidden(ib_inputs)['vision']

            seg_emb = fcn_model(cur_imgs_tensor)['backbone.maxpool']

        if idx==0: # val
            print("vert",cur_samps.shape)
            print("clip_vitl_hidden",clip_emb_hidden.shape)
            print("clip_vitl_final",clip_emb_last.shape)
            print("imagebind_hidden",ib_emb_hidden.shape)
            print("imagebind_final",ib_emb_last.shape)
            print("fcn_maxpool",seg_emb.shape)
            print("image",cur_imgs.shape)
            
            os.makedirs(wds_path+"/val",exist_ok=True)
            sink = wds.TarWriter(wds_path+f"/val/{subj}_{idx}.tar")
        else:
            os.makedirs(wds_path+"/train",exist_ok=True)
            sink = wds.TarWriter(wds_path+f"/train/{subj}_{idx-1}.tar")
            
        for ii in range(len(cur_samps)):
            abs_cnt += 1
            sink.write({
                "__key__": "sample%09d" % abs_cnt,
                "vert.npy": cur_samps[ii],
                "trial.npy": np.array([i+ii]),
                "clip_vitl_hidden.npy": clip_emb_hidden[ii].detach().cpu().numpy(),
                "clip_vitl_final.npy": clip_emb_last[ii].detach().cpu().numpy(),
                "imagebind_hidden.npy": ib_emb_hidden[ii].detach().cpu().numpy(),
                "imagebind_final.npy": ib_emb_last[ii].detach().cpu().numpy(),
                "fcn_maxpool.npy": seg_emb[ii].detach().cpu().numpy(),
                "image.npy": cur_imgs[ii],
            })
        sink.close()
        
    # test
    os.makedirs(wds_path+"/test",exist_ok=True)
    abs_cnt = -1
    for idx,i in enumerate(tqdm(range(0, num_test, samples_per_shard))):
        sink = wds.TarWriter(wds_path+f"/test/{subj}_{idx}.tar")

        cur_imgs = np.moveaxis(test_imgs[i:i+samples_per_shard],-1,1)
        cur_imgs_tensor = torch.Tensor(cur_imgs).to(device)
        with torch.no_grad():
            clip_emb_hidden = clip_extractor.embed_image(cur_imgs_tensor)
            clip_emb_last = clip_extractor_last.embed_image(cur_imgs_tensor)

            ib_inputs = {ModalityType.VISION: imagebind_transform(cur_imgs_tensor)}

            ib_emb_last = imagebind(ib_inputs)['vision']
            ib_emb_hidden = imagebind_hidden(ib_inputs)['vision']

            seg_emb = fcn_model(cur_imgs_tensor)['backbone.maxpool']
        
        for ii in range(len(cur_imgs)):
            abs_cnt += 1
            sink.write({
                "__key__": "sample%09d" % abs_cnt,
                "trial.npy": np.array([i+ii]),
                "clip_vitl_hidden.npy": clip_emb_hidden[ii].detach().cpu().numpy(),
                "clip_vitl_final.npy": clip_emb_last[ii].detach().cpu().numpy(),
                "imagebind_hidden.npy": ib_emb_hidden[ii].detach().cpu().numpy(),
                "imagebind_final.npy": ib_emb_last[ii].detach().cpu().numpy(),
                "fcn_maxpool.npy": seg_emb[ii].detach().cpu().numpy(),
                "image.npy": cur_imgs[ii],
            })
        sink.close()

    # So that we have the info, let's write this to a file
    with open(wds_path+f"/metadata_{subj}.json", "w") as f:    
        f.write(json.dumps({
            'train': num_train-samples_per_shard,
            'val': samples_per_shard,
            'test': num_test,
            'lh_samps': lh_fmri.shape[-1],
            'rh_samps': rh_fmri.shape[-1],
            'totals': num_train+num_test, 
        }, indent=4))

subj01
base_path: /fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/subj01
# training:  9841
# test:  159


  0%|          | 0/33 [00:00<?, ?it/s]



vert (300, 39548)
clip_vitl_hidden torch.Size([300, 257, 768])
clip_vitl_final torch.Size([300, 768])
imagebind_hidden torch.Size([300, 257, 1280])
imagebind_final torch.Size([300, 1024])
fcn_maxpool torch.Size([300, 64, 107, 107])
image (300, 3, 425, 425)


  0%|          | 0/1 [00:00<?, ?it/s]

# Upload to hf

In [37]:
from huggingface_hub import notebook_login, upload_file
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [38]:
from tqdm import tqdm

for sub in [0]: #np.arange(8):
    subj=f'subj0{sub+1}'
    print(f"{subj}...")
    for tar in tqdm(range(0,32)):
        upload_file(
            path_or_fileobj=f"/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/{subj}/webdataset/train/{subj}_{tar}.tar",
            path_in_repo=f"algonauts/train/{subj}_{tar}.tar",
            repo_id="pscotti/naturalscenesdataset",
            repo_type="dataset")

    tar = 0
    upload_file(path_or_fileobj=
                f"/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/{subj}/webdataset/val/{subj}_{tar}.tar",
        path_in_repo=f"algonauts/val/{subj}_{tar}.tar",
        repo_id="pscotti/naturalscenesdataset",
        repo_type="dataset")

    upload_file(path_or_fileobj=
                f"/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/{subj}/webdataset/test/{subj}_{tar}.tar",
        path_in_repo=f"algonauts/test/{subj}_{tar}.tar",
        repo_id="pscotti/naturalscenesdataset",
        repo_type="dataset")

    upload_file(
    path_or_fileobj=f"/fsx/proj-medarc/fmri/natural-scenes-dataset/algonauts_data/dataset/{subj}/webdataset/metadata_{subj}.json",
    path_in_repo=f"algonauts/metadata_{subj}.json",
    repo_id="pscotti/naturalscenesdataset",
    repo_type="dataset")

    print('done!')

subj01...


  0%|                                                    | 0/32 [00:00<?, ?it/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_0.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

  3%|█▍                                          | 1/32 [00:34<17:37, 34.11s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_1.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

  6%|██▊                                         | 2/32 [01:06<16:31, 33.06s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_2.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

  9%|████▏                                       | 3/32 [01:38<15:45, 32.62s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_3.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 12%|█████▌                                      | 4/32 [02:11<15:19, 32.85s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_4.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 16%|██████▉                                     | 5/32 [02:44<14:43, 32.73s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_5.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 19%|████████▎                                   | 6/32 [03:19<14:31, 33.54s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_6.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 22%|█████████▋                                  | 7/32 [03:57<14:37, 35.10s/it]

subj01_7.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

 25%|███████████                                 | 8/32 [04:31<13:49, 34.55s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_8.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 28%|████████████▍                               | 9/32 [05:08<13:35, 35.46s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_9.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 31%|█████████████▍                             | 10/32 [05:45<13:11, 36.00s/it]

subj01_10.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

 34%|██████████████▊                            | 11/32 [06:22<12:38, 36.11s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_11.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 38%|████████████████▏                          | 12/32 [06:56<11:54, 35.75s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_12.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 41%|█████████████████▍                         | 13/32 [07:29<11:02, 34.88s/it]

subj01_13.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

 44%|██████████████████▊                        | 14/32 [08:02<10:17, 34.28s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_14.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 47%|████████████████████▏                      | 15/32 [08:35<09:32, 33.69s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_15.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 50%|█████████████████████▌                     | 16/32 [09:13<09:20, 35.05s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_16.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 53%|██████████████████████▊                    | 17/32 [09:47<08:41, 34.79s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_17.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 56%|████████████████████████▏                  | 18/32 [10:22<08:07, 34.81s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_18.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 59%|█████████████████████████▌                 | 19/32 [10:57<07:32, 34.83s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_19.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 62%|██████████████████████████▉                | 20/32 [11:31<06:56, 34.70s/it]

subj01_20.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

 66%|████████████████████████████▏              | 21/32 [12:05<06:19, 34.52s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_21.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 69%|█████████████████████████████▌             | 22/32 [12:42<05:51, 35.11s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_22.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 72%|██████████████████████████████▉            | 23/32 [13:20<05:25, 36.13s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_23.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 75%|████████████████████████████████▎          | 24/32 [13:53<04:42, 35.26s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_24.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 78%|█████████████████████████████████▌         | 25/32 [14:26<04:01, 34.43s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_25.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 81%|██████████████████████████████████▉        | 26/32 [14:58<03:22, 33.78s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_26.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 84%|████████████████████████████████████▎      | 27/32 [15:30<02:46, 33.25s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_27.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 88%|█████████████████████████████████████▋     | 28/32 [16:03<02:12, 33.22s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_28.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 91%|██████████████████████████████████████▉    | 29/32 [16:41<01:43, 34.42s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_29.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 94%|████████████████████████████████████████▎  | 30/32 [17:14<01:08, 34.16s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_30.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

 97%|█████████████████████████████████████████▋ | 31/32 [17:49<00:34, 34.29s/it]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_31.tar:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

100%|███████████████████████████████████████████| 32/32 [18:21<00:00, 34.42s/it]


Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_0.tar:   0%|          | 0.00/1.73G [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

subj01_0.tar:   0%|          | 0.00/890M [00:00<?, ?B/s]

done!
