## Multiheaded attention mechanism

- The below just explaines how to get the most attentive words for an output label

In [151]:
import torch
import torch.nn as nn

In [152]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query):
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  
        keys = self.keys(keys) 
        queries = self.queries(query)  
        
        #scaled dot product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
       
        
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out1 = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out2 = self.fc_out(out1)


        return out1,out2
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query):
        attention1,attention = self.attention(value, key, query)
        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return attention1,out
    
    
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        max_length,
        seq_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = TransformerBlock(embed_size,heads,dropout=dropout,
                                                      forward_expansion=forward_expansion)
            

        self.dropout = nn.Dropout(dropout)
        self.l1  = nn.Linear(in_features=seq_length*embed_size,out_features=6)
        

    def forward(self, x):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        attention1,out = self.layers(out, out, out)
        out = out.reshape((N,-1))
        
        final1 =self.l1(out)
        final1 = torch.softmax(final1,dim=1)
        
        return attention1,final1
    
    


In [153]:
#we have 1 trainign example
x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0]])

model = Encoder( 
        src_vocab_size=10,
        embed_size=4,
        num_layers=1,
        heads=2,
        forward_expansion=1,
        dropout=0,
        max_length=100,
        seq_length=9)



attention1,final1= model(x)



In [154]:
print("result shape")
print(final1.shape)
attention1 = attention1.squeeze()
print("Scaled dot product attention shape")
print(attention1.shape)

result shape
torch.Size([1, 6])
Scaled dot product attention shape
torch.Size([9, 4])


# Below is just an experiment on using the attention to get the most important words for each label
- Here, each class probability is multiplied by the attention

In [155]:
attention1 # Seq length  by embedding (9 by 4)

tensor([[-5.4913e-04,  1.0866e-01,  3.8197e-02, -4.8900e-01],
        [ 1.6612e-02, -2.2768e-01,  3.7639e-02, -5.3477e-01],
        [ 2.5642e-03, -7.2879e-01,  3.7813e-02, -5.4182e-01],
        [ 1.8606e-02, -3.3708e-01,  3.7900e-02, -4.9620e-01],
        [ 6.1051e-03,  2.1642e-03,  3.8481e-02, -5.4020e-01],
        [ 2.1993e-03,  7.4170e-02,  3.8115e-02, -4.9558e-01],
        [-4.7350e-03,  1.7263e-01,  3.7927e-02, -4.8462e-01],
        [ 1.6801e-02, -2.3931e-01,  3.8206e-02, -4.9756e-01],
        [ 1.7887e-02, -3.8195e-01,  3.9050e-02, -4.4195e-01]],
       grad_fn=<SqueezeBackward0>)

In [156]:
classes_list = final1.tolist()[0]
print("classes probabiliteis list")
print(classes_list)

classes probabiliteis list
[0.0628705620765686, 0.13802635669708252, 0.10018505156040192, 0.15317615866661072, 0.15744292736053467, 0.38829895853996277]


In [157]:
for i,j in enumerate(classes_list):
    #if(j>0.5): # commented because its not a real example
    
    # multiply probability of each class with the attention
    class_attention_dot_product = attention1*j
    # get the top 2 largest words
    classes, scores = torch.topk((class_attention_dot_product.sum(dim=1)),k=2)
    print("class ",i)
    print("Important sequence Index ",scores.tolist())
    

class  0
Important sequence Index  [6, 0]
class  1
Important sequence Index  [6, 0]
class  2
Important sequence Index  [6, 0]
class  3
Important sequence Index  [6, 0]
class  4
Important sequence Index  [6, 0]
class  5
Important sequence Index  [6, 0]
