In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch import optim
import math
import torch.nn.functional as F


from tqdm import tqdm

In [2]:
from torch.utils.data import DataLoader
train_data=torchvision.datasets.MNIST(
    root='MNIST',
    train=True,
    transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))
                               ]),
    download=True
)
test_data=torchvision.datasets.MNIST(
    root='MNIST',
    train=False,
    transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))
                               ]),
    download=True
)
train_load=DataLoader(dataset=train_data,batch_size=128,shuffle=True)
test_load=DataLoader(dataset=test_data,batch_size=128,shuffle=True)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
class positionalEncoding(nn.Module):
    def __init__(self, patch_num, embedding_dim):
        super(positionalEncoding, self).__init__()
        self.P = torch.zeros((1, patch_num, embedding_dim))
        X = torch.arange(patch_num, dtype=torch.float32).reshape(-1, 1)/ torch.pow(10000, torch.arange(0, embedding_dim, 2, dtype=torch.float32) / embedding_dim)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, x):
        x = x + self.P[:, :x.shape[1], :].to(device)
        return x

In [5]:
# selfAttention
def selfAttention(queries, keys, values):
    d = queries.shape[-1]
    scores = torch.bmm(queries, keys.transpose(1, 2))/ math.sqrt(d)
    return torch.bmm(nn.Softmax(dim=-1)(scores), values)

