In [52]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ViT(nn.Module):
    def __init__(
        self,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout=0.1,
    ):
        super().__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embedding = nn.Parameter(
            # torch.zeros(1, (image_size // patch_size) ** 2 + 1, dim)
            torch.zeros(1, 13 + 1, dim)
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=dim,
                nhead=heads,
                dim_feedforward=mlp_dim,
                dropout=dropout,
                batch_first=True,  # Đặt batch_first=True
            ),
            num_layers=depth,
        )
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):   # x shape (2, 13, 3) = (batch_size, num_patches, dim)
        print(f"self.cls_token: {self.cls_token}")
        print(f"self.pos_embedding: {self.pos_embedding}")
        print(f"x: {x}")
        print(x.shape)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        # cls_tokens = x[:, 0].unsqueeze(1)
        print(cls_tokens.shape)

        one_hot = F.one_hot(torch.argmax(x, dim=2), num_classes=x.size(2)).float()
        print(one_hot)
        x = torch.cat((one_hot, x), dim=2)
        print(f"After concat vs one_hot: {x}")
        x = torch.cat((cls_tokens, x), dim=1)
        print(f"After concat vs cls_token: {x}")
        print(x.shape)
        x += self.pos_embedding
        print(f"After add pos_embedding: {x}")
        print(x.shape)
        x = self.transformer(x)
        print(x.shape)
        x = self.head(x[:, 0])
        print(x.shape)
        print(x)
        return x


# Ví dụ sử dụng
model = ViT(
   num_classes=3, dim=6, depth=3, heads=2, mlp_dim=12
)
x = torch.randn(2, 13, 3)  # Một batch gồm 2 ảnh
output = model(x)

self.cls_token: Parameter containing:
tensor([[[0., 0., 0., 0., 0., 0.]]], requires_grad=True)
self.pos_embedding: Parameter containing:
tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]], requires_grad=True)
x: tensor([[[-0.9285, -0.8449, -0.5990],
         [-0.2176,  1.1140,  0.7230],
         [ 0.2371,  0.2902, -0.4879],
         [ 1.0518, -1.4160, -1.5377],
         [-0.5393,  0.7882, -2.0191],
         [-0.8573,  0.2572,  0.9746],
         [-1.3147, -0.3570, -1.3314],
         [-0.8846, -1.3643, -1.0141],
         [-0.6322, -0.7458, -0.2021],
      

In [53]:
import torch

# Giả sử model là một đối tượng mô hình ViT
num_params = sum(p.numel() for p in model.parameters())
print(num_params)

1173


In [49]:
(97*3)*2

582

In [54]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class H39_63_ViT(nn.Module):
    def __init__(
        self,
        num_classes=3,
        dim=6,
        depth=3,
        heads=2,
        mlp_dim=12,
        dropout=0.1,
    ):
        super().__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.zeros(1, 13 + 1, dim))
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=dim,
                nhead=heads,
                dim_feedforward=mlp_dim,
                dropout=dropout,
                batch_first=True,  # Đặt batch_first=True
            ),
            num_layers=depth,
        )
        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):  # x shape (2, 13, 3) = (batch_size, num_patches, dim)
        # print(f"self.cls_token: {self.cls_token}")
        # print(f"self.pos_embedding: {self.pos_embedding}")
        # print(f"x: {x}")
        # print(x.shape)
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        # print(cls_tokens.shape)

        one_hot = F.one_hot(torch.argmax(x, dim=2), num_classes=x.size(2)).float()
        # print(one_hot)
        x = torch.cat((one_hot, x), dim=2)
        # print(f"After concat vs one_hot: {x}")
        x = torch.cat((cls_tokens, x), dim=1)
        # print(f"After concat vs cls_token: {x}")
        # print(x.shape)
        x += self.pos_embedding
        # print(f"After add pos_embedding: {x}")
        # print(x.shape)
        x = self.transformer(x)
        # print(x.shape)
        x = self.head(x[:, 0])
        # print(x.shape)
        # print(x)
        return x


if __name__ == "__main__":
    # Ví dụ sử dụng
    model = ViT(num_classes=3, dim=6, depth=3, heads=2, mlp_dim=12)
    x = torch.randn(2, 13, 3)  # Một batch gồm 2 ảnh
    output = model(x)

    num_params = sum(p.numel() for p in model.parameters())
    print(num_params)


self.cls_token: Parameter containing:
tensor([[[0., 0., 0., 0., 0., 0.]]], requires_grad=True)
self.pos_embedding: Parameter containing:
tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]], requires_grad=True)
x: tensor([[[ 1.1078,  0.6078, -0.7931],
         [-0.6969, -1.3807,  1.0727],
         [-0.6976, -1.2132,  0.7139],
         [ 0.6209, -1.0919,  0.7523],
         [ 0.0340,  1.2201, -0.1346],
         [-1.6126, -0.9207, -1.0670],
         [-0.5852,  1.7451,  0.6311],
         [-0.4041, -1.2397,  1.8388],
         [ 1.9605,  0.1487, -0.1732],
      

In [58]:
x = torch.randn(26, 3)
print(x.shape, x)
y = x.view(-1, 13, 3)
print(y.shape, y)

torch.Size([26, 3]) tensor([[ 1.4476,  0.7937, -0.6119],
        [-0.0121,  0.1928,  1.8161],
        [-0.7225, -0.0851,  0.9332],
        [ 1.1289,  0.7718, -0.2430],
        [-0.2996,  0.6422,  0.6020],
        [ 1.2265,  1.2118, -0.7797],
        [ 0.1727,  1.5401,  0.1851],
        [-0.6264, -0.8372,  0.2559],
        [-1.6641, -0.6122,  2.0225],
        [ 0.1420,  0.7673,  0.2306],
        [-1.2340, -0.7469,  0.7317],
        [ 0.0363, -0.5170, -1.2705],
        [-0.7391,  0.2611,  0.9240],
        [-1.6722, -1.3749,  0.1121],
        [ 3.5994, -1.2491,  1.4166],
        [-0.5055,  0.1773, -0.3726],
        [-0.8148, -0.0719, -0.1190],
        [ 1.7486, -2.1217, -0.6415],
        [-0.0672, -1.4461,  1.5898],
        [ 0.3514, -1.6013, -0.4099],
        [ 1.7133, -0.4120, -1.4850],
        [-0.2850,  0.3422,  0.9327],
        [ 1.8040,  0.2909, -0.3727],
        [ 1.0165,  1.2955,  0.0141],
        [ 0.8261, -0.5779, -0.0935],
        [ 0.4408, -0.2276,  0.6525]])
torch.Size([2, 13