In [1]:
from transformer import *

# Parametros

In [2]:
vocab_size = 10000
d_model = 512
num_layers = 6
num_heads = 8
d_ff = 2048
dropout = 0.1
max_seq_length = 10
batch_size = 2

tokens = torch.randint(0, vocab_size, (batch_size, max_seq_length))
src_mask = torch.ones(batch_size, 1, max_seq_length, max_seq_length)

# Encoder Transformer

In [3]:
transformer_encoder = TranformerEncoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_seq_length)

In [4]:
output = transformer_encoder(tokens, src_mask)

In [5]:
print(output.shape)

torch.Size([2, 10, 512])


In [6]:
classifier = ClassifierHead(d_model, num_classes=5)

In [7]:
cls_respresentattion = output[:, 0, :]
logits = classifier(cls_respresentattion)
print("Output shape:", logits.shape)
print("Logits:", logits)

Output shape: torch.Size([2, 5])
Logits: tensor([[-1.9659, -2.1014, -1.3732, -1.4973, -1.3445],
        [-2.3736, -1.6140, -1.0863, -1.8565, -1.5414]],
       grad_fn=<LogSoftmaxBackward0>)


In [8]:
regressor = RegressionHead(d_model, output_dim=1)
regression_output = regressor(cls_respresentattion)
print("Regression output shape:", regression_output.shape)

Regression output shape: torch.Size([2, 1])


In [9]:
regression_output

tensor([[-0.5168],
        [-0.2422]], grad_fn=<AddmmBackward0>)

# Decoder Transformer

In [10]:
def generate_tgt_mask(seq_len):
    return torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

In [13]:
tgt_seq_len = max_seq_length
tgt = torch.rand(batch_size, tgt_seq_len, d_model)
tgt_mask = generate_tgt_mask(tgt_seq_len)

decoder_layer = DecoderLayer(d_model, num_heads, d_ff, dropout)
output_decoder_layer = decoder_layer(tgt, output, tgt_mask)
print("Decoder output shape:", output_decoder_layer.shape)

Decoder output shape: torch.Size([2, 10, 512])


In [14]:
output_decoder_layer

tensor([[[-1.4046, -1.8119, -0.1540,  ..., -1.4842,  0.5221, -1.7169],
         [-2.3092, -2.1600,  1.4733,  ...,  0.7517,  0.7313, -1.3751],
         [ 0.5899,  0.4951,  0.8977,  ..., -0.8767, -0.1960, -0.3707],
         ...,
         [-0.8261, -0.5276, -0.2341,  ..., -0.3555, -0.9859, -0.6180],
         [-0.2501, -0.1086,  1.3521,  ..., -0.3142, -0.7225, -0.9077],
         [-0.1234, -0.0985,  1.5096,  ...,  0.7190, -1.6096,  1.1190]],

        [[-2.0849, -0.6344,  1.2268,  ...,  0.5142,  0.0996,  1.2028],
         [-2.0038, -2.5707,  2.0834,  ...,  1.1457, -2.0182, -0.3672],
         [-1.4734, -1.6599,  0.5760,  ...,  0.8894, -1.4828,  1.0271],
         ...,
         [ 0.2398, -0.4471,  0.7239,  ..., -0.4634,  0.3087, -1.0807],
         [-1.6472, -1.6188,  0.7646,  ...,  0.9864,  0.4554, -0.2269],
         [ 0.6739, -0.6478,  0.0313,  ..., -1.1234, -1.3882, -1.3978]]],
       grad_fn=<NativeLayerNormBackward0>)