# Transformer

<div style="display: flex; align-items: center;">
    <img src="../imgs/Transformer.jpg" alt="Your Image" width="300" style="margin-right: 20px;">
    <div>
        <p>The Transformer model is a revolutionary deep learning architecture that, with its unique self attention mechanism as its core, completely changes the way sequence modeling is done. This mechanism allows the model to consider all elements in parallel when processing sequences, rather than gradually processing them in order like traditional recurrent neural networks, greatly improving computational efficiency. Through multi head attention, Transformer can simultaneously capture sequence information from different perspectives, enhancing the model's ability to learn complex features.</p>
        <p>In addition to self attention mechanism, Transformer also introduces positional encoding to solve the problem of element order in sequences, which is crucial for maintaining the temporal sensitivity of sequence data. In each encoder and decoder layer of the model, the output of the self attention layer is transmitted to the feedforward network for further feature extraction and processing. In order to improve the training stability of deep networks, Transformer adopts layer normalization technology and alleviates the problem of gradient vanishing through residual connections, making the training of deep networks more feasible.</p>
        <p>These design features of the Transformer model have quickly made it mainstream in the field of natural language processing, especially in tasks such as machine translation, text summarization, and question answering systems. Its flexibility and powerful representation ability have also shown wide application potential in other fields such as speech recognition and image processing, making it one of the most influential models in the current field of deep learning.</p>
    </div>
</div>

# Vision Transformer

<div style="display: flex; align-items: center;">
    <img src="../imgs/ViT.jpg" alt="Your Image" width="600" style="margin-right: 20px;">
    <div>
        <p>ViT (vision transformer) is a model proposed by Google in 2020 that directly applies transformer to image classification. Many subsequent works have been improved based on ViT. The idea of ViT is simple: directly divide the image into fixed size patches, and then obtain patch embeddings through linear transformation, which is similar to NLP's words and word embeddings. Since the input of the transformer is a sequence of token embeddings, the patch embeddings of the image can be fed into the transformer for feature extraction and classification. As shown in the schematic diagram of the ViT model, in fact, the ViT model only uses the Encoder of the transformer to extract features (the original transformer also has a decoder section, which is used to implement sequence to sequence, such as machine translation).</p>
        <p></p>
    </div>
</div>

Before building the model, let's review the parameters of the new layer:

**nn.TransformerEncoderLayer** is a module implemented in PyTorch that standardizes the common functions for a Transformer encoder layer. This layer is a building block for the Transformer encoder architecture, which can be stacked multiple times to form the full Transformer encoder. The common parameters for nn.TransformerEncoderLayer are as follows:

`d_model`: The feature dimension of the input and output of the layer, which corresponds to the number of expected features in the input (seq_len, batch, d_model).

`nhead`: The number of heads for the multiheadattention mechanisms in the layer.

`dim_feedforward`: The dimension of the feedforward network inside the transformer block.

`dropout`: (optional) The dropout probability for the dropout layer inside the transformer block. Default is 0.1.

`activation`: (optional) The activation function to use in the feedforward network. Default is "relu".

`batch_first`: (optional) If True, the input and output tensors are provided as (batch, seq, feature). If False, they are provided as (seq, batch, feature). Default is False.

`bias`: (optional) Whether to include bias terms in the attention scores and feedforward network. Default is True.

`norm_eps`: (optional) The epsilon constant to use for the layer normalization. Default is 1e-5.

`norm_first`: (optional) Whether to perform layer normalization before the feedforward network. If False, the feedforward network is performed first, followed by layer normalization. Default is False.

