In [None]:
from pathlib import Path
from random import choice, sample, seed
from typing import List, Tuple, cast

from numpy import ndarray
from matplotlib.pyplot import subplots, Figure # type: ignore
from PIL.Image import open as open_image
from torch.cuda import is_available
from torchvision.transforms import Grayscale, Resize, ToTensor, Normalize, Compose # type: ignore
from torchvision.datasets import ImageFolder # type: ignore

In [None]:
DEVICE = "cuda" if is_available() else "cpu"

In [None]:
path = Path.cwd() / "data"

In [None]:
image_list = [*path.glob("*/*")]
image_path = choice(image_list)
image = open_image(image_path)

print(f"Random Image Path : {image_path}")
print(f"Image Class : {image_path.parent.stem}")
print(f"Image Height : {image.height}")
print(f"Image Width : {image.width}")
image

In [None]:
transforms = Compose([
    Grayscale(num_output_channels=1),
    Resize((128, 128)),
    ToTensor(),
    Normalize((0.5,), (.5,))
])

transforms

In [None]:
def plot_transformed_image(image_paths: List[Path], transform: Compose, n: int = 3, r_seed: int = 42):
    seed(r_seed)
    random_image_paths = sample(image_paths, k=n)
    for image_path in random_image_paths:
        with open_image(image_path) as f:
            fig, ax = cast(Tuple[Figure, ndarray], subplots(1, 2)) # type: ignore
            print(type(fig), type(ax)) # type: ignore
            ax[0].imshow(f)
            ax[0].set_title(f"Original Size : {f.size}")
            ax[0].axis(False)
            transformed_image = transform(f).permute(1,2,0) # type: ignore
            ax[1].imshow(transformed_image)
            ax[1].set_title(f"Transformed \nsize : {transformed_image.shape}") # type: ignore
            ax[1].axis(False)
            
            fig.suptitle(f"Class : {image_path.parent.stem}", fontsize=16) # type: ignore

In [None]:
plot_transformed_image(image_list, transforms)

In [None]:
train_data = ImageFolder(
    root=path.as_posix(),
    transform=transforms,
    target_transform=None,
)
train_data