### Configuration

In [None]:
import os

In [None]:
# Model configuration - adjust to match your trained model
MODEL_SIZE = "8B"
LORA_RANK = 64
EPOCHS = 2
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 2
LEARNING_RATE = 2e-4
WARMUP_STEPS = 10
MAX_SEQ_LENGTH = 4096

# Path to your trained model - using the best checkpoint
# Based on trainer_state.json, checkpoint-23000 had the best eval_loss
MODEL_PATH = f"../../../outputs/qwen3_{MODEL_SIZE}_polish_inclusive_proofreading_lora_r{LORA_RANK}_lr{LEARNING_RATE}_ep{EPOCHS}_bs{BATCH_SIZE}_ga{GRADIENT_ACCUMULATION_STEPS}_warmup{WARMUP_STEPS}_seq{MAX_SEQ_LENGTH}/checkpoint-23000"
MODEL_PATH = os.path.abspath(MODEL_PATH)
print(f"Using best model from: {MODEL_PATH}")

# Inference parameters
TEMPERATURE = 0.3  # Lower temperature for more precise transformations
TOP_P = 0.9
TOP_K = 50
MAX_NEW_TOKENS = 4096  # Enough for longer texts

# File paths
TEST_FILE = "../../../data/taskA/test_B.jsonl"
OUTPUT_FILE = "predictions_test_B.jsonl"

### Setup Environment

In [None]:
# Fix HuggingFace cache permissions

import os
os.environ['HF_HOME'] = '/mnt/d/Pobrane/poleval-gender/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/mnt/d/Pobrane/poleval-gender/.cache/huggingface/transformers'
os.environ['HF_DATASETS_CACHE'] = '/mnt/d/Pobrane/poleval-gender/.cache/huggingface/datasets'
os.environ['TRITON_CACHE_DIR'] = '/mnt/d/Pobrane/poleval-gender/.cache/triton'

import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='tqdm')

### Load the Fine-tuned Model

In [None]:
from unsloth import FastLanguageModel
import torch

print(f"Loading model from {MODEL_PATH}...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = MAX_SEQ_LENGTH,
    load_in_4bit = True,
)

# Enable inference mode for 2x faster generation
FastLanguageModel.for_inference(model)

print("Model loaded successfully!")

### Load System Prompt

