In [32]:
import torch
import math
import torch.nn as nn
import numpy as np

In [3]:
sentence="This is the text which I will be using to learn about Embeddings, Transformer"

In [4]:
dc={v:k for k,v in enumerate(sorted(sentence.replace(",","").lower().split()))}
dc

{'about': 0,
 'be': 1,
 'embeddings': 2,
 'i': 3,
 'is': 4,
 'learn': 5,
 'text': 6,
 'the': 7,
 'this': 8,
 'to': 9,
 'transformer': 10,
 'using': 11,
 'which': 12,
 'will': 13}

In [5]:
sentence_emb=[]
for i in sentence.replace(",","").lower().split():
    sentence_emb.append(dc[i])
sentence_input=torch.tensor(sentence_emb)
sentence_input

tensor([ 8,  4,  7,  6, 12,  3, 13,  1, 11,  9,  5,  0,  2, 10])

In [6]:
vocab_size=50000
torch.manual_seed(42)
embed=torch.nn.Embedding(vocab_size,3)
embedded_sentence=embed(sentence_input).detach()
embedded_sentence

tensor([[ 1.2791,  1.2964,  0.6105],
        [-0.7279, -0.5594, -0.7688],
        [ 1.0783,  0.8008,  1.6806],
        [-0.4974,  0.4396, -0.7581],
        [ 0.3189, -0.4245,  0.3057],
        [ 1.6487, -0.3925, -1.4036],
        [-0.7746, -1.5576,  0.9956],
        [-2.1055,  0.6784, -1.2345],
        [-0.8712, -0.2234,  1.7174],
        [ 1.3347, -0.2316,  0.0418],
        [ 0.7624,  1.6423, -0.1596],
        [ 1.9269,  1.4873,  0.9007],
        [-0.0431, -1.6047, -0.7521],
        [-0.2516,  0.8599, -1.3847]])

In [7]:
torch.manual_seed(42)
d=embedded_sentence.size()[-1]
d_q,d_k,d_v=2,2,4
W_q=torch.nn.Parameter(torch.rand(d_q,d))
w_k=torch.nn.Parameter(torch.rand(d_k,d))
w_v=torch.nn.Parameter(torch.rand(d_v,d))
q=embedded_sentence@W_q.T
k=embedded_sentence@w_k.T
v=embedded_sentence@w_v.T
q

tensor([[ 2.5485,  2.1001],
        [-1.4484, -1.3787],
        [ 2.3276,  2.3570],
        [-0.3269, -0.7611],
        [ 0.0100,  0.3239],
        [ 0.5581,  0.5850],
        [-1.7274, -0.7529],
        [-1.7095, -2.4968],
        [-0.3155,  0.1090],
        [ 0.9816,  1.2151],
        [ 2.1143,  1.2768],
        [ 3.4058,  2.9704],
        [-1.7942, -1.1198],
        [ 0.0347, -0.7377]], grad_fn=<MmBackward0>)

In [8]:
x_2=embedded_sentence[2]
query_2=W_q@x_2
torch.softmax(query_2@k.T/math.sqrt(d_k),dim=0)@v

tensor([2.8590, 2.3916, 1.4863, 1.9616], grad_fn=<SqueezeBackward3>)

In [9]:
class SelfAttention(nn.Module):
    def __init__(self,d_in,d_qk,d_v):
        super(SelfAttention,self).__init__()
        self.d_in=d_in
        self.d_qk=d_qk
        torch.manual_seed(42)
        self.W_query=torch.nn.Parameter(torch.randn(d_in,d_qk))
        self.W_key=torch.nn.Parameter(torch.randn(d_in,d_qk))
        self.W_value=torch.nn.Parameter(torch.randn(d_in,d_v))
    
    def forward(self,input):
        embed_in=input
        q=embed_in@self.W_query
        k=embed_in@self.W_key
        v=embed_in@self.W_value
        out=torch.softmax(q@k.T/math.sqrt(self.d_qk),dim=0)@v
        return out
        

