In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
import numpy as np

In [2]:
#from megabyte.model import Patch_Embedder, GlobalModel, LocalModel
from model import Patch_Embedder, GlobalModel, LocalModel

In [20]:
@dataclass
class Megabyte_Config:
    debug : bool = False
    dtype : torch.dtype = torch.float16

    #initialization
    init_range: float = 0.02
    layer_norm_eps: float = 1e-5

    #patch_embedder
    patch_size: int = 4

    #global model
    global_d_pre_patch: int = 64
    global_d_model =  global_d_pre_patch * patch_size
    global_n_heads = 8
    global_d_head = 64
    global_n_layers = 2
    global_d_mlp = 128

    d_vocab : int = 256

    #local model
    #global_dropout = 0.1
    local_d_model = 64
    local_n_heads = 8
    local_d_head = 16
    local_n_layers = 2 #TODO: should this be 8?
    local_d_mlp = 32

    #task
    classification : bool = False

#TODO: should there be special bytes for image_start, image_end, text_start, text_end?

class Megabyte(nn.Module):
    def __init__(self, cfg: Megabyte_Config):
        super().__init__()
        self.cfg = cfg
        self._name = "Megabyte"
        self.patch_embedder = Patch_Embedder(cfg)
        self.global_model = GlobalModel(cfg)
        self.local_model = LocalModel(cfg)
        self.local_pad = nn.Parameter(torch.randn(1, 1, cfg.local_d_model))
        self.global_to_local_proj = nn.Linear(cfg.global_d_pre_patch, cfg.local_d_model)
        self.byte_embedding_local = nn.Embedding(256, cfg.local_d_model)
        self.unembed = nn.Linear(cfg.local_d_model, 256)

        if self.cfg.classification:
            self.global_class_token = nn.Parameter(torch.randn(1, cfg.global_d_model))
            self.local_class_token = nn.Parameter(torch.randn(1, cfg.local_d_model))
            self.classification_head = nn.Linear(cfg.local_d_model, cfg.n_classes)

    def get_param_count(self) -> int:
        '''returns the number of parameters in the model'''
        return sum(p.numel() for p in self.parameters() if p.requires_grad) # all params with gradients

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and global model processing
        if self.cfg.classification:
            #add class token
            batch_size = x.shape[0]
            global_class_token = self.global_class_token.unsqueeze(0).repeat(batch_size, 1, 1)#.t  # Shape: [batch_size, 1, local_d_model]
            local_class_token = self.local_class_token.unsqueeze(0).repeat(batch_size, self.cfg.patch_size, 1)#.transpose(1,2)  # Shape: [batch_size, 1, local_d_model]

        #compute bytes_embedding for local model and offset by 1
        #print("input x.shape", x.shape)
        original_seq_length = x.size(1) #original sequence length 

        byte_embeddings_local = self.byte_embedding_local(x) # dim local_d_model
        seq_length = x.size(1)
        padded_length = (seq_length + self.cfg.patch_size - 1) // self.cfg.patch_size * self.cfg.patch_size
        pad_length = padded_length - seq_length
        offset_byte_embeddings_local = F.pad(byte_embeddings_local, (0, 0, 0, pad_length), "constant", 0) # shape [batch, seq_length  , local_d_model]

        # Insert class token at the start of the sequence
        if self.cfg.classification:
            byte_embeddings_local = torch.cat([local_class_token, byte_embeddings_local], dim=1)
        
        offset_byte_embeddings_local[:, 0, :] = self.local_pad
        embedded = self.patch_embedder(x)
        # add class token
        if self.cfg.classification:
            embedded = torch.cat([global_class_token, embedded], dim=1) # Shape: [batch_size, 1 + n_patches, global_d_model]
        
        global_out = self.global_model(embedded)
        batch_size, num_patches = global_out.shape[:2]

        reshaped = global_out.view(batch_size, num_patches, self.cfg.patch_size , self.cfg.global_d_pre_patch)
        offset_byte_embeddings_local = offset_byte_embeddings_local.view(batch_size, num_patches, self.cfg.patch_size , self.cfg.local_d_model)

        # Project each position to the dimension of the local model
        projected = self.global_to_local_proj(reshaped)

        #print("projected.shape", projected.shape)
        # Combine with byte embeddings
        combined = projected + offset_byte_embeddings_local

        # Process with local model
        local_out = self.local_model(combined) # shape [batch, n_patches, patch_size, local_d_model]
        unembedded = self.unembed(local_out) # shape [batch, n_patches, patch_size, 256]

        batch_size, num_patches, patch_size, d_local_model = unembedded.shape
        unembedded_flat = unembedded.view(batch_size * num_patches * patch_size, d_local_model)
        # Apply softmax to compute probability distribution over the vocabulary
        probs_flat = F.softmax(unembedded_flat, dim=-1)
        # Reshape back to original shape
        probs = probs_flat.reshape(batch_size, num_patches* patch_size, d_local_model)

        
        if original_seq_length < padded_length:
            # Compute the number of patches in the original sequence
            original_num_patches = original_seq_length // self.cfg.patch_size
            # If the original sequence length isn't a multiple of patch_size, add one to original_num_patches
            if original_seq_length % self.cfg.patch_size != 0:
                original_num_patches += 1
            #print("original_num_patches", original_num_patches)
            # Remove the padding from the sequence
            probs = probs[:, :original_seq_length, :]
            
        #print("output probs.shape", probs.shape)
        return probs




