In [1]:
import itertools
import logging
import math
import time
import warnings

# torch
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
from tqdm import tqdm

In [32]:

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):    #Q,K:[B,L,D]
        Q = self.fc_q(Q)        #Q:[B,L,dim_V]
        K, V = self.fc_k(K), self.fc_v(K)   #K,V:[B,L,dim_V]

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)  #Q_:[B*head,L,dim_split]
        K_ = torch.cat(K.split(dim_split, 2), 0)  #K_:[B*head,L,dim_split]
        V_ = torch.cat(V.split(dim_split, 2), 0)  #V_:[B*head,L,dim_split]

        A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)  #A:[B*head,L,L]
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)  #O:[B,L,dim_V]
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)
        
class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)


In [33]:
head = 2

Batch = 3
seq_len = 4
base_dim = 6

dim = base_dim * head

input = torch.randn(Batch, seq_len, dim)
pma = PMA(dim, head, seq_len)
output = pma(input)
output.shape




torch.Size([3, 4, 12])