In [3]:
import torch
import timm
from torch import nn
from einops import rearrange,repeat
from einops.layers.torch import Rearrange
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
learning_rate = 1e-4
num_epochs = 10

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))#这里加了一个数据归一化
])
 

trainset = torchvision.datasets.CIFAR10(root='D:/Dataset/CIFAR-10', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
 
testset = torchvision.datasets.CIFAR10(root='D:/Dataset/CIFAR-10', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)



def pair(i):
    return i if isinstance(i,tuple) else (i,i) 


class PreNorm(nn.Module):
    def __init__(self,dim,fn):#接受不同函数
        super().__init__()
        self.norm = nn.LayerNorm(dim)#进行层归一化
        self.fn = fn
    def forward(self, x, **kwargs):#接受任意数量的参数以达到接受不同函数
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):#前馈
    def __init__(self, dim, hidden_dim, dropout = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),#使用Gelu激活函数
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 3, dim_head = 64, dropout = 0.1):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5 

        self.attend = nn.Softmax(dim = -1)
       
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)  

        out = torch.matmul(attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
        

class Transformer(nn.Module):
    def __init__(self,embed_dim,depth,heads,dim_head,mlp_dim,dropout=0.1):
        super().__init__()
        self.layers=nn.ModuleList([])#存储一系列的模块
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(embed_dim, Attention(embed_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(embed_dim, FeedForward(embed_dim, mlp_dim, dropout = dropout))
            ]))#每次新加attention和FeedForward层，共depth个
    def forward(self, x):
        for attention, feedforward in self.layers:  
            x = attention(x) + x  
            x = feedforward(x) + x#残差操作
        return x
        
    




class ViT(nn.Module):
    def __init__(self,*,image_size,patch_size,num_classes,embed_dim,depth,heads,mlp_dim,pool='cls',channels=3,dim_head=64,dropout=0.1,embed_dropout=0.1):
        super().__init__()
        image_height,image_width=pair(image_size)#成对赋值
        patch_height,patch_width=pair(patch_size)

        assert image_height%patch_height==0 and image_width%patch_width==0#保证image高度和宽度可以整除patch的高度和宽度

        num_patches=(image_height//patch_height)*(image_width//patch_width)#计算有多少个patch

        patch_dim=channels*patch_height*patch_width#计算patch维度
        assert pool in {'cls', 'mean'}#输出时选择这个cls token或者用平均池化

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),#转换维度，将batch_size，channels，height，width转换为batch_size,num_patches,patch_dim
            nn.Linear(patch_dim, embed_dim),#进行线性映射
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))#进行位置信息嵌入计算，parameter表示可学习，randn为初始化，+1因为要加入cls_token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))#定义cls，与每个patch维度相同。因为每个图块被展平了，所以变为了1x1x192
        self.dropout = nn.Dropout(embed_dropout)

        self.transformer = Transformer(embed_dim, depth, heads, dim_head, mlp_dim, dropout)#transfomer块

        self.pool=pool
        self.to_latent = nn.Identity()#占位符

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)#线性层输出
        )

    def forward(self,image):
        x=self.to_patch_embedding(image)
        b, n, _ = x.shape#_表示不会显式使用

        cls_tokens=repeat(self.cls_token, '() n d -> b n d', b = b)
        x=torch.cat((cls_tokens, x), dim=1)
        x+=self.pos_embedding[:, :(n + 1)]  
        x=self.dropout(x)
        x=self.transformer(x)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)
        
pretrained_vit = timm.create_model('vit_base_patch16_224', pretrained=True)
pretrained_state_dict = pretrained_vit.state_dict()

model=ViT(
    image_size=32,#图片大小
    patch_size=4,#图块大小
    num_classes=10,#最后要分类的数量
    embed_dim=192,#patch embedding时的维度，即投射后的维度
    depth=6,#transformer的个数，对于CIFAR10来说，这里数量更少，heads同理
    heads=3,#自注意力头的数量
    mlp_dim=768,#transformer后线形层升维后的维数，这里选择乘4
    dropout=0.1,#丢弃神经元
    embed_dropout=0.1).to(device)#embedding时丢弃

pretrained_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in model.state_dict() and model.state_dict()[k].shape == v.shape}
model.load_state_dict(pretrained_state_dict, strict=False)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=0.1)



def train():
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
 
            optimizer.zero_grad()
 
            outputs = model(inputs)
            loss = criterion(outputs, labels)
 
            loss.backward()
            optimizer.step()
 
            running_loss += loss.item()
 
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
 
        accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(trainloader):.4f}, Accuracy: {accuracy:.2f}%')
 
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
 
    accuracy = 100 * correct / total
    print(f'Accuracy of the model on the 10000 test images: {accuracy:.2f}%')
 
if __name__ == "__main__":
    train()
    test()
 

Files already downloaded and verified
Files already downloaded and verified
Epoch [1/10], Loss: 2.3096, Accuracy: 10.97%
Epoch [2/10], Loss: 2.3101, Accuracy: 10.03%
Epoch [3/10], Loss: 2.3079, Accuracy: 9.85%
Epoch [4/10], Loss: 2.3061, Accuracy: 9.81%
Epoch [5/10], Loss: 2.3052, Accuracy: 9.84%
Epoch [6/10], Loss: 2.3041, Accuracy: 9.92%
Epoch [7/10], Loss: 2.3037, Accuracy: 9.84%
Epoch [8/10], Loss: 2.3036, Accuracy: 9.70%
Epoch [9/10], Loss: 2.3031, Accuracy: 9.97%
Epoch [10/10], Loss: 2.3031, Accuracy: 9.79%
Accuracy of the model on the 10000 test images: 10.00%
