# Attention

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Attention 가중치 계산

In [2]:
def attention(query, key, value):
    # 1. Attention Score 계산 (Query - Key)
    scores = torch.matmul(query, key.transpose(-2, -1))
    print("Attention Score Shape:", scores.shape)

    # 2. Softmax 적용
    attention_weights = F.softmax(scores, dim=-1)
    print("Attention Weight Shape:", attention_weights.shape)

    # 3. Attention value 계산 (=> 최종 context vector 계산)
    context_vector = torch.matmul(attention_weights, value)
    print("Context Vector Shape:", context_vector.shape)

    return context_vector

In [4]:
vocab  = {
    "나는": 0,
    "학원에": 1,
    "간다": 2,
    "<pad>": 3
}

vocab_size = len(vocab)
EMBEDDING_DIM= 4

In [6]:
inputs = ["나는", "학원에", "간다"]
inputs_ids = torch.tensor([[vocab[word] for word in inputs]])
inputs_ids

tensor([[0, 1, 2]])

In [7]:
# 1. 임베딩 적용
embedding_layer = nn.Embedding(vocab_size, EMBEDDING_DIM)
inputs_embedded = embedding_layer(inputs_ids)

# 2. 선형 변환
HIDDEN_DIM = 4
W_query = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_key = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_value = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)

input_query = W_query(inputs_embedded)
input_key = W_key(inputs_embedded)
input_value = W_value(inputs_embedded)

input_query.shape, input_key.shape, input_value.shape

(torch.Size([1, 3, 4]), torch.Size([1, 3, 4]), torch.Size([1, 3, 4]))

In [8]:
context_vector = attention(input_query, input_key, input_value)
context_vector

Attention Score Shape: torch.Size([1, 3, 3])
Attention Weight Shape: torch.Size([1, 3, 3])
Context Vector Shape: torch.Size([1, 3, 4])


tensor([[[-0.4879,  0.4057,  0.2921,  0.3593],
         [-0.5818,  0.3976,  0.3334,  0.5729],
         [-0.3253,  0.3833,  0.2063,  0.0028]]], grad_fn=<UnsafeViewBackward0>)

### Seq2Seq 모델에 Attention 추가

In [9]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        seq_len = encoder_outputs.shape[1]
        hidden_expanded = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden_expanded, encoder_outputs), dim=2)))
        attention_scores = torch.sum(self.v * energy, dim=2)
        attention_weights = F.softmax(attention_scores, dim=1)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context_vector, attention_weights

In [10]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.decoder_input_transform = nn.Linear(input_dim, hidden_dim)

    def forward(self, encoder_input, decoder_input):
        encoder_outputs, hidden = self.encoder(encoder_input)
        context_vector, _ = self.attention(hidden[-1], encoder_outputs)
        decoder_input_ = self.decoder_input_transform(decoder_input)
        output, _ = self.decoder(decoder_input_, hidden)
        combined = torch.cat((output, context_vector.unsqueeze(1)), dim=2)
        return self.fc(combined)

In [11]:
batch_size = 1
seq_len = 5
input_dim = 10
hidden_dim = 20
output_dim = 15

encoder_input = torch.randn(batch_size, seq_len, input_dim)
decoder_input = torch.randn(batch_size, 1, input_dim)

model = Seq2SeqWithAttention(input_dim, hidden_dim, output_dim)
model(encoder_input, decoder_input)

tensor([[[ 0.0658,  0.0541,  0.1407,  0.0151,  0.0493, -0.1026, -0.0659,
          -0.1211,  0.2446,  0.0917, -0.0483, -0.1916,  0.0835,  0.0582,
           0.1615]]], grad_fn=<ViewBackward0>)