In [38]:
import os
import sys
import torch
import tarfile
import webdataset as wds
from skimage import io
import io as bio
import datasets.utils as dutils 
from jepa.src.masks.multiblock3d import MaskCollator as MB3DMaskCollator
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
import yaml

sys.path.append('/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/jepa')

from jepa.src.models.vision_transformer import vit_tiny

In [33]:
class FMRIDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, are_tars=True):
        self.root = root
        self.transform = transform
        self.folders = os.listdir(root)
        self.are_tars = are_tars

        self.files = []
        if self.are_tars:
            self.folders = [folder for folder in self.folders if folder[-4:] == ".tar"]
        
        self._load_input_files()

    def _load_from_tar(self, folder):
            with tarfile.open(os.path.join(self.root, folder), 'r') as tar:
                files = tar.getmembers()
                return [(os.path.join(self.root, folder), file.name) for file in files if "func" in file.name]

    def _load_from_directory(self, folder):
        folder_path = os.path.join(self.root, folder)
        return [(folder_path, f) for f in os.listdir(folder_path) if "func" in f and os.path.isfile(os.path.join(folder_path, f))]

    def _load_input_files(self):
        with ThreadPoolExecutor() as executor:
            futures = []
            for folder in self.folders:
                if self.are_tars:
                    futures.append(executor.submit(self._load_from_tar, folder))
                else:
                    futures.append(executor.submit(self._load_from_directory, folder))

            for future in as_completed(futures):
                self.files.extend(future.result())

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path_tuple = self.files[idx]
        fmri = self.load_file(file_path_tuple)        
        if self.transform:
            fmri = self.transform(fmri)
        
        return fmri

    def load_file(self, file_path_tuple):
        if self.are_tars:
            tar_path, member_name = file_path_tuple
            with tarfile.open(tar_path, 'r') as tar:
                member = tar.getmember(member_name)
                f = tar.extractfile(member)
                fmri = io.imread(f)
        else:
            file_path = os.path.join(*file_path_tuple)
            fmri = io.imread(file_path)
        
        return fmri


In [34]:
path = "/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/data/wds/"
# data = FMRIDataset(path)

In [4]:
for i in data:
    print(i.shape)
    break

(1536, 3072)


In [35]:
path = "/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/data/wds/000001.tar"
cache_dir = "./cache"

patch_size = 8
frame_patch_size = 1
num_samples_per_epoch = 1024
batch_size = 1
num_workers = 1

In [36]:
def log_and_continue(exn):
    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
    print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
    return True

def filter_corrupted_images(sample):
    """If all the required files are not present don't use them."""
    correct_data = ("func.png" in sample and "dataset.txt" in sample and "header.npy" in sample and "meansd.png" in sample and "minmax.npy" in sample)
    return correct_data

In [48]:
with open("/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/jepa/configs/pretrain/vitt16.yaml", 'r') as y_file:
    config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)

mask_cnfg = config_fnames.get("mask")

In [52]:
_params = yaml.load("jepa/configs/pretrain/vitt16.yaml", Loader=yaml.FullLoader)

mask_collator = MB3DMaskCollator(
    crop_size=64,
    num_frames=20 ,
    patch_size=patch_size,
    tubelet_size=1,
    cfgs_mask=mask_cnfg)

aug_transform = dutils.DataPrepper(
        masking_strategy="conservative",
        patch_depth=8,
        patch_height=8,
        patch_width=8,
        frame_patch_size=1,
        num_timepoints=20
    )

train_data = wds.WebDataset(path, resampled=False, cache_dir=cache_dir, handler=log_and_continue).select(filter_corrupted_images).rename(key="__key__",
    func="func.png",
    header="header.npy",
    dataset="dataset.txt",
    minmax="minmax.npy",
    meansd="meansd.png").map_dict(func=dutils.grayscale_decoder,
    meansd=dutils.grayscale_decoder,
    minmax=dutils.numpy_decoder).to_tuple(*("func", "minmax", "meansd")).map(aug_transform)

train_dl = wds.WebLoader(
    train_data.batched(batch_size), 
    pin_memory=True,
    shuffle=False,
    batch_size=None,
    collate_fn=mask_collator,
    num_workers=num_workers, 
    
    persistent_workers=num_workers>0,
).with_epoch(num_samples_per_epoch//batch_size).with_length(num_samples_per_epoch//batch_size)


In [29]:
model = vit_tiny(in_chans=12, num_frames=48, img_size=64, patch_size=patch_size).cuda()

In [51]:
len(train_dl)

1024

In [53]:
for i in train_dl:
    i = i[0].cuda()
    # print(i[0].shape)
    # print(i[1].shape)
    # print(i[2].shape)
    i = i.permute(0, 1, -1, 2, 3).contiguous()
    print(i.shape)
    # o = model(i)
    # print(o.shape)
    

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4187403430>
Traceback (most recent call last):
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1409, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4187403430>
Traceback (most recent call last):
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    asser

(24, 64, 64, 48)
(24, 64, 64, 48)
batch 3


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 42, in fetch
    return self.collate_fn(data)
  File "/mnt/c/Users/Moham/Desktop/fMRI-foundation-model/jepa/src/masks/multiblock3d.py", line 56, in __call__
    collated_batch = torch.utils.data.default_collate(batch)
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/home/mohammed/.local/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 161, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable


(24, 64, 64, 48)
batch 3
(24, 64, 64, 48)
batch 3
