In [23]:
from pathlib import Path
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

In [4]:
data_path = Path("lfw-deepfunneled")
train_path = data_path / "train"
test_path = data_path / "test"

train_images = list(train_path.glob("*/*"))
test_images = list(test_path.glob("*/*"))

In [39]:
def get_info_one_image(img_path: Path) -> dict:
    """
    Returns:
        info for a single image
    """
    img = Image.open(img_path)
    img = np.asarray(img)
    shape = img.shape
    assert len(shape) < 4, f"{img_path} has alpha channel"
    H, W, C = shape
    assert H == 250, f"{img_path} height is {H}, expected 250"
    assert W == 250, f"{img_path} width is {W}, expected 250"
    assert C == 3, f"{img_path} channels is {C}, expected 3"

    person = img_path.parent.stem
    img_identifier = img_path.stem
    file_extension = img_path.suffix

    assert file_extension == ".jpg", f"{img_path} file extension is {file_extension}, expected .jpg"

    data = {"person": person,
            "identifier": img_identifier,
            "file_extension": file_extension,
            "height": H,
            "width": W,
            "channels": C}
    return data


def get_info_all_images(paths: list[Path]) -> pd.DataFrame:
    """
    Returns:
        info for all images in paths
    """
    all_data = defaultdict(list)
    for p in paths:
        data = get_info_one_image(p)
        for k, v in data.items():
            all_data[k].append(v)
    df = pd.DataFrame(all_data)
    return df

train_info_df = get_info_all_images(train_images)
test_info_df = get_info_all_images(test_images)

In [49]:
# ~12k datapoints, ~5.6k unique people, ~1.5k have more than one person
print(len(train_info_df))
person_groups = train_info_df.groupby("person")

print(len(person_groups))
print((person_groups.size() > 1).sum())

12185
5603
1534


In [50]:
# test set is ~1k datapoints, ~750 unique people
print(len(test_info_df))
person_groups = test_info_df.groupby("person")

print(len(person_groups))
print((person_groups.size() > 1).sum())

1048
756
146
