In [None]:
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [None]:


class IllinoisFacesDataset(Dataset):
    def __init__(self, img_dir, csv_file, transform=None):
        self.img_dir = img_dir
        self.labels = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.labels.iloc[idx]['filename'])
        image = Image.open(img_name).convert("RGB")
        
        # Example: label could be race, sex, etc.
        label = self.labels.iloc[idx].to_dict()  

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
# Example transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Usage
dataset = IllinoisFacesDataset(
    img_dir="path/to/images",
    csv_file="path/to/labels.csv",
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
