In [1]:
# https://www.youtube.com/watch?v=U0s0f995w14&t
# https://arxiv.org/pdf/1706.03762.pdf
# https://peterbloem.nl/blog/transformers 

import torch
import torch.nn.functional as F
import torch.nn as nn




In [2]:
# We’ll represent the input, a sequence of t vectors of dimension k as a t by k matrix 𝐗.
# Including a minibatch dimension b, gives us an input tensor of size (b,t,k).
# assume we have some tensor x with size (b, t, k)
x=torch.randn((32,3,2))
raw_weights = torch.bmm(x, x.transpose(1, 2))

# print(f"Raw weights shape {raw_weights.shape}\n")
# print(x[:1,])
# print(raw_weights[:1,])

weights=torch.softmax(raw_weights,dim=2)

y = torch.bmm(weights, x)

In [8]:
class SelfAttention(nn.Module):
    def __init__(self,embed_size, heads):
        super(SelfAttention,self).__init__()
    
        self.embed_size=embed_size    
        self.heads = heads
        self.head_dim=embed_size//heads 
        # These compute the queries, keys and values for all
        # heads
        assert(self.head_dim * heads==embed_size), "Embed size needs to be divided by heads \n " 
        self.keys  =  nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries =  nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values  =  nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out       =  nn.Linear(self.head_dim*heads, embed_size)

	   


    def forward(self, values,keys,query,mask):
        
        N= query.shape[0]

        value_len,key_len,query_len=values.shape[1],keys[1],query.shape[1]

        #split embedding into self.heads pices
        values =values.reshape(N, value_len, self.heads, self.head_dim)
        keys= keys.reshape(N, key_len, self.heads, self.head_dim)
        query= query.reshape(N, query_len, self .heads, self.head_dim)

        #to explore in future instead of torch.bmm 
        #we multiply matrix querry, keys to use in attention(q,k,v) (1) https://arxiv.org/pdf/1706.03762.pdf 
        # queries shape :(N,query_len,heads,heads_dim)
        # keys shape : (N,key_len,heads, heads_dim)
         
        energy=torch.einsum("nqdh,nkhd->nhqk",[query,keys])
        # energy shape (N,heads,query_len,key_len)

        if mask is not None:
            energy=energy.masked_fill(mask=0, value=float("-1e20"))
        

        attention=torch.softmax(energy/(self.embed_size**(1/2)), dim=3) #dim=3 to softmax on

        # attention shape (N,heads,query_len , key_len)
        # velues shape (N,value_len,heads,head_dim)
        out=torch.einsum("nhql,nlhd -> nqhd",[attention,values]).reshape(
            N,query_len,self.heads*self.head_dim
        ) 

        # out shape (N,query_len,heads,head_dim) then flatten last two dimensitons

        out = self.fc_out(out)
        return out

In [9]:
class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads,dropout,forward_expansion):
    super(TransformerBlock,self).__init__()

    self.attention = SelfAttention(embed_size, heads=heads)

    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)

    self.ff = nn.Sequential(
      nn.Linear(embed_size, forward_expansion * embed_size),
      nn.ReLU(), 
      nn.Linear(forward_expansion * embed_size, embed_size))
    self.dropout=nn.Dropout(dropout)

  def forward(self, value,key,query,mask):
    attention = self.attention(value,key,query,mask)

    x = self.dropout(self.norm1(attention + query))

    fedforward = self.ff(x)

    return self.dropout(self.norm2(fedforward + x))

In [None]:
class Encoder(nn.Module):
    def __init__(self, *args, **kwargs) 
        super().__init__(*args, **kwargs)

In [5]:
class Transformer(nn.Module):
    def __init__(self, k, heads, depth, seq_length, num_tokens, num_classes):
        super().__init__()

        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_length, k)

		# The sequence of transformer blocks that does all the
		# heavy lifting
        tblocks = []
        for i in range(depth):
            tblocks.append(TransformerBlock(k=k, heads=heads))
        self.tblocks = nn.Sequential(*tblocks)

		# Maps the final output sequence to class logits
        self.toprobs = nn.Linear(k, num_classes)

    def forward(self, x):
        """
        :param x: A (b, t) tensor of integer values representing
                  words (in some predetermined vocabulary).
        :return: A (b, c) tensor of log-probabilities over the
                 classes (where c is the nr. of classes).
        """
		# generate token embeddings
        tokens = self.token_emb(x)
        b, t, k = tokens.size()

        # generate position embeddings
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)

        x = tokens + positions
        x = self.tblocks(x)

        # Average-pool over the t dimension and project to class
        # probabilities
        x = self.toprobs(x.mean(dim=1))
        return F.log_softmax(x, dim=1)