# Tutorial for Splitting Datasets - Part 2

The first part of this tutorial concerned with the general concept of splitting
tabular data contained in pandas dataframes. In the second part, we would like
to focus on the `Dataset` structure, which is provided by PyTorch as a
measure for providing data to the neural network training procedure.

|                  |                                                                                                                                                                                                                                                                  |
|------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Requirements   	 | 	- Basic python skills                                                                                                                                                                                                                                           |
| Learning Goals 	 | 	- Understanding basics of a pandas dataframe <br/>- Concept of splitting data into multiple partitions <br/>- Various splitting strategies for dataframes. <br/>- Understanding cross-validation <br/>- Application of cross-validation in a practical use-case |
| Limitations    	 | - The tutorial only handles pandas dataframes and a numpy array in the practical use-case.                                                                                                                                                                       |


In [1]:
from typing import Any, List

import torch
from torch.utils.data import Dataset
import uuid

## The Dataset Object

Before we dive into the actual data splitting, we introduce the `Dataset` object in general.
While there exists the possibility to directly create a dataset from an array (`TensorDataset`)
or a directory filled with images (`ImageFolder` in `torchvision`), building a custom routine
yields maximal flexibility and requires only a little setup.

As shown in the code snippet below, creating a dataset requires the implementation of two
functions. The most important one is `__getitem__`. The function will receive an `idx` starting
from 0 and should return the training data corresponding to the index. The kind of return value
is not fixed or limited in any case. Returned tensors would be batched automatically by the
subsequent `Dataloader`, while more complex data types would require a custom collate function.
However, this is an advanced topic and is not covered in this tutorial.

The second function to implement is the `__len__` function, which should simply return the
number of elements in the dataset, which is very helpful for the dataloader.

In [2]:
class MyDataset(Dataset):
    def __len__(self) -> int:
        pass

    def __getitem__(self, idx: int) -> Any:
        pass

To showcase the base functionality of a `Dataset`, we create an object that
contains a list of 100 randomly generated strings as our dummy data.
The `__len__` function simply returns the number of elements in our data list.
The `__getitem__` function is tied to the indices of the list.

In [3]:
class MyDataset(Dataset):
    def __init__(self) -> None:
        num_samples = 100
        # This just generates a list of random strings
        self.data = [uuid.uuid4().hex.upper()[0:6] for _ in range(num_samples)]

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> str:
        return self.data[idx]

`__len__` and `__getitem__` belong to the group of *magic methods*. This implies
that the keyword name of the function is reserved to work with another operator.
In other words, when calling the `len` function on our dataset object, the
`__len__` function is called internally.
The `__getitem__` method is triggered by the bracket operator as it is usual in
dictionaries or list accesses.

In [4]:
ds = MyDataset()
print(f'Number of samples in dataset: {len(ds)}')
print('Samples:', ds[0], ds[42], ds[77])

Number of samples in dataset: 100
Samples: 9A6439 B6B39F 6D5DC2


As shown above we can extract samples of our choice with the help of the `[]` operator.

In the next steps we show some possibilities of splitting and organizing such a `Dataset`
object. We make the differentiation between *internal* and *external* splitting.

## Internal Splitting

With interal splitting we refer to the idea of subsetting the dataset itself.
An easy-to-use utility is hereby the `random_split` function, which is contained in PyTorch
itself.


In [5]:
from torch.utils.data import random_split

train_ds, test_ds = random_split(dataset=ds, lengths=(0.8, 0.2))

print(f'Samples in train dataset: {len(train_ds)}')
print(f'Samples in test dataset: {len(test_ds)}')

Samples in train dataset: 80
Samples in test dataset: 20


In [6]:
for i in range(5):
    train_ds, test_ds = random_split(dataset=ds, lengths=(0.8, 0.2))
    print(f'({i}): {train_ds[0]} {train_ds[42]} {train_ds[77]}')

(0): 8F4CAF F8258E A9E40B
(1): 7147C9 C1B6D4 0050D4
(2): 7B30DC 8F4CAF 732D9A
(3): A1C6C8 E6F0CE 768AD0
(4): 0050D4 4B402F F3B2D3


In [7]:
from torch import Generator

for i in range(5):
    gen = Generator().manual_seed(1337)
    train_ds, test_ds = random_split(ds, (0.8, 0.2), generator=gen)
    print(f'({i}): {train_ds[0]} {train_ds[42]} {train_ds[77]}')

(0): 9D4892 A1C6C8 F5F7AA
(1): 9D4892 A1C6C8 F5F7AA
(2): 9D4892 A1C6C8 F5F7AA
(3): 9D4892 A1C6C8 F5F7AA
(4): 9D4892 A1C6C8 F5F7AA


In [8]:
gen = Generator().manual_seed(1337)

for i in range(5):
    train_ds, test_ds = random_split(ds, (0.8, 0.2), generator=gen)
    print(f'({i}): {train_ds[0]} {train_ds[42]} {train_ds[77]}')

(0): 9D4892 A1C6C8 F5F7AA
(1): F4694A F5F7AA 55C342
(2): DAB541 B7EF5B 1ED98E
(3): 440303 F4694A 0C7EF7
(4): EC4CB9 B7EF5B DAB541


In [9]:
gen = Generator().manual_seed(1337)
gen_state = gen.get_state()

for i in range(5):
    train_ds, test_ds = random_split(ds, (0.8, 0.2), generator=gen)
    gen.set_state(gen_state)
    print(f'({i}): {train_ds[0]} {train_ds[42]} {train_ds[77]}')

(0): 9D4892 A1C6C8 F5F7AA
(1): 9D4892 A1C6C8 F5F7AA
(2): 9D4892 A1C6C8 F5F7AA
(3): 9D4892 A1C6C8 F5F7AA
(4): 9D4892 A1C6C8 F5F7AA


In [10]:
train_ds, val_ds, test_ds = random_split(dataset=ds, lengths=(0.6, 0.2, 0.2))

In [13]:
train_idxs = train_ds.indices
print(train_idxs)

[0, 46, 24, 15, 95, 25, 28, 26, 17, 39, 9, 14, 11, 79, 56, 6, 59, 18, 3, 58, 83, 89, 74, 4, 43, 75, 80, 52, 60, 68, 65, 98, 38, 33, 86, 35, 47, 91, 22, 82, 21, 84, 61, 30, 70, 31, 20, 99, 8, 94, 12, 63, 36, 19, 73, 2, 66, 42, 77, 69]


In [14]:
from torch.utils.data import Subset

train_ds = Subset(dataset=ds, indices=train_idxs)

## External Splitting

In [None]:
class MyDataset(Dataset):
    def __init__(self, keys: List[str], root: str) -> None:
        self.root = root
        self.keys = keys

    def load_item(self, key) -> Any:
        pass

    def __len__(self) -> int:
        return len(self.keys)

    def __getitem__(self, idx: int) -> Any:
        key = self.keys[idx]
        return self.load_item(key)