In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.W_query = nn.Parameter(torch.Tensor([[0,1],[0,1]]))
        self.W_key   = nn.Parameter(torch.Tensor([[1,1],[1,1]]))
        self.W_value = nn.Parameter(torch.Tensor([[0.2,0.5,0.4],[0.3,0.9,0.2]]))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query

        attn_scores = queries @ keys.T

        print("Keys")
        print(keys,"\n\n")

        print("Querys")
        print(queries,"\n\n")

        print("Attn")
        print(attn_scores,"\n\n")


        values = x @ self.W_value

        print("Values")
        print(values,"\n\n")


        context_vec = attn_scores @ values
        return context_vec

In [None]:
s = SelfAttention()

In [None]:
s.W_query

Parameter containing:
tensor([[0., 1.],
        [0., 1.]], requires_grad=True)

In [None]:
with torch.no_grad():
    result = s(torch.Tensor([[0,1],[0,0]]))

print("Result")
print(result)

Keys
tensor([[1., 1.],
        [0., 0.]]) 


Querys
tensor([[0., 1.],
        [0., 0.]]) 


Attn
tensor([[1., 0.],
        [0., 0.]]) 


Values
tensor([[0.3000, 0.9000, 0.2000],
        [0.0000, 0.0000, 0.0000]]) 


Result
tensor([[0.3000, 0.9000, 0.2000],
        [0.0000, 0.0000, 0.0000]])


In [None]:
with torch.no_grad():
    result = s(torch.Tensor([[0,1],[1,0]]))

print("Result")
print(result)

Keys
tensor([[1., 1.],
        [1., 1.]]) 


Querys
tensor([[0., 1.],
        [0., 1.]]) 


Attn
tensor([[1., 1.],
        [1., 1.]]) 


Values
tensor([[0.3000, 0.9000, 0.2000],
        [0.2000, 0.5000, 0.4000]]) 


Result
tensor([[0.5000, 1.4000, 0.6000],
        [0.5000, 1.4000, 0.6000]])


In [None]:
class CrossAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.W_query = nn.Parameter(torch.Tensor([[0,1],[0,1]]))
        self.W_key   = nn.Parameter(torch.Tensor([[1,1],[1,1]]))
        self.W_value = nn.Parameter(torch.Tensor([[0.2,0.5],[0.4,0.3]]))

    def forward(self, spanish, english):
        queries_1 = spanish @ self.W_query
        keys_2 = english @ self.W_key

        print("Keys")
        print(keys_2,"\n\n")

        print("Querys")
        print(queries_1,"\n\n")

        attn_scores = queries_1 @ keys_2.T

        print("Attn")
        print(attn_scores,"\n\n")

        values_2 = english @ self.W_value

        print("Values")
        print(values_2,"\n\n")


        return attn_scores @ values_2

In [None]:
c = CrossAttention()

In [None]:
with torch.no_grad():
    #                          MI    AUTO                   MY     LITTLE   PONY
    result = c(torch.Tensor([[0,1],[1,0.5]]),torch.Tensor([[1,1],[0.2,0.5],[0.5,0]]))

print("Result")
print(result)

Keys
tensor([[2.0000, 2.0000],
        [0.7000, 0.7000],
        [0.5000, 0.5000]]) 


Querys
tensor([[0.0000, 1.0000],
        [0.0000, 1.5000]]) 


Attn
tensor([[2.0000, 0.7000, 0.5000],
        [3.0000, 1.0500, 0.7500]]) 


Values
tensor([[0.6000, 0.8000],
        [0.2400, 0.2500],
        [0.1000, 0.2500]]) 


Result
tensor([[1.4180, 1.9000],
        [2.1270, 2.8500]])
