In [2]:
import math
import torch
from torch import nn, Tensor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class SelfAttention(nn.Module):
    def __init__(self, vocabulary, embedding_dim=100, hidden_dim=256, output_dim=2):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(vocabulary.vectors)
        self.embedding.weight.requires_grad = False
        self.hidden = hidden_dim

        self.input_dim = embedding_dim
        self.query = nn.Linear(embedding_dim, hidden_dim)
        self.key = nn.Linear(embedding_dim, hidden_dim)
        self.value = nn.Linear(embedding_dim, hidden_dim)
        self.softmax = nn.Softmax(dim=2)
        self.position = PositionalEncoding(embedding_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, X):
        x = self.embedding(X)  # input : [batch_size, len_seq, embedding_dim]
        x = self.position(x)
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention, values)
        weighted = torch.mean(weighted, dim=1)

        output = self.out(weighted)
        return output, attention

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)