In [41]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, IterableDataset
from torchvision import transforms
from PIL import Image
import random
from pathlib import Path

## Dataset class

In [42]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = [(file, label) for file, label in self._load_labels(img_dir)]
    
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, label = self.img_labels[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, int(label)

    def _load_labels(self, img_dir):
        labels_file = os.path.join(img_dir, 'labels.txt')
        with open(labels_file, 'r') as f:
            labels = [line.strip().split() for line in f]
        return labels

## Transforms

In [43]:
class RandomGreyscale(object):
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            img = transforms.functional.to_greyscale(img, num_output_channels=3)
        return img

In [44]:
custom_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    RandomGreyscale(p=0.3),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

## DataLoader

In [45]:
# dataset = CustomDataset(img_dir="path_to_dir", transform=custom_transform)

# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

# for images, labels in dataloader:
#     training loop
#     pass

## IterableDataset class

In [47]:
class LargeImageDataset(IterableDataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

    def __iter__(self):
        for file_path in self._get_image_files():
            image = self._load_image(file_path)
            label = self._get_label_from_path(file_path)
            yield image, label

    def _get_image_files(self):
        root_path = Path(self.root_dir)
        for entry in root_path.iterdir():
            if entry.is_file() and entry.suffix in ('.png', '.jpg', '.jpeg'):
                yield entry

    def _load_image(self, file_path):
        # Load and optionally transform the image
        image = Image.open(file_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

    def _get_label_from_path(self, file_path):
        return file_path.parent.name

In [49]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# dataset = LargeImageDataset(root_dir='path/to/dataset', transform=transform)
# dataloader = DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)

# for epoch in range(5):  # Assume 5 epochs
#     for batch_idx, (images, labels) in enumerate(dataloader):
#         pass