In [None]:
# load adience dataset





In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import tarfile

class FacesDataset(Dataset):
    def __init__(self, root_dir, fold_files, transform=None):
        """
        Args:
            root_dir (string): Directory with the extracted faces data
            fold_files (list): List of paths to fold data txt files
            transform (callable, optional): Optional transform to be applied on images
        """
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        
        # Extract the tar.gz file if it hasn't been extracted yet
        faces_dir = os.path.join(root_dir, "faces")
        if not os.path.exists(faces_dir):
            tar_path = os.path.join(root_dir, "faces.tar.gz")
            if os.path.exists(tar_path):
                print(f"Extracting {tar_path}...")
                with tarfile.open(tar_path, "r:gz") as tar:
                    tar.extractall(path=root_dir)
                print("Extraction complete!")
        
        # Load all fold data files
        for fold_file in fold_files:
            fold_data = pd.read_csv(fold_file, sep='\t', header=None)
            # Assuming the format is: image_path, label
            for _, row in fold_data.iterrows():
                img_path = os.path.join(faces_dir, row[0])
                label = row[1]
                self.samples.append((img_path, label))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image and the label
            placeholder = torch.zeros((3, 224, 224)) if self.transform else Image.new('RGB', (224, 224))
            return placeholder, label

# Define transforms for preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to standard input size
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Set up the dataset
root_dir = "data"
fold_files = [
    os.path.join(root_dir, f"fold_{i}_data.txt") for i in range(5)
]

# Create dataset and dataloader
dataset = FacesDataset(root_dir=root_dir, fold_files=fold_files, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Example usage
for images, labels in dataloader:
    # Your training/validation code here
    print(f"Batch shape: {images.shape}, Labels shape: {labels.shape}")
    break  # Just to show the first batch

KeyboardInterrupt: 