In [1]:
from math import pi, log
import torch
from torch import nn, einsum
from einops import rearrange, repeat

# Helper functions
def exists(val):
    return val is not None

def broadcat(tensors, dim = -1):
    # Broadcast tensors along a specified dimension
    num_tensors = len(tensors)
    shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
    assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
    shape_len = list(shape_lens)[0]

    dim = (dim + shape_len) if dim < 0 else dim
    dims = list(zip(*map(lambda t: list(t.shape), tensors)))

    expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
    assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatenation'
    max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
    expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
    expanded_dims.insert(dim, (dim, dims[dim]))
    expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
    tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
    return torch.cat(tensors, dim = dim)

# Rotary embedding helper functions
def rotate_half(x):
    # Rotate the input tensor by half along the last dimension
    x = rearrange(x, '... (d r) -> ... d r', r = 2)
    x1, x2 = x.unbind(dim = -1)
    x = torch.stack((-x2, x1), dim = -1)
    return rearrange(x, '... d r -> ... (d r)')

def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
    # Apply rotary embeddings to input tensor 't' using the specified frequencies 'freqs'
    freqs = freqs.to(t)
    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim
    assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
    t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    return torch.cat((t_left, t, t_right), dim = -1)

# Learned rotation helpers
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
    if exists(freq_ranges):
        rotations = einsum('..., f -> ... f', rotations, freq_ranges)
        rotations = rearrange(rotations, '... r f -> ... (r f)')

    rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
    return apply_rotary_emb(rotations, t, start_index = start_index)

