# Meta MegaByte Model

## HParams

In [1]:
SEQ_LEN = 128
BATCH_SIZE = 1

## Load Dataset

In [36]:
NUM_BATCHES = int(1e5) - (2000 + 1700)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
PRIME_LEN = 100
SEQ_LEN = 8192

In [37]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

In [38]:
import gzip
import torch
import numpy as np

with gzip.open('./data/enwik8.gz') as file:
    # strip original to 95M?
    x = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()

    # 95M, 5M (train, valid)
    train_x, valid_x = np.split(x, [int(90e6)])
    data_train, data_val = map(torch.from_numpy, (train_x, valid_x))

In [39]:
from torch.utils.data import DataLoader, Dataset

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

In [40]:
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

In [41]:
# def decode_token(token):
#     return str(chr(max(32, token)))

def decode_token(token):
    if 32 <= token <= 126:
        return str(chr(token))
    else:
        return ''

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

In [42]:
txt = TextSamplerDataset(data_train, 128)

In [43]:
decode_tokens(txt[0])

' tires, a correct exhaust, and other street-legal items. The tech official (assuming the vehicle passes) will then use his white'

## View Decoded Dataset

In [34]:
decode_tokens(x[5000:5500])

"y:&amp;#945;&amp;#957;&amp;#945;&amp;#961;&amp;#967;&amp;#943;&amp;#945;|&amp;#945;&amp;#957;&amp;#945;&amp;#961;&amp;#967;&amp;#943;&amp;#945;]]'' (&quot;without [[archon]]s (ruler, chief, king)&quot;). Anarchism as a [[political philosophy]], is the belief that ''rulers'' are unnecessary and should be abolished, although there are differing interpretations of what this means. Anarchism also refers to related [[social movement]]s) that advocate the elimination of authoritarian institutions, par"

# Training

In [25]:
import megabyte

model = megabyte.MEGABYTE(
    num_tokens = 256,
    dim = (768, 512, 256),
    depth = (6, 4, 2),
    max_seq_len = (512, 4, 4),
    flash_attn = False
).cuda()

In [44]:
import contextlib
import random
import tqdm

with open('output.txt', 'w') as f:
    with contextlib.redirect_stdout(f):
        optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

        for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
            model.train()

            for __ in range(GRADIENT_ACCUMULATE_EVERY):
                loss = model(next(train_loader), return_loss = True)
                loss.backward()

            print(f'training loss: {loss.item()}')
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optim.step()
            optim.zero_grad()

            if i % VALIDATE_EVERY == 0:
                model.eval()
                with torch.no_grad():
                    loss = model(next(val_loader), return_loss = True)
                    print(f'validation loss: {loss.item()}')

            if i != 0 and i % GENERATE_EVERY == 0:
                model.eval()
                inp = random.choice(val_dataset)[:-1]
                prime_inp = inp[:PRIME_LEN]
                prime = decode_tokens(prime_inp)
                print(f'%s \n\n %s', (prime, '*' * 100))

                sample = model.generate(prime_inp[None, :])
                sample = sample.flatten(1)

                output_str = decode_tokens(sample[0][PRIME_LEN:])
                try:
                    print(output_str)
                except:
                    print("NOTE: ERROR DECODING STRING")

100%|██████████| 8092/8092 [02:21<00:00, 57.14it/s]7,  1.43s/it]
100%|██████████| 8092/8092 [02:17<00:00, 58.89it/s]2,  1.21s/it] 
100%|██████████| 8092/8092 [02:21<00:00, 57.31it/s]06,  1.26s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.65it/s]48,  1.27s/it] 
100%|██████████| 8092/8092 [02:23<00:00, 56.55it/s]2:01,  1.55s/it]
100%|██████████| 8092/8092 [02:19<00:00, 58.01it/s]6:32,  1.24s/it] 
100%|██████████| 8092/8092 [02:23<00:00, 56.53it/s]3:39,  1.36s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.62it/s]5:58,  1.27s/it] 
100%|██████████| 8092/8092 [02:17<00:00, 58.71it/s]2:20,  1.21s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.95it/s]5:25,  1.27s/it] 
100%|██████████| 8092/8092 [02:21<00:00, 56.99it/s]0:54,  1.26s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.93it/s]5:07,  1.27s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.94it/s]5:23,  1.27s/it] 
100%|██████████| 8092/8092 [02:21<00:00, 57.13it/s]7:32,  1.50s/it] 
100%|██████████| 8092/8092 [02:22<00:00, 56.93

KeyboardInterrupt: 

In [45]:
torch.save(model.state_dict(), "./megabyte_25k_1.2836014032363892.pt")

In [46]:
torch.save(optim.state_dict(), "./megabyte(optim)_25k_1.2836014032363892.pt")

## Predict

In [47]:
def pred(prompt, prompt_len=100):
    model.eval()
    prime_inp = inp[:prompt_len]
    sample = model.generate(prime_inp[None, :])
    sample = sample.flatten(1)

    output_str = decode_tokens(sample[0][PRIME_LEN:])
    print(output_str)

In [48]:
pred("hi")

100%|██████████| 8092/8092 [02:24<00:00, 55.90it/s]


 a confusion, an influence onew [[progressionative]] spanning, opponents of and free of a horistian claim of pages, and [[bategory]] forms oference during the [[15 years]] and the most expe color increaseduced to its promain of the early theory became org.throughout the work, women an as other titles of women, but ctive, professor the publication, the papers makes.  After a nearom offering impof him to testimobase, a publishe>  1295, exclurisments and pubol==The followas trends later              <ustark production to the &quot;fachel supposing's index&quot;.  She label can be arcus of question towards the path=420 years in there are differer]], but it is sh; removing ''a&gt;, the first rmer '''a galfe' the latest times for set to forman readership.  of [[Scotland]] about 17% to oveek half 100. Undy Paris.  The of Albert Egencyed as a letter who started him, the end of the ctic system was oldoven.  (Both ra developed to berlindung in Chis created by the cage and wrote and the encyclus 