In [44]:
import torch
from torch.utils import data
from PIL import Image   #  pip install pillow
import numpy as np
from torchvision import transforms
import glob
import torchvision

In [4]:
all_imgs_path = glob.glob('/home/bear/dl/data/dataset2/*.jpg')

In [5]:
all_imgs_path[:5]

['/home/bear/dl/data/dataset2/cloudy60.jpg',
 '/home/bear/dl/data/dataset2/rain184.jpg',
 '/home/bear/dl/data/dataset2/sunrise242.jpg',
 '/home/bear/dl/data/dataset2/rain189.jpg',
 '/home/bear/dl/data/dataset2/rain143.jpg']

In [6]:
weather_dataset = Mydataset(all_imgs_path)

In [7]:
len(weather_dataset)

1122

In [8]:
weather_dataset[23]

'/home/bear/dl/data/dataset2/sunrise356.jpg'

In [9]:
wh_dl = torch.utils.data.DataLoader(weather_dataset, batch_size=4)

In [10]:
next(iter(wh_dl))

['/home/bear/dl/data/dataset2/cloudy60.jpg',
 '/home/bear/dl/data/dataset2/rain184.jpg',
 '/home/bear/dl/data/dataset2/sunrise242.jpg',
 '/home/bear/dl/data/dataset2/rain189.jpg']

In [11]:
species = ['cloudy', 'rain', 'shine', 'sunrise']

In [12]:
species_to_idx = dict((c, i) for i, c in enumerate(species))

In [13]:
species_to_idx

{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}

In [14]:
idx_to_species = dict((k, v) for v, k in species_to_idx.items())

NameError: name 'idx_to_species' is not defined

In [16]:
all_labels = []

for img in all_imgs_path:
    for i, c in enumerate(species):
        if c in img:
            all_labels.append(i)


In [17]:
all_labels[:5]

[0, 1, 3, 1, 1]

In [18]:
transform = transforms.Compose([
                    transforms.Resize((96, 96)),
                    transforms.ToTensor(),
])

In [19]:
class Mydataset(data.Dataset):
    def __init__(self, img_paths, labels, transform):
        self.imgs = img_paths
        self.labels = labels
        self.transforms = transform
        
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        
        pil_img = Image.open(img)
        pil_img = pil_img.convert("RGB")    # 可选,建议都使用
        data = self.transforms(pil_img)
        
        return data, label
    
    def __len__(self):
        return len(self.imgs)

In [20]:
weather_dataset = Mydataset(all_imgs_path, all_labels, transform)

In [21]:
type(weather_dataset)

__main__.Mydataset

In [22]:
weather_dl = data.DataLoader(
                            weather_dataset,
                            batch_size=16,
                            shuffle=True,
                            num_workers=4
)

In [23]:
imgs_batch, labels_batch = next(iter(weather_dl))

In [25]:
imgs_batch.shape

torch.Size([16, 3, 96, 96])

In [26]:
labels_batch.shape

torch.Size([16])

In [27]:
index = np.random.permutation(len(all_imgs_path))

In [28]:
index

array([ 285,  250,  529, ..., 1053,  719,  522])

In [29]:
all_imgs_path = np.array(all_imgs_path)[index]

In [31]:
all_labels = np.array(all_labels)[index]

In [32]:
s = int(len(all_imgs_path)*0.8)

In [33]:
train_imgs = all_imgs_path[:s]
train_labels = all_labels[:s]
test_imgs = all_imgs_path[s:]
test_labels = all_labels[s:]

In [34]:
train_ds = Mydataset(train_imgs, train_labels, transform)

In [35]:
test_ds = Mydataset(test_imgs, test_labels, transform)

In [36]:
train_dl = data.DataLoader(train_ds,
                           batch_size=16,
                           shuffle=True)
test_dl = data.DataLoader(test_ds,
                           batch_size=16)

In [38]:
imgs, labels = next(iter(train_dl))

In [39]:
imgs.shape

torch.Size([16, 3, 96, 96])

In [40]:
class New_dataset(data.Dataset):
    def __init__(self, some_dataset):
        self.ds = some_dataset
    def __getitem__(self, index):
        img, label = self.ds[index]
        img = img.permute(1,2,0)
        return img, label
    def __len__(self):
        return len(self.ds)
        

In [41]:
train_new_dataset = New_dataset(train_ds)

In [42]:
img, label = train_new_dataset[2]

In [45]:
img_dir = r'/home/bear/dl/data/dataset/animals'

In [66]:
dataset =  torchvision.datasets.ImageFolder(
        img_dir
)

In [67]:
dataset.classes

['elephant', 'giraffe', 'lion', 'monkey', 'tiger']

In [48]:
dataset.class_to_idx

{'elephant': 0, 'giraffe': 1, 'lion': 2, 'monkey': 3, 'tiger': 4}

In [49]:
count = len(dataset)

In [50]:
count

1955

In [51]:
train_count = int(0.8*count)

In [52]:
train_count

1564

In [53]:
test_count = count - train_count

In [54]:
test_count

391

In [79]:
train_dataset, test_dataset = data.random_split(dataset, [train_count, test_count])

In [80]:
train_dataset=train_dataset.dataset


In [92]:
class New_dataset(data.Dataset):
    def __init__(self, some_dataset,transform):
        self.transforms = transform
        self.ds = some_dataset
    def __getitem__(self, index):
        img, label = self.ds[index]
        data = self.transforms(img)
        return data, label
    def __len__(self):
        return len(self.ds)

In [94]:
train_new_dataset = New_dataset(train_dataset, transform)

In [95]:
img, label = train_new_dataset[0]

In [96]:
img

tensor([[[0.7569, 0.7804, 0.8078,  ..., 0.7765, 0.7765, 0.7804],
         [0.7647, 0.7882, 0.8196,  ..., 0.7765, 0.7765, 0.7804],
         [0.7569, 0.7843, 0.8196,  ..., 0.7765, 0.7765, 0.7804],
         ...,
         [0.8275, 0.7725, 0.7608,  ..., 0.3451, 0.3294, 0.3255],
         [0.8314, 0.7765, 0.7294,  ..., 0.3333, 0.3255, 0.3490],
         [0.8039, 0.7647, 0.7059,  ..., 0.3216, 0.3216, 0.3333]],

        [[0.8314, 0.8510, 0.8706,  ..., 0.8471, 0.8471, 0.8510],
         [0.8471, 0.8627, 0.8863,  ..., 0.8471, 0.8471, 0.8510],
         [0.8549, 0.8706, 0.8902,  ..., 0.8471, 0.8471, 0.8510],
         ...,
         [0.7216, 0.6667, 0.6549,  ..., 0.3294, 0.3137, 0.3098],
         [0.7255, 0.6706, 0.6235,  ..., 0.3176, 0.3098, 0.3333],
         [0.6980, 0.6588, 0.6000,  ..., 0.3059, 0.3059, 0.3176]],

        [[0.8588, 0.8784, 0.8980,  ..., 0.8627, 0.8627, 0.8667],
         [0.8745, 0.8941, 0.9137,  ..., 0.8627, 0.8627, 0.8667],
         [0.8863, 0.9020, 0.9216,  ..., 0.8627, 0.8627, 0.

In [63]:
pil_img = Image.open('/home/bear/dl/data/dataset/animals/elephant/pic_081.jpg')

In [65]:
pil_img.show

<bound method Image.show of <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=276x183 at 0x7F4C69F8AE50>>