In [26]:
import torch
def text_to_bytes(texts: list[str]) -> torch.Tensor:
    '''converts text to bytes and returns [batch, seq_len] tensor 
    and pads to max_seq_len in batch with zeros'''
    n_bytes = torch.nn.utils.rnn.pad_sequence([torch.Tensor(list(text.encode('ascii', 'replace'))).to(dtype=torch.long) for text in texts], batch_first=True)
    """ if torch.max(n_bytes) > 255:
        raise ValueError('max byte value is greater than 255') """
    return n_bytes 

def bytes_to_text(bytes: torch.Tensor) -> list[str]:
    '''converts bytes in torch.Tensor to text'''
    texts = []
    bytes = bytes.to(dtype=torch.uint8)
    for i in range(bytes.size(0)): # iter over batch
        texts.append(''.join([chr(b) for b in bytes[i].tolist()]))
    return texts

def megabyte_collate_fn(batch, type="image"):
    if type == "image":
        images, labels = zip(*batch)
        images = np.stack(images) # we need it to be a numpy array to use patch_images
        out_bytes = img_to_bytes(images) # should be integer type
        labels = torch.Tensor(labels).to(torch.int64)
        return out_bytes, labels
        #print("bytes.shape", bytes.shape)
    elif type == "text":
        texts = [item['text'] for item in batch]  # Extract 'text' from each item in the batch
        out_bytes = torch.Tensor(text_to_bytes(texts))  # Convert texts to bytes
        return out_bytes


In [22]:
#import dataset
#from megabyte.utils import text_to_bytes, bytes_to_text
from utils import text_to_bytes, bytes_to_text
import datasets
from torch.utils.data import DataLoader
from datasets import load_dataset
from functools import partial


In [6]:
model = Megabyte(Megabyte_Config())


In [7]:
dataset = load_dataset("roneneldan/TinyStories")


Repo card metadata block was not found. Setting CardData to empty.


In [27]:
train_dataset = dataset["train"]
test_dataset = dataset["validation"]
collate_fn = partial(megabyte_collate_fn, type="text")

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)


In [9]:
loss_fn = torch.nn.CrossEntropyLoss()

In [28]:
train_losses = []
test_losses = []
epochs = 10
device = "cpu"
model.to(device)
model.train()
for epoch in range(epochs):
    for batch in train_dataloader:
        text_bytes = batch.to(device)
        max_value = text_bytes.reshape(text_bytes.size(0)*text_bytes.size(1))
        print("max_value.shape", max_value.shape)
        max_value = max_value[max_value > 255]
        print("max_value.shape", max_value.shape)
        out = model.forward(text_bytes)
        out = out.reshape(out.size(0)*out.size(1), out.size(2))
        text_bytes = text_bytes.view(-1)
        loss = loss_fn(out, text_bytes)
        loss.backward()
        train_losses.append(loss.item())
        print(loss.item())
        break
        
    with torch.no_grad():
        model.eval()
        avg_test_loss = []
        for batch in test_dataloader:
            text_bytes = batch.to(device)
            out = model.forward(text_bytes)
            #batch_dim, patch_dim, local_dim, token= out.shape
            text_bytes = text_bytes.view(-1)
            out = out.reshape(out.size(0)*out.size(1), out.size(2))
            loss = loss_fn(out, text_bytes)
            avg_test_loss.append(loss.item())
            test_losses.append(loss.item())
        print(f"val_loss_mean Epoch {epoch+1}/{epochs}.. is {np.mean(avg_test_loss)}")
    model.train()

#save model
torch.save(model.state_dict(), "nr_1_model.pt")


max_value.shape torch.Size([128192])
max_value.shape torch.Size([0])
5.547283172607422


[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 