## Import dependencies

In [1]:
# Standard Library Imports
import os
import sys
from subprocess import call, DEVNULL
import json
import time
import traceback
import asyncio
import aiohttp
import re

# Third-Party Library Imports
import numpy as np
import boto3
from botocore import UNSIGNED
from botocore.client import Config
import webdataset as wds
import nibabel as nib
import pickle as pkl
from einops import rearrange
import torchvision.transforms as transforms
from PIL import Image
import torch
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from sklearn.preprocessing import StandardScaler
from litdata import optimize

## Helper functions

In [2]:
def reshape_to_2d(tensor):
    return rearrange(tensor, 'b h w c -> (b h) (c w)')

def reshape_to_original(tensor_2d, b=300, h=64, w=64, c=48):
    return rearrange(tensor_2d, '(b h) (c w) -> b h w c', b=b, h=h, w=w, c=c)

def header_to_dict(header):
    readable_header = {}
    for key, value in header.items():
        readable_header[key] = value
    return readable_header

def temporal_interp1d(fmri_data, change_TR):
    original_time_points = np.arange(fmri_data.shape[0])  # Time points: 0, 1, 2, ..., T-1
    new_time_points = np.arange(0, fmri_data.shape[0], change_TR)  # New time points: 0, 2, 4, ...

    reshaped_data = fmri_data.reshape(fmri_data.shape[0], -1)  # Reshape to (T, X*Y*Z)
    interpolate = interp1d(original_time_points, reshaped_data, kind='linear', axis=0, bounds_error=False, fill_value="extrapolate")
    resampled_fmri_data = interpolate(new_time_points).reshape((len(new_time_points),) + fmri_data.shape[1:])
    return resampled_fmri_data
    
def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')

## Create dir to save dataset

In [3]:
temp_folder = '/scratch/temp_MNIs'
os.makedirs(temp_folder, exist_ok=True)
print(temp_folder)

# wds_folder = os.getcwd()+'/openneuro_wds'
# prefix = 'fmri_foundation_datasets/openneuro_MNI/'

wds_folder = os.getcwd()+'/nsd_litdata' #'/nsd_wds'
prefix = 'fmri_foundation_datasets/NSD_MNI/'

os.makedirs(wds_folder, exist_ok=True)
print(wds_folder)

s3_output_folder_name = prefix.split('/')[1]+'_litdata'
print(s3_output_folder_name)

/scratch/temp_MNIs
/weka/proj-fmri/paulscotti/fMRI-foundation-model/dataset_creation/wds_creation/nsd_litdata
NSD_MNI_litdata


In [4]:
# # delete saved files
# command = f"rm {temp_folder}/*"
# call(command,shell=True)

# delete saved files
command = f"rm {wds_folder}/*"
call(command,shell=True)

0

## Job

In [5]:
s3 = boto3.client('s3')
bucket_name = 'proj-fmri'

paginator = s3.get_paginator('list_objects_v2')
file_name_list = []
for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
    for obj in page.get('Contents', []):
        file_name = obj['Key']
        file_name_list.append(file_name)
print("len(file_name_list) =", len(file_name_list))

len(file_name_list) = 3696


In [6]:
# command = f"aws s3 sync s3://proj-fmri/{prefix[:-1]} {temp_folder}"
# print(command)
# call(command, shell=True, stdout=DEVNULL, stderr=DEVNULL)

# # print(f"\nchecking for file {file_name_list[-1]}...")
# # if not os.path.exists(file_name_list[-1]):
# #     time.sleep(20)

# print("\nready!")

In [8]:
from collections import defaultdict

# Initialize a dictionary to hold the categorized file paths
datasets = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
datasets_minmean = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

for file_name in file_name_list:
    parts = file_name.split('/')
    dataset_id = parts[2]
    subject_parts = parts[3].split('_')
    subject_id = subject_parts[0]  # Extract the subject identifier

    # Check for session identifier, default to "ses-01" if not present
    session_id = "ses-01"
    for part in subject_parts:
        if part.startswith("ses-"):
            session_id = part
            break
    
    datasets[dataset_id][subject_id][session_id].append(file_name)

In [21]:
# TRs_per_sample = 32
# max_samples_per_tar = 30 # try to translate to around 1 Gb per tar
# max_TRs_per_tar = max_samples_per_tar * TRs_per_sample

# current_dataset = None
# current_subject = None

# MNI_mask = nib.load("/weka/proj-fmri/paulscotti/fMRI-foundation-model/dataset_creation/afni_conversion/tpl-MNI152NLin2009cAsym_res-02_T1w_brain.nii.gz").get_fdata()
# MNI_mask[MNI_mask>0]=1
# MNI_mask = MNI_mask.astype(bool)

# for dataset_id in ['ses-prffloc']:#list(datasets.keys()):
#     for subject_id in list(datasets[dataset_id].keys()):
#         for session_id in list(datasets[dataset_id][subject_id].keys()):
#             print(f"Processing {dataset_id} {subject_id} {session_id}")
#             # first get min and max values across all runs in session
#             run_count = 0
#             for file_name in datasets[dataset_id][subject_id][session_id]:
#                 temp_file_path = f"{temp_folder}/{file_name.split('/')[2]}/{file_name.split('/')[3]}"
#                 # temp_folder + '/' + file_name.split('/')[2] + '_' + file_name.split('/')[-1]
            
#                 # if not os.path.exists(temp_file_path):
#                 #     # s3.download_file(bucket_name, file_name, temp_file_path)
#                 #     command = f"aws s3 cp s3://proj-fmri/{file_name} {temp_file_path}"
#                 #     call(command,shell=True)

#                 if not os.path.exists(temp_file_path):
#                     raise Exception("s3 file not found")
            
