* Assuming 3 modalities  
* All the modalities are processed in parallel 
* Drop out and final Head-Proj not added to check if output matches from both implementation

In [36]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import pdb

In [37]:
class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features, mask = None, bias=False):
        super().__init__()
        #self.weight = nn.Parameter(torch.Tensor(out_features, in_features) * mask)
        torch.manual_seed(0)
        self.weight = torch.ones(out_features, in_features)  #weights set to 1 to check output, random seed not suitable becuase matric size is different
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.bias = None
        self.mask = mask

    def forward(self, input):
        if self.mask is None:
          output = input.matmul(self.weight.t())
        else:
          output = input.matmul((self.weight * self.mask).t()) # mask added to skip connections between different modalities, see notes
        if self.bias is not None:
            output += self.bias
        return output

Standard Multi head Attention

In [38]:
class AttentionLayer(nn.Module):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):       
        super().__init__()
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.ip_dim = ip_dim
        self.embed_dim = self.num_heads * self.dim_head * num_modes
        self.query_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim)
        self.key_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim)
        self.value_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attn_mask=None):
        N, S, D = query.shape
        N, T, D = value.shape
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask   
        y = torch.matmul(dot_product, value)
        return y  

class MultiHeadAttentionLayer(AttentionLayer):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):     
        super().__init__(dim_head,num_heads, num_modes, ip_dim, dropout)
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.num_modes = num_modes
        self.embed_dim = self.num_heads * self.dim_head * self.num_modes
        self.head_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, query, key, value, attn_mask=None):
        H = self.num_heads
        N, S, D_ = query.shape
        N, T, D_ = value.shape
        D = self.embed_dim
        query = self.query_proj(query).view(N, S, H, D // H).transpose(1,2)
        key = self.key_proj(key).view(N, T, H, D // H).transpose(1,2)
        value = self.value_proj(value).view(N, T, H, D // H).transpose(1,2)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim / (H* self.num_modes))
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask.to(query.device)      
        y = torch.matmul(F.softmax(dot_product, dim=-1), value)
        output = y.transpose(1,2).reshape(N, S, D)
        return output

In [39]:
multi_attn = MultiHeadAttentionLayer(dim_head = 4, num_heads = 2, num_modes = 1, ip_dim = 4) #dim_head is the embedding dim per head

In [40]:
x1 = torch.rand(2,3,4) # modality input one of Batch_size - 2, sequence length - 3, dim - 4
x2 = torch.rand(2,3,4) # modality input two of Batch_size - 2, sequence length - 3, dim - 4
x3 = torch.rand(2,3,4) # modality input three of Batch_size - 2, sequence length - 3, dim - 4

out1 = multi_attn(x1, x1, x1) # generate the output from all the modalities one after the other
out2 = multi_attn(x2, x2, x2)
out3 = multi_attn(x3, x3, x3)

out = torch.cat((out1, out2, out3), -1)

 Multi head attention with parallel processing of modalities

In [41]:
class FuseAttentionLayer(nn.Module):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):       
        super().__init__()
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.ip_dim = ip_dim
        a = num_heads * dim_head
        b = ip_dim
        B = b * num_modes
        out =  torch.hstack((torch.ones(a,b), torch.zeros(a,B - b)))
        mask = out
        for _ in range(num_modes-1):   # to generate the required mask, check notes
          out = torch.roll(out, shifts=b, dims=-1) 
          mask = torch.vstack((mask,out))
        self.embed_dim = self.num_heads * self.dim_head * num_modes
        self.query_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim, mask)
        self.key_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim, mask)
        self.value_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim, mask)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attn_mask=None):
        N, S, D = query.shape
        N, T, D = value.shape
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask   
        y = torch.matmul(dot_product, value)
        return y  

