In [1]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from collections import Counter
from matplotlib import pyplot as plt

# Dataset image sizes

Datasets

- Real faces: `ffhq_real_faces`
    - 3143 images
    - these are all in `png` format
- Diffusion-generated faces (set 1): `AIS-4SD/StableDiffusion-3-faces-20250203-1545`
    - 500 images
    - these are all in `png` format
- Diffusion-generated faces (set 2): `SFHQ-T2I`
    - 1724 images
    - these are all in `jpg` format

For each of these datasets, we are interested in the image size and whether it's consistent for all images within the dataset.

In [None]:
def get_image_sizes(image_folder_path: Path) -> list:
    """
    """
    image_names = os.listdir(image_folder_path)
    image_sizes = []
    for i in tqdm(range(len(image_names))):
        test_image_path = image_folder_path / image_names[i]
        test_img = Image.open(test_image_path)
        image_sizes.append(test_img.size)
    return image_sizes

## Real images

In [None]:
real_images_path = Path("data/ffhq_real_faces")
image_sizes = get_image_sizes(real_images_path)
Counter(image_sizes)

All images in this dataset have size (1024, 1024)

## Synthetic images

### AIS-4SD

In [None]:
synth_images_1_path = Path("data/AIS-4SD/StableDiffusion-3-faces-20250203-1545")
image_sizes = get_image_sizes(synth_images_1_path)
Counter(image_sizes)

All images in this dataset have size (768, 768)

### SFHQ-T2I

In [None]:
synth_images_2_path = Path("data/SFHQ-T2I")
image_sizes = get_image_sizes(synth_images_2_path)
Counter(image_sizes)

All images in this dataset have size (1024, 1024)

## Summary

- All 3143 real images have size (1024, 1024)
- 1724 of the diffusion-generated images have size (1024, 1024), but 500 of them have size (768, 768)

I could upscale the smaller images, but it would be safer (less likely to introduce image artifacts) to reduce the size of the larger images to (768, 768).

# Image pre-processing

I will experiment image pre-processing techniques in this notebook as it will be easier to display the images and understand how they are transformed by various functions.

A recap of PyTorch dataset functionality:
- `torch.utils.data.Dataset` stores the samples and their corresponding labels
    - Each time it's called, it returns an [input, label] pair
    - Pre-processing functions can be defined / called inside this class
    - A custom Dataset class must implement three functions: __init__, __len__, and __getitem__
- `torch.utils.data.DataLoader` wraps an iterable around the Dataset to enable easy access to the samples   
    - Enabled iteration through the dataset in batches
    - Provides access to in-built functions for shuffling, parallel processing etc
    - Calls the `__getitem__()` function from the Dataset class to create a batch of data


In [9]:
image_folder_names_to_labels = {
    "ffhq_real_faces": "real",
    "AIS-4SD/StableDiffusion-3-faces-20250203-1545": "synthetic",
    "SFHQ-T2I": "synthetic"
}

data_root_dir = Path("data")

In [None]:
class FaceImageDataset:
    def __init__(self, data_root_dir: Path):
        """
        Args:
            data_root_dir: Path to directory containing image subdirectories
        """
        self.data_root_dir = data_root_dir
        self.img_size = (768, 768)
        self.samples = []
        self.get_real_images()
        self.get_synth_images()

    def get_real_images(self):
        real_folder_paths = [folder_path for folder_path, class_label in image_folder_names_to_labels.items() if class_label == "real"]
        for folder_path in real_folder_paths:
            for image_name in os.listdir(self.data_root_dir / folder_path):
                if image_name.lower().endswith((".png", ".jpg")):
                    image_path = self.data_root_dir / folder_path / image_name
                    self.samples.append((image_path, 0))

    def get_synth_images(self):
        synth_folder_paths = [folder_path for folder_path, class_label in image_folder_names_to_labels.items() if class_label == "synthetic"]
        for folder_path in synth_folder_paths:
            for image_name in os.listdir(self.data_root_dir / folder_path):
                if image_name.lower().endswith((".png", ".jpg")):
                    image_path = self.data_root_dir / folder_path / image_name
                    self.samples.append((image_path, 1))
    
    def apply_transforms(self):
        image_transforms = transforms.Compose([
            transforms.Resize(size=self.image_size)
        ])

    def __len__(self):
        """Return the total number of samples"""
        return len(self.samples)

    def __getitem__(self, idx: int):
        """
        Get one sample
        Returns:
            image: Transformed image tensor
            label: 0 for real, 1 for synthetic
        """
        image_path, label = self.samples[idx]
        image = Image.open(img_path)
        image = self.apply_transforms(image)
        
        return image, label
        