In [1]:
from PIL import Image, UnidentifiedImageError
from pathlib import Path
import numpy as np
import torch

In [2]:
root_dir = Path(".").resolve().parent / "data" / "cats" / "Data"
assert root_dir.exists(), f"Root directory {root_dir} does not exist"

In [3]:
images = []
shapes = []

for file in root_dir.rglob("*.png"):
    try:
        img = Image.open(file).convert("RGB")
        img.verify()  # Verify that it is an image
        data = np.array(img)
        assert data.ndim == 3, f"Image {file} is not RGB image"
        assert data.shape[2] == 3, f"Image {file} does not have 3 channels"

        images.append(data)
        shapes.append(data.shape)

    except (UnidentifiedImageError, OSError) as e:
        print(f"Invalid image {file}: {e}")
        file.unlink()  # Remove the invalid file
    except Exception as e:
        print(f"Error processing {file}: {e}")

In [4]:
np.where(~(np.array(shapes) == (64, 64, 3)))

(array([], dtype=int64), array([], dtype=int64))

In [5]:
images_ds = np.stack(images)
del images

images_ds = torch.from_numpy(images_ds).permute(0, 3, 1, 2)
torch.save(images_ds, root_dir.parent / "cats_raw.pt")

In [6]:
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)

In [None]:
# images_ds = torch.load(root_dir.parent / "cats_raw.pt")
# images_ds = images_ds.to(torch.float32) / 255.0
# images_ds = (images_ds - IMAGENET_MEAN.reshape(1, 3, 1, 1)) / IMAGENET_STD.reshape(1, 3, 1, 1)
# torch.save(images_ds, root_dir.parent / "cats_transformed.pt")