# Multi-Task Fine-tuning: Translation + Proofreading

This notebook takes the already fine-tuned proofreading model and adds translation capability with a short, low learning rate fine-tuning.

**Strategy:**
- Start from the best proofreading checkpoint
- Train on translation data with a small learning rate (to preserve proofreading knowledge)
- Short training (1 epoch) to avoid catastrophic forgetting
- The model should then be able to do both tasks

### Configuration

In [None]:
import os

# Base model - using the best proofreading checkpoint
BASE_MODEL_SIZE = "8B"
BASE_LORA_RANK = 64
BASE_EPOCHS = 2
BASE_BATCH_SIZE = 1
BASE_GRADIENT_ACCUMULATION_STEPS = 2
BASE_LEARNING_RATE = 2e-4
BASE_WARMUP_STEPS = 10
BASE_MAX_SEQ_LENGTH = 4096

# Path to the proofreading model
PROOFREADING_MODEL_PATH = f"../../../outputs/qwen3_{BASE_MODEL_SIZE}_polish_inclusive_proofreading_lora_r{BASE_LORA_RANK}_lr{BASE_LEARNING_RATE}_ep{BASE_EPOCHS}_bs{BASE_BATCH_SIZE}_ga{BASE_GRADIENT_ACCUMULATION_STEPS}_warmup{BASE_WARMUP_STEPS}_seq{BASE_MAX_SEQ_LENGTH}/checkpoint-23000"
PROOFREADING_MODEL_PATH = os.path.abspath(PROOFREADING_MODEL_PATH)

print(f"Starting from proofreading model: {PROOFREADING_MODEL_PATH}")

In [None]:
# Training configuration for translation fine-tuning
MAX_SEQ_LENGTH = 4096

# LoRA configuration - keep same rank as proofreading
LORA_RANK = 64
LORA_ALPHA = 64
LORA_DROPOUT = 0

# Training hyperparameters - SMALL learning rate to preserve proofreading knowledge
LEARNING_RATE = 1e-4  # Much smaller than proofreading (2e-4)
EPOCHS = 1  # Just 1 epoch to add translation capability
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch size = 4
WARMUP_STEPS = 5
WEIGHT_DECAY = 0.01

# Output directory
OUTPUT_DIR = f"../../../outputs/qwen3_{BASE_MODEL_SIZE}_multitask_proofreading+translation_lora_r{LORA_RANK}_lr{LEARNING_RATE}_ep{EPOCHS}_bs{BATCH_SIZE}_ga{GRADIENT_ACCUMULATION_STEPS}_warmup{WARMUP_STEPS}_seq{MAX_SEQ_LENGTH}"
OUTPUT_DIR = os.path.abspath(OUTPUT_DIR)

# Data paths
TRAIN_DATA_PATH = "../../../data/taskB/train.jsonl"

print(f"Output directory: {OUTPUT_DIR}")
print(f"Training data: {TRAIN_DATA_PATH}")
print(f"\nKey settings:")
print(f"  - Learning rate: {LEARNING_RATE} (5x lower than proofreading)")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

### Setup Environment

In [None]:
# Set cache directories
import os
os.environ['HF_HOME'] = '/home/adam/Downloads/poleval-gender-new/.cache/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/home/adam/Downloads/poleval-gender-new/.cache/huggingface/transformers'
os.environ['HF_DATASETS_CACHE'] = '/home/adam/Downloads/poleval-gender-new/.cache/huggingface/datasets'
os.environ['TRITON_CACHE_DIR'] = '/home/adam/Downloads/poleval-gender-new/.cache/triton'
    
import warnings
warnings.filterwarnings('ignore')

### Load the Proofreading Model

In [None]:
from unsloth import FastLanguageModel
import torch

print("Loading proofreading model...")
print(f"From: {PROOFREADING_MODEL_PATH}")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = PROOFREADING_MODEL_PATH,
    max_seq_length = MAX_SEQ_LENGTH,
    dtype = None,  # Auto-detect
    load_in_4bit = True,
)

print("✓ Proofreading model loaded!")
print(f"Model type: {type(model).__name__}")

### Continue Training with Existing LoRA Adapters

The model already has LoRA adapters from proofreading training. We'll continue training with the same adapters to add translation capability.

In [None]:
# The model already has LoRA adapters from proofreading training
# We don't need to add new ones - just continue training with the existing adapters
# This will allow the model to learn translation while preserving proofreading knowledge

