In [123]:
'''End-to-end deep reconstruction: Data preparation scripts.
Originally developed by Guohua Shen.
This script creates LMDB data of fMRI and images for training of end-to-end deep reconstruction.
'''


import csv
import fnmatch
import glob
import os
from datetime import datetime

import PIL.Image
#import caffe
import lmdb
import numpy as np
#from scipy.misc import imresize

import bdpy
import cv2
import matplotlib.pyplot as plt

# parameters

In [112]:
fmri_file = './data/fmri/sub-01_perceptionNaturalImageTraining_original_VC.h5'
roi_selector = 'ROI_VC = 1'
imagenet_file = './data/images/training'

# Main

In [106]:
# Load h5 file
fmri_data_bd = bdpy.BData(fmri_file)

In [110]:
# Extract ImageNet Labels
fmri_labels = fmri_data_bd.get('Label')[:, 1].flatten()

fmri_labels = ['n%08d_%d' % (int(('%f' % a).split('.')[0]),
                                 int(('%f' % a).split('.')[1]))
                   for a in fmri_labels]
print(len(fmri_labels))

6000


In [111]:
# Get fMRI data in the ROI
fmri_data = fmri_data_bd.select(roi_selector)
print(fmri_data.shape)

(6000, 11726)


In [154]:
item = 10
signle_image = cv2.imread(imagenet_file + f'/{fmri_labels[item]}.JPEG')
signle_fmri = fmri_data[item]

print(signle_image.shape)
print(signle_fmri.shape)

(500, 500, 3)
(11726,)


In [173]:
fmri_labels.sort()

In [174]:
fmri_labels

['n01518878_10042',
 'n01518878_10042',
 'n01518878_10042',
 'n01518878_10042',
 'n01518878_10042',
 'n01518878_12028',
 'n01518878_12028',
 'n01518878_12028',
 'n01518878_12028',
 'n01518878_12028',
 'n01518878_14075',
 'n01518878_14075',
 'n01518878_14075',
 'n01518878_14075',
 'n01518878_14075',
 'n01518878_14910',
 'n01518878_14910',
 'n01518878_14910',
 'n01518878_14910',
 'n01518878_14910',
 'n01518878_5958',
 'n01518878_5958',
 'n01518878_5958',
 'n01518878_5958',
 'n01518878_5958',
 'n01518878_7346',
 'n01518878_7346',
 'n01518878_7346',
 'n01518878_7346',
 'n01518878_7346',
 'n01518878_7579',
 'n01518878_7579',
 'n01518878_7579',
 'n01518878_7579',
 'n01518878_7579',
 'n01518878_8432',
 'n01518878_8432',
 'n01518878_8432',
 'n01518878_8432',
 'n01518878_8432',
 'n01639765_22407',
 'n01639765_22407',
 'n01639765_22407',
 'n01639765_22407',
 'n01639765_22407',
 'n01639765_32862',
 'n01639765_32862',
 'n01639765_32862',
 'n01639765_32862',
 'n01639765_32862',
 'n01639765_37122',


In [171]:
len(list(set(fmri_labels)))

1200

In [172]:
6000/1200

5.0

In [155]:
from torch.utils.data import Dataset
import bdpy
import cv2
import numpy as np

class CustomDataLoader(Dataset):
    def __init__(self, fmri_file:str, imagenet_folder:str, roi_selector:str = 'ROI_VC = 1', transform = None):
        super(CustomDataLoader, self).__init__()

        self.imagenet_folder = imagenet_folder

        # Load h5 file
        fmri_data_bd = bdpy.BData(fmri_file)

        # Get ImageNet Labels
        fmri_labels = fmri_data_bd.get('Label')[:, 1].flatten()
        self.fmri_labels = ['n%08d_%d' % (int(('%f' % a).split('.')[0]),
                              int(('%f' % a).split('.')[1])) for a in fmri_labels]


        # Get fMRI data in the ROI
        self.fmri_data = fmri_data_bd.select(roi_selector)

        self.transform = transform


    def __getitem__(self, item):
        image = cv2.imread(self.imagenet_folder + f'/{self.fmri_labels[item]}.JPEG')
        fmri = self.fmri_data[item]

        if self.transform is not None:
            image = self.transform(image)

        return image, fmri


    def __len__(self):
        return len(self.fmri_labels)

In [156]:
dataset = CustomDataLoader(fmri_file = './data/fmri/sub-01_perceptionNaturalImageTraining_original_VC.h5',
                 imagenet_folder= './data/images/training')

In [161]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=1200)

