In [1]:
# %% Import Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
# %% Dataset Preparation
vocab_size = 20
batch_size = 32
seq_len = 10
data_loader = [(torch.randint(0, vocab_size, (batch_size, seq_len)),
                torch.randint(0, vocab_size, (batch_size, seq_len))) for _ in range(100)]


In [13]:
# %% Transformer Model
class Transformer(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, n_head, d_ff, num_layers, dropout):
        super(Transformer, self).__init__()
        self.src_embedding = nn.Embedding(input_dim, d_model)
        self.tgt_embedding = nn.Embedding(output_dim, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        self.encoder_layers = nn.ModuleList([EncoderBlock(d_model, n_head, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderBlock(d_model, n_head, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt, tgt_mask):
        src = self.src_embedding(src)
        src = self.positional_encoding(src)
        for layer in self.encoder_layers:
            src = layer(src, None)

        tgt = self.tgt_embedding(tgt)
        tgt = self.positional_encoding(tgt)
        for layer in self.decoder_layers:
            tgt = layer(tgt, src, tgt_mask)

        return self.fc_out(tgt)
    
# %% Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# %% Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout):
        super(EncoderBlock, self).__init__()
        self.attn = MultiHeadAttention(d_model, n_head)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = self.norm1(x + self.dropout(self.attn(x, x, x, mask)))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x
    
# %% Decoder Block
class DecoderBlock(nn.Module):
    def __init__(self, d_model, n_head, d_ff, dropout):
        super(DecoderBlock, self).__init__()

        self.self_attn = MultiHeadAttention(d_model, n_head)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        self.cross_attn = MultiHeadAttention(d_model, n_head)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask):
        tgt = self.norm1(tgt + self.dropout1(self.self_attn(tgt, tgt, tgt, tgt_mask)))
        tgt = self.norm2(tgt + self.dropout2(self.cross_attn(tgt, memory, memory)))
        tgt = self.norm3(tgt + self.dropout3(self.ffn(tgt)))
        return tgt
    
# %% Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        out, _ = self.attention(q, k, v, mask)
        out = self.concat(out)
        out = self.w_concat(out)
        return out

    def split(self, tensor):
        batch_size, length, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
        return tensor

    def concat(self, tensor):
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor
        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

# %% Scaled Dot-Product Attention
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score)
        v = torch.matmul(score, v)
        return v, score

# %% Feed Forward Network
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))
    


In [14]:
d_model = 512
n_head = 8
d_ff = 2048
dropout = 0.1
num_layers = 6

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(vocab_size, vocab_size, d_model, n_head, d_ff, num_layers, dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [15]:
# %% Train Model Function
def train_model(model, data_loader, optimizer, criterion, device, num_epochs, seq_len):

    model.train()
    total_loss = 0

    for src, tgt in data_loader:
        src, tgt = src.to(device), tgt.to(device)
        mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
        optimizer.zero_grad()
        output = model(src, tgt, mask)
        loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data_loader)

In [16]:
# %% Evaluate Model Function
def evaluate_model(model, data_loader, criterion, device, seq_len):

    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in data_loader:
            src, tgt = src.to(device), tgt.to(device)
            mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
            output = model(src, tgt, mask)
            loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
            total_loss += loss.item()

    return total_loss / len(data_loader)

In [17]:
# %% 실행 코드
num_epochs = 5
seq_len = 10

for epoch in range(num_epochs):
    # 훈련 실행
    train_loss = train_model(model, data_loader, optimizer, criterion, device, num_epochs, seq_len)

    # 평가 실행
    eval_loss = evaluate_model(model, data_loader, criterion, device, seq_len)

    # 결과 출력
    print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}")

Epoch 1, Train Loss: 0.3015, Eval Loss: 0.0003
Epoch 2, Train Loss: 0.0004, Eval Loss: 0.0002
Epoch 3, Train Loss: 0.0002, Eval Loss: 0.0001
Epoch 4, Train Loss: 0.0002, Eval Loss: 0.0001
Epoch 5, Train Loss: 0.0001, Eval Loss: 0.0001
