# Week 3 Day 12: Tokenization at Scale & Sequence Preparation

## Overview
In this notebook, we'll explore tokenization at scale and efficient sequence preparation for language model training. We'll focus on:
- Implementing efficient tokenization with HuggingFace Tokenizers
- Comparing different tokenization algorithms
- Implementing sequence packing and masking for efficient training

In [None]:
# Import necessary libraries
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import re
import requests
import time
import random
import os
from typing import List, Dict, Tuple, Optional
from collections import Counter
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors
from tokenizers.models import BPE, WordPiece, Unigram
from tokenizers.trainers import BpeTrainer, WordPieceTrainer, UnigramTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel

# Set random seeds and device
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Preparation and Tokenizer Training

In [None]:
# Download and prepare data
text_urls = {
    "fiction": "https://www.gutenberg.org/files/1342/1342-0.txt",
    "science": "https://www.gutenberg.org/files/2009/2009-0.txt"
}
os.makedirs("data", exist_ok=True)
all_text_path = "data/all_texts.txt"
with open(all_text_path, "w", encoding="utf-8") as f_all:
    for name, url in text_urls.items():
        text = requests.get(url).text
        f_all.write(text)

# Train a BPE tokenizer
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)
trainer = BpeTrainer(vocab_size=8000, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train([all_text_path], trainer)
tokenizer.decoder = decoders.ByteLevel()
os.makedirs("tokenizers", exist_ok=True)
tokenizer.save("tokenizers/bpe_tokenizer.json")
print('BPE tokenizer trained and saved.')

## 2. Sequence Packing

In [None]:
def create_packed_sequences(texts, tokenizer, max_seq_len=128):
    token_ids = [tokenizer.encode(text).ids for text in texts]
    pad_id = tokenizer.token_to_id("[PAD]") or 0
    
    packed_sequences, sequence_mappings = [], []
    current_sequence, current_mapping, current_length = [], [], 0
    
    token_ids_with_idx = sorted([(i, ids) for i, ids in enumerate(token_ids)], key=lambda x: len(x[1]), reverse=True)
    
    for orig_idx, ids in token_ids_with_idx:
        if current_length + len(ids) > max_seq_len:
            padding = [pad_id] * (max_seq_len - current_length)
            packed_sequences.append(current_sequence + padding)
            sequence_mappings.append(current_mapping + [-1] * len(padding))
            current_sequence, current_mapping, current_length = [], [], 0
        
        current_sequence.extend(ids)
        current_mapping.extend([orig_idx] * len(ids))
        current_length += len(ids)
        
    if current_length > 0:
        padding = [pad_id] * (max_seq_len - current_length)
        packed_sequences.append(current_sequence + padding)
        sequence_mappings.append(current_mapping + [-1] * len(padding))
        
    return packed_sequences, sequence_mappings

with open(all_text_path, 'r', encoding='utf-8') as f:
    text_content = f.read()
sentences = re.split(r'(?<=[.!?])\s+', text_content)
sample_sentences = [s.strip() for s in sentences if s.strip()][:50]

packed_sequences, sequence_mappings = create_packed_sequences(sample_sentences, tokenizer)
print(f'Created {len(packed_sequences)} packed sequences.')

## 3. Attention Masking for Packed Sequences

In [None]:
def create_packed_attention_mask(sequence_mapping, seq_len):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    
    # Create a broadcastable mapping for comparison
    mapping_tensor = torch.tensor(sequence_mapping, dtype=torch.long).unsqueeze(1)
    # Prevent attention between different sequences
    inter_sequence_mask = mapping_tensor != mapping_tensor.T
    # Prevent attention to/from padding
    padding_mask = (mapping_tensor == -1) | (mapping_tensor.T == -1)
    
    combined_mask = mask | inter_sequence_mask | padding_mask
    return combined_mask.float().masked_fill(combined_mask, float('-inf')).masked_fill(~combined_mask, 0.0)

# Visualize a mask
if packed_sequences:
    sample_mask = create_packed_attention_mask(sequence_mappings[0], len(sequence_mappings[0]))
    plt.figure(figsize=(8, 6))
    sns.heatmap(sample_mask.numpy(), cmap='Blues_r')
    plt.title('Packed Sequence Attention Mask')
    plt.show()

## 4. Efficient Data Pipeline

In [None]:
class PackedSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, packed_sequences, sequence_mappings):
        self.packed_sequences = packed_sequences
        self.sequence_mappings = sequence_mappings

    def __len__(self):
        return len(self.packed_sequences)

    def __getitem__(self, idx):
        sequence = torch.tensor(self.packed_sequences[idx], dtype=torch.long)
        mapping = self.sequence_mappings[idx]
        
        input_ids = sequence[:-1]
        targets = sequence[1:]
        attention_mask = create_packed_attention_mask(mapping[:-1], len(input_ids))
        
        return {"input_ids": input_ids, "targets": targets, "attention_mask": attention_mask}

dataset = PackedSequenceDataset(packed_sequences, sequence_mappings)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

# Test the dataloader
if len(dataloader) > 0:
    batch = next(iter(dataloader))
    print(f"Input shape: {batch['input_ids'].shape}")
    print(f"Attention mask shape: {batch['attention_mask'].shape}")