# Test Beam Search for Chinese Translation

Quick test to see if beam search avoids the comma truncation issue.

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_dir = "../models/nllb-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir, attn_implementation="eager")

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)
print(f"Device: {device}")

In [None]:
# Test sentence
english = "At 10:00pm, Sun Yijie, who had been pregnant for four months, was released on bail of NT$200,000."

tokenizer.src_lang = "eng_Latn"
inputs = tokenizer(english, return_tensors="pt").to(device)

print(f"Input: {english}")
print()

In [None]:
# Test 1: Greedy (current approach)
print("=" * 80)
print("GREEDY DECODING (current)")
print("=" * 80)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
        max_new_tokens=200
    )

translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {translation}")
print(f"Length: {len(outputs[0])} tokens")
print(f"Ends with comma: {translation.rstrip().endswith(',')}")
print()

In [None]:
# Test 2: Beam Search (num_beams=5)
print("=" * 80)
print("BEAM SEARCH (num_beams=5)")
print("=" * 80)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
        max_new_tokens=200,
        num_beams=5
    )

translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {translation}")
print(f"Length: {len(outputs[0])} tokens")
print(f"Ends with comma: {translation.rstrip().endswith(',')}")
print()

In [None]:
# Test 3: Beam Search + no_repeat_ngram
print("=" * 80)
print("BEAM SEARCH + no_repeat_ngram_size=3")
print("=" * 80)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
        max_new_tokens=200,
        num_beams=5,
        no_repeat_ngram_size=3
    )

translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Output: {translation}")
print(f"Length: {len(outputs[0])} tokens")
print(f"Ends with comma: {translation.rstrip().endswith(',')}")
print()

## Result

If beam search produces longer output: ✅ Use beam search in notebook 07  
If all produce same truncation: ❌ This is a fundamental NLLB model issue