In [1]:
import torch 

### pytorch 加载数据
- Dataset 提供一种方式去获取数据以及其 label
    - 如何获取每一个数据及其 label
    - 告诉我们总共有多少数据
- Dataloader 为后面的网络提供不同的数据形式

下面展示从蜜蜂蚂蚁数据集hymenoptera_data构建dataset


In [2]:
# 读取图片数据
from glob import glob
import cv2  

ants_train_dir = "/Users/xuxu/WorkSpace/PythonProjects/MachineLearning/PytorchBasic/_data/hymenoptera_data/hymenoptera_data/train/ants"
bees_train_dir = "/Users/xuxu/WorkSpace/PythonProjects/MachineLearning/PytorchBasic/_data/hymenoptera_data/hymenoptera_data/train/bees"

ants_train_paths = glob(ants_train_dir + "/*.jpg")
img = cv2.imread(ants_train_paths[29])

In [3]:
from torch.utils.data import Dataset
import os 
import cv2  
from glob import glob

class MyData(Dataset):

    def __init__(self, root_dir, label):
        self.root_dir = root_dir 
        self.label = label
        self.path = os.path.join(root_dir, label)
        self.img_paths = glob(self.path + "/*.jpg")
    
    # idx 如果不是整数，则不能进行 dataset 相加
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = cv2.imread(img_path)
        label = self.label
        return img, label 

    def __len__(self):
        return len(self.img_paths)


ants_train_data = MyData('_data/hymenoptera_data/hymenoptera_data/train', 'ants')
bees_train_data = MyData('_data/hymenoptera_data/hymenoptera_data/train', 'bees')
train_data = ants_train_data + bees_train_data

### Tensorboard 的使用

In [4]:
# vscode 中通过命令打开 tensorboard
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("_log")
# train_data[0][0].shape (375, 500, 3)，使用dataformats指定图像格式，hight，wide，channel
writer.add_image("test", train_data[0][0], 0, dataformats="HWC")
for i in range(100): 
    # 参数，唯一标识, y, x
    writer.add_scalar("y=x^1.5", i**1.5+500, i)
writer.close()

### 常见的transform

1. Compose: 类似与sklearn的pipline，将多个transform进行组合
2. ToTensor: 将图像变为tensor
3. Normalize: 归一化 (channel)
4. Resize: 缩放至指定尺寸
5. RandomCrop: 随机裁剪


In [14]:
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter("./_log")
trans_tensor = transforms.ToTensor()
img_ndarray = train_data[0][0]
img_tensor = trans_tensor(img_ndarray)
writer.add_image( "ToTensor", img_tensor)


trans_norm = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
img_norm = trans_norm(img_tensor)
writer.add_image( "Normalize", img_norm)


print(img_tensor.size())
trans_resize = transforms.Resize((512, 512))
img_resize = trans_resize(img_tensor)
print(img_resize.size())
writer.add_image( "Resize", img_resize)


trans_compose = transforms.Compose([trans_tensor, trans_norm, trans_resize])
img_r = trans_compose(img_ndarray)
writer.add_image( "compose", img_r)

trans_rc = transforms.RandomCrop((300, 300))
for i in range(4):
    img_rc = trans_rc(img_tensor)
    writer.add_image( "RandomCrop", img_rc, i)
writer.close()


torch.Size([3, 375, 500])
torch.Size([3, 512, 512])


### torchvision 中的数据集

数据集样例

In [6]:
import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
tran_set = torchvision.datasets.CIFAR10(root="./_data", train=True, transform= dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./_data", train=False, transform= dataset_transform, download=True)

# print(test_set[0])
# print(test_set.classes)

# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()

writer = SummaryWriter("./_log/p10")
for i in range(10):
    img, t = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()



Files already downloaded and verified
Files already downloaded and verified
