In [1]:
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms

In [2]:
class SsdData(Dataset):
    def __init__(self, root_dir, image_dir, label_dir, transform):
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)
        self.transform = transform
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]
        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
        img = Image.open(img_item_path)

        with open(label_item_path, 'r') as f:
            label = f.readline()

        # img = np.array(img)
        img = self.transform(img)
        sample = {'img': img, 'label': label}
        return sample

    def __len__(self):
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)

In [3]:
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
root_dir = "./data/pictures/train"

In [7]:
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = SsdData(root_dir, image_ants, label_ants, transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = SsdData(root_dir, image_bees, label_bees, transform)
train_dataset = ants_dataset + bees_dataset

In [20]:
print(train_dataset[0]["img"].shape)
# print(train_dataset[0]["img"])

torch.Size([3, 256, 256])


In [3]:
def collate_fn(batch):
    return tuple(batch)


dataloader = DataLoader(
    train_dataset, batch_size=8, shuffle=True,
    num_workers=0, drop_last=True,
    collate_fn=collate_fn
)

for i, j in enumerate(dataloader):
    print(i, len(j), j[0]["label"], j[0]["img"].shape)




In [2]:
from PIL import Image

image = Image.open("./data/backup/airplane.png")
image.show()