In [1]:
import torch 
from torch import nn


In [10]:
# Rotary Embedding and apply_rotary_pos_emb function
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super(RotaryEmbedding, self).__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, seq_len):
        # Generate a range for sequence length and reshape for tr
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq).unsqueeze(1)
        # Calculate the frequency embeddings using broadcasting instead of einsum
        freqs = t * self.inv_freq.unsqueeze(0)  # Shape: [seq_len, dim//2]
        emb = torch.cat((freqs, freqs), dim=-1)  # Duplicate to match input dimension
        return emb[None, :, :]  # Shape: [1, seq_len, dim]

def apply_rotary_pos_emb(q, k, sinusoidal_pos):
    # Split the query and key tensors into even and odd dimensions
    q_cos, q_sin = q[..., 0::2], q[..., 1::2]
    k_cos, k_sin = k[..., 0::2], k[..., 1::2]

    # Split the positional encodings into cosine and sine parts
    cos, sin = sinusoidal_pos[..., 0::2], sinusoidal_pos[..., 1::2]

    # Apply rotary embeddings without einsum, element-wise operations
    q_rot = torch.cat([q_cos * cos - q_sin * sin, q_cos * sin + q_sin * cos], dim=-1)
    k_rot = torch.cat([k_cos * cos - k_sin * sin, k_cos * sin + k_sin * cos], dim=-1)

    return q_rot, k_rot

In [30]:
a = torch.randn((5,3))
b = torch.cat((a, a), dim=-1)
print(a.shape, b.shape)
c = b[None, :, :]
d = b.unsqueeze(0)
print(c.shape, d.shape)

torch.Size([5, 3]) torch.Size([5, 6])
torch.Size([1, 5, 6]) torch.Size([1, 5, 6])


In [2]:
base = 10 
d_model=15 
device = 'cuda'
theta = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)).to(device)
seq_id = torch.arange(0, 5).to(device)

In [10]:
seq_id.unsqueeze(0).shape, theta.unsqueeze(1).shape

(torch.Size([1, 5]), torch.Size([8, 1]))

In [15]:
b = seq_id.unsqueeze(0) * theta.unsqueeze(1)
print(b.shape)

torch.Size([8, 5])


In [18]:
import torch

# Tạo hai tensor
a = torch.ones(1, 5)  # Kích thước [1, 5]
b = torch.ones(8, 1) * 2  # Kích thước [8, 1]
print(b)
print(a)
# Phép nhân element-wise với broadcasting
c = a * b

print(c.shape)  # torch.Size([8, 5])
print(c)

tensor([[2.],
        [2.],
        [2.],
        [2.],
        [2.],
        [2.],
        [2.],
        [2.]])
tensor([[1., 1., 1., 1., 1.]])
torch.Size([8, 5])
tensor([[2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2.]])
