In [1]:
import torch
from torch import nn
import torch.nn.functional as F

class RNN(nn.Module):
    def __init__(self, vocabulary, embedding_dim=100, hidden_dim=100, output_dim=2):

        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(vocabulary.vectors)
        self.embedding.weight.requires_grad = False
        self.rnn = nn.RNN(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.out = nn.Linear(hidden_dim, output_dim)
        self.hidden = hidden_dim
        self.att = nn.Linear(hidden_dim, hidden_dim)

    def compute_attention(self, x):
        w = F.tanh(self.att(x))
        w = F.softmax(w, dim=1)
        x = x * w
        return x, w
    def forward(self, X):
        input = self.embedding(X)

        output, final_hidden_state = self.rnn(input)
        final_hidden_state = final_hidden_state.mean(dim=0)

        attn_output, attention = self.compute_attention(final_hidden_state)
        return self.out(attn_output), attention