# 建立一个CNN对花朵图片进行分类
本例中使用了``torchvision.datasets.ImageFolder``以及``torch.utils.data.DataLoader``，根据本地图片文件生成训练集和测试集

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt

In [2]:
data_dir = os.path.join('..', 'data')
flower_dir = os.path.join(data_dir, 'flower_photos')
train_dir = os.path.join(flower_dir, 'train')
test_dir = os.path.join(flower_dir, 'test')

## transforms, ImageFolder, DataLoader
* 使用``torchvision.transforms.Compose``定义一些Data augmentation操作
* 首先使用``torchvision.datasets.ImageFolder``读取出图片数据，之后传入transforms操作，对图片进行处理
* 使用``torch.utils.data.DataLoader``定义Dataloader

### transforms.ToTensor()
将PIL Image或者 ndarray 转换为tensor，并且归一化至$[0-1]$
* 注意事项：归一化至$[0-1]$是直接除以255，若自己的ndarray数据尺度有变化，则需要自行修改。

### transforms.RandomRotation(degrees, resample=False, expand=False, center=None)
依degrees随机旋转一定角度

参数：
* ``degress``- (sequence or float or int) ，若为单个数，如 30，则表示在（-30，+30）之间随机旋转
若为sequence，如(30，60)，则表示在30-60度之间随机旋转
* ``resample``- 重采样方法选择，可选 PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC，默认为最近邻
* ``expand``- Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation.
* ``center``- 可选为中心旋转还是左上角旋转。Default is the center of the image.

**需要注意，``transforms.Resize(size)``如果只传入一个int，则会将尺寸为(height, width)的图片变为(size * height / width, size)(height > width)。**

In [3]:
input_size = 224
batch_size = 64

data_transforms = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ToTensor()
    ]),
    "val": transforms.Compose([
        transforms.Resize((input_size, input_size)),
#         transforms.CenterCrop(input_size),
        transforms.ToTensor()
    ])
}

image_datasets = {x: ImageFolder(os.path.join(flower_dir, x), data_transforms[x]) for x in ["train", "val"]}

train_loader, test_loader = [torch.utils.data.DataLoader(image_datasets[x], 
        batch_size=batch_size, shuffle=True, num_workers=4) for x in ["train", "val"]]

## 定义一个比较简单的多层卷积神经网络

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, padding=1)  # 224 * 224
        self.conv2 = nn.Conv2d(6, 16, 3, padding=1)
        self.conv3 = nn.Conv2d(16, 24, 3, padding=1)
        self.conv4 = nn.Conv2d(24, 10, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(10 * 14 * 14, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 5)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # 112 * 112
        x = self.pool(self.relu(self.conv2(x)))  # 56 * 56
        x = self.pool(self.relu(self.conv3(x)))  # 28 * 28
        x = self.pool(self.relu(self.conv4(x)))  # 14 * 14
        x = x.view(-1, 10 * 14 * 14)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [5]:
import torch.optim as optim

#nn.CrossEntropyLoss()中已包含softmax激活运算
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [6]:
for epoch in range(8):  # loop over the dataset multiple times
    train_correct = 0
    train_total = 0
    train_loss = 0.
    for i, data in enumerate(train_loader, 0):
        # get the inputs
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        value_pred, label_pred = torch.max(outputs, axis=1)
        train_correct += (labels == label_pred).sum().item()
        train_total += labels.shape[0]
        train_loss += loss.item() * labels.shape[0]

    train_loss /= train_total
    train_correct /= train_total

    # print statistics
    print('Train Epoch: %d\nTrain: Loss: %.4f, accuracy: %.4f ' % (epoch, train_loss, train_correct), end='')

    test_correct = 0
    test_total = 0
    test_loss = 0.
    with torch.no_grad():
        for images, labels in test_loader:
            y_pred = net(images)
            value_pred, label_pred = torch.max(y_pred, axis=1)
            test_correct += (labels == label_pred).sum().item()
            test_total += labels.shape[0]
            loss_batch = criterion(y_pred, labels)
            test_loss += loss_batch.item() * labels.shape[0]

        test_loss /= test_total
        test_correct /= test_total
        print('Test: Loss: %.4f, accuracy: %.4f' % (test_loss, test_correct))

print('Finished Training')

Train Epoch: 0
Train: Loss: 1.5872, accuracy: 0.2426 Test: Loss: 1.5444, accuracy: 0.2762
Train Epoch: 1
Train: Loss: 1.4645, accuracy: 0.3199 Test: Loss: 1.4020, accuracy: 0.3769
Train Epoch: 2
Train: Loss: 1.3272, accuracy: 0.4157 Test: Loss: 1.2645, accuracy: 0.4259
Train Epoch: 3
Train: Loss: 1.2146, accuracy: 0.4845 Test: Loss: 1.2075, accuracy: 0.5007
Train Epoch: 4
Train: Loss: 1.1525, accuracy: 0.5172 Test: Loss: 1.2079, accuracy: 0.4517
Train Epoch: 5
Train: Loss: 1.1461, accuracy: 0.5227 Test: Loss: 1.1160, accuracy: 0.5388
Train Epoch: 6
Train: Loss: 1.1159, accuracy: 0.5237 Test: Loss: 1.0781, accuracy: 0.5578
Train Epoch: 7
Train: Loss: 1.0967, accuracy: 0.5390 Test: Loss: 1.0513, accuracy: 0.5782
Finished Training


In [7]:
correct = 0
total = 0
loss = 0.
with torch.no_grad():
    for images, labels in test_loader:
        y_pred = net(images)
        value_pred, label_pred = torch.max(y_pred, axis=1)
        correct += (labels == label_pred).sum().item()
        total += labels.shape[0]
        loss_batch = criterion(y_pred, labels)
        loss += loss_batch.item() * labels.shape[0]
    
    loss /= total
    correct /= total
    print('Loss: {}, accuracy: {}'.format(loss, correct))

Loss: 1.0512882404586896, accuracy: 0.5782312925170068
