In [6]:
import torch
from transformer import AttentionLayer  # Assuming your implementation is in transformer.py

# Define batch size, sequence length, and model dimensions
batch_size = 2
seq_length = 10
hidden_size = 1024  # Hidden size from your transformer model

# Number of attention heads and size per head
num_attention_heads = 4
size_per_head = hidden_size // num_attention_heads

# Create random input tensors (simulating an input to the attention layer)
input_tensor = torch.randn(batch_size, seq_length, hidden_size)

# Create the Attention Layer instance
attention_layer = AttentionLayer(num_attention_heads=num_attention_heads, size_per_head=size_per_head)

# Forward pass through the attention layer
output = attention_layer(input_tensor, input_tensor)  # Query and key are the same for self-attention

# Check output shape
print(f"Attention layer output shape: {output.shape}")
print(output)

Attention layer output shape: torch.Size([2, 10, 1024])
tensor([[[-0.1422, -0.0356, -0.0091,  ..., -0.0275, -0.1920,  0.0517],
         [ 0.0030, -0.0731, -0.2004,  ...,  0.0390, -0.0783, -0.0441],
         [-0.0606, -0.1078, -0.0843,  ..., -0.0563, -0.0026, -0.0169],
         ...,
         [-0.1833,  0.0234, -0.1320,  ...,  0.0100, -0.0881, -0.0189],
         [-0.0416, -0.0453, -0.0318,  ..., -0.0214, -0.0495, -0.0151],
         [-0.0712, -0.0732, -0.0458,  ...,  0.0342, -0.1147,  0.0081]],

        [[ 0.0795, -0.2990,  0.0414,  ..., -0.0633, -0.1813,  0.0455],
         [ 0.0201, -0.1773,  0.0545,  ...,  0.0185, -0.1866,  0.0798],
         [ 0.1127, -0.3299,  0.1252,  ..., -0.0645, -0.2336, -0.0304],
         ...,
         [ 0.0591, -0.2137, -0.0392,  ..., -0.0780, -0.2110, -0.0872],
         [ 0.0594, -0.2677,  0.0189,  ...,  0.0406, -0.2667, -0.1104],
         [ 0.0672, -0.3054,  0.1785,  ..., -0.0820, -0.1941, -0.0598]]],
       grad_fn=<ViewBackward0>)


In [7]:
import torch
from transformer import Transformer

# Define the test parameters
batch_size = 2
seq_length = 10
hidden_size = 1024  # Should match the model's hidden size

# Create a random input tensor
input_tensor = torch.randn(batch_size, seq_length, hidden_size)

# Define other parameters for the transformer layer
num_hidden_layers = 6
num_attention_heads = 4
intermediate_size = 2048

# Instantiate the transformer
transformer = Transformer(
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size
)

# Perform the forward pass
output = transformer(input_tensor)

# Print the shape of the output tensor
print(f"Transformer output shape: {output.shape}")


Transformer output shape: torch.Size([2, 10, 1024])


In [12]:
import torch
import torch.nn as nn
from transformer import AttentionLayer  # Assuming your implementation is in transformer.py

# Define test parameters
batch_size = 2
seq_length = 10
hidden_size = 1024

# Number of attention heads and size per head
num_attention_heads = 4
size_per_head = hidden_size // num_attention_heads

# Create random input tensors and targets
input_tensor = torch.randn(batch_size, seq_length, hidden_size, requires_grad=True)  # Set requires_grad=True to track gradients
target_tensor = torch.randn(batch_size, seq_length, hidden_size)

# Create the Attention Layer instance
attention_layer = AttentionLayer(num_attention_heads=num_attention_heads, size_per_head=size_per_head)

# Define a loss function (e.g., MSELoss)
loss_fn = nn.MSELoss()

# Forward pass through the attention layer
output = attention_layer(input_tensor, input_tensor)  # Self-attention (query, key are the same)

# Compute loss
loss = loss_fn(output, target_tensor)
print(f"Loss: {loss}")
# Perform backpropagation
loss.backward()

# Check gradients for the input tensor
print(f"Input tensor gradients: {input_tensor.grad}")


Loss: 1.037908673286438
Input tensor gradients: tensor([[[-2.5584e-05,  9.1630e-06,  4.3835e-05,  ..., -3.5799e-05,
          -1.8641e-06, -5.5898e-05],
         [-3.8660e-05,  2.2877e-05,  2.1538e-05,  ..., -3.1388e-05,
           1.9090e-05, -3.3681e-05],
         [-1.6117e-05,  2.4816e-05,  6.8442e-05,  ..., -1.6960e-05,
           1.4013e-05, -3.7118e-05],
         ...,
         [-2.1360e-05,  2.0879e-05,  2.5869e-05,  ..., -3.3268e-05,
           1.9415e-05, -2.6377e-05],
         [-6.1441e-06,  1.1794e-05,  6.2522e-05,  ..., -2.0422e-05,
           8.8140e-06, -3.0288e-05],
         [-1.9210e-05,  2.8447e-05,  2.7604e-05,  ..., -2.5945e-05,
           2.7842e-05, -3.4455e-05]],

        [[-1.7136e-06,  2.1871e-05,  2.5296e-05,  ..., -1.9575e-05,
           3.2393e-05, -2.5606e-05],
         [-1.8344e-05, -1.6368e-07,  1.9154e-05,  ..., -1.9563e-06,
           1.9994e-05,  1.3611e-06],
         [-7.1946e-06,  8.0717e-07, -7.1339e-06,  ..., -1.3799e-05,
           2.6126e-05, -2.52