In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# 定义Transformer模型的编码器和解码器
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, hidden_size, num_heads, dropout):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(hidden_size, num_heads, hidden_size * 4, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(hidden_size)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return x

class TransformerDecoder(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super(TransformerDecoder, self).__init__()
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.fc(x)
        return x


In [None]:
# 定义完整的Transformer模型，包括编码器和解码器
class Transformer(nn.Module):
    def __init__(self, img_size, num_layers, hidden_size, num_heads, dropout, num_classes):
        super(Transformer, self).__init__()
        self.embedding = nn.Sequential(
            nn.Linear(img_size, hidden_size),
            nn.LayerNorm(hidden_size)
        )
        self.encoder = TransformerEncoder(num_layers, hidden_size, num_heads, dropout)
        self.decoder = TransformerDecoder(hidden_size, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.transpose(0, 1)
        x = self.encoder(x)
        x = x.mean(dim=0)
        x = self.decoder(x)
        return x

In [None]:
model = Transformer(img_size=224*224, num_layers=4, hidden_size=512, num_heads=8, dropout=0.1, num_classes=10)
x = torch.randn(10, 32, 224, 224)
y = model(x.view(10, -1, 224*224))
print(y.shape) # 输出：torch.Size([10, 10])
