# [CNNの原型であるLeNet-5を用いて洋服の画像を分類するレシピ](https://axross-recipe.com/recipes/23)

## 環境準備

In [None]:
# GPUの確認
!nvidia-smi

In [None]:
# Pythonのバージョン確認
import platform
print("python " + platform.python_version())

In [None]:
# Pytorch DL
!pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

## データセットの準備

In [None]:
# datasetの読込
from torchvision import transforms
from torchvision.datasets import FashionMNIST

train_dataset = FashionMNIST('data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = FashionMNIST('data', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
# datasetの確認
vars(train_dataset)

In [None]:
# datasetのサイズ
print(train_dataset.data.size())
print(train_dataset.targets.size())

In [None]:
# 画像データを確認
train_dataset.data[0]

In [None]:
# 画像として出力
import matplotlib.pyplot as plt

plt.imshow(train_dataset.data[0], cmap='gray')
plt.show()

In [None]:
# 画像データの正解ラベルを確認
train_dataset.targets[0]

In [None]:
# 正解ラベルの一覧を取得
train_dataset.classes

In [None]:
# 画像データの正解ラベルを一覧名で取得
train_dataset.classes[train_dataset.targets[0]]

In [None]:
# データのバッチサイズを設定
from torch.utils.data import DataLoader

batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## LeNet-5の訓練

In [None]:
# パッケージのインポート
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# LeNet5の実装
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        x = self.pool(x)

        x = x.view(-1, 16 * 5 * 5)

        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x

In [None]:
# 分類器の生成
net = LeNet5()
print(net)

In [None]:
# 損失関数の定義
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_fcn = nn.CrossEntropyLoss()

In [None]:
# GPUの利用を設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
net = net.to(device)

In [None]:
# 分類器の訓練
from tqdm import tqdm

net.train()
epochs = 100
for epoch in tqdm(range(epochs)):
    epoch_loss = 0.0
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_fcn(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * inputs.size(0)
    print(f'loss: {epoch_loss / len(train_dataloader.dataset)}')

## 精度の検証

In [None]:
# テストデータを使用して訓練した分類器の正解率
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        _, predicts = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(f'テストデータのAccuracyは {100 * correct / total} です。')

In [None]:
# それぞれのクラスに対する分類器の正解率
net.eval()
corrects = list(0. for i in range(10))
totals = list(0. for i in range(10))
with torch.no_grad():
    for inputs, labels in tqdm(test_dataloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        _, predicts = torch.max(outputs.data, 1)
        c = (predicts == labels).squeeze()
        for i in range(labels.size(0)):
            label = labels[i]
            corrects[label] += c[i].item()
            totals[label] += 1
for i in range(len(train_dataset.classes)):
    print(f'{train_dataset.classes[i]} のAccuracyは {100 * corrects[i] / totals[i]} です。')