In [1]:
!pip install torch



You should consider upgrading via the 'C:\Users\49176\PycharmProjects\pythonProject\venv\Scripts\python.exe -m pip install --upgrade pip' command.


In [1]:
import h5py
import torch
import random
from pathlib import Path
from typing import List
from torch.utils.data import DataLoader, Dataset

In [2]:
# Generate some random data
random_tensors = {f'{i}': torch.rand((random.randint(10, 30), 10)) for i in range(20)}

# Save those data to some preferred file format
with h5py.File('random.h5', 'w') as hf:
    for idx, random_ten in random_tensors.items():
            hf.create_dataset(idx, data=random_ten.detach().numpy())

In [None]:
# Showing content of an h5 file
import nexusformat.nexus as nx

f = nx.nxload('random.h5')
print(f.tree)
f.close()
# random.h5 contains tensors of shape Nx10 where N is variable

In [None]:
# Define a dataset tailored to the data that should be used
class FancyDataset(Dataset):
    # Dataset ... map-style dataset
    def __init__(self, h5_path: Path):
        self.data = h5py.File(h5_path, 'r')
        # use as "index map"
        self.ids_list = list(self.data.keys())

        # some additional stuff
        self.softmax = torch.nn.Softmax(dim=-1)

    # return the number of elements in the dataset
    def __len__(self):
        return len(self.ids_list)

    # return item at specific index
    def __getitem__(self, idx: int):
        identifier = self.ids_list[idx]
        idx_element = torch.from_numpy(self.data[identifier][:, :])
        return self.fancy_func(idx_element)

    # some function that does something
    def fancy_func(self, input: torch.Tensor):
        return self.softmax(input).mean(dim=0)

# Other types:
#   IterableDataset, TensorDataset, ConcatDataset, ...

In [None]:
# Wrap FancyDataset around the data
fancy = FancyDataset(Path('random.h5'))
fancy

In [None]:
# Defines some collate function that is useful
def collate(data: List[torch.Tensor]):
    return torch.stack(data, dim=0)

# In this case, the default_collate function is also able to do this

In [None]:
# Create dataloader
print(torch.manual_seed(42))
dataloader = DataLoader(fancy, batch_size=5, shuffle=True, collate_fn=collate)
# dataloader = DataLoader(fancy, batch_size=5, shuffle=True)

# DataLoader with default Sampler = index sampler with integral indices
# Custom Samplers = possible to use non-integral indices/keys
dataloader

In [None]:
# Use actual dataloader
data = [sample for sample in dataloader]
data

In [None]:
# Recreate output of actual dataloader with manual use of collate and batch forming
# does only apply for the first full iteration of the dataloader data
indices = [[8, 14, 17, 19, 1], [15, 18, 9, 3, 11], [4, 0, 10, 7, 13], [16, 12, 6, 2, 5]]

# Helper functions
# get items of one batch from dataset Fancy
batch_fancy = lambda batch: list(map(lambda x: fancy[x], batch))
# get items of multiple batches from dataset Fancy
fancy_indices = lambda ind: list(map(lambda single_batch: batch_fancy(single_batch), ind))
# use collate function on multiple batches
collate_fancy = lambda fan_list: list(map(lambda fancy_batches: collate(fancy_batches), fan_list))

handcraft = collate_fancy(fancy_indices(indices))
handcraft

In [None]:
# Check if manual dataloader equals automatic creation of dataloader
for idx, (dataloader_sample, handcraft_sample) in enumerate(zip(data, handcraft)):
    print(f'batches #{idx} identical: {torch.allclose(dataloader_sample, handcraft_sample)}')