<a href="https://colab.research.google.com/github/Futaba-Kosuke/STL10/blob/develop/testing/main_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import torch
import torchvision
from torchvision import transforms

In [0]:
ROOT = './'
# ROOT = './drive/My Drive/STL10/'

In [0]:
# GPUの確認
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [0]:
batch_size = 10

In [0]:
# データの読み込み
transform = transforms.Compose(
  [transforms.Resize((300, 300)),
   transforms.ToTensor(),
   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
data_set = torchvision.datasets.ImageFolder(
  root=ROOT + 'images',
  # root='./images',
  transform=transform
)
print('data_size: ', len(data_set))

data_loader = torch.utils.data.DataLoader(
  data_set,
  batch_size=batch_size,
  shuffle=False,
  num_workers=1
)
print('Finish load images')

classes = ('airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck')

data_size:  1000
Finish load images


In [0]:
# モデルの読み込み
# アンサンブル学習を行う為、複数のモデルを読み込む
from torch import nn
from torchvision import models

vgg16 = models.vgg16()
last_in_features = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(last_in_features, 10)
vgg16 = vgg16.to(device)
state_dict = torch.load(ROOT + 'models/vgg16.pth', map_location=device)
vgg16.load_state_dict(state_dict)

densenet = models.densenet161()
last_in_features = densenet.classifier.in_features
densenet.classifier = nn.Linear(last_in_features, 10)
densenet = densenet.to(device)
state_dict = torch.load(ROOT + 'models/densenet.pth', map_location=device)
densenet.load_state_dict(state_dict)

wide_resnet = models.wide_resnet50_2()
last_in_features = wide_resnet.fc.in_features
wide_resnet.fc = nn.Linear(last_in_features, 10)
wide_resnet = wide_resnet.to(device)
state_dict = torch.load(ROOT + 'models/wide_resnet.pth', map_location=device)
wide_resnet.load_state_dict(state_dict)

nets = {
    'vgg16': vgg16,
    'densenet': densenet,
    'wide_resnet': wide_resnet
}
net_names = ('vgg16', 'densenet', 'wide_resnet')

print('Finish load models')

Finish load models


In [0]:
# 検知開始
for net_name in net_names:
  nets[net_name].eval()

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
  for (inputs, labels) in data_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = torch.zeros([batch_size, 10])
    outputs = outputs.to(device)
    # アンサンブル学習
    for net_name in net_names:
      outputs += nets[net_name](inputs)

    _, predicted = torch.max(outputs, 1)
    is_correct = (predicted == labels).squeeze()

    for idx, label in enumerate(labels):
      class_correct[label] += is_correct[idx].item()
      class_total[label] += 1

# 正解率の出力
for i in range(10):
  print('Accuracy of %5s : %02.01f %%' % (
    classes[i], 100 * class_correct[i] / class_total[i]))
  
print('\nTop-1 Accuracy: %02.01f %%' % (100 * sum(class_correct) / sum(class_total)))

Accuracy of airplane : 100.0 %
Accuracy of  bird : 100.0 %
Accuracy of   car : 100.0 %
Accuracy of   cat : 100.0 %
Accuracy of  deer : 100.0 %
Accuracy of   dog : 100.0 %
Accuracy of horse : 100.0 %
Accuracy of monkey : 100.0 %
Accuracy of  ship : 100.0 %
Accuracy of truck : 100.0 %

Top-1 Accuracy: 100.0 %
