In [2]:
# Third party
import importlib
import torch
import matplotlib.pyplot as plt

# User
import importlib
import base.encode
import attention
importlib.reload(base.encode)
importlib.reload(attention)

from base.encode import get_encoder_decoder
from attention import SelfAttentionHead

torch.set_printoptions(sci_mode=False, linewidth=200)


In [3]:
encoder, decoder = get_encoder_decoder(training_data="../data/shakespear.txt", type='character')
encoder("Would you")

[35, 53, 59, 50, 42, 1, 63, 53, 59]

In [4]:
data = "Would you"
input_length = len(data)
g = torch.Generator().manual_seed(2147483647)
input_embedding_dim = 2
out_dimension = 2

encoded_input = encoder(data)
input_embedded = torch.randn(size=(len(data), input_embedding_dim), dtype=torch.float64, generator=g)
encoded_input, input_embedded, input_embedded.shape


([35, 53, 59, 50, 42, 1, 63, 53, 59],
 tensor([[-0.9205, -0.8238],
         [ 0.5364, -1.5131],
         [ 0.1597,  0.6444],
         [-0.6822,  0.4506],
         [ 1.2922, -0.9028],
         [ 0.7594,  1.1730],
         [-0.3377,  1.0273],
         [ 1.6784,  0.9476],
         [ 0.1044, -1.3956]], dtype=torch.float64),
 torch.Size([9, 2]))

In [5]:
# 9x2  2x3 -> 9x3 * 3x9
a = SelfAttentionHead(input_embedding_dim, out_dimension, block_type="decoder", generator=g)
embeddings, scores = a(input_embedded) #9x3

# 9x2 (Input) x 2x3 queries/keys -> 2x9 (Input) x 9x3 (All inputs transposed)  = 2x7 (Inputs weighted)
# 9x7 (Q) x 9x7 (K) -> 9x7 7x9 (Inputs weighted x inputs weighted) -> 9x9 (Final input matrix)
# 9x9 x 9x3 -> 9x3
embeddings, scores

(tensor([[-0.6987,  0.3197],
         [-1.2031, -0.0546],
         [-1.3092, -0.5607],
         [ 0.7820,  0.4099],
         [-2.4286, -1.0927],
         [-2.2006, -1.3453],
         [ 0.7242, -0.0397],
         [-2.3422, -1.3764],
         [-0.5217,  0.1533]], dtype=torch.float64),
 tensor([[    1.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.7512,     0.2488,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.1173,     0.5584,     0.3243,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.0824,     0.0264,     0.4716,     0.4196,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.0902,     0.7152,     0.0030,     0.0040,     0.1876,     0.0000,     0.0000,     0.0000,     0.0000],
         [    0.0033,     0.2114,     0.0115,     0.0018,     0.7422,     0.0298,     0.0000,     0.0000,     0.0000],
  