In [3]:
import torch
from torch import nn

In [4]:
class RelativePosition(nn.Module):

    def __init__(self, num_units, max_relative_position):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        nn.init.xavier_uniform_(self.embeddings_table)

    def forward(self, length_q, length_k):
        range_vec_q = torch.arange(length_q)
        range_vec_k = torch.arange(length_k)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = distance_mat_clipped + self.max_relative_position
        final_mat = torch.LongTensor(final_mat).cuda()
        embeddings = self.embeddings_table[final_mat].cuda()

        return embeddings

In [8]:
relpos = RelativePosition(32, 10).cuda()

In [23]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()

        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.max_relative_position = 2

        self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)
        self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)

        self.fc_o = nn.Linear(hid_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value, mask = None):
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        batch_size = query.shape[0]
        len_k = key.shape[1]
        len_q = query.shape[1]
        len_v = value.shape[1]

        query = self.fc_q(query)
        key = self.fc_k(key)
        value = self.fc_v(value)

        r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))

        print(f'attn1.shape: {attn1.shape}')

        r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size*self.n_heads, self.head_dim)
        print(f'r_q2.shape: {r_q2.shape}')
        r_k2 = self.relative_position_k(len_q, len_k)
        print(f'r_k2.shape: {r_k2.shape}')
        print(f'r_k2.transpose(1, 2).shape: {r_k2.transpose(1, 2).shape}')
        attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)
        print(f'attn2.shape: {attn2.shape}')
        attn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k)
        print(f'attn2.shape: {attn2.shape}')
        attn = (attn1 + attn2) / self.scale

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)

        attn = self.dropout(torch.softmax(attn, dim = -1))

        #attn = [batch size, n heads, query len, key len]
        r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        weight1 = torch.matmul(attn, r_v1)
        r_v2 = self.relative_position_v(len_q, len_v)
        weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size*self.n_heads, len_k)
        weight2 = torch.matmul(weight2, r_v2)
        weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)

        x = weight1 + weight2

        #x = [batch size, n heads, query len, head dim]

        x = x.permute(0, 2, 1, 3).contiguous()

        #x = [batch size, query len, n heads, head dim]

        x = x.view(batch_size, -1, self.hid_dim)

        #x = [batch size, query len, hid dim]

        x = self.fc_o(x)

        #x = [batch size, query len, hid dim]

        return x

In [32]:
import torchinfo
b, t, d = 1, 1024, 1024
x = torch.rand(b, t, d).cuda()
torchinfo.summary(MultiHeadAttentionLayer(d, 1, 0., torch.device('cuda')).cuda(), input_data=(x, x, x))

attn1.shape: torch.Size([1, 1, 1024, 1024])
r_q2.shape: torch.Size([1024, 1, 1024])
r_k2.shape: torch.Size([1024, 1024, 1024])
r_k2.transpose(1, 2).shape: torch.Size([1024, 1024, 1024])
attn2.shape: torch.Size([1, 1024, 1024])
attn2.shape: torch.Size([1, 1, 1024, 1024])


Layer (type:depth-idx)                   Output Shape              Param #
MultiHeadAttentionLayer                  [1, 1024, 1024]           --
├─Linear: 1-1                            [1, 1024, 1024]           1,049,600
├─Linear: 1-2                            [1, 1024, 1024]           1,049,600
├─Linear: 1-3                            [1, 1024, 1024]           1,049,600
├─RelativePosition: 1-4                  [1024, 1024, 1024]        5,120
├─Dropout: 1-5                           [1, 1, 1024, 1024]        --
├─RelativePosition: 1-6                  [1024, 1024, 1024]        5,120
├─Linear: 1-7                            [1, 1024, 1024]           1,049,600
Total params: 4,208,640
Trainable params: 4,208,640
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 4.20
Input size (MB): 12.58
Forward/backward pass size (MB): 17213.42
Params size (MB): 16.83
Estimated Total Size (MB): 17242.84

