In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

Self attention layer

torch.bmm(input, mat2, *, deterministic=False, out=None) → Tensor
Performs a batch matrix-matrix product of matrices stored in input and mat2.

input and mat2 must be 3-D tensors each containing the same number of matrices.

If input is a $(b \times n \times m)$ tensor, mat2 is a $(b \times m \times p)$ tensor, out will be a $(b \times n \times p)$ tensor.

$\text{out}_i = \text{input}_i \mathbin{@} \text{mat2}_i
out_i =input_i i @mat2 i$
​	

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, k, heads=8):
    super(SelfAttention, self).__init__()
    self.k, self.heads = k, heads

    self.tokeys = nn.Linear(k, k*heads, bias=False)
    self.toqueries = nn.Linear(k, k*heads, bias=False)
    self.tovalues = nn.Linear(k,k*heads, bias=False)
    self.unifyheads = nn.Linear(k*heads, k)

    def forward(self, x):
      b, t, k = x.size()
      h = self.heads

      queries = self.toqueries(x).view(b,t,h,k)
      keys    = self.tokeys(x).view(b,t,h,k)
      values  = self.tovalues(x).view(b,t,h,k)

      # fold heads into the batch dimension
      queries = queries.transpose(1,2).contiguous.view(b*h, t,k)
      keys = keys.transpose(1,2).contiguous.view(b*h, t, k)
      values = values.transpose(1,2).contiguous.view(b*h, t, k)

      # normalization 
      queries = queries/(k**(1/4))
      keys = key/(k**(1/4))

      # dot product of queries and keys to get W_{ij}
      dot = torch.bmm(queries, keys.transpose(1,2))
      # apply soft max ofer the w_ij 
      dot = F.softmax(dot, dim=2)
      
      # apply self attantion to the values 
      out = torch.bmm(dot, values).view(b,h,t,k)
      # swapback h,t 
      out = out.transpose(1,2).view(b,t,h*k)

      # unify the heads for output 
      unifyheads = self.unifyheads(out)

      return unifyheads 


simple transformer
inputs -> self attention -> normalization -> MLPs -> normalizatin -> outputs

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, k, heads):
    super(TransformerBlock,self).__init__()
    
    self.attention = SelfAttention(k, heads=heads)

    self.norm1 = nn.LinearNorm(k)
    self.norm2 = nn.LinearNorm(k)

    self.ff = nn.Sequential(
        nn.Linear(k, 4*k),
        nn.ReLU(),
        nn.Linear(4*k,k)
    )

  def forward(self, x):
    attention = self.attention(x)
    x = self.norm1(attention + x)
    forward = nn.ff(x)

    return self.norm2(forward + x)

classification transformer

In [None]:
class Trasformer(nn.Module):
  def __init__(self, k, heads, depth, seq_length, num_tokens, num_class):
    super(Transformer, self).__init__()

    self.num_tokens = num_tokens
    self.token_emb = nn.Embedding(num_embeddings=num_tokens, embedding_dim=k)
    self.pos_emb = nn.Embedding(num_embeddings=seq_length, embedding_dim=k)

    tblocks = []
    for i in range(depth):
      tblocks.append(TransformerBlock(k=k, heads=heads))
    self.tblocks = nn.Sequential(*tblocks)

    self.toprobs = nn.Linear(k, num_class)

  def forward(self, x): 
    tokens = self.token_emb(x)
    b, t, k = tokens.size()

    positions = torch.Tensor(t)
    positions = self.pos_emb(positions)[none,:,:].expand(b,t,k)

    x = tokens + positions 
    x = self.tblocks(x)
    x = x.mean(dim=1)
    x = self.toprobs(x)

    return F.log_softmax(x, dim=1)
