In [1]:
import timm
import math
import torch
import torch.nn as nn
import torchvision
from typing import Optional

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 原始 8 -> 4

In [3]:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.float32

In [5]:
BATCH_SIZE = 75

epochs = 40

lr = 0.01
momentum = 0.9

In [6]:
transform1 = torchvision.transforms.Compose([
                                                torchvision.transforms.RandomCrop(32, padding=4),
                                                torchvision.transforms.Resize(32),
                                                torchvision.transforms.RandomHorizontalFlip(),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

transform2 = torchvision.transforms.Compose([
                                                torchvision.transforms.Resize(32),
                                                torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                                torchvision.transforms.ConvertImageDtype(DTYPE)
                                             ])

In [7]:
from torchvision.models.vision_transformer import Encoder

class VisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 10,
    ):
        super().__init__()
        if image_size % patch_size != 0:
            print("Input shape indivisible by patch size!")

        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes


        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
        )
        self.seq_length = seq_length

        self.heads = nn.Sequential(
            nn.Linear(hidden_dim, num_classes)
        )

        fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
        nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
        nn.init.zeros_(self.conv_proj.bias)



    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, _, h, w = x.shape
        p = self.patch_size
        # torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        # torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        x = self.conv_proj(x)
        x = x.reshape(n, self.hidden_dim, n_h * n_w)
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        x = self.heads(x)

        return x

In [8]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform1)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=8, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform2)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=8, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
net = VisionTransformer(
    image_size=32,
    patch_size=4,
    num_layers=12,
    num_heads=12,
    hidden_dim=768,
    mlp_dim=768*4,
    dropout=0.2,
    attention_dropout=0.2,
    num_classes=10
)

In [10]:
net = net.train().to(device=DEVICE, dtype=DTYPE)

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

loss_func = torch.nn.CrossEntropyLoss()

In [11]:
def tran(epoch):
    net.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()

    scheduler.step()
    print('[%d] loss: %.3f' % (epoch + 1, running_loss))

In [12]:
def test():
    net.eval()
    all_counter=0
    correct_counter=0
    for i, data in enumerate(testloader, 0):
        inputs, labels = data
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        out = net(inputs)
        out = out.detach().cpu().argmax(1)
        t = labels.cpu()
        for m in range(len(t)):
            all_counter += 1
            if t[m] == out[m]:
                correct_counter += 1

    print(correct_counter, all_counter, correct_counter / all_counter)
    return (correct_counter / all_counter)

In [13]:
correctRate = 0
for i in range(epochs):
    tran(i)
    r = test()
    if(r > correctRate):
        correctRate = r
        print("best: ", r , " in NO: ", i)
        torch.save(net.cpu(),"checkpoint/trans_vit2.pth")
        net = net.to(DEVICE)

[1] loss: 1346.418
3757 10000 0.3757
best:  0.3757  in NO:  0
[2] loss: 1093.607
4256 10000 0.4256
best:  0.4256  in NO:  1
[3] loss: 1003.605
4695 10000 0.4695
best:  0.4695  in NO:  2
[4] loss: 951.874
4842 10000 0.4842
best:  0.4842  in NO:  3
[5] loss: 904.770
5267 10000 0.5267
best:  0.5267  in NO:  4
[6] loss: 876.180
5370 10000 0.537
best:  0.537  in NO:  5
[7] loss: 840.450
5724 10000 0.5724
best:  0.5724  in NO:  6
[8] loss: 822.195
5667 10000 0.5667
[9] loss: 797.509
5819 10000 0.5819
best:  0.5819  in NO:  8
[10] loss: 778.983
5718 10000 0.5718
[11] loss: 760.455
6031 10000 0.6031
best:  0.6031  in NO:  10
[12] loss: 745.992
5966 10000 0.5966
[13] loss: 725.470
6262 10000 0.6262
best:  0.6262  in NO:  12
[14] loss: 713.989
6261 10000 0.6261
[15] loss: 697.575
6273 10000 0.6273
best:  0.6273  in NO:  14
[16] loss: 688.999
6408 10000 0.6408
best:  0.6408  in NO:  15
[17] loss: 671.607
6478 10000 0.6478
best:  0.6478  in NO:  16
[18] loss: 660.171
6541 10000 0.6541
best:  0.654