In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from merge.modules.config import TransformerConfig
from merge.modules.transformer import Transformer


In [None]:
from merge.preprocessing.tokenizer import AutoTokenizer


tokenizer = AutoTokenizer.create("bert-base-uncased")

In [None]:
tokenizer.encode("Hello, my dog is cute")

In [None]:
from transformers import AutoTokenizer as HFAutoTokenizer


class HFTokenizerWrapper:
    """Wrapper for HuggingFace tokenizers that implements our interface"""

    def __init__(self, name: str):
        super().__init__()
        self.tokenizer = HFAutoTokenizer.from_pretrained(name)

    def __getattr__(self, name):
        return getattr(self.tokenizer, name)
    
    def __call__(self, *args, **kwargs):
        # Delegate the call to the underlying tokenizer
        return self.tokenizer(*args, **kwargs)

In [None]:
tokenizer = HFTokenizerWrapper("bert-base-uncased")

In [None]:
text = "This is a test"
tokenizer(text, 2)

In [None]:
hftokenizer = HFAutoTokenizer.from_pretrained("bert-base-uncased")
text = "This is a test"
hftokenizer(text)

In [None]:
from datasets import load_dataset

In [None]:
# use name="sample-10BT" to use the 10BT sample
fw = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", cache_dir="./hf_cache", num_proc=8)


In [None]:
fw = load_dataset("roneneldan/TinyStories", split="train", cache_dir="./hf_cache", num_proc=8)


In [None]:
from merge.preprocessing.tokenizer import AutoTokenizer

tokenizer = AutoTokenizer.create("EleutherAI/gpt-neo-125m")

In [None]:
fw_tokenized = fw.map(lambda x: tokenizer(x["text"]), batched=True)

In [None]:
fw_tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])

In [98]:
tokenizer.pad_token_id, tokenizer.pad_token

(50256, '<|endoftext|>')

In [102]:
max_seq_length = 1024

fw_test = fw.map(
    lambda x: tokenizer(
        x["text"],
        truncation=True,
        max_length=max_seq_length,
        padding="max_length"
    ),
    batched=True
)

Map:  15%|█▌        | 325000/2119719 [00:59<05:29, 5438.60 examples/s]


KeyboardInterrupt: 

In [79]:
import torch.nn.functional as F
from torch import Tensor
from jaxtyping import Int, Float

def process_sequences(
    dataset ,
    min_seq_length: int,
    max_seq_length: int,
):
    # First filter short sequences
    dataset = dataset.filter(lambda x: len(x['input_ids']) >= min_seq_length)
    
    # Then handle splitting and padding
    def split_and_pad(example):
        input_ids: Int[Tensor, "batch len"]  = example['input_ids']
        attention_mask: Int[Tensor, "batch len"] = example['attention_mask']
        total_length = input_ids.shape[1]
        
        # If sequence fits within max_length, just pad it
        if total_length <= max_seq_length:
            # Pad with zeros
            padding_length = max_seq_length - total_length
            return {
                'input_ids': F.pad(input_ids, (0, padding_length), value=0),
                'attention_mask': F.pad(attention_mask, (0, padding_length), value=0)
            }
        
        # If sequence is too long, split it into chunks
        chunks = {
            'input_ids': [],
            'attention_mask': []
        }
        for start in range(0, total_length, max_seq_length):
            end = start + max_seq_length
            chunk_ids = input_ids[start:end]
            chunk_mask = attention_mask[start:end]
            
            # Only keep chunk if it's long enough
            if len(chunk_ids) >= min_seq_length:
                # Pad if necessary
                if len(chunk_ids) < max_seq_length:
                    padding_length = max_seq_length - len(chunk_ids)
                    chunk_ids = F.pad(chunk_ids, (0, padding_length), value=0)
                    chunk_mask = F.pad(chunk_mask, (0, padding_length), value=0)
                chunks['input_ids'].append(chunk_ids)
                chunks['attention_mask'].append(chunk_mask)
        return chunks

    # Apply the transformation
    return dataset.map(
        split_and_pad,
        remove_columns=dataset.column_names,
        batched=True,
        batch_size=1,  # Process one at a time since outputs can have different lengths
    )