In [162]:
for x, y in dataloader:
    x=x
    y=y
    break

# ETC

In [27]:
# Settings ---------------------------------------------------------------

# Image size
img_size = (248, 248)
# For image jittering, we prepare the images to be larger than 227 x 227

# fMRI data
fmri_data_table = [
    {'subject': 'sub-01',
     'data_file': './data/fmri/sub-01_perceptionNaturalImageTraining_original_VC.h5',
     'roi_selector': 'ROI_VC = 1',
     'output_dir': './lmdb/sub-01'}
    ]

In [28]:
# Image data
image_dir = './data/images/training'
image_file_pattern = '*.JPEG'

In [64]:
sbj = fmri_data_table[0]

# Load fMRI data
fmri_data_bd = bdpy.BData(sbj['data_file'])
fmri_labels = fmri_data_bd.get('Label')[:, 1].flatten()

# fmri Labels
fmri_labels = ['n%08d_%d' % (int(('%f' % a).split('.')[0]),
                                 int(('%f' % a).split('.')[1]))
                   for a in fmri_labels]


# Get fMRI data in the ROI
fmri_data = fmri_data_bd.select(sbj['roi_selector'])

In [69]:
fmri_data_bd.show_metadata()

| Key         | Description          |
|-------------|----------------------|
| VoxelData   | 1 = VoxelData        |
| Run         | 1 = Run              |
| Block       | 1 = Block            |
| Label       | 1 = Label            |
| image_index | Label stimulus_index |
| stimulus_id | Label stimulus_id    |
| voxel_x     | Voxel x coordinate   |
| voxel_y     | Voxel y coordinate   |
| voxel_z     | Voxel z coordinate   |
| ROI_V1      | 1 = ROI V1           |
| ROI_V2      | 1 = ROI V2           |
| ROI_V3      | 1 = ROI V3           |
| ROI_V4      | 1 = ROI V4           |
| ROI_LOC     | 1 = ROI LOC          |
| ROI_FFA     | 1 = ROI FFA          |
| ROI_PPA     | 1 = ROI PPA          |
| ROI_LVC     | 1 = ROI LVC          |
| ROI_HVC     | 1 = ROI HVC          |
| ROI_VC      | 1 = ROI VC           |


In [89]:
# 6000,11726
fmri_data_bd.get('ROI_VC').shape

(6000, 11726)

In [72]:
fmri_data.shape

(6000, 11726)

In [62]:
fmri_labels

