In [9]:
import torch
import torch.nn as nn

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=4000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        # Create a long enough P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)


In [13]:
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        queries = self.query(x)
        keys = self.key(x)
        values = self.value(x)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(scores)
        weighted = torch.bmm(attention, values)
        return weighted, attention

In [14]:
class AttentionModel(nn.Module):
    def __init__(self, vocabulary, embedding_dim=100, hidden=256, output_dim=2, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding.from_pretrained(vocabulary.vectors)
        self.embed.weight.requires_grad = False
        self.posEncode = PositionalEncoding(embedding_dim, dropout)
        self.attention = SelfAttention(embedding_dim)
        self.fc1 = nn.Linear(embedding_dim,hidden) # converting n rows to 1
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden,output_dim)

    def forward(self,x):
        #print(x.shape,"1")
        x = self.embed(x)
        #print(x.shape,"2")
        x = self.posEncode(x)
        #print(x.shape,"3")
        x, weights = self.attention(x)
        x = torch.mean(x, dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x, weights