In [1]:
import torch
import torch.nn as nn

embed_dim = 128
num_heads = 8
seq_length = 10
batch_size = 16
mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
q = torch.rand(batch_size, seq_length, embed_dim)
k = torch.rand(batch_size, seq_length, embed_dim)
v = torch.rand(batch_size, seq_length, embed_dim)
output, attn_output_weights = mha(q, k, v)
print(output.shape)
print(attn_output_weights.shape)

torch.Size([16, 10, 128])
torch.Size([16, 10, 10])


In [2]:
from panther.nn import Performers


def create_projection_matrix(m, d, seed=42, scaling=False) -> torch.Tensor:
    torch.manual_seed(seed)  # Set seed for reproducibility

    nb_full_blocks = m // d  # Number of complete dxd blocks
    block_list = []  # List to store blocks
    # does QR on steps which is more effcient and numerically stable and guarantees orthogonality
    for _ in range(nb_full_blocks):
        # Generate a random dxd Gaussian matrix
        unstructured_block = torch.randn(d, d)
        # Perform QR decomposition to obtain an orthonormal matrix
        q, _ = torch.linalg.qr(unstructured_block)
        q = q.T  # Transpose so rows are the orthonormal vectors

        block_list.append(q)  # Store the block

    remaining_rows = m - nb_full_blocks * d  # Compute remaining rows

    if remaining_rows > 0:
        # Handle the last incomplete block
        unstructured_block = torch.randn(d, d)
        q, _ = torch.linalg.qr(unstructured_block)
        q = q.T  # Transpose
        block_list.append(q[:remaining_rows])  # Take only required rows

    # Stack all blocks to form the final projection matrix
    final_matrix = torch.vstack(block_list)

    if scaling:
        # If scaling is enabled, normalize rows to sqrt(d)
        multiplier = torch.full((m,), torch.sqrt(torch.tensor(d)))
    else:
        # Otherwise, scale each row using chi(d) distribution
        multiplier = torch.norm(torch.randn(m, d), dim=1)

    # Apply scaling by multiplying with diagonal matrix
    return torch.matmul(torch.diag(multiplier), final_matrix)


def test_softmax_noncausal_attention_block_output_shape():
    embed_dim = 128
    num_heads = 8
    seq_length = 10
    batch_size = 16
    num_random_features = 350
    q = torch.rand(batch_size, seq_length, embed_dim)
    k = torch.rand(batch_size, seq_length, embed_dim)
    v = torch.rand(batch_size, seq_length, embed_dim)
    mha = Performers(
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_random_features=num_random_features,
    )
    attention_block_output = mha(q, k, v)
    print(attention_block_output.shape)


test_softmax_noncausal_attention_block_output_shape()


torch.Size([16, 10, 128])


In [3]:
def test_backprop():  # test backprop and optimizer
    embed_dim = 128
    num_heads = 8
    seq_length = 10
    batch_size = 16
    num_random_features = 350
    q = torch.rand(batch_size, seq_length, embed_dim)
    k = torch.rand(batch_size, seq_length, embed_dim)
    v = torch.rand(batch_size, seq_length, embed_dim)
    mha = Performers(
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_random_features=num_random_features,
    )
    optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
    for i in range(100):
        attention_block_output = mha(q, k, v)
        loss = torch.sum(attention_block_output)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(loss.item())


test_backprop()

97.03648376464844
-1105.4796142578125
-2349.63623046875
-3687.888916015625
-5174.4677734375
-6857.13916015625
-8774.76953125
-10958.888671875
-13435.8193359375
-16228.6337890625
-19358.806640625
-22847.16796875
-26713.87109375
-30978.18359375
-35658.4921875
-40772.578125
-46337.96484375
-52372.1484375
-58892.609375
-65916.9375
-73462.859375
-81548.609375
-90192.9296875
-99414.609375
-109232.0703125
-119663.390625
-130727.1015625
-142442.296875
-154828.515625
-167904.875
-181689.4375
-196199.0625
-211449.484375
-227455.890625
-244233.90625
-261800.28125
-280171.65625
-299361.90625
-319383.5
-340250.84375
-361981.5
-384591.125
-408089.34375
-432482.46875
-457778.84375
-483989.625
-511124.59375
-539192.125
-568201.5
-598160.375
-629073.625
-660947.75
-693790.5
-727608.875
-762406.3125
-798188.6875
-834964.25
-872740.8125
-911524.25
-951319.5625
-992134.8125
-1033974.1875
-1076840.875
-1120741.875
-1165682.75
-1211664.75
-1258690.875
-1306765.75
-1355889.75
-1406056.375
-1457260.25
-150949