基于Transformer的图像分类网络Vit

code come from
https://juejin.cn/post/7277490138518175763

作者：ZackSock

链接：https://juejin.cn/post/7277490138518175763

来源：稀土掘金

著作权归作者所有。商业转载请联系作者获得授权，非商业转载请注明出处。

### 使用Vit网络进行图像分类

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [None]:
#准备数据
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7 so we will re-download the data.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
  2252800/169001437 [..............................] - ETA: 1:13:44

In [None]:
#配置训练需要用到的一些参数

learning_rate = 0.001
# 权重衰减系数
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72 
# 单个patch的尺寸
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
# attention的神经元数量
transformer_units = [
    projection_dim * 2,
    projection_dim,
]
transformer_layers = 8
# 分类网络的神经元数量
mlp_head_units = [2048, 1024]

In [None]:
#为了提高泛化能力，可以添加数据增强的操作

# 数据增强层
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# 对训练数据进行数据正确
data_augmentation.layers[0].adapt(x_train)

In [None]:
#使用前面三部分的代码创建Vit模型

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
    # 重复多次attention
    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])
    # 全连接
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

In [None]:
#训练Vit

vit = create_vit_classifier()
vit.compile(
    'adam', 
    # 因为模型输出的结果没有经过softmax，因此需要设置参数from_logits=True
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['acc']
)
vit.fit(
    x_train, y_train, 
    batch_size=batch_size,
    epochs=num_epochs,
    validation_data=[x_test, y_test]
)