In [1]:
import torch
import torch.nn as nn

In [2]:
from math import sqrt


class MultiHeadAttention(nn.Module):
    #input_dim = 임베딩 차원 / d_model = q,k,v가 가질 차원의 크기
  def __init__(self, input_dim, d_model, n_heads):
    super().__init__()
    try:
        self.input_dim = input_dim
        self.d_prime_model = int(round(d_model/n_heads,0))
        self.n_heads = n_heads
    except round(d_model/n_heads,0) % n_heads is not 0:
        print('round(d_model/n_heads,0) % n_heads is not 0')
        
    self.wq = nn.Linear(input_dim, d_model) #쿼리벡터 생성
    self.wk = nn.Linear(input_dim, d_model) #키 벡터 생성
    self.wv = nn.Linear(input_dim, d_model) #밸류 벡터 생성
    self.dense = nn.Linear(d_model, d_model) #어텐션 최종 out 을 위해.
    self.softmax = nn.Softmax(dim=-1)
    

  def forward(self, x, mask):   
    q, k, v = self.wq(x), self.wk(x), self.wv(x)
    # B S H D' 에서 B H S D'으로 변경 필요
    # B S H D' 은 batch 안에 sequentioal가 key인 H vs D' 행렬이 존재하므로 Head가 종속적
    # B H S D'은 batch 안에 head가 key 인 S vs D' 행렬이 존재 하므로 Head가 독립적
    q = q.view(q.shape[0], q.shape[1], self.n_heads, self.d_prime_model).transpose(1,2)
    k = k.view(k.shape[0], k.shape[1], self.n_heads, self.d_prime_model).transpose(1,2)
    v = v.view(v.shape[0], v.shape[1], self.n_heads, self.d_prime_model).transpose(1,2)

    # B H S D' * B H D' S = B H S S
    # 각 head에 따른 상관도를 확인 가능
    score = torch.matmul(q, k.transpose(-1, -2)) 
    score = score / sqrt(self.d_prime_model)
    
    if mask is not None:
    #마스크가 있다면 실행, 마스크가 있는곳은 1로 되어져있으므로 가장 작은수를 곱하고 더함
    #soft max 함수 이후에 0에 근사함
    #mask(B S S) 와 현 score(B H S S) 텐서 형식이 다르므로 통일 해줘야함.
        mask = mask.unsqueeze(1)
    #unsqueeze를 통해 B 1 S S로 만듬
        score = score + (mask * -1e9)

    score = self.softmax(score) # B H S S
    result = torch.matmul(score, v) # B H S D'
    result = result.transpose(1,2).contiguous()
    result = result.view(result.shape[0], result.shape[1], -1) # B S D
    result = self.dense(result)
    return result

  except round(d_model/n_heads,0) % n_heads is not 0:


In [3]:
class EncoderLayer(nn.Module):
  def __init__(self, input_dim, d_model, dff, dropout):
    super().__init__()

    self.input_dim = input_dim
    self.d_model = d_model
    self.dff = dff

    self.MHA = MultiHeadAttention(input_dim, d_model, 4) #Head를 4로 설정

    self.ffn = nn.Sequential(
      nn.Linear(d_model, dff),
      nn.ReLU(),
      nn.Linear(dff, d_model)
    )
    self.dropout1 = nn.Dropout(p=dropout)  # Overfitting 방지를 위한 Dropout 설정
    self.dropout2 = nn.Dropout(p=dropout)  

    self.layer_norm1 = nn.LayerNorm(d_model)  # Layer Normalization 정의
    self.layer_norm2 = nn.LayerNorm(d_model)  

  def forward(self, x, mask):
    x1 = self.MHA(x, mask) #Multi Head Attention 진행
    x1 = self.dropout1(x1) #Dropout
    x1 = self.layer_norm1(x1 + x) #잔차연결 후 층 정규화

    x2 = self.ffn(x1) #Multi Head Attention + 층정규화 진행된 값을 Feed Forward
    x2 = self.dropout2(x2)
    x2 = self.layer_norm2(x2 + x1) #잔차연결 후 층 정규화

    return x2 #최종 Encoder 출력값

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, input_dim, d_model, dff, dropout):
        super().__init__()

        self.input_dim = input_dim
        self.d_model = d_model
        self.dff = dff

        # Multi-Head Attention
        self.MHA = MultiHeadAttention(input_dim, d_model, n_heads=4)

        # Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model)
        )
        
        # Dropout layers
        self.dropout1 = nn.Dropout(p=dropout)
        self.dropout2 = nn.Dropout(p=dropout)

        # Layer Normalization layers
        self.layer_norm1 = nn.LayerNorm(d_model)
        self.layer_norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        # Self-attention
        x1 = self.MHA(x, mask)  # Multi-Head Attention 적용
        x1 = self.dropout1(x1)  # Dropout
        x1 = self.layer_norm1(x1 + x)  # Residual connection + Layer Normalization

        # Feed-forward
        x2 = self.ffn(x1)  # Feed-forward 네트워크
        x2 = self.dropout2(x2)  # Dropout
        x2 = self.layer_norm2(x2 + x1)  # Residual connection + Layer Normalization

        return x2  # 최종 Decoder 출력값

