In [1]:
import torch
from torch import nn
import torch.nn.functional as F

# self attention layer
class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False):
        super().__init__()
        assert k % heads == 0
        self.k, self.heads = k, heads
        self.tokeys = nn.Linear(k, k, bias=False)
        self.toqueries = nn.Linear(k, k, bias=False)
        self.tovalues = nn.Linear(k, k, bias=False)
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):
        # Example input shape: (batch_size=2, sequence_length=10, k=6)
        b, t, k = x.size() # t represents sequence length, k represents sequence dimension
        h = self.heads
        queries = self.toqueries(x)
        keys = self.tokeys(x)
        values = self.tovalues(x)
        s = k // h
        keys = keys.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.view(b, t, h, s).transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries
        keys = keys
        dot = torch.bmm(queries, keys.transpose(1, 2))
        dot = dot / (k ** (1/2))
        dot = F.softmax(dot, dim=2)
        out = torch.bmm(dot, values).view(b, h, t, s)
        out = out.transpose(1, 2).contiguous().view(b, t, s * h)
        return self.unifyheads(out)

#### Transformer block

* block simply consists of self-attention, norm layer and MLP layer. Middle representation learning

In [2]:
class TransformerBlock(nn.Module):
  def __init__(self, k, heads):
    super().__init__()

    self.attention = SelfAttention(k, heads=heads)

    self.norm1 = nn.LayerNorm(k)
    self.norm2 = nn.LayerNorm(k)

    self.ff = nn.Sequential(
      nn.Linear(k, 4 * k),
      nn.ReLU(),
      nn.Linear(4 * k, k))

  def forward(self, x):
    attended = self.attention(x)
    x = self.norm1(attended + x)

    fedforward = self.ff(x)
    return self.norm2(fedforward + x)

#### Position and tokens embedding

* For classification

In [29]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)

		# The sequence of transformer blocks that does all the
		# heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

		# Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing
                  words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the
                 classes (where c is the nr. of classes).
        """
        print(f'input size is {x.size()}')
		# generate token embeddings
        tokens = self.token_emb(x)
        print(f'encoded token size is {tokens.size()}')
        b, t, k = tokens.size() # t represents lenght, k represents dimension

		# generate position embeddings
        positions = torch.arange(t)
        print(positions)
        print(f'The shape of position encoding space is {positions.size()}')
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)
        print(f'the shape of embedding position is {positions.size()}')

        x = tokens + positions
        print(f'shape of x input to transformer block is {x.shape}')
        x = self.tblocks(x)

        # Average-pool over the t dimension and project to class
        # probabilities
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)

In [30]:
batch_size = 11
seq_len = 5
voc_size = 10
embed_token = 9
num_head = 3
depth = 4

x = torch.randint(low=1,high=10,size=(batch_size, seq_len))

# Create an instance of the SelfAttention class with k=6 and heads=2
transformer = Transformer(k=embed_token, heads=num_head,num_tokens=voc_size,seq_length=5,num_classes=2,depth=4)

# Pass the input tensor to the forward method
output = transformer(x)

input size is torch.Size([11, 5])
encoded token size is torch.Size([11, 5, 9])
tensor([0, 1, 2, 3, 4])
The shape of position encoding space is torch.Size([5])
the shape of embedding position is torch.Size([11, 5, 9])
shape of x input to transformer block is torch.Size([11, 5, 9])


#### General idea:

* word sequence input to embedding layer {Representation of both spatial and tokens meaning}. Then those will be input to transformer block