* Assuming 3 modalities  
* All the modalities are processed in parallel
* Drop out not added to check if output matches from both implementation
* QKV in a single operation; change dim_head --> dim_head * 3; First split the different modes, then heads and then qkv
* Head projection added
* Mask added for 1st modality (In CLIP text encoding is masked attn while image encoding is not)
* Different modes have different number of heads. 1st mode has 2 heads, 2nd has 4 heads and 3rd has 6 heads. All the modalities are encoded assuming to have max(heads) ~ 6 in this case, then in the last projection the unnecesssary heads are masked
* Different embedding dimensions for different modes like in CLIP

## Accounting for different Num of heads
* Assume the total number of heads for each modality equal to the max of the heads of all modalities
* Calucalte the encoding as usal but in the last stage where we do the head projectioon, mask out the unnecessary heads

## Accounting for differernt Embedding Dim:
* Assume the embedding dim to be the maximum emb dim of all the modalities
* Zero out the additional embeddings where ever it is not necessary (Masking it later will not suffice). This is important since we do not want it to affect the softmax after the dot product. DUmbfuck--Its going to change the dot product itself
* After calculating the scaled dot product attention, in the last projection there is no need to mask out the unnecessary embeddings from the value vector since k,q,v are already zero. Just slice the output accordingly


In [20]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import pdb

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


Custom linear layer to mask connections between different modalities. During the encoding the differernt modalities should be encoded independently

In [22]:
class CustomLinear(nn.Module):

    def __init__(self, in_features, out_features, mask = None, bias=False):
        super().__init__()
        #self.weight = nn.Parameter(torch.Tensor(out_features, in_features) * mask)
        torch.manual_seed(0)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.weight = torch.ones(out_features, in_features).to(device)  #weights set to 1 to check if output matches, random seed not suitable becuase matrix size is different
        # change this to random later
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.bias = None
        self.mask = mask

    def forward(self, input):
        if self.mask is None:
          output = input.matmul(self.weight.t())
        else:
          self.mask = self.mask.to(device)
          output = input.matmul((self.weight * self.mask).t()) # mask added to skip connections between different modalities, see notes
        if self.bias is not None:
            output += self.bias
        return output

Standard Multi head Attention (Dropout not added to remove randomizations to help verify output)

In [23]:
class AttentionLayer(nn.Module):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.ip_dim = ip_dim
        self.embed_dim = self.num_heads * self.dim_head * num_modes
        self.qkv_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim * 3)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, attn_mask=None):
        N, S, D = query.shape
        N, T, D = value.shape
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask
        y = torch.matmul(dot_product, value)
        return y

