# Gender-Inclusive Translation Inference - Test B

This notebook performs inference for Task B (gender-sensitive Polish⇄English translation) using the model trained for proofreading.

### Configuration

In [None]:
import os

In [None]:
# Model configuration - using the best proofreading model checkpoint
MODEL_SIZE = "8B"
LORA_RANK = 64
EPOCHS = 2
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 2e-4
WARMUP_STEPS = 10
MAX_SEQ_LENGTH = 4096

# Path to the trained proofreading model - using the best checkpoint
MODEL_PATH = f"../../../outputs/qwen3_{MODEL_SIZE}_polish_inclusive_proofreading_lora_r{LORA_RANK}_lr{LEARNING_RATE}_ep{EPOCHS}_bs{BATCH_SIZE}_ga{GRADIENT_ACCUMULATION_STEPS}_warmup{WARMUP_STEPS}_seq{MAX_SEQ_LENGTH}/checkpoint-23000"
MODEL_PATH = os.path.abspath(MODEL_PATH)
print(f"Using model from: {MODEL_PATH}")

# Inference parameters
TEMPERATURE = 0.3  # Lower temperature for more precise translations
TOP_P = 0.9
TOP_K = 50
MAX_NEW_TOKENS = 4096  # Enough for longer texts

# File paths
TEST_FILE = "../../../data/taskB/test_B.jsonl"
OUTPUT_FILE = "predictions_translation_test_B.jsonl"

### Setup Environment

**IMPORTANT:** After running this cell, you MUST restart the kernel for the cache paths to take effect. Then run all cells from the beginning.

In [None]:
# Fix HuggingFace cache permissions

import os
os.environ['HF_HOME'] = '/media/adam/NVME_500/poleval-gender-new/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/media/adam/NVME_500/poleval-gender-new/.cache/huggingface/transformers'
os.environ['HF_DATASETS_CACHE'] = '/media/adam/NVME_500/poleval-gender-new/.cache/huggingface/datasets'
os.environ['TRITON_CACHE_DIR'] = '/media/adam/NVME_500/poleval-gender-new/.cache/triton'

import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='tqdm')

### Load the Fine-tuned Model

In [None]:
from unsloth import FastLanguageModel
import torch

print(f"Loading model from {MODEL_PATH}...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = MAX_SEQ_LENGTH,
    load_in_4bit = True,
)

# Enable inference mode for 2x faster generation
FastLanguageModel.for_inference(model)

print("Model loaded successfully!")

### Load System Prompts

We'll load both English and Polish system prompts and select the appropriate one based on the prompt_language field in each test example.

