In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature=1, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, # B x T1 x V
                         k.transpose(1, 2), # B x T2 x V -> B x V x T2
                         ) # B x T1 x T2
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(~mask, -np.inf)

        attn = self.softmax(attn)

        if mask is not None:
            attn = attn.masked_fill(~mask, 0.)

        attn = self.dropout(attn)
        output = torch.bmm(attn, v) # B x T1 x T2 @ B x T1 x V

        return output, attn

In [None]:
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx):
    ''' Sinusoid position encoding table '''

    def cal_angle(position, hid_idx):
        return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_posi_angle_vec(pos_i)
                               for pos_i in range(n_position)])

    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    #change is None term on >= 0
    if padding_idx >= 0:
        # zero vector for padding dimension
        sinusoid_table[padding_idx] = 0.

    return torch.FloatTensor(sinusoid_table)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim: int, head_dim: int, n_heads: int, emb_dim: int):
        super().__init__()
        self.head_dim = head_dim
        self.n_heads = n_heads
        # self.w_key = [nn.Linear(input_dim, head_dim, bias=False) for _ in range(n_heads)]
        self.w_key = nn.Linear(input_dim, head_dim * n_heads, bias=False)
        self.w_query = nn.Linear(input_dim, head_dim * n_heads, bias=False)
        self.w_value = nn.Linear(input_dim, head_dim * n_heads, bias=False)

        self.attn = ScaledDotProductAttention(temperature=np.power(head_dim, 0.5))

        self.proj = nn.Linear(n_heads * head_dim, emb_dim, bias=False)

    def forward(self, k, q, v):
        b, t, _ = k.size()
        k = self.w_key(k)
        q = self.w_query(q)
        v = self.w_value(v)

        # B x T x n_heads* head_dim -> B x T x n_heads x head_dim
        k = k.view(b, t, self.n_heads, self.head_dim)
        q = q.view(b, t, self.n_heads, self.head_dim)
        v = v.view(b, t, self.n_heads, self.head_dim)

        # B x T x n_heads x head_dim -> n_heads x B x T x head_dim -> n_heads*B x T x head_dim
        k = k.permute(2, 0, 1, 3).contigious().view(-1, t, self.head_dim)
        q = q.permute(2, 0, 1, 3).contigious().view(-1, t, self.head_dim)
        v = v.permute(2, 0, 1, 3).contigious().view(-1, t, self.head_dim)

        output, att = self.attn(q, k, v)

        output = output.view(self.n_heads, b, t, self.head_dim).permute(
            1, 2, 0, 3).contigious().view(b, t, -1)
        
        output = self.proj(output)

        return output

In [None]:
class FFNet(nn.Module):
    def __init__(self,input_dim: int, hidden_dim: int):
        super().__init__()
        self.linear1= nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)

        self.activ = nn.ReLU()
        self.do = nn.Dropout(0.1)

    def forward(self, x):
        return self.linear2(self.do(self.activ(self.linear1(x))))
        

In [None]:
class FFTBlock(nn.Module):
    def __init(self, input_dim, head_dim, n_heads, hidden_dim):
        super().__init__()
        self.multiheadattn = MultiHeadAttention(input_dim, head_dim, n_heads, input_dim)
        self.ffnet = FFNet(input_dim, hidden_dim)
        self.ln1 = nn.LinearNorm()
        self.ln2 = nn.LinearNorm()

    def forward(self, x):
        residual = x
        x = self.multiheadattn(x, x ,x)
        x = self.ln1(residual + x)

        residual = x
        x = self.ffnet(x)
        x = self.ln2(residual + x)

        return x