In [1]:
import torch
import torch.nn.functional as F

# Example input: batch of 2 sequences, each with 3 tokens, and embedding size of 4
hidden_states = torch.tensor([[[1, 0, 1, 0], [0, 1, 0, 1], [1, 1, 1, 1]],
                              [[1, 1, 0, 0], [0, 0, 1, 1], [1, 0, 1, 1]]], dtype=torch.float32)

# Linear projections for queries, keys, and values
q_proj = torch.nn.Linear(4, 4)
k_proj = torch.nn.Linear(4, 4)
v_proj = torch.nn.Linear(4, 4)

# Compute queries, keys, and values
queries = q_proj(hidden_states)
keys = k_proj(hidden_states)
values = v_proj(hidden_states)

# Compute attention scores
scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(4.0))

# Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1)

# Compute weighted sum of values
attn_output = torch.matmul(attn_weights, values)

print(attn_output)

tensor([[[ 0.0342, -0.0846,  0.8417,  0.2226],
         [ 0.0309, -0.0841,  0.8209,  0.2075],
         [ 0.0297, -0.0901,  0.8084,  0.1971]],

        [[-0.0596,  0.0722,  0.7060,  0.1779],
         [-0.2214,  0.1443,  0.6843,  0.1402],
         [-0.1990,  0.1342,  0.6855,  0.1452]]], grad_fn=<UnsafeViewBackward0>)


In [1]:
import torch
import sys
import os

# from ..modeling_phonelm import PhoneLMAttention
# from configuration_phonelm import PhoneLMConfig

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(os.path.dirname(current_dir), "."))
sys.path.insert(0, parent_dir)

from modeling_phonelm import PhoneLMAttention
from configuration_phonelm import PhoneLMConfig

config = PhoneLMConfig()

attention_layer = PhoneLMAttention(config)

batch_size = 2
seq_length = 10
hidden_size = config.hidden_size

hidden_states = torch.randn(batch_size, seq_length, hidden_size)
attention_mask = torch.zeros(batch_size, 1, 1, seq_length) # no mask
position_ids = torch.arange(seq_length).unsqueeze(0).repeat(batch_size, 1)

print(hidden_states.shape)
print(attention_mask.shape)
print(position_ids.shape)

attn_output, attn_weights, past_key_value = attention_layer(hidden_states, attention_mask, position_ids, output_attentions=True)

print(attn_output.shape)
print(attn_weights.shape)
print(past_key_value) 


2025-01-16 11:03:34.882375: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-16 11:03:34.914279: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-01-16 11:03:34.914312: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-01-16 11:03:34.915221: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-16 11:03:34.920873: I tensorflow/core/platform/cpu_feature_guar

torch.Size([2, 10, 4096])
torch.Size([2, 1, 1, 10])
torch.Size([2, 10])
torch.Size([2, 10, 4096])
torch.Size([2, 32, 10, 10])
None
