In [None]:
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

sys.path.append(os.path.abspath("../"))

from src.datamodule import MXFaceDataset

%load_ext autoreload
%autoreload 2

In [None]:
root_dir = Path(os.environ["DATASET_DIR"]) / "FIQA-Datasets" / "TrainDatasets" / "casia_webface"
transform = transforms.Compose(
    [
        # transforms.ToPILImage(),
        transforms.ToTensor(),
    ]
)
dataset = MXFaceDataset(root_dir=str(root_dir), transform=transform)
print(f"Number of images in dataset: {len(dataset):,}")
# We assume that the data is sorted. Therefore we acces the last element to find out the number of identities
img, label = dataset[-1]
print(f"Number of identities in dataset: {label.item() + 1:,}")
print(f"img: {type(img)}, label: {type(label)}")

In [None]:
BATCH_SIZE = 16
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Load data only from the main process
)
batch = next(iter(dataloader))
x, y = batch
print(x.shape)

In [None]:
# Visualize batch
fig, axs = plt.subplots(4, 4)
fig.subplots_adjust(hspace=0.5, wspace=-0.5)
axs = np.ravel(axs)  # convert 2D array to 1D list
for i in range(BATCH_SIZE):
    img = x[i].numpy()
    img = img.transpose(1, 2, 0)  # Convert CHW to HWC
    label = y[i]  # label is torch.Tensor
    axs[i].imshow(img)
    axs[i].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    axs[i].set_xlabel(label.item())
plt.show()