# Classes
class RotaryEmbedding(nn.Module):
    def __init__(
        self,
        dim,
        custom_freqs = None,
        freqs_for = 'lang',
        theta = 10000,
        max_freq = 10,
        num_freqs = 1,
        learned_freq = False,
        use_xpos = False,
        xpos_scale_base = 512,
        interpolate_factor = 1.,
        theta_rescale_factor = 1.
    ):
        super().__init__()
        # Initialize rotary embedding parameters
        theta *= theta_rescale_factor ** (dim / (dim - 2))

        if exists(custom_freqs):
            freqs = custom_freqs
        elif freqs_for == 'lang':
            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
        elif freqs_for == 'pixel':
            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
        elif freqs_for == 'constant':
            freqs = torch.ones(num_freqs).float()
        else:
            raise ValueError(f'unknown modality {freqs_for}')

        self.cache = dict()
        self.cache_scale = dict()
        self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)

        # Interpolation factors
        assert interpolate_factor >= 1.
        self.interpolate_factor = interpolate_factor

        # X-position (xpos) support
        self.use_xpos = use_xpos
        if not use_xpos:
            self.register_buffer('scale', None)
            return

        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.scale_base = xpos_scale_base
        self.register_buffer('scale', scale)

    def get_seq_pos(self, seq_len, device, dtype, offset = 0):
        # Get sequence positions for rotary embeddings
        return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor

    def rotate_queries_or_keys(self, t, seq_dim = -2, offset = 0):
        # Rotate queries or keys using rotary embeddings
        assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
        device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
        freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}')
        return apply_rotary_emb(freqs, t)

    def rotate_queries_and_keys(self, q, k, seq_dim = -2):
        # Rotate both queries and keys using rotary embeddings
        assert self.use_xpos
        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
        seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
        freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
        scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)
        rotated_q = apply_rotary_emb(freqs, q, scale = scale)
        rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1)
        return rotated_q, rotated_k

    def get_scale(self, t, cache_key = None):
        # Get scale for X-position (xpos) support
        assert self.use_xpos

        if exists(cache_key) and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        scale = 1.
        if self.use_xpos:
            power = (t - len(t) // 2) / self.scale_base
            scale = self.scale ** rearrange(power, 'n -> n 1')
            scale = torch.cat((scale, scale), dim = -1)

        if exists(cache_key):
            self.cache[cache_key] = scale

        return scale

    def forward(self, t, cache_key = None):
        # Calculate rotary embeddings
        if exists(cache_key) and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        freqs = self.freqs

        freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
        freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

        if exists(cache_key):
            self.cache[cache_key] = freqs

        return freqs

Certainly, let's break down the provided code step by step with detailed explanations:

**Step 1: Import Dependencies**
```python
from math import pi, log
import torch
from torch import nn, einsum
from einops import rearrange, repeat
```
- The code starts by importing necessary libraries:
  - `math.pi` and `math.log` for mathematical constants and functions.
  - `torch` for PyTorch functionality.
  - `nn` for PyTorch's neural network module.
  - `einsum` from PyTorch for Einstein summation notation.
  - `rearrange` and `repeat` from the `einops` library for tensor manipulation.

**Step 2: Define Helper Functions**
```python
def exists(val):
    return val is not None
```
- `exists(val)` is a helper function that checks if a value `val` exists (i.e., is not `None`). It returns `True` if the value exists, `False` otherwise.

```python
def broadcat(tensors, dim = -1):
    # Function to broadcast tensors along a specified dimension
    # ...
    return torch.cat(tensors, dim = dim)
```
- `broadcat(tensors, dim)` is a function for broadcasting tensors along a specified dimension.
  - It first checks if all tensors have the same number of dimensions.
  - Then, it determines the dimension along which to broadcast (`dim`).
  - It expands the dimensions of the tensors that are not along the specified dimension to match the shape of the tensors along the specified dimension.
  - Finally, it concatenates the tensors along the specified dimension using `torch.cat`.

**Step 3: Define Rotary Embedding Helper Functions**
```python
def rotate_half(x):
    # Function to rotate the input tensor by half along the last dimension
    # ...
    return torch.cat((x_left, x, x_right), dim = -1)
```
- `rotate_half(x)` is a function to rotate the input tensor `x` by half along its last dimension.
  - It rearranges the tensor `x` to group elements in pairs.
  - It unbinds the tensor into two parts, `x1` and `x2`.
  - It rotates the elements by swapping and negating them.
  - Finally, it rearranges and concatenates the rotated elements back into the original shape.

```python
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
    # Function to apply rotary embeddings to an input tensor 't' using specified frequencies 'freqs'
    # ...
    return torch.cat((t_left, t, t_right), dim = -1)
```
- `apply_rotary_emb(freqs, t, start_index, scale)` is a function to apply rotary embeddings to an input tensor `t` using the specified frequencies `freqs`.
  - It calculates the rotary embeddings using cosine and sine functions based on the provided frequencies.
  - The embeddings are applied to the input tensor by scaling and combining with the original tensor.
  - The result is a tensor with rotary embeddings concatenated to the input tensor.

```python
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
    # Function to apply learned rotations to an input tensor 't'
    # ...
    return apply_rotary_emb(rotations, t, start_index = start_index)
```
- `apply_learned_rotations(rotations, t, start_index, freq_ranges)` is a function to apply learned rotations to an input tensor `t`. It internally calls `apply_rotary_emb()` with the provided rotations.

**Step 4: Define Rotary Embedding Class**
```python
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, custom_freqs = None, freqs_for = 'lang', theta = 10000, max_freq = 10, num_freqs = 1, learned_freq = False, use_xpos = False, xpos_scale_base = 512, interpolate_factor = 1., theta_rescale_factor = 1.):
        # Constructor for RotaryEmbedding class
        # ...
```
- `RotaryEmbedding` is a PyTorch module for managing rotary embeddings.
- The constructor `__init__()` initializes various parameters and settings for rotary embeddings:
  - `dim`: The dimension of embeddings.
  - `custom_freqs`: Custom frequencies for rotary embeddings (optional).
  - `freqs_for`: Type of frequencies to use ('lang', 'pixel', 'constant').
  - `theta`: Scaling factor for frequencies.
  - `max_freq`: Maximum frequency for pixel frequencies.
  - `num_freqs`: Number of frequencies for constant frequencies.
  - `learned_freq`: Whether to learn the frequencies.
  - `use_xpos`: Whether to use X-position (xpos) embeddings.
  - `xpos_scale_base`: Scaling factor for xpos embeddings.
  - `interpolate_factor`: Interpolation factor for sequence positions.
  - `theta_rescale_factor`: Rescaling factor for theta based on dimension.

**Step 5: Initialize Rotary Embeddings**
```python
        # Initialize rotary embedding parameters
        theta *= theta_rescale_factor ** (dim / (dim - 2))
```
- The code rescales `theta` based on the dimension to mitigate issues when applying rotary embeddings to longer sequences.

```python
        if exists(custom_freqs):
            freqs = custom_freqs
        elif freqs_for == 'lang':
            freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
        elif freqs_for == 'pixel':
            freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
        elif freqs_for == 'constant':
            freqs = torch.ones(num_freqs).float()
        else:
            raise ValueError(f'unknown modality {freqs_for}')
```
- The code sets the frequencies for rotary embeddings based on different modalities:
  - `'lang'`: Based on language modeling, using decreasing frequencies.
  - `'pixel'`: For image data, using linearly spaced frequencies.
  - `'constant'`: Constant frequencies.
  - Custom frequencies can be provided using `custom_freqs`.

**Step 6: Initialize Cache and Parameters**
```python
        self.cache = dict()
        self.cache_scale = dict()
        self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
```
- The code initializes caches for storing computed values and defines a trainable parameter `freqs` for the rotary frequencies

 based on the selected configuration.

```python
        # interpolation factors
        assert interpolate_factor >= 1.
        self.interpolate_factor = interpolate_factor
```
- The code sets the interpolation factor for sequence positions. It ensures that `interpolate_factor` is greater than or equal to 1.

```python
        # xpos
        self.use_xpos = use_xpos
        if not use_xpos:
            self.register_buffer('scale', None)
            return
```
- The code sets whether X-position (xpos) embeddings will be used. If not, it initializes a buffer for scaling, but it remains `None`.

```python
        scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
        self.scale_base = xpos_scale_base
        self.register_buffer('scale', scale)
```
- If `use_xpos` is `True`, the code calculates and sets the scaling factor based on the dimension and scaling parameters.

**Step 7: Define Sequence Position Calculation**
```python
    def get_seq_pos(self, seq_len, device, dtype, offset = 0):
        return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
```
- The `get_seq_pos()` method calculates sequence positions given sequence length, device, data type, and an optional offset. It accounts for interpolation.

**Step 8: Define Rotation of Queries and Keys**
```python
    def rotate_queries_or_keys(self, t, seq_dim = -2, offset = 0):
        assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
        device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
        freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}')
        return apply_rotary_emb(freqs, t)
```
- The `rotate_queries_or_keys()` method applies rotary embeddings to either queries or keys. However, it cannot be used for length-extrapolatable rotary embeddings, so it raises an assertion error if `use_xpos` is `True`.
- It calculates the rotary embeddings using frequencies obtained from `forward()` based on the sequence positions.
- The rotary embeddings are applied to the input tensor `t` using `apply_rotary_emb()`.

**Step 9: Define Rotation of Queries and Keys with X-Position**
```python
    def rotate_queries_and_keys(self, q, k, seq_dim = -2):
        assert self.use_xpos
        device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
        seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
        freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
        scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)
        rotated_q = apply_rotary_emb(freqs, q, scale = scale)
        rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1)
        return rotated_q, rotated_k
```
- The `rotate_queries_and_keys()` method is used when `use_xpos` is `True`, and it rotates both queries and keys.
- It calculates sequence positions and frequencies.
- The `get_scale()` method is used to obtain the scaling factor.
- Rotary embeddings are applied to both queries (`q`) and keys (`k`) using `apply_rotary_emb()`.

**Step 10: Get Scaling Factor for X-Position**
```python
    def get_scale(self, t, cache_key = None):
        assert self.use_xpos

        if exists(cache_key) and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        scale = 1.
        if self.use_xpos:
            power = (t - len(t) // 2) / self.scale_base
            scale = self.scale ** rearrange(power, 'n -> n 1')
            scale = torch.cat((scale, scale), dim = -1)

        if exists(cache_key):
            self.cache[cache_key] = scale

        return scale
```
- The `get_scale()` method calculates the scaling factor for X-position (xpos) embeddings.
- It checks the cache to see if the scaling factor has already been computed.
- If not, it calculates the scaling factor based on the sequence positions.
- The computed scale is cached for future use.

**Step 11: Calculate Rotary Embeddings**
```python
    def forward(self, t, cache_key = None):
        if exists(cache_key) and cache_key in self.cache:
            return self.cache[cache_key]

        if callable(t):
            t = t()

        freqs = self.freqs

        freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
        freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

        if exists(cache_key):
            self.cache[cache_key] = freqs

        return freqs
```
- The `forward()` method calculates rotary embeddings based on the provided frequencies and sequence positions.
- It checks the cache to see if the embeddings have already been computed.
- If not, it calculates the embeddings using Einstein summation and repeats the frequencies.
- The computed embeddings are cached

 for future use.

This code defines a flexible and configurable rotary embedding module that can be used in various applications, including natural language processing and computer vision, to enhance positional information in data. It provides options for selecting different types of frequencies and interpolation strategies, making it adaptable to different use cases.

## Rotary Embeddings - Pytorch

A standalone library for adding <a href="https://arxiv.org/abs/2104.09864">rotary embeddings</a> to transformers in Pytorch, following its success as <a href="https://blog.eleuther.ai/rotary-embeddings/">relative positional encoding</a>. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something <a href="https://www.nature.com/articles/s41593-021-00821-9">more</a> to rotations that can be exploited in artificial neural networks.

## Install

```bash
$ pip install rotary-embedding-torch
```

## Usage

```python
import torch
from rotary_embedding_torch import RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

rotary_emb = RotaryEmbedding(dim = 32)

# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 8, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

q = rotary_emb.rotate_queries_or_keys(q)
k = rotary_emb.rotate_queries_or_keys(k)

# then do your attention with your queries (q) and keys (k) as usual
```

If you do all the steps above correctly, you should see a dramatic improvement during training

## Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

```python
import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel',
    max_freq = 256
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)
```

## Length Extrapolatable Rotary Embeddings

In <a href="https://arxiv.org/abs/2212.10554v1">this paper</a>, they were able to fix length extrapolation issue with rotary embeddings by giving it a decay similar to ALiBi. They named this technique XPos, and you can use it by setting `use_xpos = True` on initialization.

This can only be used for autoregressive transformers

```python
import torch
from rotary_embedding_torch import RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

rotary_emb = RotaryEmbedding(
    dim = 32,
    use_xpos = True   # set this to True to make rotary embeddings extrapolate better to sequence lengths greater than the one used at training time
)

# mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc)

q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head)
k = torch.randn(1, 8, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

# instead of using `rotate_queries_or_keys`, you will use `rotate_queries_and_keys`, the rest is taken care of

q, k = rotary_emb.rotate_queries_and_keys(q, k)
```

## Interpolating Sequence Positions

This MetaAI <a href="https://arxiv.org/abs//2306.15595">paper</a> proposes simply fine-tuning on interpolations of the sequence positions for extending to longer context length for pretrained models. They show this performs much better than simply fine-tuning on the same sequence positions but extended further.

You can use this by setting the `interpolate_factor` on initialization to a value greater than `1.` (ex. if pretrained model was trained on 2048, setting `interpolate_factor = 2.` would allow fine-tuning to `2048 x 2. = 4096`)

Update: someone in the community has reported that it does not work well. please email me if you see either a positive or negative result

```python
import torch
from rotary_embedding_torch import RotaryEmbedding

rotary_emb = RotaryEmbedding(
    dim = 32,
    interpolate_factor = 2.    # add this line of code to pretrained model and fine-tune for ~1000 steps, as shown in paper
)
```

## Citations

```bibtex
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
```

```bibtex
@inproceedings{Sun2022ALT,
    title     = {A Length-Extrapolatable Transformer},
    author    = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
    year      = {2022}
}
```

```bibtex
@inproceedings{Chen2023ExtendingCW,
    title   = {Extending Context Window of Large Language Models via Positional Interpolation},
    author  = {Shouyuan Chen and Sherman Wong and Liangjian Chen and Yuandong Tian},
    year    = {2023}
}
```

```bibtex
@misc{bloc97-2023
    title   = {NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.}
    author  = {/u/bloc97},
    url     = {https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/}
}
```

References : https://github.com/lucidrains/rotary-embedding-torch and https://github.com/lucidrains/rotary-embedding-torch/commit/947f26fb74d1b61f5c1da169e80f14cab4e94f00