In [5]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0


In [6]:
import torch
import torch.nn as nn

from einops.layers.torch import Rearrange

In [7]:
class RotaryEmbedding(nn.Module):
    def __init__(self, embed_size):
        super(RotaryEmbedding, self).__init__()
        self.embed_size = embed_size
        self.freqs = 2 ** torch.linspace(0, embed_size // 2 - 1, embed_size // 2)
        self.scale = nn.Parameter(torch.ones(embed_size // 2))

    def forward(self, x):
        x = x + 0.5  # Centering the inputs around 0
        x = x.unsqueeze(-2) * self.freqs.unsqueeze(0)
        x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
        x = x * self.scale.unsqueeze(0)
        return x

In [8]:
class GPT2WithRotary(nn.Module):
    def __init__(self, vocab_size, embed_size=768, heads=12, num_layers=12, ff_hidden_size=3072, dropout=0.1):
        super(GPT2WithRotary, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rotary_embedding = RotaryEmbedding(embed_size)
        self.blocks = nn.Sequential(*[GPT2Block(embed_size, heads, ff_hidden_size, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = x + self.rotary_embedding(torch.arange(x.size(1), device=x.device).float())
        x = self.blocks(x)
        x = self.fc(x)
        return x