## Realización de un *sanity check* del mecanismo de atención

In [4]:
import torch

from attention.attention_factory import AttentionFactory

In [8]:
outputs_encoder = torch.tensor([
    [[1, 1, 1, 1], [1, 1, 1, 1]],
    [[2, 2, 2, 2], [2, 2, 2, 2]],
    [[3, 3, 3, 3], [3, 3, 3, 3]]
], dtype=torch.float32)

hidden_state = torch.tensor([[[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]], dtype=torch.float32)

print("outputs_encoder", outputs_encoder.shape)
print("hidden_state", hidden_state.shape)

outputs_encoder torch.Size([3, 2, 4])
hidden_state torch.Size([1, 3, 4])


In [12]:
attention = AttentionFactory.initialize_attention("Dot-product", 4, 4)

attention_weights = attention(hidden_state, outputs_encoder)
print("attention_weights", attention_weights)

attention_weights tensor([[10., 10.],
        [20., 20.],
        [30., 30.]])


In [13]:
normalized_vectors = torch.softmax(attention_weights, dim=1).unsqueeze(-1)
print("normalized_vectors", normalized_vectors)

normalized_vectors tensor([[[0.5000],
         [0.5000]],

        [[0.5000],
         [0.5000]],

        [[0.5000],
         [0.5000]]])


In [16]:
attention_output = normalized_vectors * outputs_encoder
print("Shape", attention_output.shape)
print("attention_output:", attention_output)

Shape torch.Size([3, 2, 4])
attention_output: tensor([[[0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000]],

        [[1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000]],

        [[1.5000, 1.5000, 1.5000, 1.5000],
         [1.5000, 1.5000, 1.5000, 1.5000]]])


In [17]:
summed_vectors = torch.sum(attention_output, dim=1, keepdim=True)

print("summed_vectors", summed_vectors)

summed_vectors tensor([[[1., 1., 1., 1.]],

        [[2., 2., 2., 2.]],

        [[3., 3., 3., 3.]]])


In [21]:
hidden_attention = hidden_state.transpose(0, 1)

print("Shape hidden", hidden_state.shape)
print("hidden_state", hidden_state)
print("---------------------------")
print("Shape hidden_transpose", hidden_attention.shape)
print("hidden_transpose", hidden_attention)

Shape hidden torch.Size([1, 3, 4])
hidden_state tensor([[[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]]])
---------------------------
Shape hidden_transpose torch.Size([3, 1, 4])
hidden_transpose tensor([[[1., 2., 3., 4.]],

        [[1., 2., 3., 4.]],

        [[1., 2., 3., 4.]]])


In [22]:
output_attention = torch.cat((summed_vectors, hidden_attention), dim=2)
print("output_attention", output_attention)

output_attention tensor([[[1., 1., 1., 1., 1., 2., 3., 4.]],

        [[2., 2., 2., 2., 1., 2., 3., 4.]],

        [[3., 3., 3., 3., 1., 2., 3., 4.]]])
