# Importing Libraries

In [None]:
import os
import random
import numpy as np
from einops import rearrange

from dataclasses import dataclass

# PyTorch
import torch
import torch.nn as nn

# Hyperparameters

In [None]:
# Patch size = P
# Sequence length = T
# Number of patches = K = T/P
# global embedding dimension = Dg
# local embedding dimension = Dl

In [None]:
@dataclass
class CONFIG:
    debug: bool = False
    
    # Model
    vocab_size: int = 512 # 256 characters + 2 special tokens
    patch_size: int = 4
    sequence_length: int = 1024
    patch_num: int = sequence_length // patch_size
    ## Global model
    global_emb_dim: int = 512
    global_num_layers: int = 4
    global_num_heads: int = 32
    ## Local model
    local_emb_dim: int = 128
    local_num_layers: int = 4
    local_num_heads: int = 8
    ## Special tokens
    PAD_ID: int = 256
    EOS_ID: int = 257
    
    # Dataset
    validation_size: float = 0.2
    
    # Device
    device: torch.device = None
    
    # Training
    batch_size: int = 2
    learning_rate: float = 2e-5
    epochs: int = 100
    
    # Seed
    seed: int = 42
    
config = CONFIG()

# Reproducibility

In [None]:
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    print(f"Seed: {seed}")
    
set_seed(config.seed)

# Device

In [None]:
def configure_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpu = torch.cuda.device_count()
        print("> Running on GPU", end=' | ')
        print("Num of GPUs: ", num_gpu)
    else:
        device = torch.device("cpu")
        print("> Running on CPU")
    return device

CONFIG.device = configure_device()

# Debug

# Dataset

# Model

## Transformer

## MegaByte

In [None]:
class GlobalModel(nn.Module):
    def __init__(self, config: CONFIG):
        super(GlobalModel, self).__init__()
        self.config = config
        
        self.emb_dim = config.global_emb_dim
        self.num_layers = config.global_num_layers
        self.num_heads = config.global_num_heads
        
        
        
    def forward(self, x):
        pass
    

In [None]:
class LocalModel(nn.Module):
    def __init__(self, config: CONFIG):
        super(LocalModel, self).__init__()
        self.config = config
        
        self.emb_dim = config.local_emb_dim
        self.num_layers = config.local_num_layers
        self.num_heads = config.local_num_heads
        
        
        
    def forward(self, x):
        pass
    

In [None]:
class MegaByteDecoder(nn.Module):
    def __init__(self, config: CONFIG):
        super(MegaByteDecoder, self).__init__()
        self.config = config
        self.global_model = GlobalModel(config)
        self.local_model = LocalModel(config)
        
        self.pad_id = config.PAD_ID
        
    def forward(self, bytes):
        pass
    
    def prepare_input(self, bytes):
        # Padding for global input
        padding_global = bytes.new(bytes.shape[0], self.patch_size).fill_(self.pad_id)
        bytes_global = torch.cat((padding_global, bytes[:, : -self.patch_size]), -1)
        
        # Rearrange bytes for local input
        bytes_input = rearrange(bytes, "b (t p) -> (b t) p", p=self.patch_size)
        
        # Padding for local input
        padding_local = bytes_input.new(bytes_input.shape[0], self.patch_size).fill_(self.pad_id)
        bytes_local = torch.cat((padding_local, bytes_input[:, : -self.patch_size]), -1)
        
        return bytes_global, bytes_local
        
    def generate(self, bytes):
        pass

In [None]:
megabyte = MegaByteDecoder(config)

# Train

# Inference