In [2]:
import hydra
import importlib
import torch
import torch.nn as nn

In [2]:
def get_config():
    params = ['model=tiny',\
                'data=openwebtext-split',
                'wandb.name=mdlm-owt',\
                'parameterization=subs',\
                'model.length=1024',\
                'eval.compute_generative_perplexity=True',\
                'sampling.steps=1000']
    with hydra.initialize(version_base=None, config_path="configs"):
        config = hydra.compose(config_name="config", overrides=params)
    return config

config = get_config()
config.model

{'name': 'tiny', 'type': 'ddit', 'hidden_size': 256, 'cond_dim': 64, 'length': 1024, 'n_blocks': 8, 'n_heads': 8, 'scale_by_sigma': True, 'dropout': 0.1, 'tie_word_embeddings': False}

In [59]:
import models.dit
importlib.reload(models.dit)
from models.dit import *

In [None]:
B = 2      # batch size
T = 512    # token window
C = 256    # embedding dimension (in config)
V = 2**10  # vocab size

dit = DIT(config, V)

x = torch.randint(0, 10, (B, T) )


out = dit(x, torch.zeros(B) )
out



tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<ViewBackward0>)

In [10]:
out = dit(x, torch.zeros(B) )

In [11]:
out.sum().backward()

<br>
<br>
<br>
<br>

In [None]:
dit = DIT(config, V)

x = torch.randint(0, 10, (B, T) )


out = dit(x, torch.zeros(B) )
out

In [8]:
B, T, C = 2, 4, 32
H = 2
V = 2**10  
x = torch.rand(B, T, C)
t = torch.randint(0, 100, (B, ))

In [5]:
t_dim = 8
cond_dim = 12

In [6]:
import dit2
importlib.reload(dit2)
from dit2 import DiTBlock, Rotary, TimestepEmbedder 

In [9]:
indices = torch.randint(0, V, (B, T))
sigma = torch.randint(0, V, (B,))

dit_2 = dit2.DIT(
         V = V,                # vocabulary size
         C = C,                # embedding dimension
         H = H,                # number of heads
         cond_dim = cond_dim,  # internal dimension for conditioning
         N = 3,                # number of blocks
         p = 0.1               # probability of dropout
         )


In [10]:
dit_2(indices, sigma)



tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], grad_fn=<ViewBackward0>)