In [58]:
import torch
from torch.utils import data
from PIL import Image   #  pip install pillow
import numpy as np
from torchvision import transforms

In [59]:
import matplotlib.pyplot as plt
%matplotlib inline

In [60]:
import glob

In [61]:
all_imgs_path = glob.glob('D:\CODE\Code_Python\Pytorch\第8章\dataset2\*.jpg')

In [62]:
all_imgs_path[:3]

['D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\cloudy1.jpg',
 'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\cloudy10.jpg',
 'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\cloudy100.jpg']

In [63]:
species = ['cloudy', 'rain', 'shine', 'sunrise']

In [64]:
species_to_idx = dict((c, i) for i, c in enumerate(species))

In [65]:
species_to_idx

{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}

In [66]:
idx_to_species = dict((v, k) for k, v in species_to_idx.items())

In [67]:
idx_to_species

{0: 'cloudy', 1: 'rain', 2: 'shine', 3: 'sunrise'}

In [68]:
all_labels = []

for img in all_imgs_path:
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)

In [107]:
all_labels[-5:]

array([2, 1, 3, 2, 1])

In [108]:
all_imgs_path[-5:]

array(['D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\shine136.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\rain117.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\sunrise107.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\shine170.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\rain40.jpg'],
      dtype='<U55')

In [71]:
index = np.random.permutation(len(all_imgs_path))

In [72]:
index

array([343, 957, 715, ..., 775, 592, 448])

In [73]:
all_imgs_path = np.array(all_imgs_path)[index]

In [74]:
all_imgs_path[:5]

array(['D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\rain138.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\sunrise271.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\shine53.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\cloudy176.jpg',
       'D:\\CODE\\Code_Python\\Pytorch\\第8章\\dataset2\\cloudy32.jpg'],
      dtype='<U55')

In [75]:
all_labels = np.array(all_labels)[index]

In [76]:
all_labels[:5]

array([1, 3, 2, 0, 0])

In [77]:
s = int(len(all_imgs_path)*0.8)

In [78]:
s

897

In [79]:
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

In [80]:
transform = transforms.Compose([
                    transforms.Resize((96, 96)),
                    transforms.ToTensor(),
])

In [81]:
class Mydataset(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
        
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        
        pil_img = Image.open(img)
        pil_img = pil_img.convert("RGB")    # 可选,建议都使用
        data = self.transforms(pil_img)
        
        return data, label
    
    def __len__(self):
        return len(self.imgs)

In [82]:
wheather_dataset = Mydataset(all_imgs_path, all_labels, transform)

In [83]:
type(wheather_dataset)

__main__.Mydataset

In [84]:
BATCH_SIZE = 16

In [85]:
wheather_dl = data.DataLoader(
                      wheather_dataset,
                      batch_size=BATCH_SIZE,
                      shuffle=True,
)

In [86]:
imgs_batch, labels_batch = next(iter(wheather_dl))

In [87]:
imgs_batch.shape

torch.Size([16, 3, 96, 96])

In [88]:
labels_batch.shape

torch.Size([16])

In [None]:
plt.figure(figsize=(12, 8))

for i, (img, label) in enumerate(zip(imgs_batch[-6:], labels_batch[-6:])):
    img = img.permute(1, 2, 0).numpy()
    plt.subplot(2, 3, i+1)
    plt.title(idx_to_species.get(label.item()))
    plt.imshow(img)

In [90]:
train_ds = Mydataset(train_imgs, train_labels, transform)

In [91]:
test_ds = Mydataset(test_imgs, test_labels, transform)

In [92]:
train_dl = data.DataLoader(train_ds,
                           batch_size=16,
                           shuffle=True)

In [93]:
test_dl = data.DataLoader(test_ds,
                           batch_size=16)

In [94]:
imgs, labels = next(iter(train_dl))

In [95]:
imgs.shape

torch.Size([16, 3, 96, 96])

In [96]:
class New_dataset(data.Dataset):
    def __init__(self, some_dataset):
        self.ds = some_dataset
    def __getitem__(self, index):
        img, label = self.ds[index]
        img = img.permute(1, 2, 0)
        return img, label
    def __len__(self):
        return len(self.ds)

In [97]:
train_new_dataset = New_dataset(train_ds)

In [98]:
img, label = train_new_dataset[2]

In [99]:
img.shape

torch.Size([96, 96, 3])