In [44]:
import os
import cv2
import torch
from torch.utils.data import Dataset

class Image_Dataset2(Dataset):
    def __init__(self, root_dir: str, image_res: int = 256, grayscale: bool = False):
        self.root_dir = root_dir
        self.image_res = image_res
        self.grayscale = grayscale
        self.images = [os.path.join(root, file)
                       for root, _, files in os.walk(root_dir)
                       for file in files if file.endswith(('.jpg', '.png'))]
        
        # Debug: Print the number of images found
        print(f"Number of images found: {len(self.images)}")

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int):
        image_path = self.images[idx]
        read_flag = cv2.IMREAD_GRAYSCALE if self.grayscale else cv2.IMREAD_COLOR
        image = cv2.imread(image_path, read_flag)
        image = cv2.resize(image, (self.image_res, self.image_res))

        if self.grayscale:
            image_tensor = (torch.tensor(image, dtype=torch.float32) / 255.0).unsqueeze(0)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_tensor = torch.tensor(image, dtype=torch.float32) / 255.0
        
        return image_tensor

In [None]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

image_size = 128
grayscale = False
train_data = Image_Dataset2(root_dir="../datasets/animalFaces", image_res=image_size, grayscale=grayscale)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

for batch in train_loader:
    print(batch.shape)
    if grayscale:
        plt.imshow(batch[0].squeeze().numpy(), cmap="gray")
    else:
        plt.imshow(batch[0].permute(1, 0, 2).numpy())
    plt.show()
    break
