In [111]:
import numpy as np
import pandas as pd
import tensorflow as tf
import torch

In [55]:
## self attention
def attention(embedded_matrix):
    d_model=embedded_matrix.shape[1]
    dk=d_model

    wq=np.random.rand(d_model,dk)
    wk=np.random.rand(d_model,dk)
    wv=np.random.rand(d_model,dk)

    Q=embedded_matrix @ wq
    K=embedded_matrix @ wk
    V=embedded_matrix @ wv

    s_scores=tf.nn.softmax(Q @ K.T/np.sqrt(dk),axis=-1)
    attention=s_scores @ V
    return attention

In [56]:
## multihead self attention
def multihead_attention(embedded_matrix,head): ##embedded matrix is the list of encoded words in the sentence 
    em_dimension=embedded_matrix.shape[1]
    dk=em_dimension//head
    heads_output=[]
    for i in range(head):
        wq=np.random.rand(em_dimension,dk)
        wk=np.random.rand(em_dimension,dk)
        wv=np.random.rand(em_dimension,dk)

        # Q=np.dot(embedded_matrix,wq)
        # K=np.dot(embedded_matrix,wk)
        # V=np.dot(embedded_matrix,wv)

        Q,K,V=embedded_matrix @ wq,embedded_matrix @ wk,embedded_matrix @ wv
        
        # S=np.dot(Q,K.T)
        # s_scores=tf.nn.softmax(tf.constant(S/np.sqrt(dk)),axis=-1)
        s_scores=tf.nn.softmax(Q @ K.T/np.sqrt(dk),axis=-1)
        head_output=np.dot(s_scores,V)
        heads_output.append(head_output)

    multihead_output=np.concatenate(heads_output,axis=-1)
    return multihead_output

In [155]:
## optimised multihead self attention using tensorflow

def tf_optimised_multihead_attention(embedded_matrix,head):
    seq_len,d_model=embedded_matrix.shape
    dk=d_model//head

    wq=tf.random.normal((d_model,d_model))
    wk=tf.random.normal((d_model,d_model))
    wv=tf.random.normal((d_model,d_model))

    Q=tf.matmul(embedded_matrix,wq)
    K=tf.matmul(embedded_matrix,wk)
    V=tf.matmul(embedded_matrix,wv)

    Q=tf.reshape(Q,(seq_len,head,dk))
    K=tf.reshape(K,(seq_len,head,dk))
    V=tf.reshape(V,(seq_len,head,dk))

    Q=tf.transpose(Q,(1,0,2))
    K=tf.transpose(K,(1,0,2))
    V=tf.transpose(V,(1,0,2))

    print(Q.shape)

    attention=tf.matmul(tf.nn.softmax(tf.matmul(Q,K,transpose_b=True)/tf.sqrt(float(dk)),axis=-1),V)
    attention=tf.transpose(attention,(1,0,2))
    attention=tf.reshape(attention,(seq_len,d_model))
    return attention


In [176]:
## multihead self attention using pytorch

def torch_optimised_multihead_attention(embedded_matrix,head):
    seq_len,d_model=embedded_matrix.shape
    dk=d_model//head

    torch.manual_seed(0) ## ?
    wq=torch.randn(d_model,d_model)
    wk=torch.randn(d_model,d_model)
    wv=torch.randn(d_model,d_model)

    Q=torch.matmul(embedded_matrix,wq).reshape(seq_len,head,dk).transpose(0,1)
    K=torch.matmul(embedded_matrix,wk).reshape(seq_len,head,dk).transpose(0,1)
    V=torch.matmul(embedded_matrix,wv).reshape(seq_len,head,dk).transpose(0,1)

    # Q=Q.reshape(seq_len,head,dk).transpose(0,1)
    # K=K.reshape(seq_len,head,dk).transpose(0,1)
    # V=V.reshape(seq_len,head,dk).transpose(0,1)

    attention=torch.matmul(torch.nn.functional.softmax(torch.matmul(Q,K.transpose(1,2))/dk**0.5),V)
    attention=attention.transpose(0,1).reshape(seq_len,d_model)
    return attention

In [177]:
embedded_matrix_tf=tf.random.normal((5,512)) # 5 word of 512 dimensions (original transfomers use 512 dimensions embedding)
embedded_matrix_torch=torch.randn((5,512))
# embedded_matrix_torch=torch.tensor(embedded_matrix_tf)
head=8                                    #head cout=8 (original transformers uses 8 headed self attention)
encoded_y_torch=pd.DataFrame(torch_optimised_multihead_attention(embedded_matrix_torch,head))
encoded_y_tf=pd.DataFrame(tf_optimised_multihead_attention(embedded_matrix_tf,head))

encoded_y_torch

(8, 5, 64)


  attention=torch.matmul(torch.nn.functional.softmax(torch.matmul(Q,K.transpose(1,2))/dk**0.5),V)


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
0,2.286448,-7.218716,-65.536964,-0.393804,-22.461563,6.10844,26.791262,2.077224,-34.954956,-21.879604,...,-61.047943,-39.060551,25.494331,14.293823,23.726423,27.52273,-12.154991,-17.126074,-13.940536,-27.326046
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-21.212711,46.038944,26.101774,13.921482,-3.878663,-32.376938,26.782457,-0.17693,-13.745987,-29.037703
2,16.871656,-0.509242,-32.699669,23.849894,14.592422,-24.102676,-5.402025,13.767224,-13.874695,-7.470405,...,-21.290718,28.004072,21.878265,30.222874,50.605042,-51.244186,1.379009,15.707146,-43.34565,-20.717903
3,31.411276,-10.957131,-15.481497,-8.431929,15.86073,-1.52431,32.276199,-14.779412,-8.207357,-25.022463,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,-0.009623,-0.018298,-0.017522,0.003621,-0.003066,0.022919,0.008482,-0.022592,0.006021,-0.009234,...,-14.560917,-3.401998,2.62621,5.854098,-0.343407,-0.398744,2.144411,-2.582226,-3.19256,-0.843531


In [178]:
encoded_y_tf

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,502,503,504,505,506,507,508,509,510,511
0,5.864044,37.820801,33.560844,-18.382177,13.682326,14.15955,3.782744,37.861504,28.479715,30.034122,...,-30.948334,1.257328,14.466679,13.083627,-26.345745,39.289833,21.016628,18.336861,45.738056,25.368822
1,-2.741022,-30.424591,28.03273,-14.077796,26.095919,-37.326965,-16.796021,-1.663698,-17.564228,16.308138,...,-12.720152,25.720636,35.871601,20.399868,20.66868,18.890152,-23.575016,10.963066,-4.893365,13.514553
2,35.684231,41.466179,-12.864634,10.458164,23.111021,0.891163,-9.494844,-15.491872,-23.477005,-23.491817,...,-2.792551,-13.303019,6.409123,77.613075,18.483149,-25.753014,24.391832,6.413098,-22.439007,-13.136946
3,5.864044,37.820801,33.560844,-18.382177,13.682326,14.15955,3.782744,37.861504,28.479715,30.034122,...,-30.948334,1.257328,14.466679,13.083627,-26.345745,39.289833,21.016628,18.336861,45.738056,25.368822
4,-29.682817,-9.594486,-23.035667,-6.67629,17.82419,-16.444164,4.94716,-20.572647,27.474737,-13.265723,...,-31.634958,42.192703,-8.291959,-20.961369,-2.794889,16.876553,21.521496,-3.419733,-20.114613,14.617207
