In [None]:
from transformers import BertTokenizerFast
import json
from pathlib import Path

## 1. Load BERT Tokenizer

In [None]:
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

print("✓ Loaded BERT tokenizer")
print(f"Vocab size: {tokenizer.vocab_size:,}")
print(f"Model max length: {tokenizer.model_max_length:,}")

## 2. Load Sample Data from SQuAD

In [None]:
# Load a sample from the training set
with open('../archive/train-v1.1.json', 'r', encoding='utf-8') as f:
    train_data = json.load(f)

# Extract a sample question-context pair
sample_article = train_data['data'][0]
sample_paragraph = sample_article['paragraphs'][0]
sample_qa = sample_paragraph['qas'][0]

context = sample_paragraph['context']
question = sample_qa['question']
answer_text = sample_qa['answers'][0]['text']
answer_start = sample_qa['answers'][0]['answer_start']

print("Sample Data:")
print("="*80)
print(f"Question: {question}")
print(f"\nContext: {context[:200]}...")
print(f"\nAnswer: '{answer_text}'")
print(f"Answer starts at character position: {answer_start}")
print(f"\nVerification: '{context[answer_start:answer_start+len(answer_text)]}'")

## 3. Basic Tokenization

In [None]:
# Tokenize question only
question_tokens = tokenizer.tokenize(question)
print("Question tokens:")
print(question_tokens)
print(f"\nNumber of tokens: {len(question_tokens)}")

In [None]:
# Tokenize context only (first 100 chars)
context_sample = context[:100]
context_tokens = tokenizer.tokenize(context_sample)
print("Context sample tokens:")
print(context_tokens)
print(f"\nNumber of tokens: {len(context_tokens)}")

## 4. Question-Context Pair Tokenization

For QA, we need to tokenize question and context together with special format:
```
[CLS] question [SEP] context [SEP]
```

In [None]:
# Tokenize question and context as a pair
encoding = tokenizer(
    question,
    context,
    truncation=True,
    padding='max_length',
    max_length=384,
    return_tensors='pt'
)

print("Encoding keys:", encoding.keys())
print(f"\nInput IDs shape: {encoding['input_ids'].shape}")
print(f"Attention mask shape: {encoding['attention_mask'].shape}")
print(f"Token type IDs shape: {encoding['token_type_ids'].shape}")

In [None]:
# Decode to see the tokenized text
decoded = tokenizer.decode(encoding['input_ids'][0])
print("Decoded tokens (first 300 chars):")
print(decoded[:300], "...")

## 5. Understanding Token Type IDs

Token type IDs distinguish question tokens (0) from context tokens (1).

In [None]:
# Display first 30 tokens with their types
tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'][0])
token_type_ids = encoding['token_type_ids'][0].tolist()

print("Token | Type (0=Question, 1=Context)")
print("="*50)
for i in range(min(30, len(tokens))):
    token = tokens[i]
    token_type = token_type_ids[i]
    segment = "QUESTION" if token_type == 0 else "CONTEXT"
    print(f"{i:3d}. {token:15s} | {token_type} ({segment})")

## 6. Offset Mappings - The Key to Answer Span Conversion

Offset mappings show which character positions each token corresponds to in the original text.

In [None]:
# Tokenize with offset mappings
encoding_with_offsets = tokenizer(
    question,
    context,
    truncation=True,
    max_length=384,
    return_offsets_mapping=True,
    return_tensors='pt'
)

offset_mapping = encoding_with_offsets['offset_mapping'][0]
print(f"Offset mapping shape: {offset_mapping.shape}")
print(f"\nFirst 20 offset mappings (char start, char end):")
print(offset_mapping[:20])

## 7. Converting Character Positions to Token Positions

This is crucial for finding the answer span in tokens.

