In [29]:
# 读取数据
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid


class MyData(Dataset):

    def __init__(self, root_dir, image_dir, label_dir, transform=None):
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)
        self.transform = transform
        # 因为label 和 Image文件名相同，进行一样的排序，可以保证取出的数据和label是一一对应的
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]
        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(
            self.root_dir, self.label_dir, label_name)
        img = Image.open(img_item_path)
        with open(label_item_path, 'r') as f:
            label = f.readline()

        if self.transform:
            img = transform(img)

        return img, label

    def __len__(self):
        assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)


# transforms = transforms.Compose([transforms.Resize(256, 256)])
# transforms.ToTensor将数据转化为Tensor
transform = transforms.Compose([transforms.Resize(400), transforms.ToTensor()])
root_dir = "/home/tzr/Documents/GitHubSYNC/PythonandMLLearning/Pytorch/demo1/hymenoptera_data/train/"
image_ants = "ants_image"
label_ants = "ants_label"
ants_dataset = MyData(root_dir, image_ants, label_ants, transform=transform)
image_bees = "bees_image"
label_bees = "bees_label"
bees_dataset = MyData(root_dir, image_bees, label_bees, transform=transform)


In [30]:
# print(ants_dataset)
# 转化为tensor

img_path = "/home/tzr/Documents/GitHubSYNC/PythonandMLLearning/Pytorch/demo1/hymenoptera_data/train/ants_image/6240329_72c01e663e.jpg"
img = Image.open(img_path)
tensor_trans = transforms.ToTensor()
tensro_img = tensor_trans(img)
# print(tensro_img)
writer = SummaryWriter("logs")

writer.add_image("Tensor_img",tensro_img)

writer.close()



In [38]:
# Normalize

# print(ants_dataset)
# 转化为tensor

img_path = "/home/tzr/Documents/GitHubSYNC/PythonandMLLearning/Pytorch/demo1/hymenoptera_data/train/ants_image/0013035.jpg"
img = Image.open(img_path)
tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
# print(tensro_img)
writer = SummaryWriter("logs")

writer.add_image("Tensor_img", tensro_img)

print(tensor_img[0][0][0])
trans_norm = transforms.Normalize([0.1, 0.1, 0.1], [0.5, 0.5, 0.5], [1, 2, 3])
img_norm = trans_norm(tensor_img)
writer.add_image("Normalise", img_norm, 3) # 3 是第三步的意思
print(tensor_img[0][0][0])

writer.close()


tensor(0.3137)
tensor(0.4275)


In [45]:
# Resize
print(img.size)
trans_resize = transforms.Resize((512,512))
img_resize = trans_resize(img)
# img_resize 再重新转化为Tensor
img_resize = tensor_trans(img_resize)
writer.add_image("Resize", img_resize, 0) 

# print(img_resize)


(768, 512)


In [53]:
# 等比例缩放
# Resize
print(img.size)
trans_resize2 = transforms.Resize(120)
# img_resize2 = trans_resize(img)
# img_resize2 再重新转化为Tensor
# img_resize2 = tensor_trans(img_resize2)
trans_compose = transforms.Compose([trans_resize2,tensor_trans])
img_resize_2=trans_compose(img)
writer.add_image("Resize2", img_resize_2, 0) 

print(type(img_resize))

(768, 512)
<class 'torch.Tensor'>


In [54]:
import torchvision
print(torchvision.__version__)

0.10.0
