In [1]:
import timm
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import imageio
from PIL import Image
import numpy as np
from tqdm.auto import tqdm

In [2]:
# from pprint import pprint
# model_names = timm.list_models('*vit*t*')
# pprint(model_names)

In [3]:
class CFG:
    model_name = 'vit_base_patch16_224_in21k'
    model_path = './vit_base_patch16_224_in21k.pth'
    pretrained = True
    inp_channels = 3
    batch_size = 256
    epoch = 7
    out_features = 10
    img_h = 224
    img_w = 224
    dropout = 0.2
    seed = 42
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
   

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(CFG.img_h),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=CFG.batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=CFG.batch_size,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
class CustomModel(nn.Module):
    def __init__(
        self, model_name=CFG.model_name, n_class=CFG.out_features, pretrained=CFG.pretrained):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes = CFG.out_features)
      
    def forward(self, x):
        x = self.backbone(x)
        output = x
        return output

In [6]:
def train():
    model = CustomModel().to(CFG.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in tqdm(range(CFG.epoch)):
        running_loss = 0.0
        for i, data in tqdm(enumerate(trainloader, 0)):
            inputs, labels = data
            inputs = inputs.to(CFG.device)
            labels = labels.to(CFG.device)


            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / 2000))

                running_loss = 0.0

    torch.save(model.state_dict(), CFG.model_path)

In [7]:
def infer():
    model = CustomModel().to(CFG.device)
    model.load_state_dict(torch.load(CFG.model_path))
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images = images.to(CFG.device)
            labels = labels.to(CFG.device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %f %%' % (
        100 * correct / total))
    

In [8]:
if __name__ == '__main__':
    print('**************train_start**************')
    train()
    print('**************infer_start**************')
    infer()

**************train_start**************


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=7.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…




HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…



**************infer_start**************


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))


Accuracy of the network on the 10000 test images: 98.770000 %


<img src='result.png'>

ViT-B/16  CIFAR-10における論文とほぼ同等の精度が確認できた