#Optimizing Model Training with PyTorch Dataset

#Efficient Data Handling with Datasets and DataLoaders

PyTorch provides the **torch.utils.data** library to make data loading easy with DataSets and DataLoader class.

**Dataset** is itself the argument of DataLoader constructor which indicates a dataset object to load form.

Meanwhile, **Dataloader** allows to iterate through the dataset in batches and gives access to inbuilt functions for multiprocessing, shuffling etc.

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
  def __init__(self):
    self.data=torch.tensor([[1,2],[3,4],[5,6]])
    self.labels=torch.tensor([0,1,0])

  def __len__(self):
    return len(self.data)

  def __getitem__(self,idx):
    return self.data[idx], self.labels[idx]

dataset=MyDataset()
dataloader=DataLoader(dataset,batch_size=2,shuffle=True)

for batch in dataloader:
  print("Batch Data:\n ", batch[0])
  print("Batch Labels: ", batch[1])

Batch Data:
  tensor([[3, 4],
        [1, 2]])
Batch Labels:  tensor([1, 0])
Batch Data:
  tensor([[5, 6]])
Batch Labels:  tensor([0])


#Enhancing Data Diversity through Augmentation

Torchvision provides simple tools for applying random transformations to generalize on unseen data.

In [6]:
import torchvision.transforms as transforms
from PIL import Image

image = Image.open('/content/doraemon.jpg')

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

augmented_image = transform(image)
print("Augmented Image Shape:", augmented_image.shape)

Augmented Image Shape: torch.Size([3, 1829, 1920])