## Train ViT on CIFAR100

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViT(nn.Module):
    def __init__(self, input_channels, image_size, patch_size, num_classes, num_heads, num_encoder_layers, dim_feedforward):
        super(ViT, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = input_channels * patch_size * patch_size

        self.conv1 = nn.Conv2d(input_channels, 768, kernel_size=patch_size, stride=patch_size, bias=False)  # Patch embedding
        self.positional_encoding = nn.Parameter(self._generate_positional_encoding(self.num_patches, 768))

        encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=num_heads, dim_feedforward=dim_feedforward)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)

        self.fc = nn.Linear(768, num_classes)
        
    def _generate_positional_encoding(self, num_patches, d_model, dtype=torch.float):
        """
        Generate a 2D positional encoding as per the Transformer model
        """
        position = torch.arange(num_patches, dtype=dtype).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=dtype) * -(torch.log(torch.tensor(10000.0)) / d_model))
        positional_encoding = torch.zeros((num_patches, d_model), dtype=dtype)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        return positional_encoding.unsqueeze(0)

    def forward(self, x):
        x = self.conv1(x).flatten(2)  # Patchify the image
        x = x.transpose(1, 2)  # Change to N C L
        x += self.positional_encoding[:, :x.size(1), :].to(x.device)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x

In [2]:
import sys
import torch.nn as nn
sys.path.append('../tools')
from CIFAR10 import CIFAR10Trainer
from CIFAR100 import CIFAR100Trainer

Vision Transformers (ViTs), while highly effective on large-scale datasets, may underperform compared to Convolutional Neural Networks (CNNs) on smaller or simpler datasets without pre-training due to several factors. ViTs require substantial data to leverage their large model capacity, which can lead to overfitting on limited data. Their design focuses on capturing global dependencies, which might be excessive for the local pattern recognition needed in smaller datasets. Additionally, without the feature-rich initialization provided by pre-training, ViTs struggle to learn from scratch, unlike CNNs that are inherently efficient and can quickly adapt to available data due to their architectural advantages in processing spatial hierarchies. So in this chapter, we will only introduce the concept and structure of ViT. For application scenarios, you can explore them yourself, and I will also try to demonstrate them in later chapters.

In [None]:
model = ViT(input_channels=3, image_size=32, patch_size=8, num_classes=100, num_heads=8, num_encoder_layers=3, dim_feedforward=2048)
trainer = CIFAR100Trainer(model, loss='CE', lr=0.01, optimizer='SGD', batch_size=128, epoch=50, model_type='classification')
trainer.train()
trainer.test()

Files already downloaded and verified
Files already downloaded and verified
2024-05-20 18:47:33
Epoch 1 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:26<00:00, 13.51it/s, train_loss=4.14]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 16.99it/s, val_acc=9.78, val_loss=0.0313]


2024-05-20 18:48:01
Epoch 2 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:27<00:00, 12.86it/s, train_loss=3.78]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.61it/s, val_acc=12.5, val_loss=0.0298]


2024-05-20 18:48:31
Epoch 3 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:27<00:00, 12.88it/s, train_loss=3.59]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.89it/s, val_acc=16.4, val_loss=0.0283]


2024-05-20 18:49:02
Epoch 4 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:28<00:00, 12.49it/s, train_loss=3.44]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.31it/s, val_acc=18.5, val_loss=0.0273]


2024-05-20 18:49:33
Epoch 5 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:27<00:00, 12.81it/s, train_loss=3.31]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.26it/s, val_acc=20.3, val_loss=0.0265]


2024-05-20 18:50:03
Epoch 6 / 50


[Train]: 100%|██████████████████████████| 352/352 [00:28<00:00, 12.48it/s, train_loss=3.2]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.35it/s, val_acc=21.4, val_loss=0.0257]


2024-05-20 18:50:34
Epoch 7 / 50


[Train]: 100%|██████████████████████████| 352/352 [00:28<00:00, 12.36it/s, train_loss=3.1]
[Valid]: 100%|████████████████████████| 40/40 [00:02<00:00, 14.28it/s, val_acc=24.2, val_loss=0.025]


2024-05-20 18:51:05
Epoch 8 / 50


[Train]: 100%|████████████████████████████| 352/352 [00:28<00:00, 12.19it/s, train_loss=3]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.43it/s, val_acc=25.8, val_loss=0.0243]


2024-05-20 18:51:37
Epoch 9 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:27<00:00, 12.59it/s, train_loss=2.91]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.34it/s, val_acc=26.7, val_loss=0.0236]


2024-05-20 18:52:08
Epoch 10 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:28<00:00, 12.28it/s, train_loss=2.81]
[Valid]: 100%|████████████████████████| 40/40 [00:02<00:00, 15.72it/s, val_acc=29.7, val_loss=0.023]


