# Importing Libraries

In [1]:
import os
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from dataclasses import dataclass
from einops import rearrange

# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

# Hyperparameters

In [2]:
# Patch size = P
# Sequence length = T
# Number of patches = K = T/P
# global embedding dimension = Dg
# local embedding dimension = Dl

In [3]:
@dataclass
class CONFIG:
    debug: bool = False
    
    # Model
    V: int = 512  # 256 ASCII characters + 2 special tokens
    P: int = 4
    T: int = 1024
    K: int = T // P  # Number of patches
    
    ## Global model
    D_G: int = 256
    n_layers_G: int = 8
    n_heads_G: int = 16
    d_head_G: int = D_G // n_heads_G
    d_ff_G: int = D_G * 4
    dropout_G: float = 0.2
    
    ## Local model
    D_L: int = 128
    n_layers_L: int = 4
    n_heads_L: int = 8
    d_head_L: int = D_L // n_heads_L
    d_ff_L: int = D_L * 4
    dropout_L: float = 0.2
    
    flash_attention: bool = False
    
    # Vocabulary
    PAD_ID: int = 256
    EOS_ID: int = 257
    
    # data
    validation_size: float = 0.2
    
    # Device
    device: torch.device = None
    
    # Training
    epochs: int = 2
    batch_size: int = 16
    learning_rate: float = 2e-5
    
    # Generation
    max_len: int = 1024
    
    # Seed
    seed: int = 42

# Reproducibility

In [4]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    print(f"Seed: {seed}")
    
set_seed(CONFIG.seed)

Seed: 42


# Device

In [5]:
def configure_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpu = torch.cuda.device_count()
        print("> Running on GPU", end=' | ')
        print("Num of GPUs: ", num_gpu)
    else:
        device = torch.device("cpu")
        print("> Running on CPU")
    return device

CONFIG.device = configure_device()

> Running on GPU | Num of GPUs:  1


# Debug

In [6]:
if CONFIG.debug:
    CONFIG.epochs = 1

# Dataset

In [7]:
# data path
dataset_path = 'data/'
# shakespeare data
shakespeare_dataset = dataset_path + 'shakespeare.txt'

In [8]:
# read the dataset
with open(shakespeare_dataset, 'r', encoding='utf-8') as f:
    shakespeare_text = f.read()

In [9]:
print(shakespeare_text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [10]:
print(f'Total number of characters in the text: {len(shakespeare_text)}')

Total number of characters in the text: 1115394


In [11]:
def encode(text):
    return [char for char in text.encode('utf-8')]

def decode(encoded_text):
    return bytes(encoded_text).decode('utf-8')

In [12]:
print(encode('Hello, World!'))
print(decode(encode('Hello, World!')))

[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33]
Hello, World!


In [13]:
shakespeare_tokens = torch.tensor(encode(shakespeare_text))

# Preprocessing

In [14]:
# Train Validation Split
train_size = int(len(shakespeare_tokens) * (1 - CONFIG.validation_size))
train_tokens = shakespeare_tokens[:train_size]
validation_tokens = shakespeare_tokens[train_size:]
print(f'Total number of tokens in the training set: {len(train_tokens)}')
print(f'Total number of tokens in the validation set: {len(validation_tokens)}')

Total number of tokens in the training set: 892315
Total number of tokens in the validation set: 223079


In [15]:
class ShakespeareDataset(Dataset):
    def __init__(self, tokens, context_length):
        self.tokens = tokens
        self.context_length= context_length
        
    def __len__(self):
        return len(self.tokens) - self.context_length
    
    def __getitem__(self, idx):
        return self.tokens[idx:idx+self.context_length], self.tokens[idx+1:idx+self.context_length+1]

train_dataset = ShakespeareDataset(train_tokens, CONFIG.T)
validation_dataset = ShakespeareDataset(validation_tokens, CONFIG.T)
train_loader = DataLoader(train_dataset, batch_size=CONFIG.batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=CONFIG.batch_size, shuffle=False)

In [16]:
sample_x, sample_y = next(iter(train_loader))
sample_x, sample_y = sample_x.to(CONFIG.device), sample_y.to(CONFIG.device)
print(sample_x.shape, sample_y.shape)
print(sample_x[0])
print(sample_y[0])

torch.Size([16, 1024]) torch.Size([16, 1024])
tensor([101, 101, 100,  ...,  10,  10,  66], device='cuda:0')
tensor([101, 100,  32,  ...,  10,  66,  69], device='cuda:0')


# Model

In [17]:
class Head(nn.Module):
    def __init__(self, d_embed: int, d_head: int):
        super().__init__()
        self.d_embed = d_embed
        self.d_head = d_head
        
        self.query = nn.Linear(self.d_embed, self.d_head)
        self.key = nn.Linear(self.d_embed, self.d_head)
        self.value = nn.Linear(self.d_embed, self.d_head)
        
    def forward(self, x):  # [batch_size, block_size, d_embed]
        query = self.query(x)  # [batch_size, block_size, d_head]
        key = self.key(x)  # [batch_size, block_size, d_head]
        value = self.value(x)  # [batch_size, block_size, d_head]
        
        attention = torch.matmul(query, key.transpose(-2, -1)) / np.sqrt(self.d_head)  # [batch_size, block_size, block_size]
        attention = F.softmax(attention, dim=-1)
        x = torch.matmul(attention, value)  # [batch_size, block_size, d_head]
        return x

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads: int, d_head: int, d_embed: int):
        super().__init__()
        self.heads = nn.ModuleList([Head(d_embed=d_embed, d_head=d_head) for _ in range(n_heads)])
        
    def forward(self, x):  # [batch_size, block_size, d_embed]
        return torch.cat([head(x) for head in self.heads], dim=-1)  # [batch_size, block_size, d_embed]

