In [1]:
import torch
import torch.nn as nn
from models.transformer_pytorch import TransformerPyTorch
from models.transformer import Transformer
from hyperparameters import hyperparameters

vocab_size = 32000
batch_size = 2
seq_len = 5

pytorch_model: nn.Module = TransformerPyTorch(
    vocab_size=vocab_size,
    d_model=hyperparameters.transformer.hidden_size,
    num_heads=hyperparameters.transformer.num_heads,
    d_ff=hyperparameters.transformer.encoder_ffn_embed_dim,
    num_encoder_layers=hyperparameters.transformer.num_hidden_layers,
    num_decoder_layers=hyperparameters.transformer.num_hidden_layers,
    dropout=hyperparameters.transformer.dropout,
    max_len=hyperparameters.transformer.max_len,
)
own_model = Transformer(
    src_vocab_size=vocab_size,
    tgt_vocab_size=vocab_size,
    d_model=hyperparameters.transformer.hidden_size,
    num_heads=hyperparameters.transformer.num_heads,
    d_ff=hyperparameters.transformer.encoder_ffn_embed_dim,
    num_encoder_layers=hyperparameters.transformer.num_hidden_layers,
    num_decoder_layers=hyperparameters.transformer.num_hidden_layers,
    dropout=hyperparameters.transformer.dropout,
    max_len=hyperparameters.transformer.max_len,
)
criterion = nn.CrossEntropyLoss(ignore_index=0, reduction="mean")

# Dummy data
src = torch.randint(1, vocab_size, (batch_size, seq_len))
tgt = torch.randint(1, vocab_size, (batch_size, seq_len))

# Ensure no zeros in the middle (just for clarity)
# but you can keep them if you want to test pad
decoder_in = tgt[:, :-1]
labels = tgt[:, 1:]

logits = pytorch_model(src, decoder_in)  # shape [B, T-1, vocab_size]
logits = logits.transpose(1, 2)  # shape [B, vocab_size, T-1]

loss = criterion(logits, labels)  # shape [B, T-1]
print("Dummy test loss =", loss.item())

# Own model
logits = own_model(src, decoder_in)  # shape [B, T-1, vocab_size]
logits = logits.transpose(1, 2)  # shape [B, vocab_size, T-1]

loss = criterion(logits, labels)  # shape [B, T-1]
print("Dummy test loss on own model =", loss.item())

Dummy test loss = 10.64527416229248
Dummy test loss on own model = 10.447006225585938


