# Remove Unwanted Tokens from Translation Predictions

This notebook cleans translation predictions by removing unwanted tokens that the model may generate, such as:
- `<think>` tags and their contents
- Other meta-tokens or reasoning artifacts

These tokens should not be part of the actual translation output.

### Configuration

In [None]:
import json
import re
import os

# Input and output files
INPUT_FILE = "../02_inference/predictions_translation_test_B_multitask.jsonl"
OUTPUT_FILE = "../02_inference/predictions_translation_test_B_multitask_cleaned.jsonl"

print(f"Input file: {INPUT_FILE}")
print(f"Output file: {OUTPUT_FILE}")

### Define Cleaning Function

In [None]:
def clean_translation(text):
    """
    Remove unwanted tokens from translation output.
    
    Args:
        text (str): Raw translation output
        
    Returns:
        str: Cleaned translation
    """
    # Remove <think> tags and their contents
    # Pattern: <think> ... </think> (including multiline)
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
    
    # Remove any remaining opening or closing think tags
    text = text.replace('<think>', '').replace('</think>', '')
    
    # Remove other potential unwanted tokens at the start
    # (add more patterns as needed)
    unwanted_prefixes = [
        'think',
        'Think',
        'Okay,',
        'Sure,',
        'Let me',
        'I will',
        'Here is',
        'Translation:',
        'Output:',
    ]
    
    for prefix in unwanted_prefixes:
        if text.strip().startswith(prefix):
            # Remove the prefix and everything up to the first newline or colon
            text = re.sub(rf'^{re.escape(prefix)}[^\n:]*[:|\n]?\s*', '', text.strip())
    
    # Clean up excessive whitespace
    # Remove leading/trailing whitespace
    text = text.strip()
    
    # Normalize multiple newlines to maximum 2
    text = re.sub(r'\n\n\n+', '\n\n', text)
    
    # Remove any leading newlines
    text = text.lstrip('\n')
    
    return text


# Test the function
test_cases = [
    "<think>\n\n</think>\n\nThis is the translation.",
    "think about it\n\nThis is the translation.",
    "Okay, here is the translation:\nThis is the translation.",
    "This is already clean.",
    "<think>Some reasoning</think>\n\n\nActual translation here.",
]

print("Testing cleaning function:")
print("="*80)
for i, test in enumerate(test_cases):
    cleaned = clean_translation(test)
    print(f"\nTest {i+1}:")
    print(f"  Input:  [{test[:60]}...]")
    print(f"  Output: [{cleaned[:60]}...]")

### Load and Analyze Predictions

In [None]:
# Load predictions
with open(INPUT_FILE, 'r', encoding='utf-8') as f:
    predictions = [json.loads(line) for line in f]

print(f"Loaded {len(predictions)} predictions")

# Analyze how many need cleaning
needs_cleaning = 0
for pred in predictions:
    target = pred['target']
    cleaned = clean_translation(target)
    if target != cleaned:
        needs_cleaning += 1

print(f"Predictions that need cleaning: {needs_cleaning} ({needs_cleaning/len(predictions)*100:.1f}%)")

### Show Examples Before/After Cleaning

In [None]:
# Show some examples of cleaning
print("="*80)
print("EXAMPLES OF CLEANING")
print("="*80)

examples_shown = 0
for i, pred in enumerate(predictions):
    target = pred['target']
    cleaned = clean_translation(target)
    
    if target != cleaned and examples_shown < 5:
        examples_shown += 1
        print(f"\nExample {examples_shown} (ID: {pred['ipis_id']}):")
        print(f"Direction: {pred['source_language']} â†’ {pred['target_language']}")
        print(f"\nBEFORE (first 200 chars):")
        print(target[:200])
        print(f"\nAFTER (first 200 chars):")
        print(cleaned[:200])
        print("-"*80)
    
    if examples_shown >= 5:
        break

### Clean All Predictions

In [None]:
# Clean all predictions
cleaned_predictions = []
cleaning_stats = {
    'total': len(predictions),
    'cleaned': 0,
    'unchanged': 0,
    'avg_length_before': 0,
    'avg_length_after': 0,
}

total_length_before = 0
total_length_after = 0

for pred in predictions:
    original_target = pred['target']
    cleaned_target = clean_translation(original_target)
    
    total_length_before += len(original_target)
    total_length_after += len(cleaned_target)
    
    if original_target != cleaned_target:
        cleaning_stats['cleaned'] += 1
    else:
        cleaning_stats['unchanged'] += 1
    
    # Create new prediction with only required fields (matching sample_translation.tsv format)
    cleaned_pred = {
        'ipis_id': pred['ipis_id'],
        'source': pred['source'],
        'target': cleaned_target
    }
    
    cleaned_predictions.append(cleaned_pred)

cleaning_stats['avg_length_before'] = total_length_before / len(predictions)
cleaning_stats['avg_length_after'] = total_length_after / len(predictions)

print("Cleaning Statistics:")
print("="*80)
print(f"Total predictions: {cleaning_stats['total']}")
print(f"Cleaned: {cleaning_stats['cleaned']} ({cleaning_stats['cleaned']/cleaning_stats['total']*100:.1f}%)")
print(f"Unchanged: {cleaning_stats['unchanged']} ({cleaning_stats['unchanged']/cleaning_stats['total']*100:.1f}%)")
print(f"Avg length before: {cleaning_stats['avg_length_before']:.1f} chars")
print(f"Avg length after: {cleaning_stats['avg_length_after']:.1f} chars")
print(f"Avg reduction: {cleaning_stats['avg_length_before'] - cleaning_stats['avg_length_after']:.1f} chars")

### Save Cleaned Predictions

In [None]:
# Save cleaned predictions as JSONL
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    for pred in cleaned_predictions:
        f.write(json.dumps(pred, ensure_ascii=False) + '\n')

print(f"Cleaned predictions saved to: {OUTPUT_FILE}")
print(f"Total predictions: {len(cleaned_predictions)}")

# Also save as .tsv (same format, just different extension for PolEval submission)
OUTPUT_FILE_TSV = OUTPUT_FILE.replace('.jsonl', '.tsv')
with open(OUTPUT_FILE_TSV, 'w', encoding='utf-8') as f:
    for pred in cleaned_predictions:
        f.write(json.dumps(pred, ensure_ascii=False) + '\n')

print(f"Also saved as: {OUTPUT_FILE_TSV}")

### Verify Output

In [None]:
# Load and verify the cleaned file
with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
    loaded_cleaned = [json.loads(line) for line in f]

print("Verification:")
print("="*80)
print(f"Original predictions: {len(predictions)}")
print(f"Cleaned predictions: {len(loaded_cleaned)}")
print(f"Match: {len(predictions) == len(loaded_cleaned)}")

# Verify format matches sample_translation.tsv
print("\nFormat verification:")
expected_fields = ['ipis_id', 'source', 'target']
if loaded_cleaned:
    actual_fields = list(loaded_cleaned[0].keys())
    print(f"Expected fields: {expected_fields}")
    print(f"Actual fields: {actual_fields}")
    print(f"Format matches: {actual_fields == expected_fields}")

# Show a few random examples
import random
print("\nRandom sample of cleaned predictions:")
print("="*80)
for i in random.sample(range(len(loaded_cleaned)), min(3, len(loaded_cleaned))):
    pred = loaded_cleaned[i]
    print(f"\nID: {pred['ipis_id']}")
    print(f"Source (first 100 chars): {pred['source'][:100]}...")
    print(f"Target (first 150 chars): {pred['target'][:150]}...")
    print("-"*80)