In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

# Positional Encoding (위치 인코딩)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# Multi-Head Self Attention (다중 헤드 자기 어텐션)
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Q, K, V를 위한 선형 변환 레이어
        self.qkv_linear = nn.Linear(d_model, d_model * 3)
        self.out_linear = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, seq_length, d_model = x.shape

        # Q, K, V로 변환 및 다중 헤드 분할
        qkv = self.qkv_linear(x).view(batch_size, seq_length, 3, self.num_heads, self.d_k)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        # 어텐션 점수 계산 및 스케일링
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = self.softmax(scores)

        # 어텐션 가중치를 V에 적용하여 컨텍스트 벡터 생성
        context = torch.matmul(attn, v).permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, d_model)
        return self.out_linear(context)

# Feed Forward Network (포지션별 피드 포워드 네트워크)
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

# Transformer Encoder Layer (트랜스포머 인코더 레이어)
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(TransformerEncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = PositionwiseFeedForward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_output = self.attention(x)
        x = self.norm1(x + attn_output)
        ffn_output = self.ffn(x)
        return self.norm2(x + ffn_output)

# Transformer Encoder (트랜스포머 인코더)
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Transformer Model (트랜스포머 모델)
class Transformer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, d_ff=2048, num_layers=6, vocab_size=10000):
        super(Transformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder = TransformerEncoder(d_model, num_heads, d_ff, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        x = self.encoder(x)
        return self.fc_out(x)

# 모델 생성 및 테스트
model = Transformer()
test_input = torch.randint(0, 10000, (1, 10))

test_output = model(test_input)
print(test_output.shape)  # 예상 출력: (1, 10, 10000) (각 단어에 대한 확률 분포)


torch.Size([1, 10, 10000])
