<a href="https://colab.research.google.com/github/Futaba-Kosuke/STL10/blob/feature%2Ftrain/training/main_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# STL10の画像分類

## ○手法
複数のモデルでFine-tuningを行い、アンサンブル学習を行う。  
また、各モデルの正解率とアンサンブル学習した場合の認識精度を比較する。

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
 !ls

drive  images  sample_data


In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image

# バージョン確認 (Google Colab default)
print(torch.__version__)  # 1.4.0
print(torchvision.__version__)  # 0.5.0
print(np.__version__)  # 1.18.2
print(matplotlib.__version__)  # 3.2.1
print(Image.__version__)  # 7.0.0

1.4.0
0.5.0
1.18.2
3.2.1
7.0.0


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [0]:
batch_size = 25
num_workers = 3

In [6]:
# データの読み込み
transform = {
  'train': transforms.Compose(
    [transforms.RandomHorizontalFlip(p=0.5),
     transforms.ToTensor(),
     transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  ),
  'test': transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
  )
}
data_set = {
  x: torchvision.datasets.STL10(root='./images', split=x, download=True, transform=transform[x])
  for x in ('train', 'test')
}
data_size = {
  x: len(data_set[x]) for x in ('train', 'test')
}
print(data_size)

data_loaders = {
  x[0]: torch.utils.data.DataLoader(data_set[x[0]], batch_size=x[1], shuffle=x[2], num_workers=num_workers)
  for x in (('train', batch_size, True), ('test', 100, False))
}

Files already downloaded and verified
Files already downloaded and verified
{'train': 5000, 'test': 8000}


In [0]:
classes = ('airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck')

In [8]:
"""
各モデルで10エポック学習して検証
resnet: 90 %
alexnet: 79 %
vgg16: 93 %
densenet: 94 %
googlenet: 90 %
shufflenet: 62 %
mobilenet: 89 %
resnext: 93 %
wide_resnet: 94 %
mnasnet: 85 %

vgg16, densenet, wide_resnetを一旦採用、各々学習させてモデルを出力させた。
"""

# モデル構築
from torchvision import models
from torch import nn

vgg16 = models.vgg16(pretrained=True)
last_in_features = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(last_in_features, 10)
vgg16 = vgg16.to(device)

densenet = models.densenet161(pretrained=True)
last_in_features = densenet.classifier.in_features
densenet.classifier = nn.Linear(last_in_features, 10)
densenet = densenet.to(device)

wide_resnet = models.wide_resnet50_2(pretrained=True)
last_in_features = wide_resnet.fc.in_features
wide_resnet.fc = nn.Linear(last_in_features, 10)
wide_resnet = wide_resnet.to(device)

nets = {
    'vgg16': vgg16,
    'densenet': densenet,
    'wide_resnet': wide_resnet
}

nets

