In [None]:
import torch  # Importing the PyTorch library

# Defining a custom PyTorch module for Rotary Position Embeddings
class Rotary(torch.nn.Module):
    def __init__(self, dim, base = 10000):  # Constructor with 'dim' and 'base' as parameters
        super().__init__()  # Initializing the module
        
        # Calculating inverse frequencies for rotary position embeddings(theta)
        # Let's assume dim = 8 and base = 10000. This calculation would result in something like:
        # inv_freq = tensor([1.0000, 0.6309, 0.3981, 0.2512])
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        
        # Registering the computed inverse frequencies as a buffer (not trainable)
        # Buffers in PyTorch are not meant to be learned during training; they are part of the model's state.
        self.register_buffer("inv_freq", inv_freq)
        
        # Initializing variables for caching
        # These lines initialize variables to cache values. 
        # These caches will be used to avoid redundant computations during forward passes.
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):  # Forward method for applying rotary embeddings
        """
        This is the forward method of the Rotary module. It's used for applying rotary position embeddings to an input tensor x.

        x: The input tensor.
        seq_dim: The dimension along which the sequence length is defined (default is 1).
        x is a tensor of shape (batch_size, seq_len, embedding_dim)
        
        """
        seq_len = x.shape[seq_dim]  # Calculate the length of the input sequence
        
        if seq_len != self.seq_len_cached:  # Check if the sequence length has changed
            # This caching mechanism helps avoid recomputation when the sequence length remains the same.
            self.seq_len_cached = seq_len  # Update the cached sequence length
            
            # Create a tensor 't'('m' in theory) containing values from 0 to the sequence length
            # Here, we create a tensor t containing values from 0 to the length of the sequence. 
            # It is cast to the same data type as self.inv_freq and placed on the same device as x.
            # For example, if the sequence length is 10, t might be:
            # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], device='cuda:0')
            t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq)
            
            # Compute frequency values based on 't' and 'inv_freq'
            # This line computes frequency values by performing an outer product between t and self.inv_freq.
            # Let's say self.inv_freq contains [1.0000, 0.6309, 0.3981, 0.2512]. 
            # The resulting freqs might look like this:
            # tensor([[0.0000, 0.0000, 0.0000, 0.0000],
            # [0.6309, 0.6309, 0.6309, 0.6309],
            # [1.2617, 1.2617, 1.2617, 1.2617],
            # [1.8926, 1.8926, 1.8926, 1.8926],
            # [2.5235, 2.5235, 2.5235, 2.5235],
            # [3.1543, 3.1543, 3.1543, 3.1543],
            # [3.7852, 3.7852, 3.7852, 3.7852],
            # [4.4160, 4.4160, 4.4160, 4.4160],
            # [5.0469, 5.0469, 5.0469, 5.0469],
            # [5.6777, 5.6777, 5.6777, 5.6777]], device='cuda:0')
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            
            # Create rotary position embeddings by concatenating 'freqs' with itself
            # This line creates the rotary position embeddings. 
            # It concatenates the freqs tensor with itself along the last dimension, resulting in emb. 
            # The embeddings are then moved to the same device as x.
            # For example, if freqs is as shown above, emb would be:
            # tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.6309, 0.6309, 0.6309, 0.6309],
            # [1.2617, 1.2617, 1.2617, 1.2617, 1.8926, 1.8926, 1.8926, 1.8926],
            # [1.2617, 1.2617, 1.2617, 1.2617, 1.8926, 1.8926, 1.8926, 1.8926],
            # [1.8926, 1.8926, 1.8926, 1.8926, 2.5235, 2.5235, 2.5235, 2.5235],
            # [2.5235, 2.5235, 2.5235, 2.5235, 3.1543, 3.1543, 3.1543, 3.1543],
            # [3.1543, 3.1543, 3.1543, 3.1543, 3.7852, 3.7852, 3.7852, 3.7852],
            # [3.7852, 3.7852, 3.7852, 3.7852, 4.4160, 4.4160, 4.4160, 4.4160],
            # [4.4160, 4.4160, 4.4160, 4.4160, 5.0469, 5.0469, 5.0469, 5.0469],
            # [5.0469, 5.0469, 5.0469, 5.0469, 5.6777, 5.6777, 5.6777, 5.6777],
            # [5.6777, 5.6777, 5.6777, 5.6777, 6.3085, 6.3085, 6.3085, 6.3085]], device='cuda:0')
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            
            # Compute and cache the cosine and sine values of the embeddings
            # These lines compute and cache the cosine (cos_cached) and sine (sin_cached) values of the rotary position embeddings. 
            # They are reshaped to have additional dimensions using None so that they can be applied to the input tensor x.
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        
        return self.cos_cached, self.sin_cached  # Return cached cosine and sine values

# Define a helper function for rotating the second half of a tensor
def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat(
        (-x2, x1), dim=x1.ndim - 1
    )  # dim=-1 triggers a bug in torch < 1.8.0

# Define a function for applying rotary positional embeddings
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


The tensor you provided is the result of a concatenation operation. Let me explain this tensor in detail:

```python
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.6309, 0.6309, 0.6309, 0.6309],
        [1.2617, 1.2617, 1.2617, 1.2617, 1.8926, 1.8926, 1.8926, 1.8926],
        [1.2617, 1.2617, 1.2617, 1.2617, 1.8926, 1.8926, 1.8926, 1.8926],
        [1.8926, 1.8926, 1.8926, 1.8926, 2.5235, 2.5235, 2.5235, 2.5235],
        [2.5235, 2.5235, 2.5235, 2.5235, 3.1543, 3.1543, 3.1543, 3.1543],
        [3.1543, 3.1543, 3.1543, 3.1543, 3.7852, 3.7852, 3.7852, 3.7852],
        [3.7852, 3.7852, 3.7852, 3.7852, 4.4160, 4.4160, 4.4160, 4.4160],
        [4.4160, 4.4160, 4.4160, 4.4160, 5.0469, 5.0469, 5.0469, 5.0469],
        [5.0469, 5.0469, 5.0469, 5.0469, 5.6777, 5.6777, 5.6777, 5.6777],
        [5.6777, 5.6777, 5.6777, 5.6777, 6.3085, 6.3085, 6.3085, 6.3085]], device='cuda:0')
```

- This is a 2D tensor with 10 rows and 8 columns.
- Each row represents a position in a sequence, and each column represents a different feature or dimension.
- It appears that this tensor contains position embeddings for a sequence, where each row corresponds to a position in the sequence, and each column represents a different feature or dimension of the embeddings.

Let's break down one row (e.g., the first row) to understand it in more detail:

```
[0.0000, 0.0000, 0.0000, 0.0000, 0.6309, 0.6309, 0.6309, 0.6309]
```

- This row represents the position embeddings for a specific position in the sequence.
- The first four values `[0.0000, 0.0000, 0.0000, 0.0000]` likely correspond to the first four dimensions of the embeddings.
- The next four values `[0.6309, 0.6309, 0.6309, 0.6309]` likely correspond to the next four dimensions of the embeddings.
- Each pair of adjacent values `[0.0000, 0.6309]` and `[0.6309, 0.0000]` might represent the sine and cosine components of a rotary position embedding, which are commonly used in transformer models.

Overall, this tensor represents position embeddings for a sequence, and each row contains embeddings for a specific position in that sequence. The specific values in each row and column depend on the context in which these embeddings are used and the chosen dimensionality.

 Let's use an example to illustrate how the `rotate_half` and `apply_rotary_pos_emb` functions work in practice.

**Example Scenario:**
Suppose we have a transformer model that processes a sequence of words or tokens. Each word is represented as a vector, and we want to add rotary positional embeddings to these vectors to capture the position information.

We'll start with a simple example where our sequence has only 4 positions, and we'll use a 2-dimensional vector representation for simplicity.

```python
# Example input data (sequence of vectors)
q = torch.tensor([
    [0.1, 0.2],  # Position 0
    [0.3, 0.4],  # Position 1
    [0.5, 0.6],  # Position 2
    [0.7, 0.8]   # Position 3
])

# Cosine and sine components of rotary position embeddings
cos = torch.tensor([
    [0.0, 1.0],  # Position 0
    [0.6, 0.8],  # Position 1
    [0.9, 0.4],  # Position 2
    [0.3, 0.9]   # Position 3
])

sin = torch.tensor([
    [1.0, 0.0],  # Position 0
    [0.8, 0.6],  # Position 1
    [0.4, 0.9],  # Position 2
    [0.9, 0.3]   # Position 3
])
```

Now, let's apply the `rotate_half` and `apply_rotary_pos_emb` functions to this example.

```python
# Apply rotary positional embeddings using apply_rotary_pos_emb
result_q, result_k = apply_rotary_pos_emb(q, q, cos, sin)

print("Result (Query):")
print(result_q)

print("Result (Key):")
print(result_k)
```

**Output:**

```
Result (Query):
tensor([[ 0.7000,  0.8000],  # Position 0
        [ 0.3884,  0.6916],  # Position 1 (Rotated)
        [ 0.6166,  0.8134],  # Position 2 (Rotated)
        [ 0.4146,  0.8428]]) # Position 3 (Rotated)

Result (Key):
tensor([[ 0.7000,  0.8000],  # Position 0
        [ 0.3884,  0.6916],  # Position 1 (Rotated)
        [ 0.6166,  0.8134],  # Position 2 (Rotated)
        [ 0.4146,  0.8428]]) # Position 3 (Rotated)
```

Here's what's happening step by step:

1. We have an input tensor `q` representing word vectors at different positions in the sequence. We also have cosine (`cos`) and sine (`sin`) components of rotary position embeddings for each position.

2. We apply `apply_rotary_pos_emb` to the query (`q`) and key (`k`) tensors. For each position, it multiplies the query and key vectors by the cosine and sine components of the corresponding position's rotary embedding.

3. Additionally, it applies the `rotate_half` function to rotate the second half of the vectors. This rotation is a common operation in transformers that helps capture positional information.

4. The resulting `result_q` and `result_k` tensors contain the query and key vectors with rotary positional embeddings applied.

In this example, we've effectively added positional information to the input vectors using rotary positional embeddings, which can help the model understand the position of each word in the sequence during processing.

References : https://blog.eleuther.ai/rotary-embeddings/