def transpose_qkv(X, h):
# 输入的X：batch_size * n * em_dim => b * n * h * e/h =>(b*h) * n * e/h
    X = X.reshape(X.shape[0], X.shape[1], h, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def multiConcat(X, h):
# 输入的X：(b*h) * n * e/h => b * h * n * e/h=> b * n * h * e/h => b * n * e
    X = X.reshape(-1, h, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

# multi-head
class multiHeadAttention(nn.Module):
    def __init__(self, input_dim, embedding_dim, h):
        super(multiHeadAttention, self).__init__()
        self.h = h
        self.x2q = nn.Linear(input_dim, embedding_dim, False)
        self.x2k = nn.Linear(input_dim, embedding_dim, False)
        self.x2v = nn.Linear(input_dim, embedding_dim, False)
        self.selfAttention = selfAttention
        self.wo = nn.Linear(embedding_dim, embedding_dim, False)
        self.forRes = nn.Linear(embedding_dim, input_dim, False)
    def forward(self, x):
        q = self.x2q(x).to(device)
        k = self.x2k(x).to(device)
        v = self.x2v(x).to(device)
        
        queries = transpose_qkv(q, self.h)
        keys = transpose_qkv(k, self.h)        
        values = transpose_qkv(v, self.h)
        
        output = self.selfAttention(queries, keys, values)
        output_concat = multiConcat(output,self.h)
        final_output = self.forRes(self.wo(output_concat)) + x
        return final_output

In [6]:
class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()
        self.mlp_block = torch.nn.Sequential()
        self.mlp_block.add_module("linear1", nn.Linear(input_size, 1024))
        self.mlp_block.add_module("ReLU1", nn.ReLU())
        self.mlp_block.add_module("linear2", nn.Linear(1024, 512))
        self.mlp_block.add_module("ReLU2", nn.ReLU())
        self.mlp_block.add_module("linear3", nn.Linear(512, output_size))
    def forward(self, x):
        # print(f"mlp_input_size{x.shape}")
        return F.softmax(self.mlp_block(x),dim=1)

In [7]:
class Net(nn.Module):
    def __init__(self, patch_num, embedding_dim, h, output_size):
        super(Net, self).__init__()
        #position_encoding
        self.positionalEncoding = positionalEncoding(patch_num, embedding_dim)
        #multi-ihead-attention
        self.multihead = multiHeadAttention(embedding_dim, embedding_dim, h)
        #mlp
        self.mlp = MLP(patch_num * embedding_dim, output_size)
    def forward(self, x):
        # print(f"ori input.shape{x.shape}")
        x = self.positionalEncoding(x)
        # print(f"after poscode,x.shape{x.shape}")
        x = self.multihead(x)
        # print(f"after multihead,x.shape{x.shape}")
        # n * w=>flatten
        x = x.flatten(-2)
        # print(f"after flatten,x.shape{x.shape}")
        x = self.mlp(x)
        return x

In [8]:
net = Net(16, 512, 8, 10).to(device)
print(net)

Net(
  (positionalEncoding): positionalEncoding()
  (multihead): multiHeadAttention(
    (x2q): Linear(in_features=512, out_features=512, bias=False)
    (x2k): Linear(in_features=512, out_features=512, bias=False)
    (x2v): Linear(in_features=512, out_features=512, bias=False)
    (wo): Linear(in_features=512, out_features=512, bias=False)
    (forRes): Linear(in_features=512, out_features=512, bias=False)
  )
  (mlp): MLP(
    (mlp_block): Sequential(
      (linear1): Linear(in_features=8192, out_features=1024, bias=True)
      (ReLU1): ReLU()
      (linear2): Linear(in_features=1024, out_features=512, bias=True)
      (ReLU2): ReLU()
      (linear3): Linear(in_features=512, out_features=10, bias=True)
    )
  )
)


In [9]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
          f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
#         print(f"x  {x.shape}")
        x = self.proj(x)
#         print(f"proj  {x.shape}")
        x = x.flatten(2)
#         print(f"flatten    {x.shape}")
        x = x.transpose(1, 2)
#         print(f"trans{x.shape}")
        return x  
    
# x = torch.randn((100,1,28,28))
# mypatchtest = PatchEmbed(28, 7, 1, 512)
# res = mypatchtest(x)
# print(res.shape)

In [10]:
num_epoches = 5

optimizer = optim.SGD(net.parameters(), lr=0.1)# Adam梯度下降
lossCal = nn.CrossEntropyLoss()
pic2Patches = PatchEmbed(28, 7, 1, 512).to(device)#res.size = torch.Size([batch_size, 16, 512])

In [11]:
losses = []
acces = []
eval_losses = []
eval_acces = []

for epoch in tqdm(range(num_epoches), desc = "training network"):
    train_loss = 0
    train_acc = 0
    num_correct = 0
    for img, label in train_load:
        img = img.to(device)
        # print(f"berore topatch{img.shape}")
        img = pic2Patches(img)
        # print(f"after topatch{img.shape}")
        label = label.to(device)
        # 前向传播
        
        res = net(img)
        loss = lossCal(res, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # loss.item()是将零维张量转换为浮点数
        train_loss += loss.item()
        # 计算分类的准确率
        _, pred = res.max(1)
        num_correct += torch.sum(pred == label)
        acc = num_correct / img.shape[0]
        train_acc +=acc
    losses.append(train_loss / len(train_load))
    acces.append(train_acc / len(train_load))
    tqdm.write(f"epoch number: {epoch+1,num_epoches}loss：{train_loss/len(train_load)}")
    tqdm.write(f"   corrrect{int(10000 * num_correct/len(train_data)) / 100}%")


training network:  20%|█████████████▍                                                     | 1/5 [00:10<00:41, 10.46s/it]

epoch number: (1, 5)loss：1.7041966554198438
   corrrect78.3%


training network:  40%|██████████████████████████▊                                        | 2/5 [00:20<00:29,  9.93s/it]

epoch number: (2, 5)loss：1.5538828538170755
   corrrect91.44%


training network:  60%|████████████████████████████████████████▏                          | 3/5 [00:29<00:19,  9.73s/it]

epoch number: (3, 5)loss：1.5346541203923825
   corrrect93.23%


training network:  80%|█████████████████████████████████████████████████████▌             | 4/5 [00:38<00:09,  9.63s/it]

epoch number: (4, 5)loss：1.5218887674783084
   corrrect94.46%


training network: 100%|███████████████████████████████████████████████████████████████████| 5/5 [00:48<00:00,  9.69s/it]

epoch number: (5, 5)loss：1.5130044295589553
   corrrect95.26%