{'densenet': DenseNet(
   (features): Sequential(
     (conv0): Conv2d(3, 96, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (norm0): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu0): ReLU(inplace=True)
     (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (denseblock1): _DenseBlock(
       (denselayer1): _DenseLayer(
         (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu1): ReLU(inplace=True)
         (conv1): Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (norm2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu2): ReLU(inplace=True)
         (conv2): Conv2d(192, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       )
       (denselayer2): _DenseLayer(
         (norm1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_run

In [0]:
# 損失・最適関数の定義
from torch import optim

net_names = ('vgg16', 'densenet', 'wide_resnet')

criterion = {
    i: nn.CrossEntropyLoss()
    for i in net_names
}
optimizer = {
    i: optim.SGD(nets[i].parameters(), lr=0.001, momentum=0.9)
    for i in net_names
}
scheduler = {
    i: optim.lr_scheduler.StepLR(optimizer[i], step_size=10, gamma=0.5)  # 10エポックごとに学習率が1/10に更新
    for i in net_names
}

In [0]:
# 学習用関数
def train_net(net, criterion, optimizer, scheduler, num_epochs):
  for epoch in range(num_epochs):
    print('Epoch %d/%d' % (epoch, num_epochs - 1))
    print('-' * 10)

    running_loss_sum = 0
    train_loss_sum = 0
    test_loss_sum = 0

    # モデルの更新
    net.train()
    for i, (inputs, labels) in enumerate(data_loaders['train']):
      inputs = inputs.to(device)
      labels = labels.to(device)
      # 勾配の初期化
      optimizer.zero_grad()
      # 予測
      outputs = net(inputs)
      # 損失の導出
      loss = criterion(outputs, labels)
      # 逆伝播
      loss.backward()
      # 勾配の更新
      optimizer.step()

      running_loss_sum += loss.item()
      train_loss_sum += loss.item()
      if i % 50 == 49:
        print('[%d] running_loss: %.6f' % (i + 1, running_loss_sum / 50))
        running_loss_sum = 0

    train_cnt = i

    # モデルの評価
    net.eval()
    cnt_correct = 0
    for i, (inputs, labels) in enumerate(data_loaders['test']):
      inputs = inputs.to(device)
      labels = labels.to(device)

      # 予測
      outputs = net(inputs)
      # 損失の導出
      loss = criterion(outputs, labels)
      # lossの加算
      test_loss_sum += loss.item()

      _, predicted = torch.max(outputs, 1)
      is_correct = (predicted == labels).squeeze()
      
      for j in range(len(is_correct)):
        cnt_correct += is_correct[j].item()  # 正解なら1, 不正解なら0

    test_cnt = i

    print('')
    print('train_loss_ave:\t%.6f' % (train_loss_sum / train_cnt))
    print('test_loss_ave:\t%.6f' % (test_loss_sum / test_cnt))
    print('test_accuracy: \t%.6f %%' % (cnt_correct / (test_cnt)))

    net.train()
    scheduler.step()

    if cnt_correct / (test_cnt) > 96:
      break

In [11]:
num_epochs = 100
net_name = 'vgg16'
train_net(nets[net_name], criterion[net_name], optimizer[net_name], scheduler[net_name], num_epochs)

Epoch 0/99
----------
[50] running_loss: 1.096508
[100] running_loss: 0.473394
[150] running_loss: 0.539791
[200] running_loss: 0.488359

train_loss_ave:	0.652777
test_loss_ave:	0.295101
test_accuracy: 	91.379747 %
Epoch 1/99
----------
[50] running_loss: 0.364733
[100] running_loss: 0.360205
[150] running_loss: 0.310101
[200] running_loss: 0.351651

train_loss_ave:	0.348415
test_loss_ave:	0.285618
test_accuracy: 	91.784810 %
Epoch 2/99
----------
[50] running_loss: 0.209470
[100] running_loss: 0.259208
[150] running_loss: 0.266822
[200] running_loss: 0.223910

train_loss_ave:	0.241058
test_loss_ave:	0.259650
test_accuracy: 	92.873418 %
Epoch 3/99
----------
[50] running_loss: 0.224871
[100] running_loss: 0.201740
[150] running_loss: 0.219978
[200] running_loss: 0.228263

train_loss_ave:	0.219812
test_loss_ave:	0.226928
test_accuracy: 	93.898734 %
Epoch 4/99
----------
[50] running_loss: 0.159555
[100] running_loss: 0.155166
[150] running_loss: 0.211822
[200] running_loss: 0.174620

tr

In [12]:
num_epochs = 100
net_name = 'densenet'
train_net(nets[net_name], criterion[net_name], optimizer[net_name], scheduler[net_name], num_epochs)

Epoch 0/99
----------
[50] running_loss: 1.498270
[100] running_loss: 0.594414
[150] running_loss: 0.488722
[200] running_loss: 0.415724

train_loss_ave:	0.753048
test_loss_ave:	0.225588
test_accuracy: 	93.810127 %
Epoch 1/99
----------
[50] running_loss: 0.336897
[100] running_loss: 0.284185
[150] running_loss: 0.290808
[200] running_loss: 0.292722

train_loss_ave:	0.302666
test_loss_ave:	0.186107
test_accuracy: 	95.075949 %
Epoch 2/99
----------
[50] running_loss: 0.207897
[100] running_loss: 0.173028
[150] running_loss: 0.177993
[200] running_loss: 0.210680

train_loss_ave:	0.193366
test_loss_ave:	0.180358
test_accuracy: 	95.303797 %
Epoch 3/99
----------
[50] running_loss: 0.169465
[100] running_loss: 0.130233
[150] running_loss: 0.164517
[200] running_loss: 0.152371

train_loss_ave:	0.154921
test_loss_ave:	0.182660
test_accuracy: 	95.227848 %
Epoch 4/99
----------
[50] running_loss: 0.119960
[100] running_loss: 0.111904
[150] running_loss: 0.135021
[200] running_loss: 0.126202

tr

In [13]:
num_epochs = 100
net_name = 'wide_resnet'
train_net(nets[net_name], criterion[net_name], optimizer[net_name], scheduler[net_name], num_epochs)

Epoch 0/99
----------
[50] running_loss: 1.820399
[100] running_loss: 0.844741
[150] running_loss: 0.614791
[200] running_loss: 0.563531

train_loss_ave:	0.965694
test_loss_ave:	0.286087
test_accuracy: 	91.949367 %
Epoch 1/99
----------
[50] running_loss: 0.402876
[100] running_loss: 0.400832
[150] running_loss: 0.389181
[200] running_loss: 0.374166

train_loss_ave:	0.393732
test_loss_ave:	0.226656
test_accuracy: 	94.025316 %
Epoch 2/99
----------
[50] running_loss: 0.293318
[100] running_loss: 0.275283
[150] running_loss: 0.224843
[200] running_loss: 0.243830

train_loss_ave:	0.260622
test_loss_ave:	0.199234
test_accuracy: 	94.974684 %
Epoch 3/99
----------
[50] running_loss: 0.209196
[100] running_loss: 0.190019
[150] running_loss: 0.185692
[200] running_loss: 0.180369

train_loss_ave:	0.192280
test_loss_ave:	0.206805
test_accuracy: 	94.468354 %
Epoch 4/99
----------
[50] running_loss: 0.134862
[100] running_loss: 0.136458
[150] running_loss: 0.150713
[200] running_loss: 0.158662

tr

In [20]:
# 各モデルでそれぞれ予測した場合（test）
for net_name in net_names:
  nets[net_name].eval()

  class_correct = list(0. for i in range(10))
  class_total = list(0. for i in range(10))

  with torch.no_grad():
    for (inputs, labels) in data_loaders['test']:
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = nets[net_name](inputs)
      
      _, predicted = torch.max(outputs, 1)
      is_correct = (predicted == labels).squeeze()

      for i in range(len(is_correct)):
        label = labels[i]
        class_correct[label] += is_correct[i].item()  # 正解なら1, 不正解なら0
        class_total[label] += 1

  print(net_name)

  for i in range(10):
    print('Accuracy of %5s : %2.01f %%' % (
      classes[i], 100 * class_correct[i] / class_total[i]))
    
  print('Accuracy Ave: %2.01f %% \n' % (100 * sum(class_correct) / sum(class_total)))

vgg16
Accuracy of airplane : 96.2 %
Accuracy of  bird : 95.6 %
Accuracy of   car : 96.2 %
Accuracy of   cat : 90.6 %
Accuracy of  deer : 93.8 %
Accuracy of   dog : 91.8 %
Accuracy of horse : 94.5 %
Accuracy of monkey : 96.0 %
Accuracy of  ship : 98.8 %
Accuracy of truck : 95.5 %
Accuracy Ave: 94.9 % 

densenet
Accuracy of airplane : 96.1 %
Accuracy of  bird : 97.5 %
Accuracy of   car : 97.6 %
Accuracy of   cat : 91.4 %
Accuracy of  deer : 93.6 %
Accuracy of   dog : 91.6 %
Accuracy of horse : 94.5 %
Accuracy of monkey : 95.4 %
Accuracy of  ship : 98.6 %
Accuracy of truck : 92.2 %
Accuracy Ave: 94.9 % 

wide_resnet
Accuracy of airplane : 97.5 %
Accuracy of  bird : 95.5 %
Accuracy of   car : 97.1 %
Accuracy of   cat : 88.4 %
Accuracy of  deer : 96.2 %
Accuracy of   dog : 92.2 %
Accuracy of horse : 93.2 %
Accuracy of monkey : 96.2 %
Accuracy of  ship : 98.4 %
Accuracy of truck : 94.0 %
Accuracy Ave: 94.9 % 



In [21]:
# 各モデルでそれぞれ予測した場合（train）
for net_name in net_names:
  nets[net_name].eval()

  class_correct = list(0. for i in range(10))
  class_total = list(0. for i in range(10))

  with torch.no_grad():
    for (inputs, labels) in data_loaders['train']:
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = nets[net_name](inputs)
      
      _, predicted = torch.max(outputs, 1)
      is_correct = (predicted == labels).squeeze()

      for i in range(len(is_correct)):
        label = labels[i]
        class_correct[label] += is_correct[i].item()  # 正解なら1, 不正解なら0
        class_total[label] += 1

  print(net_name)

  for i in range(10):
    print('Accuracy of %5s : %2.01f %%' % (
      classes[i], 100 * class_correct[i] / class_total[i]))
    
  print('Accuracy Ave: %2.01f %% \n' % (100 * sum(class_correct) / sum(class_total)))

vgg16
Accuracy of airplane : 99.8 %
Accuracy of  bird : 99.4 %
Accuracy of   car : 99.8 %
Accuracy of   cat : 98.6 %
Accuracy of  deer : 98.6 %
Accuracy of   dog : 98.2 %
Accuracy of horse : 99.6 %
Accuracy of monkey : 99.6 %
Accuracy of  ship : 100.0 %
Accuracy of truck : 99.8 %
Accuracy Ave: 99.3 % 

densenet
Accuracy of airplane : 99.8 %
Accuracy of  bird : 99.2 %
Accuracy of   car : 100.0 %
Accuracy of   cat : 98.4 %
Accuracy of  deer : 98.8 %
Accuracy of   dog : 97.6 %
Accuracy of horse : 99.4 %
Accuracy of monkey : 98.4 %
Accuracy of  ship : 100.0 %
Accuracy of truck : 99.2 %
Accuracy Ave: 99.1 % 

wide_resnet
Accuracy of airplane : 100.0 %
Accuracy of  bird : 99.8 %
Accuracy of   car : 99.8 %
Accuracy of   cat : 99.8 %
Accuracy of  deer : 99.8 %
Accuracy of   dog : 100.0 %
Accuracy of horse : 100.0 %
Accuracy of monkey : 100.0 %
Accuracy of  ship : 100.0 %
Accuracy of truck : 100.0 %
Accuracy Ave: 99.9 % 



In [37]:
# アンサンブル学習した場合
for net_name in net_names:
  nets[net_name].eval()

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
  for (inputs, labels) in data_loaders['test']:
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = torch.zeros([100, 10])
    outputs = outputs.to(device)
    # アンサンブル学習
    for net_name in net_names:
      outputs += nets[net_name](inputs)

    _, predicted = torch.max(outputs, 1)
    
    is_correct = (predicted == labels).squeeze()

    for i in range(len(is_correct)):
      label = labels[i]
      class_correct[label] += is_correct[i].item()  # 正解なら1, 不正解なら0
      class_total[label] += 1

for i in range(10):
  print('Accuracy of %5s : %2.01f %%' % (
    classes[i], 100 * class_correct[i] / class_total[i]))
  
print('Accuracy Ave: %2.01f %%' % (100 * sum(class_correct) / sum(class_total)))

Accuracy of airplane : 97.5 %
Accuracy of  bird : 97.5 %
Accuracy of   car : 97.4 %
Accuracy of   cat : 93.0 %
Accuracy of  deer : 95.9 %
Accuracy of   dog : 94.5 %
Accuracy of horse : 95.6 %
Accuracy of monkey : 98.0 %
Accuracy of  ship : 99.2 %
Accuracy of truck : 95.8 %
Accuracy Ave: 96.4 %


In [0]:
# 保存
for net_name in net_names:
  PATH = './drive/My Drive/' + net_name + '.pth'
  torch.save(nets[net_name].state_dict(), PATH)

## ○結果
各モデルのテストデータでのTop-1 Accuracyは94.9%。  
アンサンブル学習を行った場合のTop-1 Accuracyは96.4%となった。