In [3]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
class CustomDataset(Dataset):

    def __init__(self, dataset, transform=None):
        self.data, self.targets = self.split_data_with_targets(dataset)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # grab the data
        image = self.data[index]
        label = self.targets[index]

        # check for transform
        if self.transform is not None:
            image = self.transform(image)

        return image, torch.FloatTensor([label])
    
    def split_data_with_targets(self, dataset):
        images = []
        labels = []
        for image, label in dataset:
            images.append(image)
            labels.append(label)
        return images, labels

In [None]:
# resize the image and transform to tensor
transformations = transforms.Compose([transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root="dataset")
custom_dataset = CustomDataset(dataset=dataset, transform=transformations)

1. stratified sampling

In [None]:
# Stratified Sampling for train and val
train_idx, validation_idx = train_test_split(np.arange(len(custom_dataset)),
                                             test_size=0.1,
                                             random_state=1,
                                             shuffle=True,
                                             stratify=custom_dataset.targets)

# Subset dataset for train and val
training_dataset = Subset(custom_dataset, train_idx)
validation_dataset = Subset(custom_dataset, validation_idx)

2. random shuffle sampling

In [None]:
# Stratified Sampling for train and val
train_idx, validation_idx = train_test_split(np.arange(len(custom_dataset)),
                                             test_size=0.1,
                                             random_state=1,
                                             shuffle=True)

# Subset dataset for train and val
training_dataset = Subset(custom_dataset, train_idx)
validation_dataset = Subset(custom_dataset, validation_idx)