In [1]:
import torch
import torch.nn as nn
from torch.testing import assert_close
from AxialEncoding import AxialEncoding

In [None]:
def test_axial_encoding():
    # Setup
    d1, d2 = 4, 6
    d_out = d1 + d2
    n1, n2 = 3, 5 
    context_length = n1 * n2
    batch_size = 2

    # Initialize the model
    model = AxialEncoding(d_out=d_out, context_length=context_length,
                          n1=n1, n2=n2, d1=d1, d2=d2)

    # Ensure parameters are not accidentally detached
    assert isinstance(model.params1, nn.Parameter)
    assert isinstance(model.params2, nn.Parameter)

    # Dummy input
    x = torch.zeros(batch_size, context_length, d_out)
    out = model(x)

    # Check shape
    assert out.shape == (batch_size, context_length, d_out), \
        f"Unexpected output shape: {out.shape}"

    # Manual reconstruction of expected positional encoding
    expected_encoding = torch.empty(context_length, d_out)

    for i in range(context_length):
        r = i % n1
        s = i // n1
        expected_encoding[i, :d1] = model.params1[r]
        expected_encoding[i, d1:] = model.params2[s]

    # Check that output[0] matches manual construction
    assert_close(out[0], expected_encoding, rtol=1e-5, atol=1e-5)

    # Check gradients flow
    out.sum().backward()
    assert model.params1.grad is not None, "Gradient for params1 missing"
    assert model.params2.grad is not None, "Gradient for params2 missing"

    print("✅ AxialEncoding test passed.")



In [3]:
# Run the test
test_axial_encoding()

✅ AxialEncoding test passed.