In [None]:
def find_answer_span(answer_start_char, answer_text, offset_mapping, sequence_ids):
    """
    Convert character-level answer position to token-level positions.
    
    Args:
        answer_start_char: Character position where answer starts
        answer_text: The answer text
        offset_mapping: Tensor of (start, end) character offsets for each token
        sequence_ids: List indicating which tokens belong to context (1) vs question (0)
    
    Returns:
        start_token_idx: Token index where answer starts
        end_token_idx: Token index where answer ends
    """
    answer_end_char = answer_start_char + len(answer_text)
    
    # Find start token
    start_token_idx = None
    for idx, (start, end) in enumerate(offset_mapping):
        # Only consider context tokens (sequence_id == 1)
        if sequence_ids[idx] == 1:
            if start <= answer_start_char < end:
                start_token_idx = idx
                break
    
    # Find end token
    end_token_idx = None
    for idx, (start, end) in enumerate(offset_mapping):
        if sequence_ids[idx] == 1:
            if start < answer_end_char <= end:
                end_token_idx = idx
                break
    
    return start_token_idx, end_token_idx

# Get sequence IDs to identify context tokens
sequence_ids = encoding_with_offsets.sequence_ids(0)

# Find answer span in tokens
start_token, end_token = find_answer_span(
    answer_start, 
    answer_text, 
    offset_mapping,
    sequence_ids
)

print(f"Answer: '{answer_text}'")
print(f"Character position: {answer_start} to {answer_start + len(answer_text)}")
print(f"\nToken position: {start_token} to {end_token}")

In [None]:
# Verify by decoding the token span
if start_token is not None and end_token is not None:
    answer_token_ids = encoding_with_offsets['input_ids'][0][start_token:end_token+1]
    decoded_answer = tokenizer.decode(answer_token_ids)
    
    print("Verification:")
    print(f"Original answer: '{answer_text}'")
    print(f"Decoded from tokens: '{decoded_answer}'")
    print(f"\nMatch: {answer_text.lower() in decoded_answer.lower()}")
else:
    print("Could not find answer span in tokens")

## 8. Visualize Token Alignment

In [None]:
# Show tokens around the answer
if start_token is not None and end_token is not None:
    tokens = tokenizer.convert_ids_to_tokens(encoding_with_offsets['input_ids'][0])
    
    print("Tokens around the answer:")
    print("="*80)
    
    # Show 5 tokens before and after
    window_start = max(0, start_token - 5)
    window_end = min(len(tokens), end_token + 6)
    
    for idx in range(window_start, window_end):
        token = tokens[idx]
        is_answer = start_token <= idx <= end_token
        marker = ">>> " if is_answer else "    "
        print(f"{marker}{idx:3d}. {token:20s} {offset_mapping[idx]}")

## 9. Handling Long Contexts (Stride)

When context exceeds max_length, we use stride to create overlapping windows.

In [None]:
# Example with a longer context
long_context = context * 3  # Artificially create a long context

# Tokenize with stride
encoding_with_stride = tokenizer(
    question,
    long_context,
    truncation='only_second',  # Only truncate context, not question
    max_length=384,
    stride=128,  # Overlap of 128 tokens between chunks
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding='max_length'
)

print(f"Number of chunks created: {len(encoding_with_stride['input_ids'])}")
print(f"\nEach chunk has {encoding_with_stride['input_ids'][0].shape[0]} tokens")
print(f"Stride (overlap): 128 tokens")

## 10. Summary and Key Takeaways

In [None]:
print("""Key Concepts for BERT Tokenization in QA:

1. **Special Token Format**: [CLS] question [SEP] context [SEP]

2. **Token Type IDs**:
   - 0 = question tokens
   - 1 = context tokens

3. **Offset Mappings**:
   - Maps each token to its character position in original text
   - Essential for converting answer_start (char) to token positions

4. **Answer Span Conversion**:
   - Character position (answer_start) → Token position (start_token_idx)
   - Use offset_mapping to find which tokens contain the answer

5. **Handling Long Contexts**:
   - Use truncation='only_second' to preserve question
   - Use stride for overlapping windows
   - Set return_overflowing_tokens=True

6. **Sequence IDs**:
   - Use encoding.sequence_ids() to identify context vs question
   - Important for ensuring answer is only in context portion

Next Steps:
- Build PyTorch Dataset that handles this tokenization
- Implement answer span finding logic in batch processing
- Create DataLoader for training
""")