In [5]:
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader

In [6]:
class BettaDataset(Dataset):
    def __init__(self):
        self.images_dir = "betta"
        image_list = os.listdir(self.images_dir)
        print("Dataset contains ", len(image_list), " betta images.")
        self.data = []
        for image_name in image_list:
            image_label = "betta"
            image_path = self.images_dir + os.sep + image_name
            self.data.append([image_path, image_label])
        self.class_map = {"betta" : 0}
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idex):
        image_path, class_name = self.data[idex]
        image = cv2.imread(image_path)
        class_id = self.class_map[class_name]
        image_tensor = torch.from_numpy(image) / 255
        image_tensor = image_tensor.permute(2, 0, 1)
        return image_tensor, class_id

In [7]:
# Load the dataset
dataset = BettaDataset()
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

Dataset contains  1356  betta images.


In [8]:
dataset.__getitem__(5)

(tensor([[[0.5098, 0.5294, 0.5451,  ..., 0.6392, 0.6196, 0.6078],
          [0.5176, 0.5333, 0.5529,  ..., 0.6078, 0.6196, 0.6314],
          [0.5216, 0.5373, 0.5490,  ..., 0.6078, 0.6235, 0.6275],
          ...,
          [0.3843, 0.3843, 0.4078,  ..., 0.4667, 0.4471, 0.4627],
          [0.3608, 0.3569, 0.3686,  ..., 0.4314, 0.4275, 0.4588],
          [0.3529, 0.3804, 0.3961,  ..., 0.4314, 0.4471, 0.4863]],
 
         [[0.6510, 0.6471, 0.6431,  ..., 0.6706, 0.6471, 0.6353],
          [0.6471, 0.6510, 0.6510,  ..., 0.6392, 0.6471, 0.6588],
          [0.6431, 0.6471, 0.6549,  ..., 0.6392, 0.6471, 0.6549],
          ...,
          [0.6980, 0.7059, 0.7294,  ..., 0.5804, 0.5608, 0.5765],
          [0.6902, 0.6902, 0.7098,  ..., 0.5569, 0.5412, 0.5725],
          [0.6824, 0.7137, 0.7373,  ..., 0.5569, 0.5725, 0.6118]],
 
         [[0.6824, 0.6667, 0.6510,  ..., 0.5529, 0.5176, 0.5059],
          [0.6824, 0.6706, 0.6667,  ..., 0.5216, 0.5176, 0.5294],
          [0.6784, 0.6745, 0.6706,  ...,