In [1]:
import torch
import torchinfo
from model import Transformer
from torch.utils.data import DataLoader

In [2]:
with open("../../datasets/addition/dataset_train.pt", "rb") as f:
    train_BT = torch.load(f, weights_only=False)

with open("../../datasets/addition/dataset_test.pt", "rb") as f:
    test_BT = torch.load(f, weights_only=False)

train_BT

tensor([[1, 0, 6, 1],
        [0, 3, 6, 3],
        [5, 0, 6, 5],
        [4, 4, 6, 2],
        [2, 2, 6, 4],
        [3, 5, 6, 2],
        [2, 4, 6, 0],
        [5, 4, 6, 3],
        [3, 4, 6, 1],
        [1, 5, 6, 0],
        [0, 4, 6, 4],
        [4, 1, 6, 5],
        [1, 4, 6, 5],
        [5, 5, 6, 4],
        [4, 3, 6, 1],
        [5, 1, 6, 0],
        [2, 3, 6, 5],
        [1, 3, 6, 4]])

In [3]:
vocab_size = test_BT[0, 2].item() + 1
vocab_size

7

In [4]:
model = Transformer(
    vocab_size=vocab_size,
    d_model=3,
    n_heads=1,
    layers=1
)
model.compile()

torchinfo.summary(model)
model.train()

Transformer(
  (embedding): Embedding(7, 3)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiHeadSelfAttention(
        (q_proj): Linear(in_features=3, out_features=3, bias=True)
        (k_proj): Linear(in_features=3, out_features=3, bias=True)
        (v_proj): Linear(in_features=3, out_features=3, bias=True)
        (o_proj): Linear(in_features=3, out_features=3, bias=True)
      )
      (ff): PositionwiseFeedForward(
        (fc1): Linear(in_features=3, out_features=12, bias=True)
        (fc2): Linear(in_features=12, out_features=3, bias=True)
      )
      (ln1): LayerNorm((3,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((3,), eps=1e-05, elementwise_affine=True)
    )
  )
  (fc_out): Linear(in_features=3, out_features=7, bias=True)
)

In [5]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.0005
)

In [6]:
total_steps = 100000
validation_mod = 100
batch_size = 10

current_step = 0

while current_step < total_steps:
    for i in range(train_BT.shape[0] // batch_size):
        batch = train_BT[i * batch_size:(i + 1) * batch_size]  # [B, 4]

        pred = model(batch)  # [B, 4, vocab_size]

        batch_flat = batch.view(-1)                     # [B*4]
        pred_flat = pred.view(-1, vocab_size)           # [B*4, V]

        loss = criterion(pred_flat, batch_flat)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        current_step += 1

        # ---------- VALIDATION ----------
        if current_step % validation_mod == 0:
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0

            with torch.no_grad():
                for j in range(test_BT.shape[0] // batch_size):
                    val_batch = test_BT[j * batch_size:(j + 1) * batch_size]

                    val_pred = model(val_batch)          # [B, 4, V]

                    val_batch_flat = val_batch.view(-1)  # [B*4]
                    val_pred_flat = val_pred.view(-1, vocab_size)

                    val_loss += criterion(val_pred_flat, val_batch_flat).item()

                    preds = val_pred_flat.argmax(dim=1)
                    correct += (preds == val_batch_flat).sum().item()
                    total += val_batch_flat.numel()

            val_loss /= (test_BT.shape[0] // batch_size)
            val_acc = correct / total

            print(
                f"Step: {current_step} | "
                f"Train Loss: {loss.item():.4f} | "
                f"Val Loss: {val_loss:.4f} | "
                f"Val Acc: {val_acc:.4f}"
            )

            model.train()


Step: 100 | Train Loss: 1.8276 | Val Loss: 1.8290 | Val Acc: 0.3750
Step: 200 | Train Loss: 1.5846 | Val Loss: 1.6150 | Val Acc: 0.3750
Step: 300 | Train Loss: 1.4048 | Val Loss: 1.4479 | Val Acc: 0.4750
Step: 400 | Train Loss: 1.2882 | Val Loss: 1.3349 | Val Acc: 0.5500
Step: 500 | Train Loss: 1.1646 | Val Loss: 1.2143 | Val Acc: 0.6500
Step: 600 | Train Loss: 1.0121 | Val Loss: 1.0552 | Val Acc: 0.7500
Step: 700 | Train Loss: 0.8966 | Val Loss: 0.9346 | Val Acc: 0.8000
Step: 800 | Train Loss: 0.8084 | Val Loss: 0.8445 | Val Acc: 0.8000
Step: 900 | Train Loss: 0.7349 | Val Loss: 0.7697 | Val Acc: 0.8000
Step: 1000 | Train Loss: 0.6727 | Val Loss: 0.7067 | Val Acc: 0.8000
Step: 1100 | Train Loss: 0.6193 | Val Loss: 0.6531 | Val Acc: 0.8000
Step: 1200 | Train Loss: 0.5727 | Val Loss: 0.6068 | Val Acc: 0.8500
Step: 1300 | Train Loss: 0.5315 | Val Loss: 0.5657 | Val Acc: 0.8500
Step: 1400 | Train Loss: 0.4940 | Val Loss: 0.5281 | Val Acc: 1.0000
Step: 1500 | Train Loss: 0.4591 | Val Loss:

KeyboardInterrupt: 