2024-05-20 18:52:39
Epoch 11 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:29<00:00, 11.90it/s, train_loss=2.72]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.77it/s, val_acc=29.2, val_loss=0.0227]


2024-05-20 18:53:11
Epoch 12 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:29<00:00, 11.90it/s, train_loss=2.65]
[Valid]: 100%|████████████████████████| 40/40 [00:02<00:00, 13.90it/s, val_acc=31.3, val_loss=0.022]


2024-05-20 18:53:44
Epoch 13 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.32it/s, train_loss=2.58]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.63it/s, val_acc=32.8, val_loss=0.0219]


2024-05-20 18:54:18
Epoch 14 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.20it/s, train_loss=2.51]
[Valid]: 100%|█████████████████████████| 40/40 [00:03<00:00, 10.89it/s, val_acc=33, val_loss=0.0213]


2024-05-20 18:54:53
Epoch 15 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.68it/s, train_loss=2.44]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.90it/s, val_acc=33.5, val_loss=0.0211]


2024-05-20 18:55:27
Epoch 16 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:32<00:00, 10.81it/s, train_loss=2.38]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.08it/s, val_acc=34.6, val_loss=0.0207]


2024-05-20 18:56:02
Epoch 17 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:32<00:00, 10.98it/s, train_loss=2.33]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.10it/s, val_acc=35.7, val_loss=0.0205]


2024-05-20 18:56:38
Epoch 18 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.10it/s, train_loss=2.27]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.43it/s, val_acc=36.2, val_loss=0.0202]


2024-05-20 18:57:13
Epoch 19 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.18it/s, train_loss=2.21]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.37it/s, val_acc=37.3, val_loss=0.0198]


2024-05-20 18:57:47
Epoch 20 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.15it/s, train_loss=2.16]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 10.49it/s, val_acc=37.6, val_loss=0.0197]


2024-05-20 18:58:23
Epoch 21 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:32<00:00, 10.77it/s, train_loss=2.11]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.36it/s, val_acc=37.8, val_loss=0.0196]


2024-05-20 18:58:58
Epoch 22 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.54it/s, train_loss=2.06]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.76it/s, val_acc=37.3, val_loss=0.0198]


2024-05-20 18:59:32
Epoch 23 / 50


[Train]: 100%|████████████████████████████| 352/352 [00:28<00:00, 12.19it/s, train_loss=2]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.54it/s, val_acc=39.2, val_loss=0.0192]


2024-05-20 19:00:04
Epoch 24 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:29<00:00, 12.09it/s, train_loss=1.95]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 13.00it/s, val_acc=39.1, val_loss=0.0191]


2024-05-20 19:00:36
Epoch 25 / 50


[Train]: 100%|██████████████████████████| 352/352 [00:28<00:00, 12.26it/s, train_loss=1.9]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.63it/s, val_acc=39.7, val_loss=0.0192]


2024-05-20 19:01:07
Epoch 26 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:29<00:00, 12.08it/s, train_loss=1.84]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.62it/s, val_acc=39.3, val_loss=0.0192]


2024-05-20 19:01:39
Epoch 27 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:28<00:00, 12.42it/s, train_loss=1.79]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.55it/s, val_acc=39.1, val_loss=0.0192]


2024-05-20 19:02:10
Epoch 28 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:27<00:00, 12.61it/s, train_loss=1.74]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.74it/s, val_acc=39.6, val_loss=0.0192]


2024-05-20 19:02:41
Epoch 29 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:26<00:00, 13.23it/s, train_loss=1.68]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 16.78it/s, val_acc=40.7, val_loss=0.0191]


2024-05-20 19:03:10
Epoch 30 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:28<00:00, 12.56it/s, train_loss=1.63]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.88it/s, val_acc=40.9, val_loss=0.0189]


2024-05-20 19:03:41
Epoch 31 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:29<00:00, 11.98it/s, train_loss=1.58]
[Valid]: 100%|████████████████████████| 40/40 [00:03<00:00, 10.60it/s, val_acc=40.4, val_loss=0.019]


2024-05-20 19:04:14
Epoch 32 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.30it/s, train_loss=1.52]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 10.69it/s, val_acc=40.2, val_loss=0.0192]