In [None]:
# Load both English and Polish translation system prompts
with open('../../../system_prompts/translation/system_prompt_en_translation', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT_EN = f.read().strip()

with open('../../../system_prompts/translation/system_prompt_pl_translation', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT_PL = f.read().strip()

print("System prompts loaded.")
print(f"English system prompt length: {len(SYSTEM_PROMPT_EN)} characters")
print(f"Polish system prompt length: {len(SYSTEM_PROMPT_PL)} characters")
print(f"\nFirst 200 characters of English prompt:")
print(SYSTEM_PROMPT_EN[:200] + "...")
print(f"\nFirst 200 characters of Polish prompt:")
print(SYSTEM_PROMPT_PL[:200] + "...")

### Load Test Data

In [None]:
import json

def load_jsonl(file_path):
    """Load JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load test data
test_data = load_jsonl(TEST_FILE)

print(f"Loaded {len(test_data)} test examples from {TEST_FILE}")
print(f"\nFirst example:")
print(f"IPIS ID: {test_data[0]['ipis_id']}")
print(f"Prompt: {test_data[0]['prompt']}")
print(f"Source language: {test_data[0]['source_language']}")
print(f"Target language: {test_data[0]['target_language']}")
print(f"Source (first 200 chars): {test_data[0]['source'][:200]}...")

### Analyze Dataset

Let's check the distribution of translation directions in the test set.

In [None]:
from collections import Counter

# Count translation directions
translation_directions = Counter(
    f"{item['source_language']} → {item['target_language']}" 
    for item in test_data
)

print("Translation directions in test set:")
for direction, count in translation_directions.items():
    print(f"  {direction}: {count} examples ({count/len(test_data)*100:.1f}%)")

# Count prompt languages
prompt_languages = Counter(item['prompt_language'] for item in test_data)
print(f"\nPrompt languages:")
for lang, count in prompt_languages.items():
    print(f"  {lang}: {count} examples ({count/len(test_data)*100:.1f}%)")

### Try Base Model (Without Fine-tuning)

The proofreading model wasn't trained for translation. Let's try the base Qwen model to see if it has translation capability.

### Generate Predictions

This will generate translations for all texts in the test set, handling both PL→EN and EN→PL directions.

### Test Translation (Single Example)

Let's test the translation on a single example first to verify it's working correctly.

In [None]:
# Test with a few examples from different translation directions
test_indices = [0]  # First 3 examples

print("="*80)
print("TESTING TRANSLATION ON SAMPLE EXAMPLES")
print("="*80)

for test_idx in test_indices:
    item = test_data[test_idx]
    
    print(f"\n{'='*80}")
    print(f"Example {test_idx + 1}: {item['ipis_id']}")
    print(f"Direction: {item['source_language']} → {item['target_language']}")
    print(f"Prompt language: {item['prompt_language']}")
    print(f"{'='*80}")
    
    # Construct the prompt
    user_message = item['prompt'] + item['source']
    
    # Select appropriate system prompt
    system_prompt = SYSTEM_PROMPT_EN if item['prompt_language'] == 'EN' else SYSTEM_PROMPT_PL
    
    print(f"\nSystem prompt (first 150 chars):")
    print(system_prompt[:150] + "...")
    
    print(f"\nUser message (first 200 chars):")
    print(user_message[:200] + "...")
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message}
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    print(f"\nFull prompt length: {len(text)} characters")
    print(f"\nFull prompt (last 300 chars before generation):")
    print("..." + text[-300:])
    
    # Tokenize and move to GPU
    inputs = tokenizer(text, return_tensors="pt").to("cuda")
    
    # Generate prediction
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            do_sample=True if TEMPERATURE > 0 else False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode only the generated tokens
    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Clean up artifacts
    if "<|im_end|>" in response:
        response = response.split("<|im_end|>")[0]
    if "<|im_start|>" in response:
        response = response.split("<|im_start|>")[-1]
        if "\n" in response:
            response = response.split("\n", 1)[-1]
    
    print(f"\nGENERATED TRANSLATION:")
    print(response.strip())
    
    print(f"\nStats:")
    print(f"  - Source length: {len(item['source'])} chars")
    print(f"  - Target length: {len(response.strip())} chars")
    print(f"  - Is same as source: {response.strip() == item['source']}")
    
print(f"\n{'='*80}")
print("TEST COMPLETE")
print("="*80)

In [None]:
from tqdm.auto import tqdm
import os

# Create checkpoint directory
CHECKPOINT_DIR = "inference_checkpoints_translation"
SAVE_INTERVAL = 10
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Check for existing checkpoints to resume from
checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("predictions_checkpoint_")])
if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    checkpoint_path = os.path.join(CHECKPOINT_DIR, latest_checkpoint)
    print(f"Found checkpoint: {latest_checkpoint}")
    print("Loading predictions from checkpoint...")
    
    with open(checkpoint_path, 'r', encoding='utf-8') as f:
        predictions = [json.loads(line) for line in f]
    
    processed_ids = {p['ipis_id'] for p in predictions}
    start_idx = len(predictions)
    print(f"Resuming from example {start_idx} ({len(predictions)} already processed)")
else:
    predictions = []
    processed_ids = set()
    start_idx = 0
    print("Starting from scratch")

print(f"\nGenerating predictions for {len(test_data)} examples...")
print(f"Parameters: temperature={TEMPERATURE}, top_p={TOP_P}, max_new_tokens={MAX_NEW_TOKENS}")
print()

try:
    for idx, item in enumerate(tqdm(test_data[start_idx:], initial=start_idx, total=len(test_data), desc="Generating translations")):
        # Skip if already processed
        if item['ipis_id'] in processed_ids:
            continue
        
        # Construct the prompt using the same format as training
        user_message = item['prompt'] + item['source']
        
        # Select appropriate system prompt based on prompt_language
        system_prompt = SYSTEM_PROMPT_EN if item['prompt_language'] == 'EN' else SYSTEM_PROMPT_PL
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_message}
        ]
        
        # Apply chat template
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize and move to GPU
        inputs = tokenizer(text, return_tensors="pt").to("cuda")
        
        # Generate prediction
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                top_k=TOP_K,
                do_sample=True if TEMPERATURE > 0 else False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        # Decode only the generated tokens (excluding the input prompt)
        generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Clean up any remaining chat template artifacts
        if "<|im_end|>" in response:
            response = response.split("<|im_end|>")[0]
        if "<|im_start|>" in response:
            response = response.split("<|im_start|>")[-1]
            if "\n" in response:
                response = response.split("\n", 1)[-1]
        
        # Create prediction entry matching the expected format for Task B
        prediction = {
            "source_resource_id": item['source_resource_id'],
            "ipis_id": item['ipis_id'],
            "prompt": item['prompt'],
            "source": item['source'],
            "target": response.strip(),
            "prompt_language": item['prompt_language'],
            "source_language": item['source_language'],
            "target_language": item['target_language']
        }
        
        predictions.append(prediction)
        processed_ids.add(item['ipis_id'])
        
        # Save checkpoint every SAVE_INTERVAL examples
        if len(predictions) % SAVE_INTERVAL == 0:
            checkpoint_file = os.path.join(CHECKPOINT_DIR, f"predictions_checkpoint_{len(predictions):05d}.jsonl")
            with open(checkpoint_file, 'w', encoding='utf-8') as f:
                for pred in predictions:
                    f.write(json.dumps(pred, ensure_ascii=False) + '\n')
            print(f"\n✓ Checkpoint saved: {checkpoint_file} ({len(predictions)} predictions)")
            print(f"Direction: {prediction['source_language']} → {prediction['target_language']}")
            print(f"Latest translation (first 150 chars): {response.strip()[:150]}...")

    print(f"\n✓ Generated {len(predictions)} translations successfully!")
    
except Exception as e:
    print(f"\nError occurred: {e}")
    print(f"Predictions saved up to example {len(predictions)}")
    print(f"To resume, simply re-run this cell - it will load from the last checkpoint")
    raise

### Save Predictions

In [None]:
# Save predictions to JSONL file with .tsv extension for PolEval submission
output_file_jsonl = OUTPUT_FILE
output_file_tsv = OUTPUT_FILE.replace('.jsonl', '.tsv')

# Save as .jsonl for reference
with open(output_file_jsonl, 'w', encoding='utf-8') as f:
    for pred in predictions:
        f.write(json.dumps(pred, ensure_ascii=False) + '\n')

# Save as .tsv for PolEval submission
with open(output_file_tsv, 'w', encoding='utf-8') as f:
    for pred in predictions:
        f.write(json.dumps(pred, ensure_ascii=False) + '\n')

print(f"✓ Predictions saved to:")
print(f"  - {output_file_jsonl}")
print(f"  - {output_file_tsv} (for PolEval submission)")
print(f"Total predictions: {len(predictions)}")

### Preview Predictions

In [None]:
# Show example predictions for both translation directions
print("="*80)
print("EXAMPLE PREDICTIONS")
print("="*80)

# Find examples of each translation direction
pl_to_en = [p for p in predictions if p['source_language'] == 'PL' and p['target_language'] == 'EN']
en_to_pl = [p for p in predictions if p['source_language'] == 'EN' and p['target_language'] == 'PL']

if pl_to_en:
    print("\n" + "="*80)
    print("POLISH → ENGLISH EXAMPLES")
    print("="*80)
    for i in range(min(2, len(pl_to_en))):
        pred = pl_to_en[i]
        print(f"\nExample {i+1}:")
        print(f"IPIS ID: {pred['ipis_id']}")
        print("-"*80)
        print(f"SOURCE (Polish, first 250 chars):\n{pred['source'][:250]}...")
        print("-"*80)
        print(f"TARGET (English, first 250 chars):\n{pred['target'][:250]}...")
        print("="*80)

if en_to_pl:
    print("\n" + "="*80)
    print("ENGLISH → POLISH EXAMPLES")
    print("="*80)
    for i in range(min(2, len(en_to_pl))):
        pred = en_to_pl[i]
        print(f"\nExample {i+1}:")
        print(f"IPIS ID: {pred['ipis_id']}")
        print("-"*80)
        print(f"SOURCE (English, first 250 chars):\n{pred['source'][:250]}...")
        print("-"*80)
        print(f"TARGET (Polish, first 250 chars):\n{pred['target'][:250]}...")
        print("="*80)

### Statistics

In [None]:
# Calculate statistics by translation direction
import numpy as np
from collections import Counter

print("PREDICTION STATISTICS")
print("="*80)
print(f"Total predictions: {len(predictions)}")
print()

# Statistics by translation direction
directions = Counter(f"{p['source_language']} → {p['target_language']}" for p in predictions)
print("Translation directions:")
for direction, count in directions.items():
    print(f"  {direction}: {count} examples ({count/len(predictions)*100:.1f}%)")
print()

# Length statistics for PL → EN
if pl_to_en:
    source_lengths_pl_en = [len(p['source']) for p in pl_to_en]
    target_lengths_pl_en = [len(p['target']) for p in pl_to_en]
    
    print("Polish → English:")
    print(f"  Source (PL) length - Mean: {np.mean(source_lengths_pl_en):.1f}, Median: {np.median(source_lengths_pl_en):.1f}")
    print(f"  Target (EN) length - Mean: {np.mean(target_lengths_pl_en):.1f}, Median: {np.median(target_lengths_pl_en):.1f}")
    print(f"  Average length change: {np.mean(target_lengths_pl_en) - np.mean(source_lengths_pl_en):.1f} chars ({(np.mean(target_lengths_pl_en) / np.mean(source_lengths_pl_en) - 1) * 100:.1f}%)")
    print()

# Length statistics for EN → PL
if en_to_pl:
    source_lengths_en_pl = [len(p['source']) for p in en_to_pl]
    target_lengths_en_pl = [len(p['target']) for p in en_to_pl]
    
    print("English → Polish:")
    print(f"  Source (EN) length - Mean: {np.mean(source_lengths_en_pl):.1f}, Median: {np.median(source_lengths_en_pl):.1f}")
    print(f"  Target (PL) length - Mean: {np.mean(target_lengths_en_pl):.1f}, Median: {np.median(target_lengths_en_pl):.1f}")
    print(f"  Average length change: {np.mean(target_lengths_en_pl) - np.mean(source_lengths_en_pl):.1f} chars ({(np.mean(target_lengths_en_pl) / np.mean(source_lengths_en_pl) - 1) * 100:.1f}%)")

print("="*80)