In [1]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor as ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image

model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


max_length = 16
num_beams = 4
gen_kwargs = {
    "max_length": max_length, 
    "num_beams": num_beams, 
    "return_dict_in_generate": True, 
    "output_hidden_states": True
}

  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 4.50k/4.50k [00:00<00:00, 2.20MB/s]
Downloading: 100%|██████████| 937M/937M [00:33<00:00, 29.4MB/s] 
Downloading: 100%|██████████| 228/228 [00:00<00:00, 112kB/s]
Downloading: 100%|██████████| 241/241 [00:00<00:00, 124kB/s]
Downloading: 100%|██████████| 779k/779k [00:00<00:00, 1.07MB/s]
Downloading: 100%|██████████| 446k/446k [00:00<00:00, 595kB/s] 
Downloading: 100%|██████████| 1.29M/1.29M [00:00<00:00, 1.40MB/s]
Downloading: 100%|██████████| 120/120 [00:00<00:00, 60.9kB/s]


In [3]:
import re
import os
import os.path as osp

from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class Algonauts2023Raw(Dataset):
    """
        Load original data for Algonauts2023 dataset
    """

    def __init__(self, data_path: str, hemisphere: str = "L", transform=None, train: bool = True, return_img_ids: bool = False):
        """
            Initialize a torch.utils.data.Dataset object for algonauts2023 dataset

            Args:
                data_path,              str, path to the algonauts2023 dataset which contains only ONE subject
                hemisphere,             str, select which hemisphere of the brain to be modeled
                                            can ONLY select "L" or "R"
                                            and ONLY applicable when train is TRUE
                transform,              torchvision.transform methods, apply normalization to the dataset
                train,                  bool, training data will be loaded if True. Test data otherwise.
                return_img_ids,         bool, return image ids, only used for feature extraction
        """

        # collect data paths
        path_struct = osp.join(data_path, "{}_split")
        self.dataset = list()
        self.transform = transform
        self.train = train
        self.return_img_ids = return_img_ids

        if train:
            shared_path = osp.join(
                path_struct.format("training"), "training_{}")
            if hemisphere == "L":
                self.fmri = np.load(osp.join(shared_path.format(
                    "fmri"), "lh_training_fmri.npy"))
            elif hemisphere == "R":
                self.fmri = np.load(osp.join(shared_path.format(
                    "fmri"), "rh_training_fmri.npy"))

            self.feature_path = shared_path.format("images")

        else:
            self.feature_path = osp.join(
                path_struct.format("test"), "test_images")

        self.dataset = list(os.listdir(self.feature_path))

        # sorted in ascending order if not train set
        if not train:
            self.dataset = sorted(self.dataset, key=lambda x: int(
                re.findall("\d{4}", x)[0]) - 1)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index: int):
        """
            Load designated sample

            Arg:
                index,          int, sample id

            Returns:
                image,          np.ndarray, the 3d numpy array of the image used to retrive fmri data
                fmri,           np.ndarray, the hemisphere FMRI data generated by the image
                img_ids,        str, image ids, only used for feature extraction
        """

        feat_file = self.dataset[index]
        feat_idx = int(re.findall("\d{4}", feat_file)[0]) - 1
        img = Image.open(osp.join(self.feature_path, feat_file))
        if img.mode != "RGB":
            img = img.convert(mode="RGB")

        if self.transform:
            img = self.transform(img)

        if self.return_img_ids:
            return img, self.fmri[feat_idx] if self.train else 0, feat_file
        else:
            return img, self.fmri[feat_idx] if self.train else 0


In [26]:
subj = "subj08"
train = True
tpe = "training" if train else "test"
path = "/mnt/data/{}".format(subj)
save = "/mnt/data/{}/{}_split/{}_features/vit-gpt2-image-captioning/decoder-raw".format(
    subj, tpe, tpe)


In [27]:
dset = Algonauts2023Raw(path, train=train, return_img_ids=True)

In [28]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from tqdm import tqdm
import numpy as np
from PIL import Image


def func(x):
    
    imgs = list()
    ids = list()
    for img, _, id in x:
        imgs.append(img)
        ids.append(id)
        
    return imgs, ids


model.eval()

ids = list()
features = list()
for img, id in tqdm(DataLoader(dset, batch_size=64, num_workers=12, collate_fn=func)):
    
    pixel_values = feature_extractor(
        images=img, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    feats = model.generate(pixel_values, **gen_kwargs)
    feats = [x for x in feats.encoder_hidden_states]
    feats = torch.stack(feats[-4:]).cpu()
    
    features.append(feats)
    ids.append(id)


100%|██████████| 138/138 [18:01<00:00,  7.84s/it]


In [None]:
for feats, id in tqdm(zip(features, ids)):

    for i in range(len(id)):
        hs = feats[:,i]
        hs = hs.numpy().astype(np.float32)

        if not os.path.isdir(os.path.join(save)):
            os.makedirs(save)

        np.save(os.path.join(save, id[i].split(".")[0]+".npy"), hs)

129it [06:08,  2.90s/it]