In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset


### torch.utils.data
Gives building blocks for:
- Representing datasets
- Wrapping datasets
- Loading them in batches, shuffling and multiprocessing
- sampling strategies

Putting it Together (Hierarchy + Differences)

- Dataset vs IterableDataset:

    Dataset → random access by index.<br>
    IterableDataset → stream-like, only __iter__.

- Subset, ConcatDataset, ChainDataset:
    Wrappers around existing datasets to manipulate them (slice, merge).

- Samplers:
    Decide the order & probability of sampling indices.
    Sequential, Random, WeightedRandom, etc.

- BatchSampler:
    Groups sampled indices into batches.

- DataLoader:
    Orchestrates everything: takes dataset (or iterable), sampler, batching, workers, pinning, etc.

Mental Model 🧠

Think of it like a data factory pipeline:
1. Dataset/IterableDataset = the raw materials.

2. Sampler = decides which pieces to pick.

3. BatchSampler = packs them into boxes.

4. DataLoader = the delivery truck (does multiprocessing, shuffling, etc).

5. Wrappers (Subset, ConcatDataset, etc) = tools to rearrange/resize your raw material.

In [8]:
# 1. Dataset: It defines the interface for all datasets in PyTorch.
# Super flexible → you can load anything (images, CSV rows, text files, tensors, API calls, etc).
class customDataset(Dataset):
    def __init__(self):
        self.X = torch.randn((10,3)).to("mps")
        self.y = torch.randint(0,2,(10,)).to("mps") # labels: 0,1 
    
    def __len__(self):
        return len(self.X)
    def __getitem__(self, index):
        return (self.X[index],self.y[index])

In [9]:
a = customDataset()
print(a.__getitem__(3))

(tensor([ 0.4090,  0.7271, -1.0578], device='mps:0'), tensor(0, device='mps:0'))


In [13]:
# 2. TensorDataset:
'''
A ready-made implementation of Dataset.
Stores tensors of the same length and returns them as tuples.
You don’t have to subclass anything. It just pairs tensors sample-by-sample.
'''

class myDataset():
    def __init__(self):
        self.X = torch.randn((10,3)).to("mps")
        self.y = torch.randint(0,2,(10,)).to("mps")
        
    def couple_data_labels(self):
        data = TensorDataset(self.X,self.y)
        return data
    
    def batch_size(self):
        batch_data = self.couple_data_labels()
        return batch_data[0].shape

    