In [2]:
!pip install torch torchvision opencv-python numpy Pillow scikit-learn matplotlib gunicorn Flask-Cors


Collecting gunicorn
  Using cached gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting Flask-Cors
  Using cached flask_cors-5.0.1-py3-none-any.whl.metadata (961 bytes)
Using cached gunicorn-23.0.0-py3-none-any.whl (85 kB)
Using cached flask_cors-5.0.1-py3-none-any.whl (11 kB)
Installing collected packages: gunicorn, Flask-Cors
Successfully installed Flask-Cors-5.0.1 gunicorn-23.0.0


In [4]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import cv2
import numpy as np

class UTKFaceDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        """
        Args:
            image_dir (string): Path to the directory containing images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path)

        # Parse the label from the filename: age_gender_race_date_time
        label = image_name.split('_')
        age = int(label[0]) // 5  # Convert age to a 5-year range (age/5)
        gender = int(label[1])  # 0 for male, 1 for female

        # Convert image to numpy array for processing
        image = np.array(image)

        # Apply face detection and cropping
        image = self.detect_and_crop_face(image)

        # If face is not detected, return None (or handle it as needed)
        if image is None:
            return None

        # Resize and apply transformations
        if self.transform:
            image = self.transform(image)

        return image, (age, gender)

    def detect_and_crop_face(self, image):
        """
        Detect face and crop it using OpenCV.
        We are using the Haar Cascade Classifier for face detection.
        """
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
        faces = face_cascade.detectMultiScale(gray, 1.1, 4)
        
        if len(faces) == 0:
            return None  # No face detected
        
        # Crop the first face detected (you can modify this to handle multiple faces)
        x, y, w, h = faces[0]
        image = image[y:y+h, x:x+w]
        
        return image
