In [24]:
import torch
import torch.nn as nn
import math

In [144]:
class CGAEncoder(nn.Module):
    def __init__(self, Nq, D):
        super(CGAEncoder, self).__init__()
        self.D, self.attn_weights = D, None
        self.W_q = nn.Linear(D, D)
        self.W_k = nn.Linear(D, D)
        self.W_v = nn.Linear(D, D)
    
    # query: (B, Nq, D), query_mask: (B, Nq), key: (B, L, L, D), value: (B, L, L, D), kv_mask: (B, L, L)
    def forward(self, query, query_mask, key, value, kv_mask):
        B, L, D =  key.shape[0], key.shape[1], key.shape[3]
        # key, value: (B, L, L, D) -> key, value: (B, L*L, D)
        key, value = torch.reshape(key, (B, L*L, D)), torch.reshape(value, (B, L*L, D))
        query, key, value  = self.W_q(query), self.W_k(key), self.W_v(value) 
        # query: (B, Nq, D), key, value: (B, L*L, D) -> attn_weights: (B, Nq, L*L)
        if query.dim() == 2:
            query = query.unsqueeze(1)
        attn_weights = torch.matmul(query, torch.transpose(key, 2, 1))/math.sqrt(D)
        # kv_mask: (B, L, L) -> kv_mask: (B, 1, L*L)
        mask = kv_mask.reshape(B, L*L).unsqueeze(1)
        if query_mask is not None:
            # query_mask: (B, Nq) -> query_mask: (B, Nq, 1)
            query_mask = query_mask.float().unsqueeze(2)
            # query_mask: (B, Nq, 1), kv_mask: (B, 1, L*L) -> mask: (B, Nq, L*L)
            mask = mask*query_mask
        # attn_weights: (B, Nq, L*L)
        attn_weights = attn_weights * mask
        print(attn_weights.shape, mask.shape)
        # attn_weights: (B, Nq, L*L)
        attn_weights = attn_weights.masked_fill(mask == 0, -1e9)
        # attn_weights: (B, Nq, L*L)
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        # attn_weights: (B, Nq, L*L), value: (B, L*L, D) -> attn_out: (B, Nq, D)
        attn_out = torch.matmul(attn_weights, value)
        self.attn_weights = attn_weights
        return attn_out

In [145]:
B, L, L, D, Nq = 3, 4, 4, 8, 5
query, key = torch.rand(B, Nq, D), torch.rand(B,L,L,D)
query_mask, kv_mask = torch.ones(B, Nq),torch.ones(B, L, L) 
query_mask[:, -1], kv_mask[:, :, -1] = 0, 0

In [146]:
encoder = CGAEncoder(Nq, D)
encoder(query, query_mask, key, key, kv_mask)

torch.Size([3, 5, 16]) torch.Size([3, 5, 16])


tensor([[[ 0.3037,  0.0427, -0.0457,  0.0515,  0.0088,  0.2548, -0.4338,
          -0.0245],
         [ 0.3043,  0.0422, -0.0453,  0.0521,  0.0080,  0.2542, -0.4340,
          -0.0245],
         [ 0.3021,  0.0448, -0.0416,  0.0572,  0.0045,  0.2549, -0.4374,
          -0.0211],
         [ 0.3027,  0.0451, -0.0430,  0.0571,  0.0072,  0.2544, -0.4340,
          -0.0224],
         [ 0.3407,  0.0574, -0.0623,  0.0828, -0.0292,  0.2581, -0.4014,
          -0.0187]],

        [[ 0.4021,  0.1335, -0.1149,  0.2367, -0.0140,  0.2556, -0.3511,
           0.0902],
         [ 0.3987,  0.1343, -0.1117,  0.2381, -0.0118,  0.2536, -0.3526,
           0.0911],
         [ 0.3958,  0.1379, -0.1113,  0.2357, -0.0118,  0.2583, -0.3545,
           0.0912],
         [ 0.3964,  0.1372, -0.1109,  0.2380, -0.0126,  0.2560, -0.3539,
           0.0928],
         [ 0.3657,  0.1286, -0.0823,  0.1888, -0.0025,  0.2950, -0.3846,
           0.0573]],

        [[ 0.2624,  0.0963, -0.0411,  0.1983,  0.0965,  0.1712, -0

In [147]:
print(encoder.attn_weights)

tensor([[[0.0891, 0.0750, 0.0907, 0.0000, 0.0806, 0.0917, 0.0785, 0.0000,
          0.0903, 0.0872, 0.0722, 0.0000, 0.0788, 0.0798, 0.0861, 0.0000],
         [0.0905, 0.0760, 0.0879, 0.0000, 0.0796, 0.0909, 0.0780, 0.0000,
          0.0912, 0.0866, 0.0713, 0.0000, 0.0807, 0.0816, 0.0856, 0.0000],
         [0.0858, 0.0804, 0.0826, 0.0000, 0.0769, 0.0923, 0.0864, 0.0000,
          0.0842, 0.0848, 0.0704, 0.0000, 0.0809, 0.0836, 0.0915, 0.0000],
         [0.0856, 0.0779, 0.0863, 0.0000, 0.0810, 0.0911, 0.0819, 0.0000,
          0.0849, 0.0839, 0.0729, 0.0000, 0.0832, 0.0824, 0.0888, 0.0000],
         [0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625,
          0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625, 0.0625]],

        [[0.0783, 0.0919, 0.0909, 0.0000, 0.0912, 0.0825, 0.0747, 0.0000,
          0.0817, 0.0798, 0.0808, 0.0000, 0.0756, 0.0922, 0.0804, 0.0000],
         [0.0798, 0.0889, 0.0909, 0.0000, 0.0874, 0.0817, 0.0786, 0.0000,
          0.0837, 0.0822, 0.07

In [148]:
query_sent = torch.rand(B, D)
out = encoder(query_sent, None, key, key, kv_mask)

torch.Size([3, 1, 16]) torch.Size([3, 1, 16])


In [152]:
print(out.squeeze(1).shape)

torch.Size([3, 8])


In [153]:
a = torch.rand(B, 1, D)
b = torch.rand(B, D, L*L)
c = torch.bmm(a, b)
print(c.shape)

torch.Size([3, 1, 16])
