In [None]:
"""
Script to decode token IDs from example.txt using Qwen3 tokenizer.
Handles special memory tokens (IDs larger than vocab size) by replacing them with '<MT>'.
Handles label masking tokens (-100) by replacing them with '<-100>'.
"""

import ast
from transformers import AutoTokenizer
from typing import List

def safe_decode_with_mem_tokens(tokenizer, token_ids: List[int]) -> str:
    """
    Decodes token IDs, replacing tokens larger than vocab size with '<MT>' 
    for special 'MEM_TOKEN' tokens, and -100 with '<-100>' for label masking.
    
    Args:
        tokenizer: HuggingFace tokenizer
        token_ids: List of token IDs to decode
        
    Returns:
        Decoded string with memory tokens replaced by '<MT>' and -100 replaced by '<-100>'
    """
    vocab_size = tokenizer.vocab_size
    
    # Convert IDs to tokens, replacing out-of-vocab IDs with None
    tokens = []
    for token_id in token_ids:
        if token_id == -100:
            # Label masking token
            tokens.append("<-100>")
        elif token_id <= vocab_size + 50:
            token = tokenizer.convert_ids_to_tokens([token_id])[0]
            tokens.append(token)
        else:
            # Token ID is larger than vocab size - this is a memory token
            tokens.append("<MT>")
    
    # Convert tokens to string, replacing special placeholders as needed
    result = tokenizer.convert_tokens_to_string(tokens)
    return result


In [None]:
# Qwen3 tokenizer from HuggingFace
print("Loading Qwen3 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)

# Read the example.txt file
print("Reading example.txt...")
with open('example.txt', 'r') as f:
    content = f.read()

# Parse the content (assuming it's a dictionary-like string)
print("Parsing token IDs...")
try:
    # Try to evaluate as a Python literal
    data = ast.literal_eval("{" + content + "}")
except:
    # If that fails, try to parse manually
    import re
    data = {}
    
    # Extract input_ids
    input_ids_match = re.search(r"'input_ids':\s*\[([\d,\s]+)\]", content)
    if input_ids_match:
        data['input_ids'] = [int(x.strip()) for x in input_ids_match.group(1).split(',') if x.strip()]
    
    # Extract prompt_answer_ids
    prompt_answer_ids_match = re.search(r"'prompt_answer_ids':\s*\[([\d,\s]+)\]", content)
    if prompt_answer_ids_match:
        data['prompt_answer_ids'] = [int(x.strip()) for x in prompt_answer_ids_match.group(1).split(',') if x.strip()]
    
    # Extract labels
    labels_match = re.search(r"'labels':\s*\[([\d,\s-]+)\]", content)
    if labels_match:
        data['labels'] = [int(x.strip()) for x in labels_match.group(1).split(',') if x.strip()]

print(f"\nVocabulary size: {tokenizer.vocab_size}")
print("=" * 80)

# Decode each field
for key, token_ids in data.items():
    if isinstance(token_ids, list) and token_ids:
        print(f"\n{key}:")
        print(f"  Length: {len(token_ids)} tokens")
        
        # Count memory tokens
        mem_token_count = sum(1 for tid in token_ids if tid > tokenizer.vocab_size)
        if mem_token_count > 0:
            print(f"  Memory tokens: {mem_token_count}")
        
        # Decode
        decoded = safe_decode_with_mem_tokens(tokenizer, token_ids)
        print(f"  Decoded text:")
        print(decoded)
        print()
