In [10]:
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import glob
import cv2
import os

In [23]:
class CustomDataset(Dataset):
    def __init__(self):
        self.imgs_path = r"dataset/"
        file_list = glob.glob(self.imgs_path + "*")
        #print(f"File list: {file_list}")
        
        self.data = []
        self.class_counts = {}
        
        for class_path in file_list:
            class_name = os.path.basename(class_path)
            img_paths = glob.glob(os.path.join(class_path, "*.jpg")) + glob.glob(os.path.join(class_path, "*.jpeg"))
            self.class_counts[class_name] = len(img_paths)
            for img_path in img_paths:
                self.data.append([img_path, class_name])
                
        #self.data = [[img_path.replace("\\", "/"), class_name] for img_path, class_name in self.data]

        #print(self.data)
        
        self.class_map = {"cataract": 0, "diabetic_retinopathy": 1, "glaucoma": 2, "normal": 3}
        self.img_dim = (416, 416)
    
    def __len__(self):
        return len(self.data)
        
    
    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = cv2.imread(img_path)
        img = cv2.resize(img, self.img_dim)
        class_id = self.class_map[class_name]
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(2,0,1)
        class_id = torch.tensor([class_id])
        return img_tensor, class_id
    
#CustomDataset()

In [17]:
dataset = CustomDataset()
class_counts = dataset.class_counts
print(class_counts)

{'cataract': 938, 'diabetic_retinopathy': 1098, 'glaucoma': 906, 'normal': 1074}


In [8]:
#if __name__ == "__main__":
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size = 64, shuffle = True)

In [61]:
for imgs, labels in dataloader:
    print("Batch of images has shape: ",imgs[0].shape)
    break
    print("Batch of labels has shape: ", labels[0].shape)

Batch of images has shape:  torch.Size([3, 416, 416])


In [62]:
for imgs, labels in dataloader:
    print("Batch of labels has shape: ", labels[0].shape)
    break

Batch of labels has shape:  torch.Size([1])
