# DeBERTa Tokenizer Implementation and Demonstration (From Scratch)

This notebook contains a complete implementation of a DeBERTa tokenizer from scratch using PyTorch, along with demonstrations of its usage.

In [None]:
import torch
import re
from collections import OrderedDict
from typing import List, Dict, Union, Optional

class DeBERTaTokenizer:
    def __init__(self, vocab_file: Optional[str] = None, max_len: int = 512, do_lower_case: bool = True):
        self.max_len = max_len
        self.do_lower_case = do_lower_case
        self.vocab = OrderedDict()
        self.ids_to_tokens = {}
        self.special_tokens = {
            '[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4
        }
        if vocab_file:
            self.load_vocab(vocab_file)
        else:
            self.vocab.update(self.special_tokens)
        self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
        
        # Compile regex pattern for tokenization
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def load_vocab(self, vocab_file: str):
        with open(vocab_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                token = line.strip()
                self.vocab[token] = len(self.vocab)
        self.vocab.update(self.special_tokens)
        self.ids_to_tokens = {v: k for k, v in self.vocab.items()}

    def tokenize(self, text: str) -> List[str]:
        if self.do_lower_case:
            text = text.lower()
        tokens = []
        for token in self.pat.findall(text):
            if token in self.vocab:
                tokens.append(token)
            else:
                for char in token:
                    if char in self.vocab:
                        tokens.append(char)
                    else:
                        tokens.append('[UNK]')
        return tokens

    def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]:
        return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]

    def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
        return [self.ids_to_tokens.get(id, '[UNK]') for id in ids]

    def encode(self, text: Union[str, List[str]], text_pair: Optional[Union[str, List[str]]] = None, 
               max_length: Optional[int] = None, padding: bool = True, truncation: bool = True) -> Dict[str, torch.Tensor]:
        if isinstance(text, str):
            return self._encode_single(text, text_pair, max_length, padding, truncation)
        
        batch_encoding = {'input_ids': [], 'attention_mask': []}
        for i, t in enumerate(text):
            pair = text_pair[i] if text_pair is not None else None
            encoding = self._encode_single(t, pair, max_length, padding, truncation)
            for key, value in encoding.items():
                batch_encoding[key].append(value)
        
        return {k: torch.stack(v) for k, v in batch_encoding.items()}

    def _encode_single(self, text: str, text_pair: Optional[str] = None, 
                       max_length: Optional[int] = None, padding: bool = True, 
                       truncation: bool = True) -> Dict[str, torch.Tensor]:
        tokens = ['[CLS]'] + self.tokenize(text) + ['[SEP]']
        if text_pair:
            tokens += self.tokenize(text_pair) + ['[SEP]']
        
        if truncation and len(tokens) > (max_length or self.max_len):
            tokens = tokens[:(max_length or self.max_len) - 1] + ['[SEP]']
        
        ids = self.convert_tokens_to_ids(tokens)
        
        if padding and len(ids) < (max_length or self.max_len):
            ids += [self.vocab['[PAD]']] * ((max_length or self.max_len) - len(ids))
        
        attention_mask = [1 if id != self.vocab['[PAD]'] else 0 for id in ids]
        
        return {
            'input_ids': torch.tensor(ids),
            'attention_mask': torch.tensor(attention_mask)
        }

    def decode(self, ids: Union[List[int], torch.Tensor], skip_special_tokens: bool = True) -> str:
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        tokens = self.convert_ids_to_tokens(ids)
        if skip_special_tokens:
            tokens = [token for token in tokens if token not in self.special_tokens.keys()]
        return ' '.join(tokens)

    def add_tokens(self, new_tokens: List[str]) -> int:
        added = 0
        for token in new_tokens:
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab)
                self.ids_to_tokens[len(self.vocab) - 1] = token
                added += 1
        return added

# Helper function to create a small vocabulary for demonstration
def create_small_vocab():
    vocab = OrderedDict({
        '[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4,
        'hello': 5, 'world': 6, 'how': 7, 'are': 8, 'you': 9,
        'this': 10, 'is': 11, 'a': 12, 'test': 13
    })
    return vocab

# Initialize tokenizer with a small vocabulary
small_vocab = create_small_vocab()
tokenizer = DeBERTaTokenizer()
tokenizer.vocab = small_vocab
tokenizer.ids_to_tokens = {v: k for k, v in small_vocab.items()}

print("DeBERTa Tokenizer implemented from scratch and initialized with a small vocabulary.")

## 1. Basic Tokenization

In [None]:
text = "hello world how are you"
tokens = tokenizer.tokenize(text)
print(f"Original text: {text}")
print(f"Tokenized: {tokens}")

## 2. Encoding Single Texts

In [None]:
text = "hello world"
encoding = tokenizer.encode(text)
print(f"Original text: {text}")
print(f"Encoded input_ids: {encoding['input_ids']}")
print(f"Attention mask: {encoding['attention_mask']}")

## 3. Encoding Text Pairs

In [None]:
text_a = "hello world"
text_b = "how are you"
encoding = tokenizer.encode(text_a, text_b)
print(f"Text A: {text_a}")
print(f"Text B: {text_b}")
print(f"Encoded input_ids: {encoding['input_ids']}")
print(f"Attention mask: {encoding['attention_mask']}")

## 4. Batch Encoding

In [None]:
texts = ["hello world", "how are you"]
encodings = tokenizer.encode(texts)
print(f"Batch texts: {texts}")
print(f"Encoded input_ids: {encodings['input_ids']}")
print(f"Attention masks: {encodings['attention_mask']}")

## 5. Decoding

In [None]:
ids = [2, 5, 6, 3, 0, 0]  # [CLS] hello world [SEP] [PAD] [PAD]
decoded = tokenizer.decode(ids)
print(f"Original ids: {ids}")
print(f"Decoded text: {decoded}")

## 6. Handling Unknown Tokens

In [None]:
text = "hello unknown world"
tokens = tokenizer.tokenize(text)
print(f"Original text: {text}")
print(f"Tokenized: {tokens}")

## 7. Adding New Tokens

In [None]:
new_tokens = ['new', 'tokens']
num_added = tokenizer.add_tokens(new_tokens)
print(f"Number of tokens added: {num_added}")

text = "hello new tokens world"
tokens = tokenizer.tokenize(text)
print(f"Original text: {text}")
print(f"Tokenized with new tokens: {tokens}")