In [41]:
import torch.nn as nn
import torch.nn.functional as F
import torch, math
class MHAtt(nn.Module):
    def __init__(self, channel):
        super(MHAtt, self).__init__()
        HIDDEN_SIZE = 512
        self.HIDDEN_SIZE = HIDDEN_SIZE
        self.linear_v = Conv1DBlock((channel, channel, channel), 1)
        self.linear_k = Conv1DBlock((channel, channel, channel), 1)
        self.linear_q = Conv1DBlock((channel, channel, channel), 1)
        self.linear_merge = Conv1DBlock((channel, channel, channel), 1)
        self.HEAD = 1
        self.dropout = nn.Dropout(0.5)

    def forward(self, v, k, q, mask):
        n_batches = q.size(0)
        v = self.linear_v(v).view(
            n_batches,
            -1,
            self.HEAD,
            self.HIDDEN_SIZE
        ).transpose(1, 2)  # b, head, seq, hidden_dim

        k = self.linear_k(k).view(
            n_batches,
            -1,
            self.HEAD,
            self.HIDDEN_SIZE
        ).transpose(1, 2)

        q = self.linear_q(q).view(
            n_batches,
            -1,
            self.HEAD,
            self.HIDDEN_SIZE
        ).transpose(1, 2)

        atted = self.att(v, k, q, mask)
        print(atted.shape)
        atted = atted.transpose(1, 2).contiguous().view(
            n_batches,
            self.HIDDEN_SIZE,
            -1,
        )
        print(atted.shape)
        atted = self.linear_merge(atted)

        return atted

    def att(self, value, key, query, mask):
        d_k = query.size(-1) # hidden dim

        scores = torch.matmul(
            query, key.transpose(-2, -1)
        ) / math.sqrt(d_k) # (b, head, seq_q, hidden_dim) x (b, head, hidden_dim, seq_k) -> (b,head,seq_q,seq_k)

        if mask is not None: # mask(b, seq_q)
            scores = scores.masked_fill(mask, -1e9)  # value 中 padding部分会
        att_map = F.softmax(scores, dim=-1)  # query中每个词 在 所有value上的概率分布
        att_map = self.dropout(att_map)
        return torch.matmul(att_map, value) # (b,head,seq_q,seq_k) x (b, head, seq_k, hidden_)


# ---------------------------
# ---- Feed Forward Nets ----
# ---------------------------
    
class FFN(nn.Module):
    def __init__(self, channel):
        super(FFN, self).__init__()
        
        self.mlp = Conv1DBlock((channel, channel, channel), 1)

    def forward(self, x):
        return self.mlp(x)


# ------------------------
# ---- Self Attention ----
# ------------------------
    
class SA(nn.Module):
    def __init__(self, channel):
        super(SA, self).__init__()

        self.mhatt = MHAtt(channel)
        self.ffn = FFN(channel)

        self.dropout1 = nn.Dropout(0.5)
        self.norm1 = nn.BatchNorm1d(channel)

        self.dropout2 = nn.Dropout(0.5)
        self.norm2 = nn.BatchNorm1d(channel)

    def forward(self, x, x_mask):
        x = self.norm1(x + self.dropout1(
            self.mhatt(x, x, x, x_mask)
        )) # (b, seq_q, hidden_dim)

        x = self.norm2(x + self.dropout2(
            self.ffn(x)
        ))

        return x


# -------------------------------
# ---- Self Guided Attention ----
# -------------------------------
class SGA(nn.Module):
    def __init__(self, channel):
        super(SGA, self).__init__()

        self.mhatt1 = MHAtt(channel)
        self.mhatt2 = MHAtt(channel)
        self.ffn = FFN(channel)

        self.dropout1 = nn.Dropout(0.5)
        self.norm1 = nn.BatchNorm1d(channel)

        self.dropout2 = nn.Dropout(0.5)
        self.norm2 = nn.BatchNorm1d(channel)

        self.dropout3 = nn.Dropout(0.5)
        self.norm3 = nn.BatchNorm1d(channel)

    def forward(self, x, y, x_mask, y_mask):
        x = self.norm1(x + self.dropout1(
            self.mhatt1(x, x, x, x_mask)
        ))

        x = self.norm2(x + self.dropout2(
            self.mhatt2(y, y, x, y_mask)
        ))

        x = self.norm3(x + self.dropout3(
            self.ffn(x)
        ))

        return x
class SGA_last(nn.Module):
    def __init__(self, channel):
        super(SGA_last, self).__init__()

        self.mhatt1 = MHAtt(channel)
        self.mhatt2 = MHAtt(channel)
        self.ffn = FFN(channel)

        self.dropout1 = nn.Dropout(0.5)
        self.norm1 = nn.BatchNorm1d(channel)

        self.dropout2 = nn.Dropout(0.5)
        self.norm2 = nn.BatchNorm1d(channel)

        self.dropout3 = nn.Dropout(0.5)
        self.norm3 = nn.BatchNorm1d(channel)

    def forward(self, x, y, x_mask, y_mask):
        x = self.norm1(x + self.dropout1(
            self.mhatt1(x, x, x, x_mask)
        ))

        # x = self.norm2(x + self.dropout2(
        #     self.mhatt2(y, y, x, y_mask)
        # ))

        x = self.norm2(self.dropout2(
            self.mhatt2(x, x, y, x_mask)
        ))

        x = self.norm3(self.dropout3(
            self.ffn(x)
        ))

        return x
class MCA_ED(nn.Module):
    def __init__(self, channel):
        super(MCA_ED, self).__init__()

        self.enc_list = nn.ModuleList([SA(channel) for _ in range(3)])
        self.dec_list = nn.ModuleList([SGA(channel) for _ in range(2)])

        self.dec_last = SGA_last(channel)

    def forward(self, x, y, x_mask, y_mask):
        # Get hidden vector
        for enc in self.enc_list:
            x = enc(x, x_mask)

        for dec in self.dec_list:
            y = dec(y, x, y_mask, x_mask)

        y = self.dec_last(y, x, y_mask, x_mask)
        return x, y

In [42]:
a = MCA_ED(512)
b = torch.randn((8,512,128))
c = torch.randn((8,512,128))
d = (torch.randn((8,128)) < 0.05).unsqueeze(1).unsqueeze(2)
e = (torch.randn((8,128)) < 0.05).unsqueeze(1).unsqueeze(2)
c = a(b,c,d,e)
print(c.shape)

torch.Size([8, 1, 128, 512])
torch.Size([8, 512, 128])


RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2

In [26]:
class Conv1DBNReLU(nn.Module):
    def __init__(self, in_channel, out_channel, ksize):
        super(Conv1DBNReLU, self).__init__()
        self.conv = nn.Conv1d(in_channel, out_channel, ksize, bias=False)
        self.bn = nn.BatchNorm1d(out_channel)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
class Conv1DBlock(nn.Module):
    def __init__(self, channels, ksize):
        super(Conv1DBlock, self).__init__()
        self.conv = nn.ModuleList()
        for i in range(len(channels)-2):
            self.conv.append(Conv1DBNReLU(channels[i], channels[i+1], ksize))
        self.conv.append(nn.Conv1d(channels[-2], channels[-1], ksize))

    def forward(self, x):
        for conv in self.conv:
            x = conv(x)
        return x