# Tests

This notebooks contains all the tests used to validate the model.

In [7]:
from T1000 import *

import torch

test_config = GPTConfig()

n_context = 10
x_vector = torch.randint(0, test_config.d_vocab, (n_context,))
x_embeded = torch.rand((n_context, test_config.d_model))

## Test 1: Attention Head

Test that a single attention head is working correctly.

In [8]:
single_attention = AttentionHead(test_config)
y = single_attention(x_embeded)
assert y.shape == x_embeded.shape
assert torch.isnan(y).any() == 0

## Test 2: Multi-Head Attention

Test that the multi-head attention is working correctly.

In [9]:
multi_attention = MultiHeadedAttention(test_config)
y = multi_attention(x_embeded)
assert y.shape == x_embeded.shape
assert torch.isnan(y).any() == 0

## Test 3: MLP

Test that the MLP is working correctly.

In [10]:
mlp = MLP(test_config)
y = mlp(x_embeded)
assert y.shape == x_embeded.shape
assert torch.isnan(y).any() == 0

## Test 4: Transformer Block

Test that the transformer block is working correctly.

In [11]:
transformer = Transformer(test_config)
y = transformer(x_vector)
print(y)
assert y.shape == (n_context, test_config.d_vocab)
assert torch.isnan(y).any() == 0

tensor([[-5.0241e-01,  3.0332e-01, -2.1832e+00,  ..., -1.4108e+00,
         -5.5624e-01,  4.0513e-01],
        [-2.9042e-01,  1.3649e-01, -1.8997e+00,  ..., -1.4271e+00,
         -3.0434e-01,  2.5214e-01],
        [-2.3843e-01,  7.0428e-02, -1.7417e+00,  ..., -1.2985e+00,
         -1.8418e-01,  1.4210e-01],
        ...,
        [-5.5133e-02, -2.4541e-02, -1.3909e+00,  ..., -1.0647e+00,
         -2.9231e-02, -9.7207e-04],
        [ 6.0265e-03, -6.3036e-02, -1.3164e+00,  ..., -1.0311e+00,
          2.3282e-02, -3.7926e-02],
        [ 6.5875e-03, -8.1598e-02, -1.3266e+00,  ..., -1.0392e+00,
          1.1540e-02, -1.3776e-02]], grad_fn=<AddmmBackward0>)


In [12]:
test_config = GPTConfig()
test_x = torch.randint(low=0, high=test_config.d_vocab, size=(10,))
test_T1000 = Transformer(test_config)
test_y = test_T1000.forward(test_x)
print(test_x.shape)
print(test_y.shape)

torch.Size([10])
torch.Size([10, 10000])
