In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass
from model import Patch_Embedder, GlobalModel, LocalModel
import numpy as np

In [23]:
@dataclass
class Megabyte_Config:
    debug : bool = False
    dtype : torch.dtype = torch.float16

    #initialization
    init_range: float = 0.02
    layer_norm_eps: float = 1e-5

    #patch_embedder
    patch_size: int = 4

    #global model
    global_d_pre_patch: int = 32
    global_d_model =  global_d_pre_patch * patch_size
    global_n_heads = 8
    global_d_head = 8
    global_n_layers = 2
    global_d_mlp = 64

    d_vocab : int = 256
    
    #local model
    #global_dropout = 0.1
    local_d_model = 16
    local_n_heads = 4
    local_d_head = 4
    local_n_layers = 2
    local_d_mlp = 8

    #task
    classification : bool = True
    n_classes = 10

#TODO: should there be special bytes for image_start, image_end, text_start, text_end?

class Megabyte(nn.Module):
    def __init__(self, cfg: Megabyte_Config):
        super().__init__()
        self.cfg = cfg
        self._name = "Megabyte"
        self.patch_embedder = Patch_Embedder(cfg)
        self.global_model = GlobalModel(cfg)
        self.local_model = LocalModel(cfg)
        self.local_pad = nn.Parameter(torch.randn(1, 1, cfg.local_d_model))
        self.global_to_local_proj = nn.Linear(cfg.global_d_pre_patch, cfg.local_d_model)
        self.byte_embedding_local = nn.Embedding(256, cfg.local_d_model)
        self.unembed = nn.Linear(cfg.local_d_model, 256)

        if self.cfg.classification:
            self.global_class_token = nn.Parameter(torch.randn(1, cfg.global_d_model))
            self.local_class_token = nn.Parameter(torch.randn(1, cfg.local_d_model))
            self.classification_head = nn.Linear(cfg.local_d_model, cfg.n_classes)

    def get_param_count(self) -> int:
        '''returns the number of parameters in the model'''
        return sum(p.numel() for p in self.parameters() if p.requires_grad) # all params with gradients

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and global model processing
        if self.cfg.classification:
            #add class token
            batch_size = x.shape[0]
            global_class_token = self.global_class_token.unsqueeze(0).repeat(batch_size, 1, 1)#.t  # Shape: [batch_size, 1, local_d_model]
            local_class_token = self.local_class_token.unsqueeze(0).repeat(batch_size, self.cfg.patch_size, 1)#.transpose(1,2)  # Shape: [batch_size, 1, local_d_model]

        #print("input x.shape", x.shape)
        #compute bytes_embedding for local model and offset by 1
        byte_embeddings_local = self.byte_embedding_local(x) # dim local_d_model

        # Insert class token at the start of the sequence

        if self.cfg.classification:
            byte_embeddings_local = torch.cat([local_class_token, byte_embeddings_local], dim=1)
        #print("after cat byte_embeddings_local.shape", byte_embeddings_local.shape)
        if self.cfg.debug : print("byte_embeddings_local.shape", byte_embeddings_local.shape)
        offset_byte_embeddings_local = F.pad(byte_embeddings_local, (0, 0, 1, 0), "constant", 0)
        if self.cfg.debug :  print("offset_byte_embeddings_local.shape", offset_byte_embeddings_local.shape)
        offset_byte_embeddings_local[:, 0, :] = self.local_pad
        offset_byte_embeddings_local = offset_byte_embeddings_local[:, :-1, :] # remove last byte
    
        #print("offset_byte_embeddings_local.shape", offset_byte_embeddings_local.shape)
    
        if self.cfg.debug : print("input tensor", x.shape)
        embedded = self.patch_embedder(x) 
        # add class token
        #print("pre embedded.shape", embedded.shape)
        if self.cfg.classification:
            embedded = torch.cat([global_class_token, embedded], dim=1)
        #print("post embedded.shape", embedded.shape)
        global_out = self.global_model(embedded)
        batch_size, num_patches, _ = global_out.shape


        reshaped = global_out.view(batch_size, num_patches, self.cfg.patch_size , self.cfg.global_d_pre_patch)
        if self.cfg.debug : print("shape offset_byte_embeddings_local", offset_byte_embeddings_local.shape)
        offset_byte_embeddings_local = offset_byte_embeddings_local.view(batch_size, num_patches, self.cfg.patch_size , self.cfg.local_d_model)

        if self.cfg.debug : print("reshaped.shape", reshaped.shape)
        # Project each position to the dimension of the local model
        projected = self.global_to_local_proj(reshaped)

        if self.cfg.debug : print("projected.shape", projected.shape)
        # Combine with byte embeddings
        if self.cfg.debug : print("offset_byte_embeddings_local.shape", offset_byte_embeddings_local.shape)
        if self.cfg.debug : print("projected.shape", projected.shape)
        
        combined = projected + offset_byte_embeddings_local

        # Process with local model
        if self.cfg.debug :  print("combined.shape", combined.shape)
        local_out = self.local_model(combined) # shape [batch, n_patches, patch_size, local_d_model]
        unembedded = self.unembed(local_out)

        batch_size, num_patches, patch_size, d_local_model = unembedded.shape
        unembedded_flat = unembedded.view(batch_size * num_patches * patch_size, d_local_model)
        # Apply softmax to compute probability distribution over the vocabulary
        probs_flat = F.softmax(unembedded_flat, dim=-1)
        if self.cfg.debug :  print("probs_flat.shape", probs_flat.shape)
        # Reshape back to original shape
        probs = probs_flat.reshape(batch_size, num_patches, patch_size, d_local_model)
        if self.cfg.debug : print("probs.shape", probs.shape) 
        return probs