class MultiHeadAttentionLayer(AttentionLayer):

    def __init__(self, dim_head, num_heads, num_modes, ip_dim, dropout=0.1):
        super().__init__(dim_head,num_heads, num_modes, ip_dim, dropout)
        self.num_heads = num_heads
        self.dim_head = dim_head
        self.num_modes = num_modes
        self.embed_dim = self.num_heads * self.dim_head * self.num_modes
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.head_proj1 = CustomLinear(self.embed_dim, self.embed_dim *4)
        self.head_proj2 = CustomLinear(self.embed_dim*4, ip_dim * num_modes)

    def forward(self, x, attn_mask=None):
        H = self.num_heads
        N, S, D_ = x.shape
        D = self.embed_dim * 3
        qkv_proj = self.qkv_proj(x).view(N, S, 3*H, D // (3*H)).transpose(1,2)
        query, key, value = torch.chunk(qkv_proj, 3, dim = 1)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.dim_head)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask.to(query.device)
        y = torch.matmul(F.softmax(dot_product, dim=-1), value)
        output = y.transpose(1,2).reshape(N, S, D//3)
        output = self.head_proj1(output)
        output = self.head_proj2(output)
        return output

### different number of heads and embedding dim for each modality

In [24]:
multi_attn1 = MultiHeadAttentionLayer(dim_head = 2, num_heads = 2, num_modes = 1, ip_dim = 5).to(device) #dim_head is the embedding dim per head
multi_attn2 = MultiHeadAttentionLayer(dim_head = 4, num_heads = 4, num_modes = 1, ip_dim = 5).to(device)
multi_attn3 = MultiHeadAttentionLayer(dim_head = 6, num_heads = 6, num_modes = 1, ip_dim = 5).to(device)

In [25]:
torch.manual_seed(0)
x1 = torch.rand(2,4,5).to(device) # modality input one of Batch_size - 2, sequence length - 4, dim - 5
x2 = torch.rand(2,4,5).to(device) # modality input two of Batch_size - 2, sequence length - 4, dim - 5
x3 = torch.rand(2,4,5).to(device) # modality input three of Batch_size - 2, sequence length - 4, dim - 5
attn_mask1 = torch.tril(torch.ones(4,4)).to(device)
for _ in range(1):
  x1 = multi_attn1(x1, attn_mask1) # generate the output from all the modalities one after the other # Shape --> [N x S x (Ip_dim * num_modes)]
  x2 = multi_attn2(x2)
  x3 = multi_attn3(x3)
out = torch.cat((x1, x2, x3), -1)

In [26]:
x1.shape

torch.Size([2, 4, 5])

In [27]:
out.shape # Shape --> [N x S x (Ip_dim * num_modes)]

torch.Size([2, 4, 15])

 ### Multi head attention with parallel processing of modalities
 * Shape of mask shoudl be (out_features x in_features) --> Check custom linear class

In [28]:
class FuseAttentionLayer(nn.Module):

    def __init__(self, dim_heads_list, num_heads_list, num_modes, ip_dim, dropout=0.1):
        super().__init__()
        self.num_heads = max(num_heads_list)
        self.dim_heads = max(dim_heads_list)
        self.dim_heads_list = dim_heads_list
        self.ip_dim = ip_dim
        #create first mask for input projection
        a = self.num_heads * self.dim_heads * 3 # Multiply by 3 to account for qkv
        b = ip_dim
        B = b * num_modes
        out =  torch.hstack((torch.ones(a,b), torch.zeros(a,B - b)))
        mask = out
        a1 = torch.tensor(num_heads_list) * torch.tensor(dim_heads_list)
        b1 = a1 * 4
        #create second mask for out projection and then to remove extra heads
        # a1 = a//3 # a includes query, key and value
        # out1 =  torch.hstack((torch.ones(a1,b), torch.zeros(a1,B - b)))
        # mask1 = out1
        for _ in range(num_modes-1):   # to generate the required mask, check notes --> perhaps simpler to hard code the matrix like mask1
          out = torch.roll(out, shifts=b, dims=-1)
          mask = torch.vstack((mask,out))
        mask1_mode1 = torch.hstack((torch.ones(a1[0],b1[0]), torch.zeros(a1[0],b1[1]), torch.zeros(a1[0],b1[2])))
        mask1_mode2 = torch.hstack((torch.zeros(a1[1],b1[0]), torch.ones(a1[1],b1[1]), torch.zeros(a1[1],b1[2])))
        mask1_mode3 = torch.hstack((torch.zeros(a1[2],b1[0]), torch.zeros(a1[2],b1[1]), torch.ones(a1[2],b1[2])))
        mask1 = torch.vstack((mask1_mode1, mask1_mode2, mask1_mode3)).t() # transpose to get the shape (output_dim x input_dim)
        mask2_mode1 = torch.hstack((torch.ones(b1[0],ip_dim), torch.zeros(b1[0], ip_dim), torch.zeros(b1[0],ip_dim)))
        mask2_mode2 = torch.hstack((torch.zeros(b1[1],ip_dim), torch.ones(b1[1], ip_dim), torch.zeros(b1[1],ip_dim)))
        mask2_mode3 = torch.hstack((torch.zeros(b1[2],ip_dim), torch.zeros(b1[2], ip_dim), torch.ones(b1[2],ip_dim)))
        mask2 = torch.vstack((mask2_mode1, mask2_mode2, mask2_mode3)).t() # transpose to get the shape (output_dim x input_dim)
        # pdb.set_trace()
        # out1 = torch.roll(out1, shifts=b, dims=-1)
        # mask1 = torch.vstack((mask1,out1))
        # heads_remove=[] # stores a list of rows to mask to ignore certains heads
        # for i,j in enumerate(num_heads_list):
        #   for k in range(max(num_heads_list)*self.dim_heads - j*self.dim_heads):
        #     heads_remove.append(i*self.num_heads*self.dim_heads + (max(num_heads_list)*self.dim_heads-k-1))
        # # pdb.set_trace()
        # mask1[heads_remove] = 0
        # #print("mask1", mask1)
        # # print("shape of mask =", mask.shape)
        # mask1 = mask1.t()
        self.embed_dim = self.num_heads * self.dim_heads * num_modes
        self.qkv_proj = CustomLinear(self.ip_dim * num_modes, self.embed_dim * 3, mask) # *3 is to account for qkv in the same proj
        self.dropout = nn.Dropout(dropout)
        proj1_input = torch.tensor(num_heads_list) @ torch.tensor(dim_heads_list)
        self.head_proj1 = CustomLinear(proj1_input, proj1_input * 4, mask1)
        self.head_proj2 = CustomLinear(proj1_input * 4, ip_dim * num_modes, mask2)

    def forward(self, query, key, value, attn_mask=None):
        N, S, D = query.shape
        N, T, D = value.shape
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)
        dot_product = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            dot_product += additive_mask
        y = torch.matmul(F.softmax(dot_product, dim=-1), value)
        return y

