In [1]:
import torch
import torch.nn.functional as F


Rotary positional encoding
goglu

In [2]:
def apply_rope(x, seq_dim, dim_head):
    pos = torch.arange(x.shape[seq_dim], dtype=torch.float32, device=x.device)
    inv_freq = 1.0/(10000**(torch.arange(0, dim_head, 2, device=x.device).float() /dim_head))
    sinusoid_inp = torch.einsum('i,j->ij', pos, inv_freq)
    pos_emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
    return (x*pos_emb) + (x.roll(shift=1, dims=-1)*pos_emb.roll(shift=1, dims=-1))

In [3]:
class GoGLU(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.proj1 = torch.nn.Linear(input_dim, hidden_dim)
        self.proj2 = torch.nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        gate = F.gelu(self.proj1(x))
        return self.proj2(x) * gate

In [4]:
class GroupHeadedAttention(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, group_size):
        super(GroupHeadedAttention, self).__init__()
        self.num_heads = num_heads
        self.group_size = group_size
        self.head_dim = embed_dim // num_heads

        self.qkv_proj = torch.nn.Linear(embed_dim, 3*embed_dim)
        self.out_proj = torch.nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len, self.num_heads, 3*self.head_dim)
        
        q, k, v = qkv.chunk(3, dim=-1)
        q = apply_rope(q, seq_len, self.head_dim)
        k = apply_rope(k, seq_len, self.head_dim)

        groups = self.num_heads // self.group_size
        attn_scores = torch.einsum('bnhd,bmhd->bhnm', q, k)
        attn_scores = attn_scores.view(batch_size, groups, self.group_size, seq_len, seq_len)
        attn_probs = F.softmax(attn_scores, dim=-1)

        #Weighted sum of values
        attn_output = torch.einsum('bgnm,bmhd->bgnhd', attn_probs, v)
        attn_output = attn_output.reshape(batch_size, seq_len, embed_dim)

        return self.out_proj(attn_output)

In [None]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, embed_dim, num_heads. group_size, ff_dim):
        super(TransformerBlock, self).__init__()
        self.attention = GroupHeadedAttention(embed_dim, num_heads, group_size)
        self.norm1 = torch.nn.LayerNorm(embed_dim)
        self.ff = GoGLU(embed_dim, ff_dim)
        self.norm2 = torch.nn.LayerNorm(embed_dim)

        def forward(self, x):
            attn_out = self.attention(x)
            x = self.norm1(x + attn_out)
            ff_out = self.ff(x) 
            x = self.norm2(x + ff_out)

            return x

In [None]:
class GPTModel(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers, num_heads, group_size, ff_dim):
        super(GPTModel, self).__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.layers = torch.nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, group_size, ff_dim) for _ in range(num_layers)
        ])
        self.lm_head = torch.nn.Linear(embed_dim, vocab_size)
    
    def forward(self, x):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')

ModuleNotFoundError: No module named 'transformers'