In [24]:
def text_to_bytes(texts: list[str]) -> torch.Tensor:
    '''converts text to bytes and returns [batch, seq_len] tensor 
    and pads to max_seq_len in batch with zeros'''
    return torch.nn.utils.rnn.pad_sequence([torch.Tensor([ord(c) for c in text]).to(dtype=torch.long) for text in texts], batch_first=True)

def bytes_to_text(bytes: torch.Tensor) -> list[str]:
    '''converts bytes in torch.Tensor to text'''
    texts = []
    bytes = bytes.to(dtype=torch.uint8)
    for i in range(bytes.size(0)): # iter over batch
        texts.append(''.join([chr(b) for b in bytes[i].tolist()]))
    return texts



In [14]:
def megabyte_collate_fn(batch, type="image"):
    if type == "image":
        images, labels = zip(*batch)
        images = np.stack(images) # we need it to be a numpy array to use patch_images
        bytes = img_to_bytes(images) # should be integer type
        labels = torch.Tensor(labels).to(torch.int64)
        #print("bytes.shape", bytes.shape)
    elif type == "text":
        texts, labels = zip(*batch)
        bytes = torch.Tensor(text_to_bytes(texts))
    return bytes, labels 

In [49]:
#import dataset
from utils import text_to_bytes, bytes_to_text
import datasets
from torch.utils.data import DataLoader

text = "from marcy to madison square"
bytes = text_to_bytes([text, "to aim at yout, you to smithereens, cock sucker take one from your team and i need you to rememebr one thing"])
model = Megabyte(Megabyte_Config())
out = model.forward(bytes)


In [50]:
print(out.shape)
print(bytes.shape)


torch.Size([216, 256])
torch.Size([216])


out.shape torch.Size([216, 256])
bytes.shape torch.Size([216])
loss tensor(5.5448, grad_fn=<NllLossBackward0>)


In [16]:
from functools import partial
collate_fn = partial(megabyte_collate_fn, type="text")
#load dataset 
dataset = datasets.load_dataset('tiny_shakespeare')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

epochs = 1
device = "mps"
model.to(device)

for epoch in range(epochs):
    for batch in dataloader:
        text_bytes = batch.to(device) # ?
        out = model.forward(text_bytes)
        batch_dim, patch_dim, local_dim, token= out.shape
        out = out[:,1:, : , :].reshape(batch_dim * (patch_dim-1)*local_dim, token) # resha
        text_bytes = bytes.view(-1)