class FuseMultiHeadAttentionLayer(FuseAttentionLayer):

    def __init__(self, dim_heads_list, num_heads_list, num_modes, ip_dim, dropout=0.1):
        super().__init__(dim_heads_list,num_heads_list, num_modes, ip_dim, dropout)
        self.num_heads = max(num_heads_list)
        self.dim_heads = max(dim_heads_list)
        self.num_modes = num_modes
        self.embed_dim = self.num_heads * self.dim_heads * self.num_modes
        self.dim_heads_list = dim_heads_list
        self.num_heads_list = num_heads_list
         # splice required values before head projection
        heads_skip = torch.arange(3) * self.num_heads # create a tensor to get the start numbers for heads
        heads_keep = torch.cat((torch.arange(heads_skip[0], heads_skip[0]+self.dim_heads_list[0]), torch.arange(heads_skip[1], heads_skip[1]+self.dim_heads_list[1]), torch.arange(heads_skip[2], heads_skip[2]+self.dim_heads_list[2])))
        self.dim_heads_list = torch.tensor(self.dim_heads_list)
        dims_heads_keep = torch.cat((self.dim_heads_list[0].repeat(self.num_heads_list[0]), self.dim_heads_list[1].repeat(self.num_heads_list[1]), self.dim_heads_list[2].repeat(self.num_heads_list[2])))
        # print(heads_keep)
        # print(dims_heads_keep)
        dims_keep = [] # store the indices of the elements to keep based on the heads and emb_dim of each modality
        # pdb.set_trace()
        for i in range(len(heads_keep)):
            values = torch.arange((heads_keep[i] * self.dim_heads),(heads_keep[i] * self.dim_heads + dims_heads_keep[i]))
            dims_keep.append(values)
        self.dims_keep = torch.cat(dims_keep)

    def forward(self, x, attn_mask=None):
        H = self.num_heads
        N, S, D_ = x.shape
        D = self.embed_dim * 3
        M = self.num_modes
        # query shape = Batch x seq_len x (ip_dim*Num_mode)
        qkv_proj = self.qkv_proj(x) #shape - B x seq_len x emb_dim
        qkv_proj = qkv_proj.view(N, S, M, D // M).transpose(-3,-2) #shape - B x N_modes x seq_len x (emb_dim/N_modes)
        qkv_proj = qkv_proj.view(N, M, S, 3 * H, D // (M*H*3)).transpose(-3,-2) #shape - B x N_modes x (3*N_head) x seq_len x (emb_dim/(N_head * N_modes)
        query, key, value = torch.chunk(qkv_proj, 3, dim = 2) #shape - B x N_modes x (1*N_head) x seq_len x (emb_dim/(N_head * N_modes)
        for i,dim_heads in enumerate(self.dim_heads_list): # equating the additional values in embedding dimension to zero
          query[:,i,:,:,dim_heads:] = 0
          key[:,i,:,:,dim_heads:] = 0
          value[:,i,:,:,dim_heads:] = 0
        dot_product = torch.matmul(query, key.transpose(-2, -1)) # [/ math.sqrt(self.dim_head)] --> do this division after creating chunks
        # print(dot_product.transpose(-1,-2).reshape(N, M, S, D//(M)).transpose(-1,-2).reshape(N, S, D))
        # print(dot_product.shape)
        # Split to add mask for one particular mode
        mode1_dp, mode2_dp, mode3_dp  = torch.chunk(dot_product, 3, dim=1)
        mode1_dp = mode1_dp / math.sqrt(self.dim_heads_list[0])
        mode2_dp = mode2_dp / math.sqrt(self.dim_heads_list[1])
        mode3_dp = mode3_dp / math.sqrt(self.dim_heads_list[2])
        if attn_mask is not None:
            additive_mask = (1 - attn_mask) * -1e9
            mode1_dp += additive_mask.to(query.device)
        dot_product = torch.cat((mode1_dp, mode2_dp, mode3_dp), dim = 1)
        y = torch.matmul(F.softmax(dot_product, dim=-1), value) #B x N_modes x N_head x seq_len x (emb_dim/(N_head * N_modes)
        output = y.transpose(-3,-2).reshape(N, M, S, D//(3*M)) # B x N_modes x seq_len x (emb_dim/N_modes)
        output = output.transpose(-3,-2).reshape(N, S, D//3) # B x seq_len x emb_dim
        # print(output.shape)
        #pdb.set_trace()
        # print(dims_keep)
        output_spliced = output[:,:,self.dims_keep] # B x seq_len x (M1H1D1,M1H1D2,..M1H2D1,...M2H1D1...) ~ B x seq_len x dims_keep
        # print("spliced_output shape = ", output_spliced.shape)
        # intermediate dim == 4 x dim_head
        # pdb.set_trace()
        output_proj = self.head_proj1(output_spliced) # B x seq_len x (dims_keep * 4)
        output = self.head_proj2(output_proj) # B x seq_len x (ip_dim*num_modes)
        return output

In [29]:
torch.manual_seed(0)
x1 = torch.rand(2,4,5) # modality input one of Batch_size - 2, sequence length - 4, ip_dim - 5
x2 = torch.rand(2,4,5) # modality input two of Batch_size - 2, sequence length - 4, ip_dim - 5
x3 = torch.rand(2,4,5) # modality input three of Batch_size - 2, sequence length - 4, ip_dim - 5
attn_mask = torch.tril(torch.ones(4,4)).to(device)
num_heads_list = [2,4,6]
dim_heads_list = [2,4,6]
fuse_multi_attn = FuseMultiHeadAttentionLayer(dim_heads_list = [2,4,6], num_heads_list = [2,4,6], num_modes = 3, ip_dim = 5).to(device) #dim_head is the embedding dim per head #num_head is heads per mode
x_cat = torch.cat((x1,x2,x3),-1).to(device)  # concatenate all the different modes together
for _ in range(1):
  x_cat = fuse_multi_attn(x_cat, attn_mask) # generate the output from the cancatenated input
out_fuse = x_cat

In [30]:
out_fuse.shape

torch.Size([2, 4, 15])

In [31]:
torch.allclose(out_fuse, out) # check if the output from both the methods are the same

True

Time Comparasion

In [32]:
import time

In [33]:
torch.manual_seed(0)
start_time1 = time.time()
for _ in range(1000):
  x1 = torch.rand(2,4,5).to(device)
  x2 = torch.rand(2,4,5).to(device)
  x3 = torch.rand(2,4,5).to(device)

  out1 = multi_attn1(x1)
  out2 = multi_attn2(x2)
  out3 = multi_attn3(x3)

  out = torch.cat((out1, out2, out3), -1)
end_time1 = time.time()
time_1 = end_time1 - start_time1

In [34]:
torch.manual_seed(0)
start_time2 = time.time()
for _ in range(1000):
  x1 = torch.rand(2,4,5)
  x2 = torch.rand(2,4,5)
  x3 = torch.rand(2,4,5)

  x_cat = torch.cat((x1,x2,x3),-1).to(device)
  out_fuse = fuse_multi_attn(x_cat)
end_time2 = time.time()
time_2 = end_time2 - start_time2

In [35]:
time_std_multihead = round(time_1,2)
time_my_multihead = round(time_2,2)

In [36]:
time_std_multihead # time for standard implementation

8.95

In [37]:
time_my_multihead # time for modified implementation

8.07

In [38]:
time_std_multihead/time_my_multihead # can be upto 3 times as fast since there are 3 modalities. The time complexity scales linearly with modes in standard implementation but it is constant is the modified implementation

1.1090458488228003