<a href="https://colab.research.google.com/github/Tkht44/Vit_vs_CNN/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# トレーニングデータのインポート & チェック

In [48]:
# import needed file
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn


In [49]:
# import data
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]
)

# ハイパーパラメータの設定
batch_size = 100
epochs = 10

# 訓練データ
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# テストデータ
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)


In [None]:
print(trainset.data.shape)
(50000, 32, 32, 3)

# テスト用データセット：縦横32ピクセルのRGBの画像が10000枚
print(testset.data.shape)
(10000, 32, 32, 3)

# クラス一覧を確認する
print(trainset.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# classesはよく利用するので別途保持しておく
classes = trainset.classes

In [None]:
def imshow(img):
    # 非正規化する
    img = img / 2 + 0.5
    # torch.Tensor型からnumpy.ndarray型に変換する
    print(type(img)) # <class 'torch.Tensor'>
    npimg = img.numpy()
    print(type(npimg))
    # 形状を（RGB、縦、横）から（縦、横、RGB）に変換する
    print(npimg.shape)
    npimg = np.transpose(npimg, (1, 2, 0))
    print(npimg.shape)
    # 画像を表示する
    plt.imshow(npimg)
    plt.show()

dataiter = iter(trainloader)
images, labels = next(dataiter)
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

In [None]:
# 別ファイルのimport
import cnn
import vit

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# パラメータの調整
cnn_net = cnn.Net(
    class_num=10,
    conv_size=[3,6,16],
    linear_size=[120,84]
)
vit_net = vit.Net(
    image_size=32, # height = width = image_size
    patch_size=4,
    n_classes=10,
    dim=256,
    depth=3,
    n_heads=4,
    mlp_dim=256,
)

cnn_net.to(device)
vit_net.to(device)

# 損失関数
criterion = nn.CrossEntropyLoss()
# CNNオプティマイザの設定
optimizer_cnn = torch.optim.SGD(cnn_net.parameters(),lr = 0.001, momentum=0.9)
optimizer_vit = torch.optim.SGD(vit_net.parameters(),lr = 0.001, momentum=0.9)


In [None]:
print(cnn_net)
print(vit_net)

# ViTでの学習

In [None]:
# トレーニングする

for epoch in range(epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    train_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer_vit.zero_grad()
        outputs = vit_net(inputs)
        loss = criterion(outputs, labels)

        # 誤差逆伝播
        loss.backward()
        optimizer_vit.step()
        # 損失と精度の計算
        train_loss += loss.item()/len(trainloader)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        train_acc += acc/len(trainloader)
        del inputs
        del outputs
        del loss
    vit_net.eval()
    with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = vit_net(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()/len(testloader)
                test_acc = (outputs.argmax(dim=1) == labels).float().mean()
                test_acc += test_acc/len(testloader)

    print(f'Epoch {epoch+1} : train acc. {train_acc:.2f} train loss {train_loss:.2f}')
    print(f'Epoch {epoch+1} : test acc. {test_acc:.2f} test loss {test_loss:.2f}')
print('Finished Training')

# CNNでの学習

In [None]:
# トレーニングする
for epoch in range(epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    train_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer_cnn.zero_grad()
        outputs = cnn_net(inputs)
        # 誤差逆伝播
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_cnn.step()
        # 損失と精度の計算
        train_loss += loss.item()/len(trainloader)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        train_acc += acc/len(trainloader)
        del inputs
        del outputs
        del loss
    cnn_net.eval()
    with torch.no_grad():
            for data in testloader:
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = cnn_net(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item()/len(testloader)
                test_acc = (outputs.argmax(dim=1) == labels).float().mean()
                test_acc += test_acc/len(testloader)

    print(f'Epoch {epoch+1} : train acc. {train_acc:.2f} train loss {train_loss:.2f}')
    print(f'Epoch {epoch+1} : test acc. {test_acc:.2f} test loss {test_loss:.2f}')
print('Finished Training')

In [None]:
cnn_Path = './cnn_cifar_net.pth'
vit_Path = './vit_cifar_net.pth'
torch.save(cnn_net.state_dict(),cnn_Path)
torch.save(vit_net.state_dict(), vit_Path)

# 比較

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


cnn_Path = './cnn_cifar_net.pth'
vit_Path = './vit_cifar_net.pth'
cnn_net = cnn.Net(
    class_num=10,
    conv_size=[3,6,16],
    linear_size=[120,84]
)
vit_net = vit.Net(
    image_size=32, # height = width = image_size
    patch_size=4,
    n_classes=10,
    dim=256,
    depth=3,
    n_heads=4,
    mlp_dim=256,
    
)

cnn_net.load_state_dict(torch.load(cnn_Path, map_location=device))
vit_net.load_state_dict(torch.load(vit_Path, map_location=device))

cpu


<All keys matched successfully>

In [None]:
cnn_class_correct = list(0. for i in range(10))
cnn_class_total = list(0. for i in range(10))
vit_class_correct = list(0. for i in range(10))
vit_class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        cnn_outputs = cnn_net(images)
        vit_outputs = vit_net(images)
        _, cnn_predicted = torch.max(cnn_outputs, 1)
        _, vit_predicted = torch.max(vit_outputs, 1)
        c = (cnn_predicted == labels).squeeze()
        v = (vit_predicted == labels).squeeze()
        for i in range(batch_size):
            label = labels[i]
            cnn_class_correct[label] += c[i].item()
            vit_class_correct[label] += v[i].item()    
            cnn_class_total[label] += 1
            vit_class_total[label] += 1
print('Accuracy of \t : \t \t CNN \t ViT')
for i in range(10):
    print('Accuracy of %5s \t:\t %2d %% \t %2d %%' % (classes[i], 100 * cnn_class_correct[i] / cnn_class_total[i], 100 * vit_class_correct[i] / vit_class_total[i]))

Accuracy of 	 : 	 CNN 	 ViT
Accuracy of airplane 	:	 44 % 	 55 %
Accuracy of automobile 	:	 57 % 	 67 %
Accuracy of  bird 	:	 23 % 	 37 %
Accuracy of   cat 	:	 34 % 	 35 %
Accuracy of  deer 	:	 31 % 	 35 %
Accuracy of   dog 	:	 48 % 	 48 %
Accuracy of  frog 	:	 57 % 	 62 %
Accuracy of horse 	:	 58 % 	 49 %
Accuracy of  ship 	:	 48 % 	 65 %
Accuracy of truck 	:	 62 % 	 45 %