In [None]:
# Load the Polish system prompt used during training
with open('../../../system_prompts/proofreading/system_prompt_pl_proofreading', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT = f.read().strip()

print("System prompt loaded.")
print(f"System prompt length: {len(SYSTEM_PROMPT)} characters")

### Load Test Data

In [None]:
import json

def load_jsonl(file_path):
    """Load JSONL file into a list of dictionaries."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load test data
test_data = load_jsonl(TEST_FILE)

print(f"Loaded {len(test_data)} test examples from {TEST_FILE}")
print(f"\nFirst example:")
print(f"IPIS ID: {test_data[0]['ipis_id']}")
print(f"Prompt: {test_data[0]['prompt'][:100]}...")
print(f"Source: {test_data[0]['source'][:100]}...")

### Generate Predictions

This will generate gender-inclusive versions for all texts in the test set.

In [None]:
from tqdm.auto import tqdm
import os
import re

# Create checkpoint directory
CHECKPOINT_DIR = "inference_checkpoints"
SAVE_INTERVAL = 25
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Check for existing checkpoints to resume from
checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.startswith("predictions_checkpoint_")])
if checkpoint_files:
    latest_checkpoint = checkpoint_files[-1]
    checkpoint_path = os.path.join(CHECKPOINT_DIR, latest_checkpoint)
    print(f"Found checkpoint: {latest_checkpoint}")
    print("Loading predictions from checkpoint...")
    
    with open(checkpoint_path, 'r', encoding='utf-8') as f:
        predictions = [json.loads(line) for line in f]
    
    processed_ids = {p['ipis_id'] for p in predictions}
    start_idx = len(predictions)
    print(f"Resuming from example {start_idx} ({len(predictions)} already processed)")
else:
    predictions = []
    processed_ids = set()
    start_idx = 0
    print("Starting from scratch")

print(f"\nGenerating predictions for {len(test_data)} examples...")
print(f"Parameters: temperature={TEMPERATURE}, top_p={TOP_P}, max_new_tokens={MAX_NEW_TOKENS}")
print()

try:
    for idx, item in enumerate(tqdm(test_data[start_idx:], initial=start_idx, total=len(test_data), desc="Generating predictions")):
        # Skip if already processed
        if item['ipis_id'] in processed_ids:
            continue
        
        # Construct the prompt using the same format as training
        user_message = item['prompt'] + item['source']
        
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_message}
        ]
        
        # Apply chat template
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize and move to GPU
        inputs = tokenizer(text, return_tensors="pt").to("cuda")
        
        # Generate prediction
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                top_k=TOP_K,
                do_sample=True if TEMPERATURE > 0 else False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        # Decode only the generated tokens (excluding the input prompt)
        generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Clean up any remaining chat template artifacts
        if "<|im_end|>" in response:
            response = response.split("<|im_end|>")[0]
        if "<|im_start|>" in response:
            response = response.split("<|im_start|>")[-1]
            if "\n" in response:
                response = response.split("\n", 1)[-1]
        
        # For debugging: also save the full decoded output
        full_generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Create prediction entry in the same format as sample_proofreading.jsonl
        prediction = {
            "ipis_id": item['ipis_id'],
            "source": item['source'],
            "target": response.strip(),
            "normalised_target": None,  # Will be filled in later normalization step
            "full_generated_text": full_generated_text  # Save complete output for debugging
        }
        
        predictions.append(prediction)
        processed_ids.add(item['ipis_id'])
        
        # Save checkpoint every 10 examples
        if len(predictions) % SAVE_INTERVAL == 0:
            checkpoint_file = os.path.join(CHECKPOINT_DIR, f"predictions_checkpoint_{len(predictions):05d}.jsonl")
            with open(checkpoint_file, 'w', encoding='utf-8') as f:
                for pred in predictions:
                    f.write(json.dumps(pred, ensure_ascii=False) + '\n')
            print(f"\nCheckpoint saved: {checkpoint_file} ({len(predictions)} predictions)")
            print(f"Latest prediction (first 100 chars): {response.strip()[:100]}...")

    print(f"\nGenerated {len(predictions)} predictions successfully!")
    
except Exception as e:
    print(f"\nError occurred: {e}")
    print(f"Predictions saved up to example {len(predictions)}")
    print(f"To resume, simply re-run this cell - it will load from the last checkpoint")
    raise

### Save Predictions

In [None]:
# Save predictions to JSONL file
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
    for pred in predictions:
        f.write(json.dumps(pred, ensure_ascii=False) + '\n')

print(f"Predictions saved to {OUTPUT_FILE}")
print(f"Total predictions: {len(predictions)}")

### Preview Predictions

In [None]:
# Show a few example predictions
print("=" * 80)
print("EXAMPLE PREDICTIONS")
print("=" * 80)

for i in [0, 1, 2]:  # Show first 3 examples
    if i < len(predictions):
        print(f"\nExample {i+1}:")
        print(f"IPIS ID: {predictions[i]['ipis_id']}")
        print("-" * 80)
        print(f"SOURCE (first 200 chars):\n{predictions[i]['source'][:200]}...")
        print("-" * 80)
        print(f"TARGET (first 200 chars):\n{predictions[i]['target'][:200]}...")
        print("=" * 80)

### Statistics

In [None]:
# Calculate some basic statistics
import numpy as np

source_lengths = [len(p['source']) for p in predictions]
target_lengths = [len(p['target']) for p in predictions]

print("PREDICTION STATISTICS")
print("=" * 80)
print(f"Total predictions: {len(predictions)}")
print()
print("Source text lengths (characters):")
print(f"  Mean: {np.mean(source_lengths):.1f}")
print(f"  Median: {np.median(source_lengths):.1f}")
print(f"  Min: {np.min(source_lengths)}")
print(f"  Max: {np.max(source_lengths)}")
print()
print("Target text lengths (characters):")
print(f"  Mean: {np.mean(target_lengths):.1f}")
print(f"  Median: {np.median(target_lengths):.1f}")
print(f"  Min: {np.min(target_lengths)}")
print(f"  Max: {np.max(target_lengths)}")
print()
print(f"Average length increase: {np.mean(target_lengths) - np.mean(source_lengths):.1f} chars ({(np.mean(target_lengths) / np.mean(source_lengths) - 1) * 100:.1f}%)")
print("=" * 80)