In [4]:
import torch
import torch.nn.functional as F

class Attention:
    def __init__(self,input,wq,wk,wv):
        self.input = input
        self.wq = wq
        self.wk = wk
        self.wv = wv

    def qkv(self):
        self.q = self.input@self.wq
        self.k = self.input@self.wk
        self.v = self.input@self.wv
        return self.q,self.k,self.v
    
    def get_k_transpose(self):
        self.k_transpose =  torch.transpose(self.k,dim0=0,dim1 =1)
        return self.k_transpose

    def get_attention_scores(self):
        self.attention_scores = (self.q@self.k_transpose)/torch.sqrt(torch.tensor(2))
        return self.attention_scores
    
    def get_attention_outputs(self):
        self.attention_outputs = F.softmax(self.attention_scores)@self.v
        return self.attention_outputs
    


In [7]:
input = torch.tensor([[1,1,0,0,0,0],
                      [1,0,0,0,0,0],
                      [0,0,1,1,0,0],
                      [0,0,0,1,0,0],
                      [0,0,0,0,1,1]],dtype=float)
wq1 = torch.tensor([[10,0],
                   [0,10],
                   [0,0],
                   [0,0],
                   [0,0],
                   [0,0]],dtype=float)

wk1 = torch.tensor([[2,0],
                   [0,2],
                   [0,0],
                   [0,0],
                   [0,0],
                   [0,0]],dtype=float)

wv1 = torch.tensor([[1,0],
                   [0,1],
                   [0,0],
                   [0,0],
                   [0,0],
                   [0,0]],dtype=float)

input.shape,wq1.shape,wk1.shape,wv1.shape

(torch.Size([5, 6]),
 torch.Size([6, 2]),
 torch.Size([6, 2]),
 torch.Size([6, 2]))

In [8]:
head1 = Attention(input,wq1,wk1,wv1)
print("Head 1")
q1,k1,v1 = head1.qkv()
print("Q",q1)
print("K",k1)
print("V",v1)
k1_transpose = head1.get_k_transpose()
print("K Transpose",k1_transpose)
attention_scores_1 = head1.get_attention_scores()
print("Attention Scores",attention_scores_1)
attention_output_1 = head1.get_attention_outputs()
print("Attention Output",attention_output_1)

Head 1
Q tensor([[10., 10.],
        [10.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.]], dtype=torch.float64)
K tensor([[2., 2.],
        [2., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], dtype=torch.float64)
V tensor([[1., 1.],
        [1., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], dtype=torch.float64)
K Transpose tensor([[2., 2., 0., 0., 0.],
        [2., 0., 0., 0., 0.]], dtype=torch.float64)
Attention Scores tensor([[28.2843, 14.1421,  0.0000,  0.0000,  0.0000],
        [14.1421, 14.1421,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]], dtype=torch.float64)
Attention Output tensor([[1.0000, 1.0000],
        [1.0000, 0.5000],
        [0.4000, 0.2000],
        [0.4000, 0.2000],
        [0.4000, 0.2000]], dtype=torch.float64)


  self.attention_outputs = F.softmax(self.attention_scores)@self.v


In [9]:
wq2 = torch.tensor([[0,0],
                   [0,0],
                   [10,0],
                   [0,10],
                   [0,0],
                   [0,0]],dtype=float)

wk2 = torch.tensor([[0,0],
                   [0,0],
                   [2,0],
                   [0,2],
                   [0,0],
                   [0,0]],dtype=float)

wv2 = torch.tensor([[0,0],
                   [0,0],
                   [1,0],
                   [0,1],
                   [0,0],
                   [0,0]],dtype=float)

input.shape,wq2.shape,wk2.shape,wv2.shape

(torch.Size([5, 6]),
 torch.Size([6, 2]),
 torch.Size([6, 2]),
 torch.Size([6, 2]))

In [10]:
head2 = Attention(input,wq2,wk2,wv2)
print("Head 2")
q2,k2,v2 = head2.qkv()
print("Q",q2)
print("K",k2)
print("V",v2)
k2_transpose = head2.get_k_transpose()
print("K Transpose",k2_transpose)
attention_scores_2 = head2.get_attention_scores()
print("Attention Scores",attention_scores_2)
attention_output_2 = head2.get_attention_outputs()
print("Attention Output",attention_output_2)

Head 2
Q tensor([[ 0.,  0.],
        [ 0.,  0.],
        [10., 10.],
        [ 0., 10.],
        [ 0.,  0.]], dtype=torch.float64)
K tensor([[0., 0.],
        [0., 0.],
        [2., 2.],
        [0., 2.],
        [0., 0.]], dtype=torch.float64)
V tensor([[0., 0.],
        [0., 0.],
        [1., 1.],
        [0., 1.],
        [0., 0.]], dtype=torch.float64)
K Transpose tensor([[0., 0., 2., 0., 0.],
        [0., 0., 2., 2., 0.]], dtype=torch.float64)
Attention Scores tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, 28.2843, 14.1421,  0.0000],
        [ 0.0000,  0.0000, 14.1421, 14.1421,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]], dtype=torch.float64)