2024-05-20 19:04:49
Epoch 33 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.47it/s, train_loss=1.46]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 11.89it/s, val_acc=41.3, val_loss=0.0189]


2024-05-20 19:05:23
Epoch 34 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.32it/s, train_loss=1.39]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.49it/s, val_acc=41.8, val_loss=0.0192]


2024-05-20 19:05:57
Epoch 35 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:31<00:00, 11.26it/s, train_loss=1.32]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.34it/s, val_acc=40.4, val_loss=0.0195]


2024-05-20 19:06:31
Epoch 36 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.51it/s, train_loss=1.28]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.43it/s, val_acc=41.3, val_loss=0.0193]


2024-05-20 19:07:04
Epoch 37 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.68it/s, train_loss=1.22]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.01it/s, val_acc=41.5, val_loss=0.0195]


2024-05-20 19:07:37
Epoch 38 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.50it/s, train_loss=1.15]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.48it/s, val_acc=41.3, val_loss=0.0197]


2024-05-20 19:08:11
Epoch 39 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:30<00:00, 11.53it/s, train_loss=1.09]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 14.92it/s, val_acc=41.1, val_loss=0.0198]


2024-05-20 19:08:44
Epoch 40 / 50


[Train]: 100%|█████████████████████████| 352/352 [00:28<00:00, 12.23it/s, train_loss=1.02]
[Valid]: 100%|█████████████████████████| 40/40 [00:03<00:00, 12.88it/s, val_acc=41, val_loss=0.0204]


2024-05-20 19:09:16
Epoch 41 / 50


[Train]: 100%|████████████████████████| 352/352 [00:29<00:00, 11.96it/s, train_loss=0.954]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.92it/s, val_acc=40.1, val_loss=0.0206]


2024-05-20 19:09:48
Epoch 42 / 50


[Train]: 100%|████████████████████████| 352/352 [00:28<00:00, 12.50it/s, train_loss=0.893]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.16it/s, val_acc=41.6, val_loss=0.0206]


2024-05-20 19:10:19
Epoch 43 / 50


[Train]: 100%|████████████████████████| 352/352 [00:28<00:00, 12.36it/s, train_loss=0.828]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.43it/s, val_acc=40.8, val_loss=0.0209]


2024-05-20 19:10:51
Epoch 44 / 50


[Train]: 100%|████████████████████████| 352/352 [00:29<00:00, 11.99it/s, train_loss=0.771]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.00it/s, val_acc=40.5, val_loss=0.0216]


2024-05-20 19:11:23
Epoch 45 / 50


[Train]: 100%|████████████████████████| 352/352 [00:28<00:00, 12.46it/s, train_loss=0.701]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.69it/s, val_acc=41.2, val_loss=0.0217]


2024-05-20 19:11:54
Epoch 46 / 50


[Train]: 100%|████████████████████████| 352/352 [00:28<00:00, 12.57it/s, train_loss=0.642]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 13.70it/s, val_acc=40.6, val_loss=0.0221]


2024-05-20 19:12:25
Epoch 47 / 50


[Train]: 100%|████████████████████████| 352/352 [00:29<00:00, 12.00it/s, train_loss=0.579]
[Valid]: 100%|███████████████████████| 40/40 [00:02<00:00, 15.64it/s, val_acc=40.5, val_loss=0.0221]


2024-05-20 19:12:57
Epoch 48 / 50


[Train]: 100%|████████████████████████| 352/352 [00:28<00:00, 12.26it/s, train_loss=0.525]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.62it/s, val_acc=41.1, val_loss=0.0223]


2024-05-20 19:13:29
Epoch 49 / 50


[Train]: 100%|████████████████████████| 352/352 [00:31<00:00, 11.33it/s, train_loss=0.487]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 12.77it/s, val_acc=41.9, val_loss=0.0226]


2024-05-20 19:14:03
Epoch 50 / 50


[Train]: 100%|████████████████████████| 352/352 [00:30<00:00, 11.59it/s, train_loss=0.438]
[Valid]: 100%|███████████████████████| 40/40 [00:03<00:00, 11.02it/s, val_acc=40.6, val_loss=0.0232]
