# Importing Libraries

In [None]:
import os
import random
import numpy as np
from einops import rearrange

from dataclasses import dataclass

# PyTorch
import torch
import torch.nn as nn

# Hyperparameters

In [None]:
# Patch size = P
# Sequence length = T
# Number of patches = K = T/P
# global embedding dimension = Dg
# local embedding dimension = Dl

In [None]:
@dataclass
class CONFIG:
    debug: bool = False
    
    # Model
    vocab_size: int = 512 # 256 characters + 2 special tokens
    patch_size: int = 4
    sequence_length: int = 1024
    num_patch: int = sequence_length // patch_size
    flash_attention: bool = False
    ## Global model
    global_dim: int = 256
    global_num_layers: int = 4
    global_num_heads: int = 32
    global_dim_feedforward: int = 512
    global_dropout: float = 0.1
    ## Local model
    local_dim: int = 256
    local_num_layers: int = 4
    local_num_heads: int = 8
    local_dim_feedforward: int = 512
    local_dropout: float = 0.1
    ## Special tokens
    PAD_ID: int = 256
    EOS_ID: int = 257
    
    # Dataset
    validation_size: float = 0.2
    
    # Device
    device: torch.device = None
    
    # Training
    batch_size: int = 2
    learning_rate: float = 2e-5
    epochs: int = 100
    
    # Seed
    seed: int = 42
    
config = CONFIG()

# Reproducibility

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    print(f"Seed: {seed}")
    
set_seed(config.seed)

# Device

In [None]:
def configure_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpu = torch.cuda.device_count()
        print("> Running on GPU", end=' | ')
        print("Num of GPUs: ", num_gpu)
    else:
        device = torch.device("cpu")
        print("> Running on CPU")
    return device

CONFIG.device = configure_device()

# Debug

# Dataset

# Model

## Transformer

In [None]:
class TransformerDecoder(nn.Module):

## MEGABYTE

In [None]:
class MegaByteDecoder(nn.Module):
    def __init__(self, config: CONFIG):
        super(MegaByteDecoder, self).__init__()
        self.config = config
        self.patch_size = config.patch_size
        
        self.global_embedding = nn.Embedding(config.vocab_size, config.global_dim)
        self.global_model = TransformerDecoder(
            n_layers=config.global_num_layers,
            nhead=config.global_num_heads,
            d_model=self.patch_size * config.global_dim,
            dim_feedforward=config.global_dim_feedforward,
            dropout=config.global_dropout,
            flash_attention=config.flash_attention
        )
        
        self.local_embedding = nn.Embedding(config.vocab_size, config.local_dim)
        self.local_model = TransformerDecoder(
            n_layers=config.local_num_layers,
            nhead=config.local_num_heads,
            d_model=config.local_dim,
            dim_feedforward=config.local_dim_feedforward,
            dropout=config.local_dropout,
            flash_attention=config.flash_attention
        )
        
        self.pad_id = config.PAD_ID
        self.eos_id = config.EOS_ID
        
    def forward(self, bytes):
        bytes_global, bytes_local = self.prepare_input(bytes)
        
        # Global model
        self.global_positional_encoding = PositionalEncoding(config.global_emb_dim)
        global_bytes_embedded = self.global_embedding(bytes_global)
        global_input = rearrange(global_bytes_embedded, "b (t p) e -> b t (p e)", p=self.patch_size)
        global_output = self.global_model(global_input)
        
        # Local model
        local_bytes_embedded = self.local_embedding(bytes_local)
        self.local_positional_encoding = PositionalEncoding(config.local_emb_dim)
        global_output_rearranged = rearrange(global_output, "b t (p e) -> (b t) p e", p=self.patch_size)
        local_input = local_bytes_embedded + global_output_rearranged
        local_output = self.local_model(local_input)
        
        # Rearrange output
        batch_size = bytes_global.shape[0]
        x = rearrange(local_output, "(b t) l v -> b (t l) v", b=batch_size)
        return x
        
    def prepare_input(self, bytes):
        # Padding for global input
        padding_global = bytes.new(bytes.shape[0], self.patch_size).fill_(self.pad_id)
        bytes_global = torch.cat((padding_global, bytes[:, : -self.patch_size]), -1)
        
        # Rearrange bytes for local input
        bytes_input = rearrange(bytes, "b (t p) -> (b t) p", p=self.patch_size)
        
        # Padding for local input
        padding_local = bytes_input.new(bytes_input.shape[0], self.patch_size).fill_(self.pad_id)
        bytes_local = torch.cat((padding_local, bytes_input[:, : -self.patch_size]), -1)
        
        return bytes_global, bytes_local
        
    @torch.no_grad()
    def generate(self, bytes):
        pass

In [None]:
megabyte = MegaByteDecoder(config)

# Train

# Inference