From: https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1

In [1]:
# Some useful settings for interactive work
%load_ext autoreload
%autoreload 2

In [7]:
import transformers as tf
import torch

In [3]:
train_data = tf.generate_random_data(9000)
val_data = tf.generate_random_data(3000)

train_dataloader = tf.batchify_data(train_data)
val_dataloader = tf.batchify_data(val_data)

562 batches of size 16
187 batches of size 16


In [4]:
model = tf.Transformer(
    num_tokens=4, dim_model=8, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1
)



In [5]:
train_loss_list, validation_loss_list = tf.fit(model, train_dataloader, val_dataloader, 10)

Training and validating model
------------------------- Epoch 1 -------------------------
Training loss: 0.6471
Validation loss: 0.4179

------------------------- Epoch 2 -------------------------
Training loss: 0.4294
Validation loss: 0.3990

------------------------- Epoch 3 -------------------------
Training loss: 0.4062
Validation loss: 0.3798

------------------------- Epoch 4 -------------------------
Training loss: 0.3850
Validation loss: 0.3484

------------------------- Epoch 5 -------------------------
Training loss: 0.3584
Validation loss: 0.3067

------------------------- Epoch 6 -------------------------
Training loss: 0.3308
Validation loss: 0.2690

------------------------- Epoch 7 -------------------------
Training loss: 0.3073
Validation loss: 0.2438

------------------------- Epoch 8 -------------------------
Training loss: 0.2847
Validation loss: 0.2145

------------------------- Epoch 9 -------------------------
Training loss: 0.2694
Validation loss: 0.2057

-------

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Here we test some examples to observe how the model predicts
examples = [
    torch.tensor([[2, 0, 0, 0, 0, 0, 0, 0, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 1, 1, 1, 1, 1, 1, 1, 1, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 1, 0, 1, 0, 1, 0, 1, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 0, 1, 0, 1, 0, 1, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 3]], dtype=torch.long, device=device),
    torch.tensor([[2, 0, 1, 3]], dtype=torch.long, device=device)
]

for idx, example in enumerate(examples):
    result = tf.predict(model, example)
    print(f"Example {idx}")
    print(f"Input: {example.view(-1).tolist()[1:-1]}")
    print(f"Continuation: {result[1:-1]}")
    print()

Example 0
Input: [0, 0, 0, 0, 0, 0, 0, 0]
Continuation: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Example 1
Input: [1, 1, 1, 1, 1, 1, 1, 1]
Continuation: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Example 2
Input: [1, 0, 1, 0, 1, 0, 1, 0]
Continuation: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]

Example 3
Input: [0, 1, 0, 1, 0, 1, 0, 1]
Continuation: [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]

Example 4
Input: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
Continuation: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]

Example 5
Input: [0, 0, 1, 0, 0, 1]
Continuation: [0, 1, 0, 1, 0, 1, 0, 1, 0]