#                 func_nii = nib.load(temp_file_path).get_fdata()
#                 func_nii = np.moveaxis(func_nii, -1, 0)
#                 data = func_nii[:,MNI_mask] # find normalization values only inside of the MNI brain mask

#                 # ignore outliers via standard deviation exclusion
#                 low = data.mean() - 2 * data.std()
#                 high = data.mean() + 2 * data.std()
#                 filtered_data = data[(data > low) & (data < high)]
#                 min_val = np.min(filtered_data)
#                 max_val = np.max(filtered_data)
                
#                 run_count +=1
#                 if run_count==1: 
#                     min = min_val
#                     max = max_val
#                 else:
#                     min += min_val
#                     max += max_val
                    
#             min /= run_count
#             max /= run_count
#             print(f"min = {min} | max = {max}")
#             datasets_minmean[dataset_id][subject_id][session_id] = [min,max]

Processing ses-prffloc sub-01 ses-prffloc
min = 0.0 | max = 1058.3333333333333


KeyError: 'ses-prffloc'

In [10]:
# with open('minmax_datasets.json', 'w') as file:
#     json.dump(datasets_minmean, file)
# json_dump = True
# print("done!")

In [32]:
with open('minmax_datasets.json', 'r') as file:
    datasets_minmean = json.load(file)
print("loaded!")

loaded!


In [33]:
def making_litdata(file_name):
    parts = file_name.split('/')
    dataset_id = parts[2]
    subject_parts = parts[3].split('_')
    subject_id = subject_parts[0]  # Extract the subject identifier
    
    # Check for session identifier, default to "ses-01" if not present
    session_id = "ses-01"
    for part in subject_parts:
        if part.startswith("ses-"):
            session_id = part
            break
            
    minmax = datasets_minmean[dataset_id][subject_id][session_id]
    min, max = minmax[0], minmax[1]
    
    temp_file_path = f"{temp_folder}/{file_name.split('/')[2]}/{file_name.split('/')[3]}"
    # temp_file_path = temp_folder + '/' + file_name.split('/')[2] + '_' + file_name.split('/')[-1]

    # if not os.path.exists(temp_file_path):
    #     # s3.download_file(bucket_name, file_name, temp_file_path)
    #     command = f"aws s3 cp s3://proj-fmri/{file_name} {temp_file_path}"
    #     call(command,shell=True)

    if not os.path.exists(temp_file_path):
        raise Exception("s3 file not found")
    
    func_nii = nib.load(temp_file_path).get_fdata()
    func_nii = np.moveaxis(func_nii, -1, 0)
    func_nii = func_nii[:,6:94,8:112,10:82].astype(np.float16) # [T, 97, 115, 97] to [T, 88, 104, 72]

    # normalize by min max
    func_nii = (func_nii - min) / (max - min)
    
    data = {
        "func": func_nii, 
    }
    return data

In [34]:
filtered_file_name_list = [file for file in file_name_list if "prffloc" not in file]
print("prffloc encountered an error so for sake of time im just skipping it")

optimize(
    fn=making_litdata,  # The function applied over each input.
    inputs=filtered_file_name_list,  # Provide any inputs. The fn is applied on each item.
    output_dir=wds_folder,  # The directory where the optimized data are stored.
    num_workers=16,  # The number of workers. The inputs are distributed among them.
    chunk_bytes="256MB"  # The maximum number of bytes to write into a data chunk.
)

prffloc encountered an error so for sake of time im just skipping it
Create an account on https://lightning.ai/ to optimize your data faster using multiple nodes and large machines.
Storing the files under /weka/proj-fmri/paulscotti/fMRI-foundation-model/dataset_creation/wds_creation/nsd_litdata
Setup started with fast_dev_run=False.
Setup finished in 0.146 seconds. Found 3600 items to process.
Starting 20 workers with 3600 items.
Workers are ready ! Starting data processing...


Progress:   0%|                                                                 | 0/3600 [00:00<?, ?it/s]

Rank 5 inferred the following `['numpy']` data format.
Rank 6 inferred the following `['numpy']` data format.
Rank 12 inferred the following `['numpy']` data format.
Rank 3 inferred the following `['numpy']` data format.
Rank 7 inferred the following `['numpy']` data format.
Rank 2 inferred the following `['numpy']` data format.
Rank 8 inferred the following `['numpy']` data format.
Rank 17 inferred the following `['numpy']` data format.
Rank 18 inferred the following `['numpy']` data format.
Rank 11 inferred the following `['numpy']` data format.
Rank 1 inferred the following `['numpy']` data format.
Rank 13 inferred the following `['numpy']` data format.
Rank 10 inferred the following `['numpy']` data format.
Rank 15 inferred the following `['numpy']` data format.
Rank 9 inferred the following `['numpy']` data format.
Rank 4 inferred the following `['numpy']` data format.
Rank 19 inferred the following `['numpy']` data format.
Rank 16 inferred the following `['numpy']` data format.
R

In [None]:
command = f"aws s3 cp --recursive {wds_folder} s3://proj-fmri/fmri_foundation_datasets/{s3_output_folder_name}"
call(command,shell=True)

upload: nsd_litdata/chunk-0-0.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-0.bin
upload: nsd_litdata/chunk-0-10.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-10.bin
upload: nsd_litdata/chunk-0-1.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-1.bin
upload: nsd_litdata/chunk-0-100.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-100.bin
upload: nsd_litdata/chunk-0-101.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-101.bin
upload: nsd_litdata/chunk-0-102.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-102.bin
upload: nsd_litdata/chunk-0-104.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-104.bin
upload: nsd_litdata/chunk-0-106.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-106.bin
upload: nsd_litdata/chunk-0-103.bin to s3://proj-fmri/fmri_foundation_datasets/NSD_MNI_litdata/chunk-0-103.bin
upload: nsd