In [10]:
self=SelfAttention(3,2,4)
self(embedded_sentence)

tensor([[-0.0585,  0.1223,  0.0948,  0.0433],
        [ 0.0680, -0.2547, -0.1070,  0.0849],
        [-1.1127,  0.5667,  0.8144, -0.2457],
        [ 0.3429, -0.3360, -0.2759,  0.1794],
        [-0.3058,  0.0626,  0.1939, -0.0554],
        [11.2437, -4.7408, -7.4778,  3.6876],
        [-2.5044,  0.5327,  1.4455, -0.9125],
        [ 0.3871, -0.4229, -0.3273,  0.2025],
        [-6.3978,  3.2540,  3.9641, -3.6095],
        [ 0.0235, -0.0909, -0.0281,  0.0596],
        [ 0.4361, -0.1815, -0.2718,  0.1964],
        [-0.0658,  0.1768,  0.1211,  0.0490],
        [ 0.0238, -0.3170, -0.1049,  0.0770],
        [ 3.9775, -1.9090, -2.7097,  1.3658]], grad_fn=<MmBackward0>)

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_kq,d_v,num_heads):
        super(MultiHeadAttention,self).__init__()
        self.d_in=d_in
        self.heads=nn.ModuleList(SelfAttention(d_in,d_kq,d_v) for _ in range(num_heads))

    def forward(self,input):
        
        return torch.cat([head(input) for head in self.heads],dim=-1)

In [12]:
attn=MultiHeadAttention(3,2,4,6)
attn(embedded_sentence)

