In [1]:
from pathlib import Path
from collections import namedtuple

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

## ハイパーパラメータの設定

In [3]:
class Hparams():
    batch_size = 16
    input_size = 784
    output_size = 10
    learning_rate = 0.01
    epoch_num = 5

hparams = Hparams()

## データの用意

In [5]:
transform = transforms.Compose(
    [transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 実際はnum_workersは1でいい.
trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=hparams.batch_size,
                                         shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100.0%

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## ネットワークの設定

In [13]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        # CNNによる特徴量作成
        self.features = self.make_feature()
        # 全結合による分類
        self.classifier = self.make_classifier(num_classes)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    def make_feature(self):
        return nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(5),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
    
    def make_classifier(self, num_classes):
        return nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

## 学習の設定

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet(10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=hparams.learning_rate)

In [15]:
for epoch in range(hparams.epoch_num):
    model.train()
    for idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, epoch , hparams.epoch_num,
          100. * epoch / hparams.epoch_num, loss.item()))

    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss += criterion(output, target).item() 
            pred = output.argmax(dim=1, keepdim=True) # 予測の最大値を取得
            correct += pred.eq(target.view_as(pred)).sum().item()
        loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


Test set: Average loss: 0.1106, Accuracy: 3306/10000 (33%)


Test set: Average loss: 0.0978, Accuracy: 4190/10000 (42%)


Test set: Average loss: 0.0908, Accuracy: 4684/10000 (47%)


Test set: Average loss: 0.0809, Accuracy: 5423/10000 (54%)



KeyboardInterrupt: 