In [None]:
from datasets import load_dataset
import torch
import torchvision.transforms as T
from torch.utils.data import IterableDataset, DataLoader


dataset = load_dataset(
    "Supermaxman/esa-hubble",
    split="train",
    cache_dir="./data/hubble",
    streaming=True
)


dataset.shuffle(seed=42)
resize_transform = T.Compose([
    T.ToTensor(),
    T.RandomCrop(size = 128),
    T.Normalize(mean = (0,0,0), std = (1,1,1), inplace=False),
])

def collate(batch):
    # batch is a list of samples
    images = [resize_transform(sample["image"]) for sample in batch]
    images = torch.stack(images)
    return images

dataloader = DataLoader(dataset, batch_size=1, num_workers=0, collate_fn = collate)

# Example usage: iterate over one batch.

for j in range(2):
    for images in dataloader:
        print(images[0])  # Expected shape: [batch_size, 3, 128, 128]
        break

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import functional as F

class FaceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Path to the directory with subject subfolders.
            transform (callable, optional): Optional transforms to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # Gather image file paths from all subject directories
        self.samples = []
        # Each subdirectory contains images from one subject
        for subject in sorted(os.listdir(root_dir)):
            subject_dir = os.path.join(root_dir, subject)
            if os.path.isdir(subject_dir):
                for fname in sorted(os.listdir(subject_dir)):
                    if fname.endswith('.png'):
                        self.samples.append(os.path.join(subject_dir, fname))
                        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, index):
        img_path = self.samples[index]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# Define your transforms: resize, convert to tensor, and normalize.
transform = T.Compose([
    T.Resize((128, 128)),
    T.ToTensor(),  # Converts PIL image to tensor (scales pixel values to [0,1])
    # T.Normalize(mean=[0.485, 0.456, 0.406], 
    #             std=[0.229, 0.224, 0.225])
])

# Create the dataset and dataloader.
dataset = FaceDataset('data/faces/subjects_0-1999_72_imgs', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)

# Example usage: iterate over one batch.
for images in dataloader:
    print(images.shape)  # Expected shape: [batch_size, 3, 128, 128]
    display(F.to_pil_image(images[0]))
    break
    