class FuseMultiHeadAttentionLayer(FuseAttentionLayer):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):     
        super().__init__(dim_head,num_heads, num_modes, ip_dim, dropout)
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.num_modes = num_modes
        self.embed_dim = self.num_heads * self.dim_head * self.num_modes
        self.head_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, query, key, value, attn_mask=None):
        H = self.num_heads
        N, S, D_ = query.shape
        N, T, D_ = value.shape
        D = self.embed_dim
        M = self.num_modes
        # query shape = Batch x seq_len x (ip_dim*Num_mode)
        query = self.query_proj(query).view(N, S, M, D // M).transpose(-3,-2) #shape - B x N_modes x seq_len x (emb_dim/N_modes)
        query = query.view(N, M, S, H, D // (M*H)).transpose(-3,-2) #shape - B x N_modes x N_head x seq_len x (emb_dim/(N_head * N_modes)
        key = self.key_proj(key).view(N, S, M, D // M).transpose(-3,-2).view(N, M, S, H, D // (M*H)).transpose(-3,-2)
        value = self.value_proj(value).view(N, S, M, D // M).transpose(-3,-2).view(N, M, S, H, D // (M*H)).transpose(-3,-2)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim / (H* self.num_modes))
        # print(dot_product.transpose(-1,-2).reshape(N, M, S, D//(M)).transpose(-1,-2).reshape(N, S, D))
        # print(dot_product.shape)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask.to(query.device)      
        y = torch.matmul(F.softmax(dot_product, dim=-1), value) #B x N_modes x N_head x seq_len x (emb_dim/(N_head * N_modes)
        # pdb.set_trace() 
        output = y.transpose(-3,-2).reshape(N, M, S, D//(M)).transpose(-3,-2).reshape(N, S, D)
        return output

In [42]:
fuse_multi_attn = FuseMultiHeadAttentionLayer(dim_head = 4, num_heads = 2, num_modes = 3, ip_dim = 4) #dim_head is the embedding dim per head
x_cat = torch.cat((x1,x2,x3),-1)  # concatenate all the different modes together
out_fuse = fuse_multi_attn(x_cat, x_cat, x_cat) # generate the output from the cancatenated input

In [52]:
out_fuse

tensor([[[2.6792, 2.6792, 2.6792, 2.6792, 2.6792, 2.6792, 2.6792, 2.6792,
          2.3100, 2.3100, 2.3100, 2.3100, 2.3100, 2.3100, 2.3100, 2.3100,
          2.5008, 2.5008, 2.5008, 2.5008, 2.5008, 2.5008, 2.5008, 2.5008],
         [2.6011, 2.6011, 2.6011, 2.6011, 2.6011, 2.6011, 2.6011, 2.6011,
          2.3598, 2.3598, 2.3598, 2.3598, 2.3598, 2.3598, 2.3598, 2.3598,
          2.4687, 2.4687, 2.4687, 2.4687, 2.4687, 2.4687, 2.4687, 2.4687],
         [2.6354, 2.6354, 2.6354, 2.6354, 2.6354, 2.6354, 2.6354, 2.6354,
          2.2682, 2.2682, 2.2682, 2.2682, 2.2682, 2.2682, 2.2682, 2.2682,
          2.4149, 2.4149, 2.4149, 2.4149, 2.4149, 2.4149, 2.4149, 2.4149]],

        [[2.7431, 2.7431, 2.7431, 2.7431, 2.7431, 2.7431, 2.7431, 2.7431,
          2.0634, 2.0634, 2.0634, 2.0634, 2.0634, 2.0634, 2.0634, 2.0634,
          2.4275, 2.4275, 2.4275, 2.4275, 2.4275, 2.4275, 2.4275, 2.4275],
         [2.6861, 2.6861, 2.6861, 2.6861, 2.6861, 2.6861, 2.6861, 2.6861,
          2.0877, 2.0877, 2.0877

In [51]:
torch.allclose(out_fuse, out) # check if the output from both the methods are the same

True

Time Comparasion

In [44]:
import time

In [45]:
torch.manual_seed(0)
start_time1 = time.time()
for _ in range(10000):
  x1 = torch.rand(2,3,4) 
  x2 = torch.rand(2,3,4) 
  x3 = torch.rand(2,3,4) 

  out1 = multi_attn(x1, x1, x1)
  out2 = multi_attn(x2, x2, x2)
  out3 = multi_attn(x3, x3, x3)

  out = torch.cat((out1, out2, out3), -1)
end_time1 = time.time()
time_1 = end_time1 - start_time1

In [46]:
torch.manual_seed(0)
start_time2 = time.time()
for _ in range(10000):
  x1 = torch.rand(2,3,4) 
  x2 = torch.rand(2,3,4) 
  x3 = torch.rand(2,3,4) 

  x_cat = torch.cat((x1,x2,x3),-1)
  out_fuse = fuse_multi_attn(x_cat, x_cat, x_cat)
end_time2 = time.time()
time_2 = end_time2 - start_time2

In [47]:
time_std_multihead = round(time_1,2)
time_my_multihead = round(time_2,2)

In [48]:
time_std_multihead # time for standard implementation

5.6

In [49]:
time_my_multihead # time for modified implementation

2.61

In [50]:
time_std_multihead/time_my_multihead # should be around 3 times faster since there are 3 modalities. The time complexity scales linearly with modes in standard implementation but it is constant is the modified implementation

2.1455938697318007