In [1]:
import torch
import torch.nn as nn
import math

In [None]:
# 포지셔널 인코딩 핵심파트

class PositionalEncoding(nn.Module):

    def __init__(self, dim, seq_len_max):

        PE = torch.zeros(seq_len_max, dim)

        position = torch.arange(seq_len_max, dim)
        position = torch.arange(seq_len_max, dim)
        # 중요!!!!!
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))

        scores = torch.M

        PE[:, 0::2] = torch.sin(position * div_term)
        PE[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('PE', PE.unsqueeze(0))

    def forward(self, X):
        return X + self.PE[:,  :X.size(1)]

In [None]:
# 멀티헤드 어텐션 핵심파트 
class MultiHeadAttention(nn.Module):

    def __init__(self, dim, head_num):
        super(MultiHeadAttention, self).__init__()

        self.dim = dim
        self.head_num = head_num
        self.word_dim = dim // head_num

        self.W_q = nn.Linear(dim, dim)
        self.W_k = nn.Linear(dim, dim)
        self.W_v = nn.Linear(dim, dim)
        self.W_o = nn.Linear(dim, dim)

    def scaled_dot_product(self, Q, K, V, mask = None):
        # 중요
        scores = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(torch.tensor(self.word_dim).float())

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) 

        probs = nn.functional.softmax(scores, dim = -1)

        heads = torch.matmul(probs, V) # V와 matmul 수행

        return heads
    

    def split(self, X): # 나누기
        print()
    
    def combine(self, X): # 합치기
        print()

    def forward(self, X_Q, X_K, X_V, mask = None):
        Q = self.split(self.W_q(X_Q))
        K = self.split(self.W_k(X_K))
        V = self.split(self.W_v(X_V))

        heads = self.scaled_dot_product(Q, K, V, mask) # 중요(scaled dot product 수행)
        output = self.W_o(self.combine(heads)) # 중요(Q, K, V 합친 heads를 output에 mapping)

        return output

In [4]:
# 트랜스포머 인코더 핵심파트

class FFN(nn.Module):
    def __init__(self, dim, FFN_dim):
        super(FFN, self).__init__()

        self.FFN_layer = nn.Sequential(nn.Linear(dim, FFN_dim),
                                       nn.ReLU(),
                                       nn.Linear(FFN_dim, dim))
    
    def forward(self, X):
        return self.FFN_layer(X)

class EncoderLayer(nn.Module):
    def __init__(self, dim, head_num, FFN_dim, dropout):
        super(EncoderLayer, self).__init__()

        self.attention = MultiHeadAttention(dim, head_num)

        self.ffn = FFN(dim, FFN_dim)

        self.dropout = nn.Dropout(dropout)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, X, mask):
        # 인코더 1. 멀티헤드 셀프 어텐션 통과
        attn_output = self.attention(X, X, X, mask) # Q, K, V all use the same X.
        
        attn_output = self.dropout(attn_output)

        X = self.norm1(X + attn_output) # Residual Connection 적용

        # 인코더 2. FFN 통과
        ffn_output = nn.FFN(X)
        ffn_output = self.dropout(ffn_output)
        output = self.norm2(X + ffn_output)

        return output

In [None]:
# 트랜스포머 디코더 핵심파트

class DecoderLayer(nn.Module):
    def __init__(self, dim, head_num, FFN_dim, dropout):
        super(DecoderLayer, self).__init__()

        # 1. 먼저 셀프 어텐션부터
        self.self_attention = MultiHeadAttention(dim, head_num)


        # 2. 마스크드 멀티헤드 크로스 어텐션
        self.cross_attention = MultiHeadAttention(dim, head_num)


        # 3. FFN
        self.ffn = FFN(dim, FFN_dim)


        # 4. LayerNorm Layers
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, X, enc_output, cross_attn_mask, self_attn_mask):
        # 1. 셀프어텐션 with 잔차 연결
        self_attn_output = self.self_attention(X, X, X, self_attn_mask)
        self_attn_output = self.dropout(self_attn_output)
        X = self.norm1(X + self_attn_output)

        # 2. 크로스어텐션 with 잔차 연결
        cross_attn_output = self.cross_attention(X, enc_output, enc_output, cross_attn_mask) # 가장중요
        cross_attn_output = self.dropout(cross_attn_output)
        X = self.norm2(X + cross_attn_output)

        # 3. FFN with 잔차 연결
        ffn_output = self.ffn(X)
        ffn_output = self.dropout(ffn_output)
        output = self.norm3(X + ffn_output)

        return output

# ViT

In [6]:
# 패치 임베딩

class PatchEmbed(nn.Module):

    def __init__(self, img_size = 224, patch_size = 16, in_chans = 3, embed_dim = 768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size) # 중요
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches


        # 중요
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size = patch_size, stride = patch_size) # 커널 사이즈, 스트라이드 모두 패치 사이즈만함.

    def forward(self, x):
        B, C, H, W = x.shape # 배치, 채널, 높이, 너비 순
        x = self.proj(x).flatten(2).transpose(1, 2)

        return x

In [7]:
# 어텐션
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3) # for q, k, v
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale # Query, Key 간 연산.
        attn = attn.softmax(dim = -1) # softmax 통과

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x

In [None]:
# MLP
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features = None, out_features = None, act_layer = nn.GELU): # activation으로 nn.GELU 사용함에 유의
        super.__init__()

        out_features = out_features or hidden_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

In [9]:
# Transformer Block
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio = 4., act_layer = nn.GELU, norm_layer = nn.LayerNorm):
        super.__init__()

        self.norm1 = norm_layer(dim)

        self.attn = Attention(dim, num_heads=num_heads)

        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features = dim, hidden_features = mlp_hidden_dim, act_layer = act_layer)

    def forward(self, x):

        # 1. Residual Connection for Self-Attention
        x = x + self.attn(self.norm1(x))

        # 2. Residual Connection for MLP
        x = x + self.mlp(self.norm2(x))

        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=28, patch_size=4, in_chans=1, num_classes=10, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        # 1. 패치임베딩
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # 2. CLS 토큰 추가
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # 3. 포지셔널 임베딩 : 중요. -> define a learnable positional embedding that matches the patchified input token size.
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # 4. 트랜스포머 블록
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,  norm_layer=norm_layer)
            for i in range(depth)])
        
        # 5. 최종 정규화
        self.norm = norm_layer(embed_dim)

        # 6. 분류기 헤드
        self.head = nn.Linear(
            embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
    def forward(self, x):
        B = x.shape[0] # 배치 사이즈 따오기, CLS 토큰에 필요함

        # 1. 패치 임베딩
        x = self.patch_embed(x)

        # 2. CLS 토큰 추가 - learnable param
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim = 1)

        # 3. 포지셔널 임베딩 추가
        x = x + self.pos_embed

        # 4. 트랜스포머 블록 통과(각 블록 모듈 순차통과)
        for blk in self.blocks:
            x = blk(x)

        # 5. 최종 정규화 - 출력 텐서 정규화 적용
        x = self.norm(x)

        # 6. CLS 토큰 최종 출력
        cls_token_final = x[:, 0]

        # 7. 분류기 헤드에 전달
        x = self.head(cls_token_final)

        return x