1차 목표
- 인코더 구현

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# import torchvision.datasets as datasets
# import torchvision.transforms as transforms

In [3]:
# # 데이터
# transform = transforms.ToTensor()

# trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


100%|██████████| 170M/170M [00:05<00:00, 29.0MB/s]


인코더에 필요한 것
- 포지셔널 인코딩
- 멀티헤드 어텐션
- Normalization & Add
- Feed forward

## 1. PositionalEncoding
- pos와 i는 shape가 달라서 바로 연산 불가능 ->  각각 2차원으로 확장해야 함

In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        batch_size, seq_len, d_model = x.shape

        # pos 위치
        pos = torch.arange(seq_len).unsqueeze(1) # (seq_len,1)
        # i 차원
        i = torch.arange(d_model).unsqueeze(0) # (1,d_model)

        angle_rates = pos / (10000 ** (2 * (i // 2) / self.d_model))

        pos_emb = torch.zeros(seq_len, d_model)
        pos_emb[:, 0::2] = torch.sin(angle_rates[:, 0::2])  # 짝수 인덱스
        pos_emb[:, 1::2] = torch.cos(angle_rates[:, 1::2])  # 홀수 인덱스

        pos_emb = pos_emb.unsqueeze(0)  # (1, seq_len, d_model)
        return x + pos_emb

In [17]:
# 테스트
x = torch.zeros(2, 10, 512)  # (batch_size=2, seq_len=10, d_model=512)
pe = PositionalEncoding(d_model=512)

out = pe(x)
print(out.shape)  # (2, 10, 512) 나와야
print(out[0, 0, :10])  # 첫 번째 토큰의 첫 10개 값 찍어보기

torch.Size([2, 10, 512])
tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])


## 2. MultiHeadAttention

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, num_heads):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_v)

        self.W_o = nn.Linear(d_v, d_model)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # 입력 query.shape = (batch_size, seq_len, d_model)

        # (batch_size, seq_len, d_k)
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 멀티헤드라서...
        # (batch_size, seq_len, num_heads, d_k_head)
            # transpose(1 ,2)를 사용해 head 축을 앞으로 -> 병렬 처리를 위함(이래....응....)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k // self.num_heads).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k // self.num_heads).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_v // self.num_heads).transpose(1, 2)

        # (batch_size, num_heads, seq_len, seq_len)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k // self.num_heads) ** 0.5
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # (batch_size, num_heads, seq_len, d_v_head)
        attn_output = torch.matmul(attn_probs, V)

        # (batch_size, seq_len, num_heads, d_v_head)
        attn_output = attn_output.transpose(1, 2).contiguous()
            # contiguous()는 비연속적인 텐서를 연속적인 메모리 배치로 변환, view를 쓸 때 같이 씀

        # (batch_size, seq_len, d_v)
        attn_output = attn_output.view(batch_size, -1, self.num_heads * (self.d_v // self.num_heads))

        # (batch_size, seq_len, d_model)
        output = self.W_o(attn_output)

        return output


In [21]:
# 테스트
# 하이퍼파라미터
batch_size = 2
seq_len = 4
d_model = 16
d_k = 16
d_v = 16
num_heads = 4

# 모델 생성
mha = MultiHeadAttention(d_model, d_k, d_v, num_heads)

# 가짜 입력 데이터 생성
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)

# 모델 실행
output = mha(query, key, value)

# 출력 결과 확인
assert output.shape == (batch_size, seq_len, d_model)
print("✅ MultiHeadAttention 통과!")


✅ MultiHeadAttention 통과!


## 3. FFN
- FFN(x)=ReLU(xW1 +b1)W2 +b



In [25]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, ff_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, ff_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(ff_dim, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return self.dropout(x)


## 4. Norm
- LayerNorm(x+sublayer(x))

In [26]:
class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(d_model)  # Layer Normalization
        self.dropout = nn.Dropout(dropout)      # Dropout

    def forward(self, x, sublayer_output):
        # Residual Connection + Layer Normalization
        return self.layer_norm(x + self.dropout(sublayer_output))


## 5. EncoderLayer

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, ff_dim=2048, dropout=0.1):
        super().__init__()

        # Multi-head attention
        d_k = d_v = d_model  # Query, Key, Value의 차원은 d_model로 설정 
        self.mha = MultiHeadAttention(d_model=d_model,
                                      d_k=d_k,
                                      d_v=d_v,
                                      num_heads=num_heads)

        # Feed Forward Network (FFN)
        self.ffn = FeedForwardNetwork(d_model=d_model,
                                      ff_dim=ff_dim,
                                      dropout=dropout)

        # Residual Connection + Layer Normalization
        self.norm1 = ResidualLayerNorm(d_model=d_model, dropout=dropout)
        self.norm2 = ResidualLayerNorm(d_model=d_model, dropout=dropout)

    def forward(self, x):
        # Multi-Head Attention + Residual Connection + LayerNorm
        attn_output = self.mha(x, x, x)  # Self-attention: query=key=value=x
        out1 = self.norm1(x, attn_output)

        # Feed Forward Network + Residual Connection + LayerNorm
        ffn_output = self.ffn(out1)
        out2 = self.norm2(out1, ffn_output)

        return out2