Attention Output tensor([[0.2000, 0.4000],
        [0.2000, 0.4000],
        [1.0000, 1.0000],
        [0.5000, 1.0000],
        [0.2000, 0.4000]], dtype=torch.float64)


  self.attention_outputs = F.softmax(self.attention_scores)@self.v


In [13]:
wq3 = torch.tensor([[0,0],
                   [0,0],
                   [0,0],
                   [0,0],
                   [10,0],
                   [0,10]],dtype=float)

wk3 = torch.tensor([[0,0],
                   [0,0],
                   [0,0],
                   [0,0],
                   [2,0],
                   [0,2]],dtype=float)

wv3 = torch.tensor([[0,0],
                   [0,0],
                   [0,0],
                   [0,0],
                   [1,0],
                   [0,1]],dtype=float)

input.shape,wq3.shape,wk3.shape,wv3.shape

(torch.Size([5, 6]),
 torch.Size([6, 2]),
 torch.Size([6, 2]),
 torch.Size([6, 2]))

In [14]:
head3 = Attention(input,wq3,wk3,wv3)
print("Head 3")
q3,k3,v3 = head3.qkv()
print("Q",q3)
print("K",k3)
print("V",v3)
k3_transpose = head3.get_k_transpose()
print("K Transpose",k3_transpose)
attention_scores_3 = head3.get_attention_scores()
print("Attention Scores",attention_scores_3)
attention_output_3 = head3.get_attention_outputs()
print("Attention Output",attention_output_3)

Head 3
Q tensor([[ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [ 0.,  0.],
        [10., 10.]], dtype=torch.float64)
K tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [2., 2.]], dtype=torch.float64)
V tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [1., 1.]], dtype=torch.float64)
K Transpose tensor([[0., 0., 0., 0., 2.],
        [0., 0., 0., 0., 2.]], dtype=torch.float64)
Attention Scores tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, 28.2843]], dtype=torch.float64)
Attention Output tensor([[0.2000, 0.2000],
        [0.2000, 0.2000],
        [0.2000, 0.2000],
        [0.2000, 0.2000],
        [1.0000, 1.0000]], dtype=torch.float64)


  self.attention_outputs = F.softmax(self.attention_scores)@self.v


In [18]:
concatenated_outputs = torch.concat((attention_output_1,attention_output_2,attention_output_3),dim=1)
concatenated_outputs

tensor([[1.0000, 1.0000, 0.2000, 0.4000, 0.2000, 0.2000],
        [1.0000, 0.5000, 0.2000, 0.4000, 0.2000, 0.2000],
        [0.4000, 0.2000, 1.0000, 1.0000, 0.2000, 0.2000],
        [0.4000, 0.2000, 0.5000, 1.0000, 0.2000, 0.2000],
        [0.4000, 0.2000, 0.2000, 0.4000, 1.0000, 1.0000]], dtype=torch.float64)

In [19]:
w0 = torch.tensor([[1,0,0,0,0,0],
                      [0,1,0,0,0,0],
                      [0,0,1,0,0,0],
                      [0,0,0,1,0,0],
                      [0,0,0,0,1,0],
                      [0,0,0,0,0,1]],dtype=float)
w0.shape

torch.Size([6, 6])

In [21]:
output = concatenated_outputs@w0
output

tensor([[1.0000, 1.0000, 0.2000, 0.4000, 0.2000, 0.2000],
        [1.0000, 0.5000, 0.2000, 0.4000, 0.2000, 0.2000],
        [0.4000, 0.2000, 1.0000, 1.0000, 0.2000, 0.2000],
        [0.4000, 0.2000, 0.5000, 1.0000, 0.2000, 0.2000],
        [0.4000, 0.2000, 0.2000, 0.4000, 1.0000, 1.0000]], dtype=torch.float64)