tensor([[-0.0585,  0.1223,  0.0948,  0.0433, -0.0585,  0.1223,  0.0948,  0.0433,
         -0.0585,  0.1223,  0.0948,  0.0433, -0.0585,  0.1223,  0.0948,  0.0433,
         -0.0585,  0.1223,  0.0948,  0.0433, -0.0585,  0.1223,  0.0948,  0.0433],
        [ 0.0680, -0.2547, -0.1070,  0.0849,  0.0680, -0.2547, -0.1070,  0.0849,
          0.0680, -0.2547, -0.1070,  0.0849,  0.0680, -0.2547, -0.1070,  0.0849,
          0.0680, -0.2547, -0.1070,  0.0849,  0.0680, -0.2547, -0.1070,  0.0849],
        [-1.1127,  0.5667,  0.8144, -0.2457, -1.1127,  0.5667,  0.8144, -0.2457,
         -1.1127,  0.5667,  0.8144, -0.2457, -1.1127,  0.5667,  0.8144, -0.2457,
         -1.1127,  0.5667,  0.8144, -0.2457, -1.1127,  0.5667,  0.8144, -0.2457],
        [ 0.3429, -0.3360, -0.2759,  0.1794,  0.3429, -0.3360, -0.2759,  0.1794,
          0.3429, -0.3360, -0.2759,  0.1794,  0.3429, -0.3360, -0.2759,  0.1794,
          0.3429, -0.3360, -0.2759,  0.1794,  0.3429, -0.3360, -0.2759,  0.1794],
        [-0.3058,  0.062

In [13]:
class Cross_Attention(nn.Module):
    def __init__(self,d_in,d_kq,d_v):
        super(Cross_Attention,self).__init__()
        self.d_in=d_in
        self.d_kq=d_kq
        torch.manual_seed(42)
        self.w_q=torch.nn.Parameter(torch.randn(d_in,d_kq))
        self.w_k=torch.nn.Parameter(torch.randn(d_in,d_kq))
        self.w_v=torch.nn.Parameter(torch.randn(d_in,d_v))

    def forward(self,x1,x2):
        q=x1@self.w_q
        k=x2@self.w_k
        attn_scores=(q@k.T)
        attn_weights=torch.softmax(attn_scores/math.sqrt(self.d_kq),dim=-1)
        v=x2@self.w_v
        out=attn_weights@v
        return out

In [14]:
cross_attn=Cross_Attention(3,2,4)
cross_attn(embedded_sentence,torch.rand(8,3))

tensor([[ 1.8606, -0.1237, -1.0357,  0.4919],
        [ 1.8258, -0.2491, -1.0506,  0.5200],
        [ 1.8476, -0.0117, -1.0007,  0.4498],
        [ 1.8465, -0.2448, -1.0596,  0.5269],
        [ 1.8279, -0.1494, -1.0266,  0.4863],
        [ 1.8699, -0.3668, -1.1015,  0.5790],
        [ 1.8078, -0.0634, -0.9958,  0.4469],
        [ 1.8413, -0.2617, -1.0612,  0.5308],
        [ 1.8290,  0.0176, -0.9848,  0.4304],
        [ 1.8415, -0.1997, -1.0458,  0.5094],
        [ 1.8727, -0.1958, -1.0601,  0.5204],
        [ 1.8661, -0.1045, -1.0334,  0.4879],
        [ 1.8130, -0.2709, -1.0498,  0.5224],
        [ 1.8705, -0.3159, -1.0892,  0.5609]], grad_fn=<MmBackward0>)

In [15]:
class CausalAttention(nn.Module):
    def __init__(self,d_in,d_kq,d_v):
        super(CausalAttention,self).__init__()

        self.d_kq=d_kq
        self.w_q=torch.nn.Parameter(torch.randn(d_in,d_kq))
        self.w_k=torch.nn.Parameter(torch.randn(d_in,d_kq))
        self.w_v=torch.nn.Parameter(torch.randn(d_in,d_v))

    def forward(self,x):
        q=x@self.w_q
        k=x@self.w_k
        v=x@self.w_v
        attn_score=q@k.T
        block_size=attn_score.size()[0]
        # mask=torch.tril(torch.ones(block_size,block_size))
        # row_sums=mask.sum(dim=1,keepdim=True)
        # masked_norm=mask/row_sums

        # print(masked_norm)
        # attn_weights=torch.softmax(attn_score/math.sqrt(self.d_kq),dim=-1)*masked_norm
        mask=torch.triu(torch.ones(block_size,block_size),diagonal=1)
        masked=attn_score.masked_fill(mask.bool(),-torch.inf)
        print(masked)
        attn_weights=torch.softmax(masked/math.sqrt(self.d_kq),dim=-1)
        print(attn_weights)
        out=attn_weights@v
        return out
                

In [16]:
causal_attn=CausalAttention(3,2,4)
causal_attn(torch.rand(4,3))

tensor([[ 0.0404,    -inf,    -inf,    -inf],
        [ 0.0243,  0.0136,    -inf,    -inf],
        [-0.0068,  0.0353, -0.2131,    -inf],
        [ 0.1780,  0.0615,  0.2613,  0.1035]], grad_fn=<MaskedFillBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5019, 0.4981, 0.0000, 0.0000],
        [0.3455, 0.3559, 0.2986, 0.0000],
        [0.2544, 0.2343, 0.2699, 0.2414]], grad_fn=<SoftmaxBackward0>)


tensor([[ 0.0598,  0.1203, -1.5357, -0.8078],
        [ 0.0743,  0.1289, -1.1333, -0.5509],
        [ 0.0492,  0.2673, -1.0864, -0.6465],
        [ 0.2381,  0.4362, -1.7540, -0.7039]], grad_fn=<MmBackward0>)

## Positional Encoding

d_model=3

PE(pos,2i) = sin(pos/10000**2i/dmodel )
PE(pos,2i+1) = cos(pos/10000**2i/dmodel )

In [131]:
def positional_encoding(tokens,d_model):
    position = torch.arange(len(tokens)).unsqueeze(0).T
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

    PE = torch.empty(len(tokens), d_model)
    PE[:, 0::2] = torch.sin(position / div_term) 
    PE[:,1::2]=torch.cos(position/div_term)[:,:int(d_model/2)]
    return PE
    

In [132]:
input=positional_encoding(sentence_input,3)+embedded_sentence
input.size()

torch.Size([14, 3])

In [130]:
attn(input).size()

torch.Size([14, 24])