## 03 Datasets & DataLoaders

### Loading a Dataset

In [4]:
# Load pre-loaded datasets FashionMNIST from torchvision

from torchvision import datasets
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor

training_set = datasets.FashionMNIST(
    root='data',
    download=True,
    train=True,
    transform=ToTensor()
)
testing_set = datasets.FashionMNIST(
    root='data',
    download=True,
    train=False,
    transform=ToTensor()
)

In [6]:
# Create a custom dataset 
# annotations_file: a csv file stores image filepath and label
import pandas as pd
import os
from torchvision.io import decode_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.img_labels = pd.read_csv(annotations_file)
        self.transform = transform
        self.target_transform = target_transform
        
    def __len__(self):
        return len(self.img_labels)
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = decode_image(img_path)    
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

### DataLoaders

In [7]:
# Initialize dataloaders
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_set, batch_size=64, shuffle=True)
test_dataloader = DataLoader(testing_set, batch_size=64, shuffle=True)

In [8]:
# Iterate through the DataLoader
for batch, (X, y) in enumerate(train_dataloader):
    print(batch)
    print(X.shape)
    print(y.shape)
    break

0
torch.Size([64, 1, 28, 28])
torch.Size([64])
