In [15]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython import display
import time
import os

os.environ["http_proxy"] = "http://proxy.uec.ac.jp:8080/"
os.environ["https_proxy"] = "http://proxy.uec.ac.jp:8080/"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [16]:
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [17]:
import torch
from vit_pytorch import ViT
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=256,
    depth=3,
    heads=4,
    mlp_dim=256,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

net = torch.nn.DataParallel(net, device_ids=[0,1])

In [18]:
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
epochs = 20

for epoch in tqdm(range(0, epochs)):
    epoch_train_loss = 0
    epoch_train_acc = 0
    epoch_test_loss = 0
    epoch_test_acc = 0
    net.train()
    for data in train_loader:
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()/len(train_loader)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        epoch_train_acc += acc/len(train_loader)
    net.eval()
    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            epoch_test_loss += loss.item()/len(test_loader)
            test_acc = (outputs.argmax(dim=1) == labels).float().mean()
            epoch_test_acc += test_acc/len(test_loader)
    print(f'Epoch {epoch+1} : train acc. {epoch_train_acc:.2f} train loss {epoch_train_loss:.2f}')
    print(f'Epoch {epoch+1} : test acc. {epoch_test_acc:.2f} test loss {epoch_test_loss:.2f}')

  5%|▌         | 1/20 [00:09<02:56,  9.31s/it]

Epoch 1 : train acc. 0.15 train loss 2.26
Epoch 1 : test acc. 0.19 test loss 2.17


 10%|█         | 2/20 [00:17<02:41,  8.95s/it]

Epoch 2 : train acc. 0.19 train loss 2.15
Epoch 2 : test acc. 0.23 test loss 2.09


 15%|█▌        | 3/20 [00:25<02:27,  8.68s/it]

Epoch 3 : train acc. 0.21 train loss 2.10
Epoch 3 : test acc. 0.25 test loss 2.06


 20%|██        | 4/20 [00:33<02:16,  8.51s/it]

Epoch 4 : train acc. 0.23 train loss 2.07
Epoch 4 : test acc. 0.26 test loss 2.02


 25%|██▌       | 5/20 [00:41<02:06,  8.41s/it]

Epoch 5 : train acc. 0.26 train loss 2.01
Epoch 5 : test acc. 0.28 test loss 1.95


 30%|███       | 6/20 [00:49<01:55,  8.28s/it]

Epoch 6 : train acc. 0.28 train loss 1.97
Epoch 6 : test acc. 0.30 test loss 1.93


 35%|███▌      | 7/20 [00:57<01:46,  8.18s/it]

Epoch 7 : train acc. 0.28 train loss 1.95
Epoch 7 : test acc. 0.31 test loss 1.90


 40%|████      | 8/20 [01:05<01:38,  8.17s/it]

Epoch 8 : train acc. 0.29 train loss 1.93
Epoch 8 : test acc. 0.31 test loss 1.88


 45%|████▌     | 9/20 [01:13<01:29,  8.16s/it]

Epoch 9 : train acc. 0.30 train loss 1.91
Epoch 9 : test acc. 0.32 test loss 1.87


 50%|█████     | 10/20 [01:21<01:21,  8.11s/it]

Epoch 10 : train acc. 0.31 train loss 1.88
Epoch 10 : test acc. 0.33 test loss 1.84


 55%|█████▌    | 11/20 [01:30<01:12,  8.09s/it]

Epoch 11 : train acc. 0.32 train loss 1.86
Epoch 11 : test acc. 0.35 test loss 1.80


 60%|██████    | 12/20 [01:38<01:04,  8.06s/it]

Epoch 12 : train acc. 0.34 train loss 1.83
Epoch 12 : test acc. 0.36 test loss 1.77


 65%|██████▌   | 13/20 [01:46<00:56,  8.06s/it]

Epoch 13 : train acc. 0.35 train loss 1.81
Epoch 13 : test acc. 0.37 test loss 1.76


 70%|███████   | 14/20 [01:54<00:48,  8.12s/it]

Epoch 14 : train acc. 0.35 train loss 1.79
Epoch 14 : test acc. 0.37 test loss 1.74


 75%|███████▌  | 15/20 [02:02<00:40,  8.10s/it]

Epoch 15 : train acc. 0.36 train loss 1.77
Epoch 15 : test acc. 0.38 test loss 1.71


 80%|████████  | 16/20 [02:10<00:32,  8.08s/it]

Epoch 16 : train acc. 0.37 train loss 1.75
Epoch 16 : test acc. 0.39 test loss 1.70


 85%|████████▌ | 17/20 [02:18<00:24,  8.14s/it]

Epoch 17 : train acc. 0.38 train loss 1.73
Epoch 17 : test acc. 0.40 test loss 1.68


 90%|█████████ | 18/20 [02:26<00:16,  8.13s/it]

Epoch 18 : train acc. 0.39 train loss 1.71
Epoch 18 : test acc. 0.41 test loss 1.66


 95%|█████████▌| 19/20 [02:34<00:08,  8.13s/it]

Epoch 19 : train acc. 0.39 train loss 1.70
Epoch 19 : test acc. 0.41 test loss 1.64


100%|██████████| 20/20 [02:42<00:00,  8.09s/it]

Epoch 20 : train acc. 0.40 train loss 1.68
Epoch 20 : test acc. 0.43 test loss 1.63





### 結果
- Epoch 20 : train acc. 0.40 train loss 1.68
- Epoch 20 : test acc. 0.43 test loss 1.63
- CNN等と比較しても，単純なViTのCIFAR-10への適用は，精度が低い．