In [6]:
import torch


def apply_rotary(x, sinusoidal_pos=None):
    if sinusoidal_pos is None:
        return x
    sin, cos = sinusoidal_pos
    # x.shape [batch, seq_len, 2]
    x1, x2 = x[..., 0::2], x[..., 1::2]
    # [cos_nθ, -sin_nθ] [x1]
    # [sin_nθ,  cos_nθ] [x2]
    # => [x1 * cos_nθ - x2 * sin_nθ, x1 * sin_nθ + x2 * cos_nθ]
    # 苏神的rotary，使用了下面的计算方法。
    # return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2, -1)
    # 考虑到矩阵乘法torch.einsum("bhmd,bhnd->bhmn", q, k)，因此可以直接在最后一个维度拼接（无需奇偶交错）
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

In [18]:
import numpy as np

dim = 120



array([[8.41470985e-01, 7.56337270e-01, 6.71063461e-01, 5.89918049e-01,
        5.15138753e-01, 4.47670835e-01, 3.87674234e-01, 3.34858341e-01,
        2.88695896e-01, 2.48555475e-01, 2.13780666e-01, 1.83735178e-01,
        1.57826640e-01, 1.35517378e-01, 1.16327471e-01, 9.98334166e-02,
        8.56644690e-02, 7.34978922e-02, 6.30538780e-02, 5.40905416e-02,
        4.63992235e-02, 3.98002019e-02, 3.41388540e-02, 2.92822593e-02,
        2.51162229e-02, 2.15426803e-02, 1.84774464e-02, 1.58482684e-02,
        1.35931453e-02, 1.16588799e-02, 9.99983333e-03, 8.57685383e-03,
        7.35635619e-03, 6.30953158e-03, 5.41166885e-03, 4.64157217e-03,
        3.98106119e-03, 3.41454224e-03, 2.92864038e-03, 2.51188379e-03,
        2.15443302e-03, 1.84784875e-03, 1.58489253e-03, 1.35935597e-03,
        1.16591414e-03, 9.99999833e-04, 8.57695793e-04, 7.35642188e-04,
        6.30957303e-04, 5.41169500e-04, 4.64158867e-04, 3.98107160e-04,
        3.41454881e-04, 2.92864452e-04, 2.51188641e-04, 2.154434

In [15]:
2 ** 2

4

In [52]:
class RotaryPositionalEmbeddings(nn.Module):
    """ 
    TODO: Implement RoPE introduced in the paper RoFormer: Enhanced Transformer with Rotary Position Embedding.
    Reference: https://arxiv.org/abs/2104.09864
    You will be implementing equation 34 in the paper (the computationally efficient form of the rotary matrix
    as covered in recitation).
    Refer to the "Example: Converting Math to PyTorch Code" slide from recitation if you need help translating 
    the equation into code.
    """

    def __init__(self, d: int, base: int = 10_000):
        super().__init__()
        assert d % 2 == 0
        self.base = base
        self.d = d

        self.position_encoding = 1.0 / (base ** ((torch.arange(1, self.d // 2 + 1).float() - 1) * 2 / dim))
        
        self.seq_len = None
        self.cos = None
        self.sin = None
        
    def _build_cache(self, x: torch.Tensor):
        """
        Compute the fixed variables that do not change during training (see recitation for more details).
        """

        self.seq_len = x.shape[-2]
        t = torch.arange(1, self.seq_len+1, device=x.device).type_as(self.position_encoding)
        emb = torch.einsum("i,j->ij", t, self.position_encoding)
        self.cos = emb.cos().view(1, 1, self.seq_len, self.d //2 )
        self.sin = emb.sin().view(1, 1, self.seq_len, self.d //2 )

        return self.cos, self.sin

    def forward(self, x: torch.Tensor):
        """
        Perform the forward pass with the input x, following equation 34 in the paper.
        """

        if x.shape[-2] != self.seq_len:
            self._build_cache(x)
        return self.cos, self.sin



In [60]:
embedding_dim = 16
seq_len = 10
batch_size = 4
head_num = 3

rotary_positional_embeddings = RotaryPositionalEmbeddings(embedding_dim)

# Create a dummy input tensor
input_features = torch.rand((batch_size, head_num, seq_len, embedding_dim))

# Forward pass
cos, sin = rotary_positional_embeddings(input_features)



x1, x2 = input_features[..., 0::2], input_features[..., 1::2]
output =  torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)


print(input_features.shape)
print(cos.shape)
print(output.shape)
        

torch.Size([4, 3, 10, 16])
torch.Size([1, 1, 10, 8])
torch.Size([4, 3, 10, 16])


In [58]:
dim = 100
base = 10000

(torch.arange(1, dim // 2 + 1).float() - 1) * 2 / dim
# ((torch.arange(1, self.d // 2 + 1).float() - 1) * 2
1.0 / (base ** ((torch.arange(1, self.d // 2 + 1).float() - 1) * 2 / self.d))



tensor([0.0000, 0.0200, 0.0400, 0.0600, 0.0800, 0.1000, 0.1200, 0.1400, 0.1600,
        0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400,
        0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
        0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
        0.7200, 0.7400, 0.7600, 0.7800, 0.8000, 0.8200, 0.8400, 0.8600, 0.8800,
        0.9000, 0.9200, 0.9400, 0.9600, 0.9800])

In [None]:
class RotaryPositionalEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half):
        """Rotary positional embedding
        Reference : https://blog.eleuther.ai/rotary-embeddings/
        Paper: https://arxiv.org/pdf/2104.09864.pdf
        Args:
            dim: Dimension of embedding
            base: Base value for exponential
            precision: precision to use for numerical values
        """
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None
        self.precision = precision

    def forward(self, x, seq_len=None):
        """
        Args:
            x: Input x with T X B X C
            seq_len: Sequence length of input x
        """
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        return self.cos_cached, self.sin_cached

In [29]:
dim = 100
base = 10000

torch.arange(1, dim+1, 2)
# base ** (torch.arange(0, dim, 2).float() / dim)
# 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35,
        37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71,
        73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99])

In [31]:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.7542e-03, 2.2909e-03, 1.9055e-03, 1.5849e-03,
        1.3183e-03, 1.0965e-03, 9.1201e-04, 7.5858e-04, 6.3096e-04, 5.2481e-04,
        4.3652e-04, 3.6308e-04, 3.0200e-04, 2.5119e-04, 2.0893e-04, 1.7378e-04,
        1.4454e-04, 1.2023e-04])

In [41]:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(1, seq_len+1).type_as(inv_freq)
# t


freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs
freqs.shape
# t.shape
# inv_freq.shape


torch.Size([10, 50])

In [45]:
emb = freqs
emb.shape

torch.Size([10, 50])

In [44]:
cos_cached = emb.cos()[:, None, None, :]
cos_cached.shape

torch.Size([10, 1, 1, 100])

In [46]:
z = t * inv_freq
z.shape

RuntimeError: The size of tensor a (10) must match the size of tensor b (50) at non-singleton dimension 0

In [None]:
# Create a dummy input tensor
input_features = torch.rand((10, 5, 10, 50))

## Graph


### 4.1


In [None]:
import json
import matplotlib.pyplot as plt

path = "/Users/yanjun/Documents/Chen/Master/24_Spring/GenAI/hw1/handout/out/4_1/train_logs.json"

# Load JSON data
data = json.loads(json_data)

# Extract the list of train_losses
train_losses = data["train_losses"]

# Plotting the line chart
plt.plot(train_losses, marker='o')
plt.title('Train Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
