
"""
Adapted from
[OLMo](https://github.com/allenai/OLMo) ,
[mistral](https://github.com/mistralai/mistral-src) ,
[GPT-FAST](https://github.com/pytorch-labs/gpt-fast)
"""

In [7]:
from importlib.metadata import version
import torch

print("TORCH VERSION :", version("torch"))
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print('GPU  : ', device.upper())

TORCH VERSION : 2.2.1
GPU  :  CUDA


In [13]:
from torch import nn
import torch.nn.functional as F

from dataclasses import dataclass

from typing import Union

In [39]:
## Preprocessing

In [16]:
@dataclass
class ModelArgs:
    n_layer: int = 12
    # input
    d_model: int = 
    seq_len: int = 500
    vocab_size:int = -1 # fill this after tokenization
    # Norm
    norm_eps: float = 1e-5
    #RoPE
    rope_base: float = 10000.0
    scaling_factor:float = 1.0
    # Attention
    n_head: int = 32
    n_kv_head: int = "?"
    sliding_window: int = "?"
    # FeedForward
    intermediate_size: int = None
    # Training
    batch_size: int = None
    device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [17]:
class RotaryEmbedding(nn.Module,ModelArgs):
    """
    [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
    """
    def __init__(self,h_dim:int,seq_len:int,base:Union[int, float]=10000,device:str=None,scaling_factor: Union[int, float]=1.0):
        super().__init__()
        # Experimental
        self.linear = nn.Linear()

        #Long-termdecay: θi=10000**−2i/d.
        theta = 1.0 / base ** (torch.arange(0,seq_len,2,dtype=torch.int16, device=device)/h_dim) #  (h_dim/2)
        m = torch.arange(seq_len,dtype=torch.int16,device=device).type_as(theta).type_as(theta) #  (seq_len)
        m /= scaling_factor
        omega = torch.outer(m,theta)  # (seq_len,h_dim/2)
        self.omega_complex = torch.polar(torch.ones_like(omega),omega)

    def forward(self,x:torch.Tensor):
        
        x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) #(b,seq_len,d_model,head,head_dim/2,2)
        omega_complex = self.omega_complex.unsqueeze(0).unsqueeze(2) # (seq_len,h_dim/2) -> (1,seq_len,1,h_dim/2)
        x_totated = x_complex * omega_complex #(b,seq_len,d_model,head,head_dim/2,2)
        x_out = torch.view_as_real(x_rotated) 
        return x_out.reshape(*x.shape)        

In [28]:
x = torch.tensor([1, 2, 3])
x

tensor([1, 2, 3])

In [37]:
x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)
y = torch.tensor([[1, 2], [3, 4]])
# torch.repeat_interleave(y, 2)
torch.repeat_interleave(y, 3, dim=1)
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)

tensor([[1, 2],
        [3, 4],
        [3, 4]])

tensor([0, 1, 1, 2, 2, 2])