In [77]:
import torchvision
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F # 内置的激活函数
import copy # 保存模型权重
import glob # 所有路径

from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torch.optim import lr_scheduler # 衰减学习率
from torch.utils import data
from PIL import Image

#### 自定义类输入数据

In [1]:
class MyDataset(data.Dataset): # 自定义类
    def __init__(self, img_paths, img_labels, transform):
        self.imgs = img_paths # 获取图片路径
        self.img_labels = img_labels # 获取标签
        self.transforms = transform
        
    def __getitem__(self, index):
        img = self.imgs[index] # 切片
        label = self.img_labels[index] # 切片
        
        pil_img = Image.open(img) # python处理
        pil_img = pil_img.convert("RGB")   # 可选,建议都使用
        data = self.transforms(pil_img)
        
        return data, label
    
    def __len__(self):
        return len(self.imgs) # 总数


NameError: name 'data' is not defined

#### 数据预处理

In [None]:
# 数据路径
img_paths = glob.glob('D:\CODE\Code_Python\Pytorch\第8章\dataset2\*.jpg')

# 分类
species = ['cloudy', 'rain', 'shine', 'sunrise']
# 转换键值对
species_to_idx = dict((specie, i) for i, specie in enumerate(species))
idx_to_species = {v: k for k,v in species_to_idx.items()}

img_labels = [] # 标签
for img in img_paths: # 获取图片
    for i, specie in enumerate(species): # 获取类别及下标
        if specie in img: # 判断图片类别
            img_labels.append(i) # 追加

# 格式转换
transform = transforms.Compose([
                    transforms.Resize((96, 96)), # 统一大小
                    transforms.ToTensor(), # 规范格式
])

weather_ds = MyDataset(img_paths, img_labels, transform)
batch_size = 16
wheather_dl = data.DataLoader(weather_ds, batch_size, shuffle=True)

imgs_batch, labels_batch = next(iter(wheather_dl))

for i, (imgs,lables) in enumerate(zip(imgs_batch[:6],labels_batch[:6])):
    imgs = imgs.permute(1, 2, 0)
    plt.subplot(2,3,i+1)
    plt.imshow(imgs)
    plt.title(idx_to_species.get(lables.item()))

#### 训练/测试数据处理

In [98]:
order = np.random.permutation(len(img_paths)) # 随机打乱

img_slice = int(len(img_paths)*0.8) # 切出80%训练

# 训练数据
train_imgs = img_paths[:img_slice]
train_labels = img_labels[:img_slice]
# 测试数据
test_imgs = img_paths[img_slice:]
test_labels = img_labels[img_slice:]

batch_size = 16
# 训练数据处理 批管理
train_ds = MyDataset(train_imgs, train_labels, transform)
train_dl = data.DataLoader(train_ds, batch_size, shuffle=True)
# 测试数据处理 批管理
test_ds = MyDataset(test_imgs, test_labels, transform)
test_dl = data.DataLoader(test_ds, batch_size)

#### 自定义类实现( width, height, channel)

In [100]:
class changeShape(data.Dataset):
    def __init__(self, dataset):
        self.ds = dataset # 获取dataset
    def __getitem__(self, index):
        img, label = self.ds[index] # 切片
        img = img.permute(1, 2, 0) # 转换
        return img, label
    def __len__(self):
        return len(self.ds) # 总数