In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F

In [3]:
B = 1  # 배치 사이즈
E = 30 # 워드임베딩 사이즈
T = 5 # 인풋 문장 길이(임의로 지정)
H = 50 # 히든 스테이트 사이즈

In [20]:
inputs = Variable(torch.randn(B,T,E))
hidden = Variable(torch.zeros(1,B,H))

In [21]:
gru = nn.GRU(E,H,batch_first=True)

In [49]:
encoder_hiddens,hidden = gru(inputs,hidden)

In [23]:
encoder_hiddens.size()

torch.Size([1, 5, 50])

In [50]:
decoder_hidden = Variable(torch.randn(1,B,H))

In [51]:
decoder_hidden.size()

torch.Size([1, 1, 50])

## Attention

일단 배치를 생략하고 구해본다

In [52]:
encoder_hiddens = encoder_hiddens.squeeze(0) # B 제거
decoder_hidden = decoder_hidden.squeeze(1)

print(encoder_hiddens.size(),decoder_hidden.size())

torch.Size([5, 50]) torch.Size([1, 50])


### 1. dot product 

$$e_{ti} = s_t^Th_i$$

In [53]:
scores=[]
for i in range(encoder_hiddens.size(0)):
    score = encoder_hiddens[i].dot(decoder_hidden[0])
    scores.append(score)
    
scores = torch.cat(scores)

In [54]:
scores # attention scores

Variable containing:
-3.0896
-2.7235
-1.1030
-0.2602
 0.2127
[torch.FloatTensor of size 5]

In [55]:
scores = encoder_hiddens.matmul(decoder_hidden.transpose(0,1)) # 행렬 연산으로도 가능
scores

Variable containing:
-3.0896
-2.7235
-1.1030
-0.2602
 0.2127
[torch.FloatTensor of size 5x1]

$$\alpha_{ti}^e=\frac{exp(e_{ti})}{\sum_{j=1}^n exp(e_{tj})}$$

In [56]:
attn_dist = F.softmax(scores,0)
print(attn_dist.sum()) # 합이 1

Variable containing:
 1
[torch.FloatTensor of size 1]



$$c_t^e = \sum_i^n \alpha_{ti}^eh_i^e$$

In [57]:
context_vector = torch.matmul(scores.transpose(0,1),encoder_hiddens) # 행렬곱으로 처리
print(context_vector.size())

torch.Size([1, 50])


## TODO : General format의 Attention 짜보기 

1. $e_{ti} = s_t^TW_{attn}^eh_i$ # attention score
2. $\alpha_{ti}^e=\frac{exp(e_{ti})}{\sum_{j=1}^n exp(e_{tj})}$ # attention distribution
3. $c_t^e = \sum_i^n \alpha_{ti}^eh_i^e$ # context vector

In [48]:
encoder_hiddens = Variable(torch.randn(5,50))
decoder_hidden = Variable(torch.randn(1,50))

## Attention module 

In [58]:
from attention import Attention

In [66]:
B = 32
T = 10
H = 50

In [67]:
attn = Attention(50,method='general') # hidden size

In [74]:
encoder_hiddens = Variable(torch.randn(B,T,H))
decoder_hidden = Variable(torch.randn(B,1,H))

In [75]:
context_vector = attn(decoder_hidden,encoder_hiddens)

In [76]:
context_vector.size()

torch.Size([32, 1, 50])