# Attention-Is-All-You-Need
https://arxiv.org/abs/1706.03762

## Self Attention Block

![Screenshot%20from%202021-08-20%2013-28-04.png](attachment:Screenshot%20from%202021-08-20%2013-28-04.png)



![Screenshot%20from%202021-08-20%2013-28-35.png](attachment:Screenshot%20from%202021-08-20%2013-28-35.png)

In [2]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    # heads-> number of parts embedding is split in
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size or head size not correct"
        
        # linear layers for val, keys, emb
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Liner(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linera(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads,dim)
        # energy_shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # (N, query_len, heads, head_dim)
        # adter einsum flatten last two dim

        out = self.fc_out(out)
        return out

## Transformer block

![Screenshot%20from%202021-08-20%2017-49-06.png](attachment:Screenshot%20from%202021-08-20%2017-49-06.png)