In [37]:
import os
import sys
import torch
import tarfile
import webdataset as wds
from skimage import io
import io as bio
import fMRI_MAE.utils as utils 
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image

sys.path.append('/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/jepa')

from jepa.src.models.vision_transformer import vit_tiny

In [48]:
class FMRIDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, are_tars=True):
        self.root = root
        self.transform = transform
        self.folders = os.listdir(root)
        self.are_tars = are_tars

        self.files = []
        if self.are_tars:
            self.folders = [folder for folder in self.folders if folder[-4:] == ".tar"]
        
        self._load_input_files()

    def _load_from_tar(self, folder):
            with tarfile.open(os.path.join(self.root, folder), 'r') as tar:
                files = tar.getmembers()
                return [(os.path.join(self.root, folder), file.name) for file in files if "func" in file.name]

    def _load_from_directory(self, folder):
        folder_path = os.path.join(self.root, folder)
        return [(folder_path, f) for f in os.listdir(folder_path) if "func" in f and os.path.isfile(os.path.join(folder_path, f))]

    def _load_input_files(self):
        with ThreadPoolExecutor() as executor:
            futures = []
            for folder in self.folders:
                if self.are_tars:
                    futures.append(executor.submit(self._load_from_tar, folder))
                else:
                    futures.append(executor.submit(self._load_from_directory, folder))

            for future in as_completed(futures):
                self.files.extend(future.result())

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path_tuple = self.files[idx]
        fmri = self.load_file(file_path_tuple)        
        if self.transform:
            fmri = self.transform(fmri)
        
        return fmri

    def load_file(self, file_path_tuple):
        if self.are_tars:
            tar_path, member_name = file_path_tuple
            with tarfile.open(tar_path, 'r') as tar:
                member = tar.getmember(member_name)
                f = tar.extractfile(member)
                fmri = io.imread(f)
        else:
            file_path = os.path.join(*file_path_tuple)
            fmri = io.imread(file_path)
        
        return fmri


In [49]:
path = "/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/data/wds/"
data = FMRIDataset(path)

In [50]:
for i in data:
    print(i.shape)
    
    break

OSError: Could not find a backend to open `<ExFileObject name='/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/data/wds/000009.tar'>`` with iomode `r`.

In [2]:
path = "/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/data/wds/000000.tar"
cache_dir = "./cache"

patch_size = 8
frame_patch_size = 1
num_samples_per_epoch = 1024
batch_size = 1
num_workers = 1

In [22]:
with tarfile.open(path, 'r') as tar:
    print(tar.getnames())

['000000.dataset.txt', '000000.func.png', '000000.header.npy', '000000.meansd.png', '000000.minmax.npy', '000001.dataset.txt', '000001.func.png', '000001.header.npy', '000001.meansd.png', '000001.minmax.npy', '000002.dataset.txt', '000002.func.png', '000002.header.npy', '000002.meansd.png', '000002.minmax.npy', '000003.dataset.txt', '000003.func.png', '000003.header.npy', '000003.meansd.png', '000003.minmax.npy', '000004.dataset.txt', '000004.func.png', '000004.header.npy', '000004.meansd.png', '000004.minmax.npy', '000005.dataset.txt', '000005.func.png', '000005.header.npy', '000005.meansd.png', '000005.minmax.npy', '000006.dataset.txt', '000006.func.png', '000006.header.npy', '000006.meansd.png', '000006.minmax.npy', '000007.dataset.txt', '000007.func.png', '000007.header.npy', '000007.meansd.png', '000007.minmax.npy', '000008.dataset.txt', '000008.func.png', '000008.header.npy', '000008.meansd.png', '000008.minmax.npy', '000009.dataset.txt', '000009.func.png', '000009.header.npy', '

In [12]:
def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
    return True

def filter_corrupted_images(sample):
    """If all the required files are not present don't use them."""
    correct_data = ("func.png" in sample and "dataset.txt" in sample and "header.npy" in sample and "meansd.png" in sample and "minmax.npy" in sample)
    return correct_data

aug_transform = utils.DataPrepper(
    masking_strategy="conservative",
    patch_depth=patch_size,
    patch_height=patch_size,
    patch_width=patch_size,
    frame_patch_size=frame_patch_size,
    num_timepoints=12
)

In [17]:
model = vit_tiny(in_chans=12, num_frames=48, img_size=64, patch_size=patch_size).cuda()

In [19]:
for i in train_dl:
    i = i[0].cuda()
    # print(i[0].shape)
    # print(i[1].shape)
    # print(i[2].shape)
    i = i.permute(0, 1, -1, 2, 3)
    o = model(i)
    print(o.shape)
    break

torch.Size([1, 1536, 192])
