In [10]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import Dataset,Features,ClassLabel,Image

In [5]:
data_dir="data/UTKFace"
files=[f for f in os.listdir(data_dir) if f.endswith(".jpg.chip.jpg")]
records=[]
for f in files:
    parts=f.split('_')
    try:
        age=int(parts[0])
        gender=int(parts[1])
        records.append([f,age,gender])
    except:
        continue

In [7]:
df=pd.DataFrame(records,columns=['filename','age','gender'])
df

Unnamed: 0,filename,age,gender
0,100_0_0_20170112213500903.jpg.chip.jpg,100,0
1,100_0_0_20170112215240346.jpg.chip.jpg,100,0
2,100_1_0_20170110183726390.jpg.chip.jpg,100,1
3,100_1_0_20170112213001988.jpg.chip.jpg,100,1
4,100_1_0_20170112213303693.jpg.chip.jpg,100,1
...,...,...,...
23700,9_1_3_20161220222856346.jpg.chip.jpg,9,1
23701,9_1_3_20170104222949455.jpg.chip.jpg,9,1
23702,9_1_4_20170103200637399.jpg.chip.jpg,9,1
23703,9_1_4_20170103200814791.jpg.chip.jpg,9,1


In [15]:

class UTKFaceDataset(Dataset):
    """Custom Dataset for the UTKFace dataset."""

    def __init__(self, df, image_dir, transform=None):
        """
        Args:
            df (pandas.DataFrame): DataFrame with filename, age, and gender columns.
            image_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.df)

    def __getitem__(self, idx):
        """
        Fetches the sample at the given index.

        Args:
            idx (int): The index of the sample.

        Returns:
            A dictionary with 'image', 'age', and 'gender' tensors.
        """
        # 1. Get image path and labels from the dataframe
        row = self.df.iloc[idx]
        img_name = os.path.join(self.image_dir, row['filename'])
        
        # 2. Load the image
        # Using .convert('RGB') is important to handle grayscale or other formats
        image = Image.open(img_name).convert('RGB')
        
        # 3. Get labels and convert them to tensors
        # Age is a continuous value, so a float tensor is appropriate.
        age = torch.tensor(row['age'], dtype=torch.float32)
        
        # Gender is a class (e.g., 0 or 1), so a long tensor is appropriate for classification.
        gender = torch.tensor(row['gender'], dtype=torch.long)

        # 4. Apply transformations, if any
        if self.transform:
            image = self.transform(image)
            
        # 5. Return a dictionary containing the data
        sample = {'image': image, 'age': age, 'gender': gender}
        
        return sample