['n01639765_47681',
 'n12596148_6386',
 'n04254680_1190',
 'n01645776_9743',
 'n03472535_10217',
 'n02943871_3559',
 'n04373894_26002',
 'n04154565_19381',
 'n13111881_11914',
 'n03397947_2934',
 'n04442312_19042',
 'n02236241_6815',
 'n03623556_17722',
 'n02472293_13453',
 'n01704323_10031',
 'n03359137_31119',
 'n02692877_31838',
 'n04090263_7624',
 'n04123740_13353',
 'n02439033_9911',
 'n02226429_17165',
 'n02445715_13221',
 'n04555897_5017',
 'n03746005_2516',
 'n03209910_22305',
 'n04044716_12763',
 'n01963571_7996',
 'n03400231_21509',
 'n03445777_8313',
 'n03793489_13628',
 'n03646296_7086',
 'n02084071_14426',
 'n03512147_23',
 'n02055803_6095',
 'n02432983_9121',
 'n02885462_19412',
 'n03544143_3031',
 'n11978233_18962',
 'n07756951_14073',
 'n04284002_21913',
 'n01518878_5958',
 'n12582231_39329',
 'n02090827_9305',
 'n01855672_14199',
 'n03602883_15696',
 'n07734017_8706',
 'n04376876_28833',
 'n03345487_8958',
 'n02799175_15904',
 'n03187595_7260',
 'n02503517_10053',
 'n0

In [92]:
images_list = glob.glob(os.path.join(image_dir, image_file_pattern))
images_list

['./data/images/training\\n01518878_10042.JPEG',
 './data/images/training\\n01518878_12028.JPEG',
 './data/images/training\\n01518878_14075.JPEG',
 './data/images/training\\n01518878_14910.JPEG',
 './data/images/training\\n01518878_5958.JPEG',
 './data/images/training\\n01518878_7346.JPEG',
 './data/images/training\\n01518878_7579.JPEG',
 './data/images/training\\n01518878_8432.JPEG',
 './data/images/training\\n01639765_22407.JPEG',
 './data/images/training\\n01639765_32862.JPEG',
 './data/images/training\\n01639765_37122.JPEG',
 './data/images/training\\n01639765_40261.JPEG',
 './data/images/training\\n01639765_44823.JPEG',
 './data/images/training\\n01639765_47681.JPEG',
 './data/images/training\\n01639765_48759.JPEG',
 './data/images/training\\n01639765_52902.JPEG',
 './data/images/training\\n01645776_10130.JPEG',
 './data/images/training\\n01645776_10758.JPEG',
 './data/images/training\\n01645776_8522.JPEG',
 './data/images/training\\n01645776_8879.JPEG',
 './data/images/training\\

In [94]:
images_table = {os.path.splitext(os.path.basename(f))[0]: f
                    for f in images_list}
images_table

{'n01518878_10042': './data/images/training\\n01518878_10042.JPEG',
 'n01518878_12028': './data/images/training\\n01518878_12028.JPEG',
 'n01518878_14075': './data/images/training\\n01518878_14075.JPEG',
 'n01518878_14910': './data/images/training\\n01518878_14910.JPEG',
 'n01518878_5958': './data/images/training\\n01518878_5958.JPEG',
 'n01518878_7346': './data/images/training\\n01518878_7346.JPEG',
 'n01518878_7579': './data/images/training\\n01518878_7579.JPEG',
 'n01518878_8432': './data/images/training\\n01518878_8432.JPEG',
 'n01639765_22407': './data/images/training\\n01639765_22407.JPEG',
 'n01639765_32862': './data/images/training\\n01639765_32862.JPEG',
 'n01639765_37122': './data/images/training\\n01639765_37122.JPEG',
 'n01639765_40261': './data/images/training\\n01639765_40261.JPEG',
 'n01639765_44823': './data/images/training\\n01639765_44823.JPEG',
 'n01639765_47681': './data/images/training\\n01639765_47681.JPEG',
 'n01639765_48759': './data/images/training\\n01639765_4

In [97]:
label_table = {os.path.splitext(os.path.basename(f))[0]: i + 1
                   for i, f in enumerate(images_list)}
label_table

{'n01518878_10042': 1,
 'n01518878_12028': 2,
 'n01518878_14075': 3,
 'n01518878_14910': 4,
 'n01518878_5958': 5,
 'n01518878_7346': 6,
 'n01518878_7579': 7,
 'n01518878_8432': 8,
 'n01639765_22407': 9,
 'n01639765_32862': 10,
 'n01639765_37122': 11,
 'n01639765_40261': 12,
 'n01639765_44823': 13,
 'n01639765_47681': 14,
 'n01639765_48759': 15,
 'n01639765_52902': 16,
 'n01645776_10130': 17,
 'n01645776_10758': 18,
 'n01645776_8522': 19,
 'n01645776_8879': 20,
 'n01645776_9361': 21,
 'n01645776_9576': 22,
 'n01645776_9693': 23,
 'n01645776_9743': 24,
 'n01664990_10648': 25,
 'n01664990_13731': 26,
 'n01664990_16740': 27,
 'n01664990_18293': 28,
 'n01664990_19129': 29,
 'n01664990_19923': 30,
 'n01664990_65': 31,
 'n01664990_7133': 32,
 'n01704323_10031': 33,
 'n01704323_10239': 34,
 'n01704323_10394': 35,
 'n01704323_5092': 36,
 'n01704323_8008': 37,
 'n01704323_8767': 38,
 'n01704323_9172': 39,
 'n01704323_9812': 40,
 'n01726692_18809': 41,
 'n01726692_21053': 42,
 'n01726692_30357': 

In [99]:
sample_index=1
sample_label = fmri_labels[sample_index - 1]  # Sample label (file name)
sample_label_num = label_table[sample_label]

KeyError: 1

In [100]:
sample_label

'n01639765_47681'

In [101]:
sample_label_num

14

In [None]:
# Create LMDB data -------------------------------------------------------

for sbj in fmri_data_table:


    # Load fMRI data
    print('Loading %s' % sbj['data_file'])
    fmri_data_bd = bdpy.BData(sbj['data_file'])

    # Load image files
    images_list = glob.glob(os.path.join(image_dir, image_file_pattern))  # List of image files (full path)
    images_table = {os.path.splitext(os.path.basename(f))[0]: f
                    for f in images_list}                                 # Image label to file path table
    label_table = {os.path.splitext(os.path.basename(f))[0]: i + 1
                   for i, f in enumerate(images_list)}                    # Image label to serial number table

    # Get image labels in the fMRI data
    #import pdb; pdb.set_trace()
    fmri_labels = fmri_data_bd.get('Label')[:, 1].flatten()

    # Convet image labels in fMRI data from float to file name labes (str)
    fmri_labels = ['n%08d_%d' % (int(('%f' % a).split('.')[0]),
                                 int(('%f' % a).split('.')[1]))
                   for a in fmri_labels]

    # Get sample indexes
    n_sample = fmri_data_bd.dataset.shape[0]

    index_start = 1
    index_end = n_sample
    index_step = 1

    sample_index_list = range(index_start, index_end + 1, index_step)

    # Shuffle the sample indexes
    sample_index_list = np.random.permutation(sample_index_list)

    # Save the sample indexes
    save_name = 'sample_index_list.txt'
    np.savetxt(os.path.join(sbj['output_dir'], save_name), sample_index_list, fmt='%d')

    # Get fMRI data in the ROI
    fmri_data = fmri_data_bd.select(sbj['roi_selector'])

    # Normalize fMRI data
    fmri_data_mean = np.mean(fmri_data, axis=0)
    fmri_data_std = np.std(fmri_data, axis=0)

    fmri_data = (fmri_data - fmri_data_mean) / fmri_data_std

    map_size = 100 * 1024 * len(sample_index_list) * 10
    env = lmdb.open(os.path.join(sbj['output_dir'], 'fmri'), map_size=map_size)

    with env.begin(write=True) as txn:
        for j0, sample_index in np.ndenumerate(sample_index_list):

            sample_label = fmri_labels[sample_index - 1]  # Sample label (file name)
            sample_label_num = label_table[sample_label]  # Sample label (serial number)

            print('Index %d, sample %d' % (j0[0] + 1, sample_index))
            print('Data label: %d (%s)' % (sample_label_num, sample_label))
            print(' ')

            # fMRI data in the sample
            sample_data = fmri_data[sample_index - 1, :]
            sample_data = np.float64(sample_data)  # Datum should be double float (?)
            sample_data = np.reshape(sample_data, (sample_data.size, 1, 1))  # Num voxel x 1 x 1

            datum = caffe.io.array_to_datum(sample_data)
            datum.label = sample_label_num  # Datum.label should be int (int32)
            # The encode is only essential in Python 3
            str_id = '%05d' % (j0[0] + 1)
            txn.put(str_id.encode('ascii'), datum.SerializeToString())

    # Create lmdb for images
    print('----------------------------------------')
    print('Images')

    map_size = 30 * 1024 * len(sample_index_list) * 10
    env = lmdb.open(os.path.join(sbj['output_dir'], 'images'), map_size=map_size)

    with env.begin(write=True) as txn:
        for j0, sample_index in np.ndenumerate(sample_index_list):

            sample_label = fmri_labels[sample_index - 1]  # Sample label (file name)
            sample_label_num = label_table[sample_label]  # Sample label (serial number)

            print('Index %d, sample %d' % (j0[0] + 1, sample_index))
            print('Data label: %d (%s)' % (sample_label_num, sample_label))
            print(' ')

            # Load images
            image_file = images_table[sample_label]
            img = PIL.Image.open(image_file)
            img = np.asarray(img)
            img = imresize(img, img_size, interp='bilinear')

            # Monochrome --> RGB
            if img.ndim == 2:
                img_rgb = np.zeros((img_size[0], img_size[1], 3), dtype=img.dtype)
                img_rgb[:, :, 0] = img
                img_rgb[:, :, 1] = img
                img_rgb[:, :, 2] = img
                img = img_rgb

            # h x w x c --> c x h x w
            img = img.transpose(2, 0, 1)

            # RGB --> BGR
            img = img[::-1]

            datum = caffe.proto.caffe_pb2.Datum()
            datum.channels = img.shape[0]
            datum.height = img.shape[1]
            datum.width = img.shape[2]
            datum.data = img.tobytes()
            datum.label = sample_label_num

            str_id = '%05d' % (j0[0] + 1)
            txn.put(str_id.encode('ascii'), datum.SerializeToString())

print('Done!')