In [1]:
import hydra
import importlib
import torch
import torch.nn as nn

In [2]:
def get_config():
    params = ['model=tiny',\
                'data=openwebtext-split',
                'wandb.name=mdlm-owt',\
                'parameterization=subs',\
                'model.length=1024',\
                'eval.compute_generative_perplexity=True',\
                'sampling.steps=1000']
    with hydra.initialize(version_base=None, config_path="configs"):
        config = hydra.compose(config_name="config", overrides=params)
    return config

config = get_config()
config.model

{'name': 'tiny', 'type': 'ddit', 'hidden_size': 256, 'cond_dim': 64, 'length': 1024, 'n_blocks': 8, 'n_heads': 8, 'scale_by_sigma': True, 'dropout': 0.1, 'tie_word_embeddings': False}

In [3]:
import models.dit
importlib.reload(models.dit)
from models.dit import *

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
B = 16     # batch size
T = 512    # token window
C = 256    #  embedding dimension (in config)
V = 2**10  # vocab size

dit = DIT(config, V)

x = torch.randint(0, 10, (B, T) )


out = dit(x, torch.zeros(B) )
out



tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

<br>
<br>
<br>
<br>

In [3]:
class Rotary(torch.nn.Module):
  def __init__(self, dim, base=10_000):
    super().__init__()
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    self.register_buffer('inv_freq', inv_freq)
    self.seq_len_cached = None
    self.cos_cached = None
    self.sin_cached = None

  def forward(self, x, seq_dim=1):
    seq_len = x.shape[seq_dim]
    if seq_len != self.seq_len_cached:
      self.seq_len_cached = seq_len
      t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
      freqs = torch.einsum("i,j->ij", t, self.inv_freq.clone())
      emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
      # dims are: batch, seq_len, qkv, head, dim
      self.cos_cached = emb.cos()[None, :, None, None, :].repeat(1,1,3,1,1)
      self.sin_cached = emb.sin()[None, :, None, None, :].repeat(1,1,3,1,1)
      # This makes the transformation on v an identity.
      self.cos_cached[:,:,2,:,:].fill_(1.)
      self.sin_cached[:,:,2,:,:].fill_(0.)

    return self.cos_cached, self.sin_cached

In [69]:
def rotate_half(x):
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
  return torch.cat((-x2, x1), dim=-1)

In [78]:
x

tensor([[[68, 13, 24, 92, 81, 81, 66, 48,  9, 12],
         [53, 50, 45, 39, 81, 50, 74, 16, 95, 77]],

        [[ 6, 40, 77, 37, 26, 88,  6, 84, 10, 87],
         [97, 56, 38, 20, 92, 81, 61, 17, 72,  2]]])

In [79]:
rotate_half(x)

tensor([[[-81, -66, -48,  -9, -12,  68,  13,  24,  92,  81],
         [-50, -74, -16, -95, -77,  53,  50,  45,  39,  81]],

        [[-88,  -6, -84, -10, -87,   6,  40,  77,  37,  26],
         [-81, -61, -17, -72,  -2,  97,  56,  38,  20,  92]]])

In [76]:
B, T, C = 2, 4, 8
input = torch.ones((B, T, C))

r = Rotary(C)
cos, sin = r(input)
cos.shape

torch.Size([1, 4, 3, 1, 8])

In [82]:
sin[0,:,0,0,:]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8415, 0.0998, 0.0100, 0.0010, 0.8415, 0.0998, 0.0100, 0.0010],
        [0.9093, 0.1987, 0.0200, 0.0020, 0.9093, 0.1987, 0.0200, 0.0020],
        [0.1411, 0.2955, 0.0300, 0.0030, 0.1411, 0.2955, 0.0300, 0.0030]])

In [55]:
cos[0,:,0,0,:cos.shape[-1]//2]

tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9950,  0.9999,  1.0000],
        [-0.4161,  0.9801,  0.9998,  1.0000],
        [-0.9900,  0.9553,  0.9996,  1.0000]])

In [39]:
r.inv_freq

tensor([1.0000, 0.0100])

In [41]:
t = torch.arange(4).type_as(r.inv_freq)
t

tensor([0., 1., 2., 3.])

In [42]:
freqs = torch.einsum("i,j->ij", t, r.inv_freq.clone())
freqs

tensor([[0.0000, 0.0000],
        [1.0000, 0.0100],
        [2.0000, 0.0200],
        [3.0000, 0.0300]])

In [48]:
emb = torch.cat((freqs, freqs), dim=-1)
print(emb.shape)
emb[None, :, None, None, :].shape

torch.Size([4, 4])


torch.Size([1, 4, 1, 1, 4])