In [10]:
class Encoder(nn.Module):
    def __init__(self, num_layers, input_dim, d_model, dff, dropout=0.1):
        super(Encoder, self).__init__()

        self.positional_encoding = nn.Embedding(1000, d_model)  # 최대 길이 1000
        self.embedding = nn.Embedding(input_dim, d_model)

        self.layers = nn.ModuleList([EncoderLayer(input_dim, d_model, dff, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        # Positional Encoding 추가
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.positional_encoding(positions)

        # Encoder layers 통과
        for layer in self.layers:
            x = layer(x, mask)

        return x  # (batch_size, seq_len, d_model)

In [11]:
class Decoder(nn.Module):
    def __init__(self, num_layers, input_dim, d_model, dff, vocab_size, dropout=0.1):
        super(Decoder, self).__init__()

        self.positional_encoding = nn.Embedding(1000, d_model)  # 최대 길이 1000
        self.embedding = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList([DecoderLayer(input_dim, d_model, dff, dropout) for _ in range(num_layers)])
        
        # Linear layer for vocab size output
        self.linear = nn.Linear(d_model, vocab_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, tgt, memory, tgt_mask=None):
        # Positional Encoding 추가
        positions = torch.arange(0, tgt.size(1), device=tgt.device).unsqueeze(0)
        tgt = self.embedding(tgt) + self.positional_encoding(positions)

        # Decoder layers 통과
        for layer in self.layers:
            tgt = layer(tgt, tgt_mask)  # Target Masking

        # Linear layer로 vocab_size로 변환
        logits = self.linear(tgt)

        # Softmax로 확률 변환
        probs = self.softmax(logits)

        return probs

In [12]:
import numpy as np

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, None], np.arange(d_model)[None, :], d_model)
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[None, ...]

    return torch.FloatTensor(pos_encoding)


In [14]:
class Transformer(nn.Module):
    def __init__(self, num_layers, input_dim, d_model, dff, vocab_size, dropout=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, input_dim, d_model, dff, dropout)
        self.decoder = Decoder(num_layers, input_dim, d_model, dff, vocab_size, dropout)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encoder(src, src_mask)  # Encoder에서 출력된 메모리
        output = self.decoder(tgt, memory, tgt_mask)  # Decoder에서 예측된 확률 출력
        return output


# 모델 테스트 예시
batch_size = 32
seq_len = 10
input_dim = 512  # 입력 차원
d_model = 512  # 모델 차원
dff = 2048  # Feed-forward 차원
vocab_size = 10000  # 단어 집합 크기
dropout = 0.1  # Dropout 비율

# 예시 입력
src = torch.randint(0, input_dim, (batch_size, seq_len))  # (batch_size, seq_len) 원본 시퀀스
tgt = torch.randint(0, vocab_size, (batch_size, seq_len))  # (batch_size, seq_len) 타겟 시퀀스

# Transformer 모델 생성
transformer = Transformer(num_layers=6, input_dim=input_dim, d_model=d_model, dff=dff, vocab_size=vocab_size, dropout=dropout)

# Forward pass
output = transformer(src, tgt)

print(output.shape)  # (batch_size, seq_len, vocab_size)
print(output)

torch.Size([32, 10, 10000])
tensor([[[1.2259e-04, 1.9482e-04, 1.0767e-04,  ..., 1.7829e-04,
          1.3503e-04, 7.0349e-05],
         [6.9761e-05, 1.8570e-04, 5.8900e-05,  ..., 1.1058e-04,
          4.9348e-05, 5.6825e-05],
         [6.0946e-05, 1.1298e-04, 9.6323e-05,  ..., 7.2866e-05,
          7.6753e-05, 8.9623e-05],
         ...,
         [1.2898e-04, 2.4305e-04, 6.3202e-05,  ..., 1.8230e-04,
          4.3429e-05, 7.0600e-05],
         [2.4772e-05, 2.2185e-04, 1.4624e-04,  ..., 6.2857e-05,
          4.3389e-05, 1.2707e-04],
         [4.5139e-05, 5.7067e-05, 3.6142e-05,  ..., 6.3688e-05,
          9.1383e-05, 3.9680e-05]],

        [[1.3846e-04, 8.4230e-05, 3.9005e-05,  ..., 8.8730e-05,
          1.2755e-04, 3.7343e-05],
         [2.9293e-05, 9.9224e-05, 9.8912e-05,  ..., 8.0741e-05,
          1.0988e-04, 7.2386e-05],
         [1.7453e-05, 1.1815e-04, 2.4781e-04,  ..., 4.4950e-05,
          6.0875e-05, 1.2215e-04],
         ...,
         [1.0958e-04, 1.0414e-04, 7.3630e-05,  ...,