# Debug Chinese Translation Truncation Issue

**Problem:** Chinese translations are being truncated at the first comma

**Goal:** Diagnose whether the issue is in:
1. Model generation (stopping too early)
2. Token decoding (decode function issue)
3. Tokenizer configuration (comma treated as EOS)

**Test cases:** Pairs 702, 1107, 616 from user examples

In [None]:
# Import libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import numpy as np

In [None]:
# Load model
model_dir = "../models/nllb-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_dir,
    attn_implementation="eager"
)

# Move to GPU if available
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

model = model.to(device)
print(f"Model loaded on device: {device}")

## Test Cases from User Examples

In [None]:
# Test sentences from user (simplified to avoid quote escaping issues)
test_cases = [
    {
        'id': 702,
        'en': "At 10:00pm, Sun Yijie, who had been pregnant for four months, was released on bail of NT$200,000."
    },
    {
        'id': 1107,
        'en': "It is boundless. If you need and are brave enough to initiate crowdfunding, everything will become possible."
    },
    {
        'id': 616,
        'en': "It was Mr. Dong's real intention to sign the agreement, which was legal and effective after being signed and sealed by all the parties concerned."
    }
]

print(f"Loaded {len(test_cases)} test cases")

## Test 1: Current Implementation (max_length=128)

In [None]:
def test_translation_current(text, tokenizer, model, device):
    print("="*80)
    print("TEST 1: Current Implementation (max_length=128)")
    print("="*80)
    
    tokenizer.src_lang = "eng_Latn"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    print(f"Input text: {text}")
    print(f"Input tokens: {len(inputs.input_ids[0])} tokens")
    print()
    
    # Generate with current parameters
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_length=128
        )
    
    # Show raw output
    print("Raw output sequence:")
    print(f"  Length: {len(outputs[0])} tokens")
    print(f"  Token IDs (first 20): {outputs[0].tolist()[:20]}")
    print()
    
    # Show tokens
    output_tokens = tokenizer.convert_ids_to_tokens(outputs[0])
    print(f"Output tokens ({len(output_tokens)} total):")
    for i, token in enumerate(output_tokens[:30]):
        print(f"  [{i}] {repr(token)}")
    if len(output_tokens) > 30:
        print(f"  ... ({len(output_tokens) - 30} more tokens)")
    print()
    
    # Decode
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Decoded translation: {translation}")
    print(f"Translation length: {len(translation)} characters")
    
    # Check for truncation (ends with comma)
    if translation.rstrip().endswith(','):
        print("WARNING: Translation ends with comma - likely truncated!")
    
    return translation

# Test first case
test_case = test_cases[0]
result = test_translation_current(test_case['en'], tokenizer, model, device)

## Test 2: Increased max_length (256)

In [None]:
def test_translation_longer(text, tokenizer, model, device):
    print("="*80)
    print("TEST 2: Increased max_length (256)")
    print("="*80)
    
    tokenizer.src_lang = "eng_Latn"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_length=256
        )
    
    output_tokens = tokenizer.convert_ids_to_tokens(outputs[0])
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Output length: {len(outputs[0])} tokens")
    print(f"Decoded translation: {translation}")
    
    if translation.rstrip().endswith(','):
        print("WARNING: Still ends with comma!")
    else:
        print("OK: Does not end with comma")
    
    return translation

result2 = test_translation_longer(test_case['en'], tokenizer, model, device)

## Test 3: Use max_new_tokens instead of max_length

In [None]:
def test_translation_max_new_tokens(text, tokenizer, model, device):
    print("="*80)
    print("TEST 3: max_new_tokens=200 (better for variable input lengths)")
    print("="*80)
    
    tokenizer.src_lang = "eng_Latn"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    print(f"Input length: {len(inputs.input_ids[0])} tokens")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_new_tokens=200
        )
    
    output_tokens = tokenizer.convert_ids_to_tokens(outputs[0])
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Output length: {len(outputs[0])} tokens (input + new)")
    print(f"Decoded translation: {translation}")
    
    if translation.rstrip().endswith(','):
        print("WARNING: Still ends with comma!")
    else:
        print("OK: Translation appears complete")
    
    return translation

result3 = test_translation_max_new_tokens(test_case['en'], tokenizer, model, device)

## Test 4: Check for EOS token issues