print("✓ Model already has LoRA adapters - ready for continued training!")
print(f"Model type: {type(model).__name__}")

### Load Translation Training Data

In [None]:
import json
from datasets import Dataset

def load_translation_data(file_path):
    """Load translation training data."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            data.append({
                'prompt': item['prompt'],
                'source': item['source'],
                'target': item['target'],
                'prompt_language': item['prompt_language'],
                'source_language': item['source_language'],
                'target_language': item['target_language']
            })
    return data

# Load training data
train_data = load_translation_data(TRAIN_DATA_PATH)

print(f"Loaded {len(train_data)} training examples")
print(f"\nFirst example:")
print(f"  Direction: {train_data[0]['source_language']} → {train_data[0]['target_language']}")
print(f"  Prompt: {train_data[0]['prompt'][:80]}...")
print(f"  Source: {train_data[0]['source'][:80]}...")
print(f"  Target: {train_data[0]['target'][:80]}...")

### Load System Prompts

In [None]:
# Load translation system prompts
with open('../../../system_prompts/translation/system_prompt_en_translation', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT_EN = f.read().strip()

with open('../../../system_prompts/translation/system_prompt_pl_translation', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT_PL = f.read().strip()

print("System prompts loaded.")
print(f"English prompt: {len(SYSTEM_PROMPT_EN)} chars")
print(f"Polish prompt: {len(SYSTEM_PROMPT_PL)} chars")

### Prepare Dataset for Training

In [None]:
def format_translation_prompt(example):
    """Format a translation example with the appropriate system prompt."""
    # Select system prompt based on prompt language
    system_prompt = SYSTEM_PROMPT_EN if example['prompt_language'] == 'EN' else SYSTEM_PROMPT_PL
    
    # Construct user message
    user_message = example['prompt'] + example['source']
    
    # Format as chat
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": example['target']}
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )
    
    return {"text": text}

# Convert to HuggingFace dataset and format
train_dataset = Dataset.from_list(train_data)
train_dataset = train_dataset.map(format_translation_prompt, remove_columns=train_dataset.column_names)

print(f"Formatted {len(train_dataset)} training examples")
print(f"\nExample formatted text (first 500 chars):")
print(train_dataset[0]['text'][:500] + "...")

### Setup Training Arguments

In [None]:
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    warmup_steps=WARMUP_STEPS,
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    logging_steps=10,
    optim="adamw_8bit",
    seed=42,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    load_best_model_at_end=False,
    report_to="none",
)

print("Training arguments configured:")
print(f"  - Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - Total epochs: {EPOCHS}")
print(f"  - Estimated steps: {len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) * EPOCHS}")

### Initialize Trainer

In [None]:
from unsloth import UnslothTrainer, UnslothTrainingArguments
from transformers import DataCollatorForLanguageModeling

trainer = UnslothTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=MAX_SEQ_LENGTH,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
    args=UnslothTrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=WARMUP_STEPS,
        num_train_epochs=EPOCHS,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        optim="adamw_8bit",
        seed=42,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=3,
        load_best_model_at_end=False,
        report_to="mlflow",
    ),
)

print("Trainer initialized")

### Start Training

This will fine-tune the proofreading model on translation data with a small learning rate to add translation capability while preserving proofreading skills.

In [None]:
print("="*80)
print("STARTING MULTI-TASK FINE-TUNING")
print("="*80)
print(f"Base model: Proofreading checkpoint")
print(f"New task: Translation (PL⇄EN)")
print(f"Strategy: Low LR ({LEARNING_RATE}), {EPOCHS} epoch")
print("="*80)
print()

trainer_stats = trainer.train()

print("\n" + "="*80)
print("TRAINING COMPLETE!")
print("="*80)

### Save Final Model

In [None]:
# Save the final model
final_model_path = os.path.join(OUTPUT_DIR, "final_model")
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"Model saved to: {final_model_path}")
print(f"\nThis model should now be able to:")
print(f"  1. Perform gender-inclusive proofreading (original task)")
print(f"  2. Translate PL⇄EN with gender inclusivity (new task)")

### Training Statistics

In [None]:
print("TRAINING STATISTICS")
print("="*80)
print(f"Total training time: {trainer_stats.metrics['train_runtime']:.2f} seconds")
print(f"Training samples/second: {trainer_stats.metrics['train_samples_per_second']:.2f}")
print(f"Final loss: {trainer_stats.metrics['train_loss']:.4f}")
print("="*80)