In [None]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from collections import Counter
from matplotlib import pyplot as plt

# Dataset image sizes

Datasets

- Real faces: `ffhq_real_faces`
    - 3143 images
- Diffusion-generated faces (set 1): `AIS-4SD/StableDiffusion-3-faces-20250203-1545`
    - 500 images
- Diffusion-generated faces (set 2): `SFHQ-T2I`
    - 1724 images

For each of these datasets, we are interested in the image size and whether it's consistent for all images within the dataset.

In [None]:
def get_image_sizes(image_folder_path: Path) -> list:
    """
    """
    image_names = os.listdir(image_folder_path)
    image_sizes = []
    for i in tqdm(range(len(image_names))):
        test_image_path = image_folder_path / image_names[i]
        test_img = Image.open(test_image_path)
        image_sizes.append(test_img.size)
    return image_sizes

## Real images

In [None]:
real_images_path = Path("data/ffhq_real_faces")
image_sizes = get_image_sizes(real_images_path)
Counter(image_sizes)

All images in this dataset have size (1024, 1024)

## Synthetic images

### AIS-4SD

In [None]:
synth_images_1_path = Path("data/AIS-4SD/StableDiffusion-3-faces-20250203-1545")
image_sizes = get_image_sizes(synth_images_1_path)
Counter(image_sizes)

All images in this dataset have size (768, 768)

### SFHQ-T2I

In [None]:
synth_images_2_path = Path("data/SFHQ-T2I")
image_sizes = get_image_sizes(synth_images_2_path)
Counter(image_sizes)

All images in this dataset have size (1024, 1024)

## Summary

- All 3143 real images have size (1024, 1024)
- 1724 of the diffusion-generated images have size (1024, 1024), but 500 of them have size (768, 768)

I could upscale the smaller images, but it would be safer (less likely to introduce image artifacts) to reduce the size of the larger images to (768, 768).

# Image pre-processing

I will experiment image pre-processing techniques in this notebook as it will be easier to display the images and understand how they are transformed by various functions.

A recap of PyTorch dataset functionality:
- `torch.utils.data.Dataset` stores the samples and their corresponding labels
    - Each time it's called, it returns an [input, label] pair
    - Pre-processing functions can be defined / called inside this class
    - A custom Dataset class must implement three functions: __init__, __len__, and __getitem__
- `torch.utils.data.DataLoader` wraps an iterable around the Dataset to enable easy access to the samples   
    - Enabled iteration through the dataset in batches
    - Provides access to in-built functions for shuffling, parallel processing etc
    - Calls the `__getitem__()` function from the Dataset class to create a batch of data


In [None]:
# 0 is `real`
# 1 is `synthetic`

image_folder_names_to_labels = {
    "ffhq_real_faces": 0,
    "AIS-4SD/StableDiffusion-3-faces-20250203-1545": 1,
    "SFHQ-T2I": 1
}

In [None]:
class FaceImageDataset:
    def __init__(self):
        self.imgs_path = "data/"
        img_folder_names = list(image_folder_names_to_labels.keys())
        self.img_size = (768, 768)

    def __len__(self):
        pass

    def __getitem__(self):
        pass