In [2]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

In [3]:

class MyData(Dataset):
    def __init__(self, root_dir, image_dir, label_dir, transform):
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)
        self.transform = transform
        # 因为label 和 Image文件名相同，进行一样的排序，可以保证取出的数据和label是一一对应的
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]
        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
        img = Image.open(img_item_path)

        # 打开文件，读取label。 with open() as f: 会自动关闭文件
        with open(label_item_path, 'r') as f: 
            label = f.readline()

        img = self.transform(img) # PIL img -> tensor
        sample = {'img': img, 'label': label}
        return sample

    def __len__(self):
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)

In [24]:
# 打印当前工作目录
current_directory = os.getcwd()
print("Current working directory:", current_directory)

Current working directory: d:\CODE\Pycharm\Codes\AI\scripts


In [4]:
if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    root_dir = "../dataset/ant_bee/train"
    image_ants = "ants_image"
    label_ants = "ants_label"
    ants_dataset = MyData(root_dir, image_ants, label_ants, transform)

    image_bees = "bees_image"
    label_bees = "bees_label"
    bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
    train_dataset = ants_dataset + bees_dataset
    dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)

In [31]:
    # 日志记录
    writer = SummaryWriter("../logs/ant_bee/")
    writer.add_image('image_test', train_dataset[119]['img'],dataformats='CHW')
    for i in range(100):
        writer.add_scalar('scalar_test', i*i, i) # title, y, x
    writer.close()

In [5]:
    # img show test
    idx = 124
    img, label = train_dataset[idx]['img'], train_dataset[idx]['label']
    print(img.shape)
    trans_norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    img_nor = trans_norm(img)
    img = transforms.ToPILImage()(img)
    img_nor = transforms.ToPILImage()(img_nor)
    img.show()
    img_nor.show()

torch.Size([3, 256, 256])
