# Imports

For visualization, we use `matplotlib` in this Notebook:

In [None]:
%pip install matplotlib # install using pip
# %conda install matplotlib # or install using conda

In [1]:
import torch
from torch.utils.data import Dataset
from matplotlib import pyplot as plt

# Data Handling

We will now look at a 2-class classification problem with a generated toy-dataset.\
The data will be stored inside a PyTorch `Dataset` subclass, which will provide you with some utilities like automatic shuffling and batching for the training loop, if paired with `DataLoader`.

The 2 classes are two normal distributions with different means.

In [None]:


class TwoClassDataset(Dataset):
    def __init__(self):
        # create synthetic dataset
        features_a = torch.normal(mean=1., std=1.0, size=(100,2))
        labels_a = torch.zeros(100, dtype=torch.long)
        features_b = torch.normal(mean=-1.0, std=1.0, size=(100,2))
        labels_b = torch.ones(100, dtype=torch.long)

        self.data = torch.cat((features_a, features_b), dim=0) # X: 200 x 2 matrix : [x1,x2]
        self.labels = torch.cat((labels_a, labels_b), dim=0) # Y: binary vector of length 200: 0: class A, 1: class B
    
    def __len__(self):
        return self.data.size(dim=0)
    
    def __getitem__(self, index):
        return self.data[index], self.labels[index]
    
dataset = TwoClassDataset()

We can plot the 2D dataset to get an idea:

In [None]:
# plot both classes
fig, ax = plt.subplots()
scatter = ax.scatter(dataset.data[:,0], dataset.data[:,1], c=dataset.labels, cmap='Paired')
legend = ax.legend(*scatter.legend_elements(), title="Classes")
ax.add_artist(legend)
plt.show()

In [None]:
print("first sample (input, label):", dataset[0])