In [1]:
import torch
from tqdm import tqdm

from neural_stack.attention import MultiHeadAttention

In [15]:
SEQ_LENGTH = 10
DIM_MODEL = 64
NUM_HEADS = 4

BATCH_SIZE = 32
NUM_ITERS = 5000

LEARNING_RATE = 1e-2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
class RandomSequenceGenerator(torch.utils.data.IterableDataset):
    def __init__(self, seq_length, dim_model):
        super(RandomSequenceGenerator, self).__init__()
        self.seq_length = seq_length
        self.dim_model = dim_model

    def __iter__(self):
        while True:
            x = torch.randn((self.seq_length, self.dim_model))
            y = torch.flip(x, dims=[0])
            yield x, y

In [11]:
multi_head_attn = MultiHeadAttention(num_heads=NUM_HEADS, dim_model=DIM_MODEL)
multi_head_attn = multi_head_attn.to(device)

dataloader = torch.utils.data.DataLoader(
    RandomSequenceGenerator(seq_length=SEQ_LENGTH, dim_model=DIM_MODEL),
    batch_size=BATCH_SIZE
)

optimizer = torch.optim.Adam(
    params=multi_head_attn.parameters(),
    lr=LEARNING_RATE
)
criterion = torch.nn.MSELoss().to(device)

In [16]:
iter_idx = 0
progressbar = tqdm(dataloader, total=NUM_ITERS)
for x, y in progressbar:
    x = x.to(device).float()
    y = y.to(device).float()
    
    out, attn_scores = multi_head_attn(x, x, x)
    loss = criterion(out, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if iter_idx % 10 == 0:
        progressbar.set_postfix({'loss': f'{loss.item():.4f}'})

    iter_idx += 1
    if iter_idx >= NUM_ITERS:
        break

100%|█████████▉| 4999/5000 [00:17<00:00, 291.20it/s, loss=0.9157]