In [19]:
class MLP(nn.Module):
    def __init__(self, d_embed: int, d_ff: int, dropout: float):
        super().__init__()
        self.dropout = dropout
        
        self.fc1 = nn.Linear(d_embed, d_ff)
        self.fc2 = nn.Linear(d_ff, d_embed)
        
    def forward(self, x):
        x = F.gelu(self.fc1(x))  # [batch_size, block_size, d_ff]
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)  # [batch_size, block_size, d_embed]
        return x

In [20]:
class Decoder(nn.Module):
    def __init__(self, n_heads: int, d_head: int, d_embed: int, d_ff: int, dropout: float):
        super().__init__()
        self.attention = MultiHeadAttention(n_heads=n_heads, d_head=d_head, d_embed=d_embed)
        self.norm1 = nn.LayerNorm(d_embed)
        
        self.mlp = MLP(d_embed=d_embed, d_ff=d_ff, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_embed)
        
    def forward(self, x):  # [batch_size, block_size, d_embed]
        x = x + self.attention(self.norm1(x))  # [batch_size, block_size, d_embed]
        x = x + self.mlp(self.norm2(x))  # [batch_size, block_size, vocab_size]
        return x

In [21]:
class PatchEmbedder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(CONFIG.V, CONFIG.D_G)
        self.positional_embedding = nn.Embedding(CONFIG.T, CONFIG.D_G)
        
    def forward(self, bytes):  # [batch_size, sequence_length]
        assert CONFIG.T % CONFIG.P == 0, "Sequence length must be divisible by patch size"
        
        bytes = self.embedding(bytes) + self.positional_embedding(torch.arange(CONFIG.T, device=bytes.device))  # [batch_size, sequence_len, d_embed]
        bytes = rearrange(bytes, "b (k p) d -> b k (p d)", b=bytes.shape[0], k=CONFIG.K, p=CONFIG.P, d=CONFIG.D_G)  # [batch_size, num_patches, patch_size * d_embed]
        return bytes

In [22]:
class GlobalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedder = PatchEmbedder()
        self.decoder = Decoder(n_heads=CONFIG.n_heads_G, d_head=CONFIG.d_head_G, d_embed=CONFIG.P*CONFIG.D_G, d_ff=CONFIG.d_ff_G, dropout=CONFIG.dropout_G)
        self.linear = nn.Linear(CONFIG.D_G, CONFIG.D_L)
        
    def forward(self, bytes):  # [batch_size, sequence_length]
        x = self.patch_embedder(bytes)  # [batch_size, num_patches, patch_size * d_embed]
        for _ in range(CONFIG.n_layers_G):
            x = self.decoder(x)  # [batch_size, num_patches, patch_size * d_embed]
        x = rearrange(x, "b k (p d) -> (b k) p d", b=bytes.shape[0], k=CONFIG.K, p=CONFIG.P, d=CONFIG.D_G)  # [batch_size * num_patches, patch_size, d_embed]
        x = self.linear(x)  # [batch_size * num_patches, patch_size, local_d_embed]
        return x

In [23]:
class LocalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(CONFIG.V, CONFIG.D_L)
        self.local_transformer = Decoder(n_heads=CONFIG.n_heads_L, d_head=CONFIG.d_head_L, d_embed=CONFIG.D_L, d_ff=CONFIG.d_ff_L, dropout=CONFIG.dropout_L)
        self.linear = nn.Linear(CONFIG.D_L, CONFIG.V)
        
    def forward(self, local_input, global_output):  # [batch_size * num_patches, patch_size], [batch_size * num_patches, patch_size, local_d_embed]
        x = self.embedding(local_input) + global_output  # [batch_size * num_patches, patch_size, local_d_embed]
        for _ in range(CONFIG.n_layers_L):
            x = self.local_transformer(x)  # [batch_size * num_patches, patch_size, local_d_embed]
        x = self.linear(x)  # [batch_size * num_patches, patch_size, vocab_size]
        x = rearrange(x, "(b k) p v -> b (k p) v", k=CONFIG.K, p=CONFIG.P, v=CONFIG.V)  # [batch_size, sequence_length, vocab_size]
        return x

