In [None]:
import torch
from torch.nn.functional import dropout

from architectures import (
    EncoderDecoderTransformer,
    EncoderOnlyTransformer,
    DecoderOnlyTransformer
)

In [None]:
max_seq_length = 2 ** 14
encoder_seq_length = 150
decoder_seq_length = 10
encoder_vocab_size = 480
decoder_vocab_size = 720
batch_size = 8
num_heads = 4
d_model = 128
d_hiddens = [64]
dropout_probability = 0.25
num_encoder_layers = 4
num_decoder_layers = 4

In [None]:
encoder_input = torch.randint(low=0, high=encoder_vocab_size, size=(batch_size, encoder_seq_length))
decoder_input = torch.randint(low=0, high=decoder_vocab_size, size=(batch_size, decoder_seq_length))

In [None]:
encoder_decoder_transformer = EncoderDecoderTransformer(
    encoder_vocab_size=encoder_vocab_size,
    decoder_vocab_size=decoder_vocab_size,
    d_model=d_model,
    max_seq_length=max_seq_length,
    num_heads=num_heads,
    d_hiddens=d_hiddens,
    dropout_probability=dropout_probability,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    d_decoder_output=decoder_vocab_size
)

print(encoder_decoder_transformer(encoder_input, decoder_input).size())

In [None]:
encoder_only_transformer = EncoderOnlyTransformer(
    vocab_size=encoder_vocab_size,
    d_model=d_model,
    max_seq_length=max_seq_length,
    num_heads=num_heads,
    d_hiddens=d_hiddens,
    dropout_probability=dropout_probability,
    num_encoder_layers=num_encoder_layers,
    d_encoder_output=encoder_vocab_size
)

print(encoder_only_transformer(encoder_input).size())

In [None]:
decoder_only_transformer = DecoderOnlyTransformer(
    vocab_size=decoder_vocab_size,
    d_model=d_model,
    max_seq_length=max_seq_length,
    num_heads=num_heads,
    d_hiddens=d_hiddens,
    dropout_probability=dropout_probability,
    num_decoder_layers=num_decoder_layers,
    d_decoder_output=decoder_vocab_size
)

print(decoder_only_transformer(decoder_input).size())