In [2]:
import torch
from torch import nn
import torch.nn.functional as F

## Self Attention (basic)

In [5]:
# attention class
class SelfAttention(nn.Module):
    def __init__(self, k , heads = 8): # k = embedding size of the input
        super().__init__()
        self.k , self.heads = k , heads

        # Now we will create the 3 linear layers (Q<K<V)
        self.query = nn.Linear(k, k * heads , bias=False) # since we are doing multi head trasnsformation we are commbing the output
        self.keys = nn.Linear(k, k * heads, bias=False)
        self.values = nn.Linear(k, k * heads, bias=False)

        # lets combine the output of all heads into one
        self.combine = nn.Linear(k * heads, k)
            
        def forward(self, x):
            batch_size , seq_len , k = x.size()
            h = self.heads

            #The output of each linear module has size (b, t, h*k)
            # which we simply reshape to (b, t, h, k) give each head its own dimension.
            queries = self.query(x).view(batch_size, seq_len, h, k)
            keys = self.keys(x).view(batch_size, seq_len, h, k)
            values = self.values(x).view(batch_size, seq_len, h, k)

            # now we will do dot products 
            # Since its same operation for all heads we will fold it into batch dims
            queries = queries.transpose(1, 2).contiguous().view(batch_size * h, seq_len, k)
            keys = keys.transpose(1, 2).contiguous().view(batch_size * h, seq_len, k)
            values = values.transpose(1, 2).contiguous().view(batch_size * h, seq_len, k)

            # Instead of dividing the dot products by sqrt(e), we scale the keys and values.
            # This should be more memory efficient
            queries = queries / ( k** (1/4))
            keys    = keys / (k** (1/4))

            # dot product of queries and keys
            dot = torch.bmm(queries, keys.transpose(1, 2))

            # softmax the dot product
            dot = F.softmax(dot, dim=2) 

            # now with batch matrix mul with values
            out = torch.bmm(dot, values).view(batch_size, seq_len, h, k)

            # now we need to unify the heads
            # swap h, t back, unify heads
            out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, h * k)
            return self.combine(out)