In [None]:
def test_eos_tokens(tokenizer):
    print("="*80)
    print("TEST 4: EOS Token Investigation")
    print("="*80)
    
    # Check EOS token
    print(f"EOS token: {repr(tokenizer.eos_token)}")
    print(f"EOS token ID: {tokenizer.eos_token_id}")
    print()
    
    # Check Chinese comma (full-width)
    chinese_comma = "\uff0c"  # Unicode for full-width comma
    chinese_comma_ids = tokenizer.encode(chinese_comma, add_special_tokens=False)
    print(f"Chinese comma (full-width):")
    print(f"  Character: {chinese_comma}")
    print(f"  Token IDs: {chinese_comma_ids}")
    print(f"  Tokens: {tokenizer.convert_ids_to_tokens(chinese_comma_ids)}")
    print()
    
    # Check English comma
    english_comma = ","
    english_comma_ids = tokenizer.encode(english_comma, add_special_tokens=False)
    print(f"English comma:")
    print(f"  Token IDs: {english_comma_ids}")
    print(f"  Tokens: {tokenizer.convert_ids_to_tokens(english_comma_ids)}")
    print()
    
    # Check if comma matches EOS
    if tokenizer.eos_token_id in chinese_comma_ids:
        print("CRITICAL: Chinese comma contains EOS token ID!")
    elif tokenizer.eos_token_id in english_comma_ids:
        print("CRITICAL: English comma contains EOS token ID!")
    else:
        print("OK: Commas do not match EOS token ID")
    
    # Check language tokens
    print()
    print("Language tokens:")
    print(f"  eng_Latn: {tokenizer.convert_tokens_to_ids('eng_Latn')}")
    print(f"  zho_Hans: {tokenizer.convert_tokens_to_ids('zho_Hans')}")

test_eos_tokens(tokenizer)

## Test 5: Disable early stopping and add generation constraints

In [None]:
def test_translation_no_early_stop(text, tokenizer, model, device):
    print("="*80)
    print("TEST 5: Disable early stopping + explicit constraints")
    print("="*80)
    
    tokenizer.src_lang = "eng_Latn"
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_new_tokens=200,
            num_beams=1,
            do_sample=False,
            early_stopping=False
        )
    
    output_tokens = tokenizer.convert_ids_to_tokens(outputs[0])
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    print(f"Output length: {len(outputs[0])} tokens")
    print(f"Last 5 tokens: {output_tokens[-5:]}")
    print(f"Decoded translation: {translation}")
    
    if translation.rstrip().endswith(','):
        print("WARNING: Still ends with comma!")
    else:
        print("OK: Translation appears complete")
    
    return translation

result5 = test_translation_no_early_stop(test_case['en'], tokenizer, model, device)

## Test All Cases with Best Parameters

In [None]:
print("="*80)
print("TESTING ALL CASES WITH BEST PARAMETERS")
print("="*80)
print()

for test_case in test_cases:
    print(f"\n{'='*80}")
    print(f"Test Case {test_case['id']}")
    print(f"{'='*80}")
    print(f"English: {test_case['en']}")
    print()
    
    tokenizer.src_lang = "eng_Latn"
    inputs = tokenizer(test_case['en'], return_tensors="pt").to(device)
    
    # Current implementation
    with torch.no_grad():
        outputs_old = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_length=128
        )
    translation_old = tokenizer.decode(outputs_old[0], skip_special_tokens=True)
    
    # New implementation
    with torch.no_grad():
        outputs_new = model.generate(
            **inputs,
            forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
            max_new_tokens=200,
            num_beams=1,
            do_sample=False,
            early_stopping=False
        )
    translation_new = tokenizer.decode(outputs_new[0], skip_special_tokens=True)
    
    print(f"OLD (max_length=128):   {translation_old}")
    print(f"NEW (max_new_tokens=200): {translation_new}")
    print()
    print(f"OLD length: {len(outputs_old[0])} tokens, {len(translation_old)} chars")
    print(f"NEW length: {len(outputs_new[0])} tokens, {len(translation_new)} chars")
    
    # Compare
    if len(translation_new) > len(translation_old):
        print(f"IMPROVED: +{len(translation_new) - len(translation_old)} characters")
    elif translation_new == translation_old:
        print("NO CHANGE")
    else:
        print("DIFFERENT (not necessarily better)")

## Summary and Recommendations

In [None]:
print("="*80)
print("DIAGNOSTIC SUMMARY")
print("="*80)
print()
print("Tests performed:")
print("1. Current implementation (max_length=128)")
print("2. Increased max_length (256)")
print("3. Using max_new_tokens instead of max_length")
print("4. EOS token investigation")
print("5. Disabled early stopping with constraints")
print()
print("RECOMMENDED FIX for notebook 07:")
print("-" * 80)
print("Replace:")
print("    outputs = model.generate(")
print("        **inputs,")
print("        forced_bos_token_id=tgt_lang_id,")
print("        output_attentions=True,")
print("        return_dict_in_generate=True,")
print("        max_length=128")
print("    )")
print()
print("With:")
print("    outputs = model.generate(")
print("        **inputs,")
print("        forced_bos_token_id=tgt_lang_id,")
print("        output_attentions=True,")
print("        return_dict_in_generate=True,")
print("        max_new_tokens=200,")
print("        num_beams=1,")
print("        do_sample=False,")
print("        early_stopping=False")
print("    )")
print("-" * 80)
print()
print("NOTE: After fixing, you will need to re-run notebook 07 to regenerate")
print("      all translations for the 2000 sentence pairs.")