# Create a Custom Dataset of images with PyTorch

In [13]:
import os
import pandas as pd
import torch
from skimage import io

# Dataset base class:
from torch.utils.data import Dataset, DataLoader, random_split

## Create a child class using Dataset as base class

In [4]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 0])
        image = io.imread(img_path)
        y_label = torch.tensor(int(self.annotations.iloc[idx, 1]))

        if self.transform:
            image = self.transform(image)
        
        return (image, y_label)

## Create the sets

In [5]:
dataset = CatsAndDogsDataset('dataset/cats_dogs.csv', 'dataset/cats_dogs_resized')

In [11]:
train_set, test_set = random_split(dataset, [7, 3])

In [14]:
batch_size = 64
train_loader = DataLoader(train_set, batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size, shuffle=True)

The dataset has been loaded!