In [24]:
class MEGABYTE(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_model = GlobalModel()
        self.local_model = LocalModel()
        
    def forward(self, bytes):  # [batch_size, sequence_length]
        global_input, local_input = self.prepare_input(bytes)  # [batch_size, sequence_length], [batch_size * num_patches, patch_size]
        global_output = self.global_model(global_input)  # [batch_size * num_patches, patch_size, local_d_embed]
        local_output = self.local_model(local_input, global_output)  # [batch_size, sequence_length, vocab_size]
        return local_output
        
    def prepare_input(self, bytes):  # [batch_size, sequence_length]
        global_padding = bytes.new(bytes.shape[0], CONFIG.P).fill_(CONFIG.PAD_ID)  # [batch_size, patch_size]
        global_input = torch.cat((global_padding, bytes[:, :-CONFIG.P]), dim=-1)  # [batch_size, sequence_length]
        
        bytes_input = rearrange(bytes, "b (k p) -> (b k) p", p=CONFIG.P)  # [batch_size * num_patches, patch_size]
        local_padding = bytes_input.new(bytes_input.shape[0], 1).fill_(CONFIG.PAD_ID)  # [patch_size]
        local_input = torch.cat((local_padding, bytes_input[:, :-1]), dim=-1)  # [batch_size * num_patches, patch_size]
        return global_input, local_input
    
    def loss(self, bytes, y):  # y: [batch_size, sequence_length]
        y = rearrange(y, "b t -> (b t)")  # [batch_size * sequence_length]
        logits = self.forward(bytes)  # [batch_size, sequence_length, vocab_size]
        logits = rearrange(logits, "b t v -> (b t) v", v=CONFIG.V)  # [batch_size * sequence_length, vocab_size]
        return F.cross_entropy(logits, y, ignore_index=CONFIG.PAD_ID)
    
    @torch.no_grad()
    def generate(self, bytes, max_len=None):
        self.eval()
        if max_len is None:
            max_len = CONFIG.max_len
        generated = bytes
        for _ in range(max_len):
            logits = self.forward(bytes)  # [batch_size, sequence_length, vocab_size]
            logits = logits[:, -1, :]  # [batch_size, vocab_size], get the last token
            _, next_token = torch.max(logits, dim=-1)  # [batch_size]
            generated = torch.cat((generated, next_token.unsqueeze(-1)), dim=-1)  # [batch_size, sequence_length]
        return generated

In [25]:
megabyte = MEGABYTE().to(CONFIG.device)
print(megabyte)

MEGABYTE(
  (global_model): GlobalModel(
    (patch_embedder): PatchEmbedder(
      (embedding): Embedding(512, 256)
      (positional_embedding): Embedding(1024, 256)
    )
    (decoder): Decoder(
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-15): 16 x Head(
            (query): Linear(in_features=1024, out_features=16, bias=True)
            (key): Linear(in_features=1024, out_features=16, bias=True)
            (value): Linear(in_features=1024, out_features=16, bias=True)
          )
        )
      )
      (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=1024, out_features=1024, bias=True)
        (fc2): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (linear): Linear(in_features=256, out_features=128, bias=True)
  )
  (local_model): LocalModel(
    (embedding): Embedding(512, 128)
    (l

In [26]:
loss = megabyte.loss(sample_x, sample_y)
print(loss)

generated = megabyte.generate(sample_x)
print(decode(generated[0].tolist()[0]))

RuntimeError: The size of tensor a (1024) must match the size of tensor b (256) at non-singleton dimension 2

# Training

In [None]:
def train(model):
    model = model.to(CONFIG.device)
    #criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG.learning_rate)
    
    train_loss = []
    validation_loss = []
    
    for epoch in range(CONFIG.epochs):
        model.train()
        running_loss = 0.0
        for x, y in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{CONFIG.epochs}'):
            x, y = x.to(CONFIG.device), y.to(CONFIG.device)
            optimizer.zero_grad()
            loss = model.loss(x, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_loss.append(running_loss / len(train_loader))
        print(f'Training Loss: {running_loss / len(train_loader)}')
        
        model.eval()
        running_loss = 0.0
        with torch.no_grad():
            for x, y in validation_loader:
                x, y = x.to(CONFIG.device), y.to(CONFIG.device)
                loss = model.loss(x, y)
                running_loss += loss.item()
                
        validation_loss.append(running_loss / len(validation_loader))
        print(f'Validation Loss: {running_loss / len(validation_loader)}')
        
    plt.plot(train_loss, label='Training Loss')
    plt.plot(validation_loss, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
train(megabyte)

# Inference

In [None]:
generated = megabyte.generate(sample_x)
print(decode(generated[0].tolist()))

# Evaluation