In [5]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

In [3]:
train_datasets = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True) # train indicates true or false
test_datasets = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True) # train indicates true or false

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170M/170M [10:50<00:00, 262kB/s]    


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [16]:
# Implement a custom Data Set

class ExampleDataSet(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        sample = self.data[idx]
        labels = self.labels[idx]
        return torch.tensor(sample, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)



In [14]:
# Lets implement Custom Data Loader for Educational Purpose
import random

from torch.utils.data.dataloader import _BaseDataLoaderIter
class CustomDataLoader(DataLoader):
    def __init__(self,dataset, shuffle=False, batch_size=1, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle=shuffle
        self.num_workers = num_workers
        self.indices  = list(range(len(dataset)))

        if self.shuffle:
            self._shuffle_indices()
        
    def _shuffle_indices(self):
        random.shuffle(self.indices)
    
    # Creating an iterator that yeilds batches
    def __iter__(self):
        batch = []
        for idx in self.indices:
            sample, label = self.dataset[idx]
            batch.append((sample, label))
             # If batch is full you yeild it

            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch:
            yield batch
    
    def __len__(self):
        len = len(self.indices) // self.batch_size

In [20]:
data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
labels = [0, 1, 0, 1, 0]

daset = ExampleDataSet(data=data, labels=labels)

loader = CustomDataLoader(dataset=daset, shuffle=True, batch_size=2, num_workers=1)

In [21]:
for batch in loader:
    print(batch)

[(tensor([1., 2.]), tensor(0)), (tensor([5., 6.]), tensor(0))]
[(tensor([3., 4.]), tensor(1)), (tensor([7., 8.]), tensor(1))]
[(tensor([ 9., 10.]), tensor(0))]
