In [1]:
import torch

from lm.attention import CausalAttention, MultiHeadAttention, SelfAttention

In [2]:
torch.manual_seed(789)

batch_size = 2
seq_len = 4
d_in = 6  # input dimension
inputs = torch.randn(batch_size, seq_len, d_in)  # (batch_size, seq_len, d_in)

sa = SelfAttention(d_in=d_in, d_out=2)
print(sa(inputs))
print('SelfAttention output shape:', sa(inputs).shape)

tensor([[[ 0.0605, -0.0033],
         [ 0.0730, -0.1587],
         [ 0.0059, -0.0838],
         [ 0.0672, -0.0157]],

        [[ 0.1103,  0.1985],
         [ 0.0171,  0.1990],
         [-0.0508,  0.1564],
         [-0.0349,  0.2016]]], grad_fn=<UnsafeViewBackward0>)
SelfAttention output shape: torch.Size([2, 4, 2])


In [3]:
batch_size = 2
seq_len = 4
d_in = 6  # input dimension
inputs = torch.randn(batch_size, seq_len, d_in)  # (batch_size, seq_len, d_in)

ca = CausalAttention(d_in=d_in, d_out=2, context_length=4, dropout=0.1)
print(ca(inputs))
print('CausalAttention output shape:', ca(inputs).shape)

tensor([[[-0.3349,  0.2773],
         [-0.0405,  0.3809],
         [-1.1918, -0.0232],
         [ 0.1500, -0.0470]],

        [[-0.0974,  1.1801],
         [-0.5743,  1.8542],
         [-0.3454,  1.1430],
         [-0.0260,  0.2978]]], grad_fn=<UnsafeViewBackward0>)
CausalAttention output shape: torch.Size([2, 4, 2])


In [4]:
batch_size = 2
seq_len = 4
d_in = 6  # input dimension
inputs = torch.randn(batch_size, seq_len, d_in)  # (batch_size, seq_len, d_in)

mha = MultiHeadAttention(
    d_in=d_in, d_out=2, context_length=4, dropout=0.0, num_heads=2
)
print(mha(inputs))
print('MultiHeadAttention output shape:', mha(inputs).shape)

tensor([[[-0.2387, -0.4613],
         [-0.0111, -0.2861],
         [ 0.0988, -0.2150],
         [ 0.2050, -0.2337]],

        [[ 0.1531, -0.4876],
         [ 0.0352, -0.5687],
         [-0.0486, -0.5851],
         [-0.0200, -0.5280]]], grad_fn=<ViewBackward0>)
MultiHeadAttention output shape: torch.Size([2, 4, 2])