In [24]:
mha = MultiHeadAttentionLayer(64, 1, 0., torch.device('cuda')).cuda()

In [25]:
x = torch.rand(8, 10, 64).cuda()
mha(x, x, x)

attn1.shape: torch.Size([8, 1, 10, 10])
r_q2.shape: torch.Size([10, 8, 64])
r_k2.shape: torch.Size([10, 10, 64])
r_k2.transpose(1, 2).shape: torch.Size([10, 64, 10])
attn2.shape: torch.Size([8, 10, 10])
attn2.shape: torch.Size([8, 1, 10, 10])


tensor([[[ 0.0738, -0.4885,  0.2331,  ...,  0.0227,  0.3316,  0.2870],
         [ 0.0803, -0.4839,  0.2306,  ...,  0.0069,  0.3582,  0.2634],
         [ 0.0820, -0.4757,  0.2313,  ..., -0.0107,  0.3658,  0.2641],
         ...,
         [ 0.0847, -0.4652,  0.2279,  ..., -0.1148,  0.4093,  0.2623],
         [ 0.0886, -0.4624,  0.2303,  ..., -0.1338,  0.4151,  0.2642],
         [ 0.0823, -0.4887,  0.2152,  ..., -0.1397,  0.3939,  0.2835]],

        [[ 0.1091, -0.4367,  0.1713,  ...,  0.0523,  0.3382,  0.2279],
         [ 0.1169, -0.4242,  0.1681,  ...,  0.0371,  0.3630,  0.2041],
         [ 0.1197, -0.4176,  0.1696,  ...,  0.0187,  0.3759,  0.2026],
         ...,
         [ 0.1230, -0.4056,  0.1681,  ..., -0.0825,  0.4193,  0.2018],
         [ 0.1257, -0.4030,  0.1662,  ..., -0.1053,  0.4261,  0.2082],
         [ 0.1219, -0.4294,  0.1560,  ..., -0.1085,  0.4046,  0.2220]],

        [[ 0.1571, -0.3615,  0.2784,  ...,  0.0481,  0.3954,  0.2985],
         [ 0.1663, -0.3549,  0.2732,  ...,  0

In [10]:
relpos(5, 5).shape

torch.Size([5, 5, 32])

In [37]:
class RelativePositionalEncoder(nn.Module):
    def __init__(self, emb_dim, max_position=512):
        super(RelativePositionalEncoder, self).__init__()
        self.max_position = max_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_position * 2 + 1, emb_dim))
        nn.init.xavier_uniform_(self.embeddings_table)

    def forward(self, seq_len_q, seq_len_k):
        range_vec_q = torch.arange(seq_len_q)
        range_vec_k = torch.arange(seq_len_k)
        relative_matrix = range_vec_k[None, :] - range_vec_q[:, None]
        clipped_relative_matrix = torch.clamp(relative_matrix, -self.max_position, self.max_position)
        relative_position_matrix = clipped_relative_matrix + self.max_position
        embeddings = self.embeddings_table[relative_position_matrix]

        return embeddings


class T5RelativePositionalEncoder(nn.Module):
    def __init__(self, num_heads, max_position=512):
        super(T5RelativePositionalEncoder, self).__init__()
        self.max_position = max_position
        self.embeddings_table = nn.Embedding(max_position*max_position, num_heads)

    def forward(self, seq_len_q, seq_len_k):
        range_vec_q = torch.arange(seq_len_q).cuda()
        range_vec_k = torch.arange(seq_len_k).cuda()
        relative_position = range_vec_k[None, :] - range_vec_q[:, None]
        relative_position_clipped = torch.clamp(relative_position, -self.max_position, self.max_position)
        final_mat = relative_position_clipped + self.max_position
        embeddings = self.embeddings_table(final_mat)

        return embeddings

In [34]:
RelativePositionalEncoder(32, 10).cuda()(5, 5).shape

torch.Size([5, 5, 32])

In [38]:
T5RelativePositionalEncoder(8, 10).cuda()(5, 5).shape

torch.Size([5, 5, 8])