In [2]:
import webdataset as wds   
from typing import Dict, Tuple, Generator
import numpy as np
import torch
from pprint import pp
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [40]:
class BaseDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        dataloader,
        length,
        sample_keys
        ) -> None:
        self.name = 'BaseDataset'
        self._length = length
        self.dataloader = iter(dataloader)
        self.sample_keys = sample_keys

    def __len__(self):
        return self._length
    
    def __getitem__(self, idx):
        
        sample = next(self.dataloader)
            
        if self.sample_keys is not None:
            sample = {
                key: sample[key] 
                for key in self.sample_keys
                if key in sample
            }
        
        return sample


class BaseBatcher:

    def __init__(
        self,
        sample_random_seq: bool = True,
        seq_min: int = 10,
        seq_max: int = 50,
        sample_keys: Tuple[str] = None,
        decoding_target: str = None,
        seed: int =  None,
        **kwargs
        ) -> None:
        assert seq_min > 0, "seq_min must be greater than 0"
        assert seq_min < seq_max, 'seq_min must be less than seq_max'
        self.sample_random_seq = sample_random_seq
        self.seq_min = seq_min
        self.seq_max = seq_max
        self.decoding_target = decoding_target
        self.sample_keys = sample_keys
        self.seed = seed
        if self.seed is not None:
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
        
    def _make_dataloader(
        self,
        files,
        repeat: bool = True,
        n_shuffle_shards: int = 1000,
        n_shuffle_samples: int = 1000,
        batch_size: int = 1,
        num_workers: int = 0
        ) -> Generator[Dict[str, torch.tensor], None, None]:
        dataset = wds.WebDataset(files)

        if n_shuffle_shards is not None:
            dataset = dataset.shuffle(n_shuffle_shards)

        dataset = dataset.decode("pil").compose(self.get_samples) # .map(self.preprocess_sample)

        if repeat:
            dataset = dataset.repeat()
        
        if n_shuffle_samples is not None:
            dataset = dataset.shuffle(n_shuffle_samples)

        return torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=batch_size,
            num_workers=num_workers
        )

    def dataset(
        self,
        tarfiles: list,
        repeat: bool=True,
        length: int = 400000,
        n_shuffle_shards: int = 1000,
        n_shuffle_samples: int = 1000,
        num_workers: int = 0
        ) -> torch.utils.data.Dataset:
        """Create Pytorch dataset that can be used for training.

        Args:
        -----
            tarfiles: list
                List of paths to data files (ie., fMRI runs) used for training.
            repeat: bool
                If True, repeat the dataset indefinitely.
            length: int
                Maximum number of samples to yield from the dataset.
            n_shuffle_shards: int
                Buffer for shuffling of tarfiles during training.
            n_shuffle_samples: int
                Buffer for shuffling of samples during training.
            num_workers: int
                Number of workers to use for data loading.

        Returns:
        -----
            torch.utils.data.Dataset: Pytorch dataset.
        """
        dataloader = self._make_dataloader(
            files=tarfiles,
            repeat=repeat,
            n_shuffle_shards=n_shuffle_shards,
            n_shuffle_samples=n_shuffle_samples,
            num_workers=num_workers
        )
        return BaseDataset(
            dataloader=dataloader,
            length=length,
            sample_keys=self.sample_keys
        )

    def get_samples(self, src):
        for sample in src:
            #key = sample['__key__']
            out = dict()
            bold = None
            for key, value in sample.items():
                #print(f" {key=} {value=}")
                if key == "bold.pyd":
                    bold = np.array(value).astype(float)
                else:
                    out[key] = value

            if bold is not None:
                parts = [(0, bold.shape[0]//2 - 1), (bold.shape[0]//2, bold.shape[0])]
                for start, end in parts:
                    #out['__key__'] = f"{key} {start}-{end}"
                    out['start'] = start
                    out['end'] = end
                    out["inputs"] = torch.from_numpy(bold[start:end+1]).to(torch.float)
                    yield out.copy()

    #def preprocess_sample(
    #    self,
    #    sample
    #    ) -> Dict[str, torch.Tensor]:
    #    out = dict(__key__=sample["__key__"])
    #    t_r = sample["t_r.pyd"]

    #    label = None
    #    f_s = None
    #    if self.bold_dummy_mode and self.decoding_target is not None:
    #        label = np.random.choice([0, 1])
    #        f_s = np.array([1, 2, 4]) if label == 0 else np.array([6, 8, 10])                

    #    for key, value in sample.items():
    #        if key == "bold.pyd":

    #            bold = np.array(value).astype(float)
    #            
    #            if self.bold_dummy_mode:
    #                bold = self.make_bold_dummy(
    #                    bold_shape=bold.shape,
    #                    t_r=t_r,
    #                    f_s=f_s
    #                )

    #            seq_on, seq_len = self._sample_seq_on_and_len(bold_len=len(bold))
    #            bold = bold[seq_on:seq_on+seq_len]
    #            t_rs = np.arange(seq_len) * t_r
    #            attention_mask = np.ones(seq_len)
    #            bold = self._pad_seq_right_to_n(
    #                seq=bold,
    #                n=self.seq_max,
    #                pad_value=0
    #            )
    #            t_rs = self._pad_seq_right_to_n(
    #                seq=t_rs,
    #                n=self.seq_max,
    #                pad_value=0
    #            )
    #            attention_mask = self._pad_seq_right_to_n(
    #                seq=attention_mask,
    #                n=self.seq_max,
    #                pad_value=0
    #            )
    #            out["inputs"] = torch.from_numpy(bold).to(torch.float)
    #            out['t_rs'] = torch.from_numpy(t_rs).to(torch.float)
    #            out["attention_mask"] = torch.from_numpy(attention_mask).to(torch.long)
    #            out['seq_on'] = seq_on
    #            out['seq_len'] = seq_len

    #        elif key in {
    #            f"{self.decoding_target}.pyd",
    #            self.decoding_target
    #            }:
    #            out["labels"] = value

    #        else:
    #            out[key] = value
    #    
    #    if self.sample_keys is not None:
    #        out = {
    #            key: out[key] 
    #            for key in self.sample_keys
    #            if key in out
    #        }

    #    if label is not None:
    #        out['labels'] = label
    #    
    #    return out

In [41]:

#url = "https://github.com/athms/learning-from-brains/raw/b061ef97680365a074671e0c95581af426773f9e/data/upstream/ds000212/ds-ds000212_sub-[03-47]_task-dis_run-[1-6].tar"
#dataset = BaseBatcher(seq_max=300, sample_random_seq=False).dataset(tarfiles=[url])

dataset = BaseBatcher(seq_max=300, sample_random_seq=False).dataset(tarfiles=[str(f) for f in Path('.').glob('*.tar')])

In [42]:
loaded = [dataset[i] for i in range(50)]

In [43]:
for l in loaded:
    print('---')
    for f in l:
        if isinstance(l[f], torch.Tensor):
            if len(l[f].shape) == 1:
                print(f, l[f])
            else:
                print(f, l[f].shape, l[f].dtype)
        else:
            pp([f, l[f]])

---
['__key__', ['ds-ds000212_sub-34_task-dis_run-6']]
['__url__', ['ds-ds000212_sub-34_task-dis_run-6.tar']]
t_r.pyd tensor([2.])
start tensor([0])
end tensor([82])
inputs torch.Size([1, 83, 1024]) torch.float32
---
['__key__', ['ds-ds000212_sub-04_task-fb_run-1']]
['__url__', ['ds-ds000212_sub-04_task-fb_run-1.tar']]
t_r.pyd tensor([2.])
start tensor([0])
end tensor([67])
inputs torch.Size([1, 68, 1024]) torch.float32
---
['__key__', ['ds-ds000212_sub-13_task-dis_run-6']]
['__url__', ['ds-ds000212_sub-13_task-dis_run-6.tar']]
t_r.pyd tensor([2.])
start tensor([83])
end tensor([166])
inputs torch.Size([1, 83, 1024]) torch.float32
---
['__key__', ['ds-ds000212_sub-46_task-dis_run-2']]
['__url__', ['ds-ds000212_sub-46_task-dis_run-2.tar']]
t_r.pyd tensor([2.])
start tensor([83])
end tensor([166])
inputs torch.Size([1, 83, 1024]) torch.float32
---
['__key__', ['ds-ds000212_sub-32_task-dis_run-1']]
['__url__', ['ds-ds000212_sub-32_task-dis_run-1.tar']]
t_r.pyd tensor([2.])
start tensor([8