In [1]:
%pip install -q transformers datasets peft accelerate bitsandbytes evaluate bert-score sacrebleu torch
%pip install -q sentencepiece sacremoses nltk


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset as HFDataset
import numpy as np
from evaluate import load
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
# Download NLTK data (required for SARI metric)
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab', quiet=True)
print("NLTK data ready")

NLTK data ready


## 1. Load and Prepare Dataset


In [None]:
# Load pre-split datasets (created using create_fixed_splits.py)
# This ensures consistency with Model 2 which uses the same splits
# Note: On Kaggle, data is in /kaggle/working/ instead of /content/
with open('/kaggle/working/final_train.json', 'r', encoding='utf-8') as f:
    train_data = json.load(f)

with open('/kaggle/working/final_val.json', 'r', encoding='utf-8') as f:
    val_data = json.load(f)

with open('/kaggle/working/final_test.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)

print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")
print(f"Test: {len(test_data)} samples")
print(f"\nSample entry:")
print(json.dumps(train_data[0], ensure_ascii=False, indent=2))
print("\n✅ Using fixed splits for consistency with Model 2")


Total pairs: 1995
Sample entry:
{
  "legal_sentence": "በዚህ አንቀጽ ንዑስ አንቀጽ መሰረት የሚቀርበው ክስ ገንዘብ ጠያቂው የማህበሩን መፍረስ ካወቀበት ጊዜ ጀምሮ በአምስት ዓመት ውስጥ ካልቀረበ በይርጋ ይታገዳል ፤",
  "simplified_sentence": "በዚህ አንቀጽ መሠረት የሚቀርብ ክስ፣ ገንዘብ ጠያቂው ማህበሩ መፍረሱን ካወቀበት ቀን ጀምሮ በ5 ዓመት ውስጥ ካልቀረበ የጊዜ ገደብ (ይርጋ) ያልፍበታል።"
}


## 2. Dataset Split


In [None]:
# Data already loaded from pre-split files in Cell 4
# No random splitting needed - using fixed splits for consistency
print(f"Train: {len(train_data)}")
print(f"Validation: {len(val_data)}")
print(f"Test: {len(test_data)}")
print("\n✅ Using fixed splits (no random splitting)")


Train: 1700
Validation: 200
Test: 95


## 3. Load Model and Tokenizer

**Model**: `masakhane/afribyt5-base`  
**Sequence Lengths**:
- Max input length: 512 bytes
- Max output length: 256 bytes


In [6]:
model_name = "masakhane/afri-byt5-base"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Sequence length settings for byte-level ByT5
max_input_length = 512  # bytes
max_output_length = 256  # bytes

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Max input length: {max_input_length} bytes")
print(f"Max output length: {max_output_length} bytes")


tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Tokenizer vocab size: 256
Max input length: 512 bytes
Max output length: 256 bytes


In [7]:
# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print(f"Model loaded: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Model loaded: masakhane/afri-byt5-base
Model parameters: 581,653,248


## 4. Configure LoRA

LoRA is applied to encoder and decoder linear layers as a regularization mechanism to reduce overfitting on a small dataset, while still allowing end-to-end adaptation.


In [8]:
# LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q", "v"],
    bias="none",
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient requirements for PEFT models
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()

# Move model to device
model = model.to(device)
print(f"Model moved to {device}")


trainable params: 2,211,840 || all params: 583,865,088 || trainable%: 0.3788
Model moved to cuda


## 5. Preprocess Data


In [9]:
def preprocess_function(examples):
    """Tokenize and prepare inputs/targets"""
    inputs = examples["legal_sentence"] if isinstance(examples["legal_sentence"], list) else [examples["legal_sentence"]]
    targets = examples["simplified_sentence"] if isinstance(examples["simplified_sentence"], list) else [examples["simplified_sentence"]]

    # Tokenize inputs (no padding - let data collator handle it)
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        truncation=True
    )

    # Tokenize targets (no padding - let data collator handle it)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_output_length,
            truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Convert to HuggingFace datasets
train_dataset = HFDataset.from_list(train_data)
val_dataset = HFDataset.from_list(val_data)
test_dataset = HFDataset.from_list(test_data)

# Tokenize datasets
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)
val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names
)

print("Datasets tokenized and ready")


Map:   0%|          | 0/1700 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Datasets tokenized and ready


## 6. Setup Metrics

**Primary Metric**: SARI (System output Against References and Inputs)  
**Secondary Metric**: BERTScore (multilingual)

Automatic metrics are complemented with qualitative evaluation on held-out legal sentences to verify preservation of legal meaning.


In [None]:
# Load metrics
sari_metric = load("sari")
bertscore_metric = load("bertscore")
bleu_metric = load("sacrebleu")

def compute_metrics(eval_pred):
    """Compute SARI and BERTScore"""
    predictions, labels = eval_pred

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # For SARI, we need sources (inputs) - we'll need to handle this differently
    # For now, compute BERTScore
    bertscore_result = bertscore_metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        lang="am",  # Amharic
        device=device
    )

    return {
        "bertscore_f1": np.mean(bertscore_result["f1"])
    }

# Note: SARI requires source sentences, which we'll compute separately during evaluation


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

## 7. Training Configuration

**Key Settings**:
- Label smoothing: 0.1
- Early stopping: patience 2 epochs, monitor validation SARI
- Learning rate: 5e-4
- Batch size: 8
- Gradient accumulation: 4 (effective batch size: 32)


In [11]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

# Training arguments
training_args = TrainingArguments(
    output_dir="/kaggle/working/afribyt5-legal-simplification",  # Changed for Kaggle
    overwrite_output_dir=True,
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    warmup_steps=100,
    logging_steps=50,
    eval_steps=300,
    save_steps=300,
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_sari",
    greater_is_better=True,
    label_smoothing_factor=0.1,
    bf16=True,
    gradient_checkpointing=True,
    report_to="none",
)

print("Training arguments configured")


Training arguments configured


## 8. Custom Trainer with SARI-based Early Stopping

Early stopping patience: 2 epochs, monitoring validation SARI


In [12]:
class CustomTrainer(Trainer):
    """Custom trainer that computes SARI for early stopping"""

    def __init__(self, *args, source_sentences=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_sentences = source_sentences

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Override evaluate to include SARI"""
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        # Get predictions
        predictions = self.predict(eval_dataset)
        pred_texts = tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)
        label_texts = tokenizer.batch_decode(predictions.label_ids, skip_special_tokens=True)

        # Compute BERTScore
        bertscore_result = bertscore_metric.compute(
            predictions=pred_texts,
            references=label_texts,
            lang="am",
            device=device
        )

        # Compute SARI if we have source sentences
        sari_scores = []
        if self.source_sentences is not None:
            for i, (source, pred, ref) in enumerate(zip(self.source_sentences, pred_texts, label_texts)):
                try:
                    sari = sari_metric.compute(
                        sources=[source],
                        predictions=[pred],
                        references=[[ref]]
                    )
                    sari_scores.append(sari["sari"])
                except:
                    pass

        metrics = {
            f"{metric_key_prefix}_bertscore_f1": np.mean(bertscore_result["f1"]),
        }

        # Always compute SARI if source sentences are available
        if sari_scores:
            metrics[f"{metric_key_prefix}_sari"] = np.mean(sari_scores)
        else:
            # If SARI couldn't be computed, set a default value to avoid errors
            metrics[f"{metric_key_prefix}_sari"] = 0.0

        self.log(metrics)
        return metrics

# Prepare source sentences for validation set (for SARI computation)
val_source_sentences = [item["legal_sentence"] for item in val_data]

# Initialize trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    source_sentences=val_source_sentences,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

print("Trainer initialized with SARI-based early stopping")


Trainer initialized with SARI-based early stopping


## 9. Train Model


In [13]:
import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model.gradient_checkpointing_enable()
model.config.use_cache = False

# Train
print("Starting training...")
train_result = trainer.train()
print("Training completed!")

# Save final model
trainer.save_model("/kaggle/working/afribyt5-legal-simplification-final")
tokenizer.save_pretrained("/kaggle/working/afribyt5-legal-simplification-final")
print("Model saved to /kaggle/working/afribyt5-legal-simplification-final")
print("✅ This will be automatically available in the notebook's output files after the session ends")
print("   You can download it from the 'Output' tab in Kaggle")


Starting training...


Step,Training Loss,Validation Loss


Training completed!


('./afribyt5-legal-simplification-final/tokenizer_config.json',
 './afribyt5-legal-simplification-final/special_tokens_map.json',
 './afribyt5-legal-simplification-final/added_tokens.json')

In [19]:
# Create zip file of the model (saved to /kaggle/working/ so it's available in output)
!cd /kaggle/working && zip -r afribyt5-legal-simplification.zip afribyt5-legal-simplification
print("Model zipped to /kaggle/working/afribyt5-legal-simplification.zip")
print("✅ This file will be automatically available in the notebook's output files after the session ends")
print("   You can download it from the 'Output' tab in Kaggle")


  adding: afribyt5-legal-simplification/ (stored 0%)
  adding: afribyt5-legal-simplification/checkpoint-216/ (stored 0%)
  adding: afribyt5-legal-simplification/checkpoint-216/adapter_model.safetensors (deflated 8%)
  adding: afribyt5-legal-simplification/checkpoint-216/special_tokens_map.json (deflated 86%)
  adding: afribyt5-legal-simplification/checkpoint-216/added_tokens.json (deflated 82%)
  adding: afribyt5-legal-simplification/checkpoint-216/adapter_config.json (deflated 57%)
  adding: afribyt5-legal-simplification/checkpoint-216/tokenizer_config.json (deflated 95%)
  adding: afribyt5-legal-simplification/checkpoint-216/training_args.bin (deflated 53%)
  adding: afribyt5-legal-simplification/checkpoint-216/scheduler.pt (deflated 62%)
  adding: afribyt5-legal-simplification/checkpoint-216/optimizer.pt (deflated 8%)
  adding: afribyt5-legal-simplification/checkpoint-216/README.md (deflated 66%)
  adding: afribyt5-legal-simplification/checkpoint-216/trainer_state.json (deflated 62%

In [30]:
# Create a truly unseen test set
print("=== Creating Clean Test Set ===")

# Step 1: Get all sentences that were used in training/validation
train_sentences = set([item["legal_sentence"] for item in train_data])
val_sentences = set([item["legal_sentence"] for item in val_data])
all_seen_sentences = train_sentences.union(val_sentences)

print(f"Sentences in train: {len(train_sentences)}")
print(f"Sentences in val: {len(val_sentences)}")
print(f"Total unique sentences seen during training: {len(all_seen_sentences)}")

# Step 2: From the ORIGINAL data list, find items NOT in train or val
unseen_data = []
for item in data:  # Go through ALL original data
    if item["legal_sentence"] not in all_seen_sentences:
        unseen_data.append(item)

print(f"\nOriginal data size: {len(data)}")
print(f"Unseen data available: {len(unseen_data)}")

# Step 3: Take first 100 (or available) from truly unseen data
clean_test_size = min(100, len(unseen_data))
clean_test_data = unseen_data[:clean_test_size]

print(f"✅ Created clean test set with {len(clean_test_data)} samples")
print(f"   (These samples were NEVER in train or val)")

# Step 4: Verify no overlap
clean_test_sentences = set([item["legal_sentence"] for item in clean_test_data])
overlap = all_seen_sentences.intersection(clean_test_sentences)

if overlap:
    print(f"⚠️  ERROR: Still found {len(overlap)} overlapping sentences!")
    print("This shouldn't happen. Check for duplicate entries in original data.")
else:
    print("✅ Verified: Clean test set has ZERO overlap with train or val!")

=== Creating Clean Test Set ===
Sentences in train: 1626
Sentences in val: 199
Total unique sentences seen during training: 1803

Original data size: 1995
Unseen data available: 82
✅ Created clean test set with 82 samples
   (These samples were NEVER in train or val)
✅ Verified: Clean test set has ZERO overlap with train or val!


## 10. Evaluation on Test Set

Evaluate on the held-out test set (100 samples) that was kept unseen during training.


In [None]:
# Prepare CLEAN test dataset (truly unseen)
clean_test_source_sentences = [item["legal_sentence"] for item in clean_test_data]
clean_test_target_sentences = [item["simplified_sentence"] for item in clean_test_data]

# Tokenize inputs for inference
clean_test_inputs = tokenizer(
    clean_test_source_sentences,
    max_length=max_input_length,
    truncation=True,
    padding=True,
    return_tensors="pt"
).to(device)

# Set model to eval mode
model.eval()

# Generate predictions directly (bypassing trainer.predict)
print("Evaluating on CLEAN test set (truly unseen)...")
clean_test_pred_texts = []

with torch.no_grad():
    for i in range(0, len(clean_test_source_sentences), 8):  # Process in batches of 8
        batch_input_ids = clean_test_inputs["input_ids"][i:i+8]
        batch_attention_mask = clean_test_inputs["attention_mask"][i:i+8]

        # Generate
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_length=max_output_length,
            num_beams=4,
            early_stopping=True,
            do_sample=False
        )

        # Decode batch
        batch_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        clean_test_pred_texts.extend(batch_preds)

        if (i // 8 + 1) % 10 == 0:
            print(f"Processed {min(i+8, len(clean_test_source_sentences))}/{len(clean_test_source_sentences)} samples...")

print(f"Generated {len(clean_test_pred_texts)} predictions")

# Use original target sentences for metrics
clean_test_label_texts = clean_test_target_sentences

# Compute metrics
print("\n=== CLEAN Test Set Results (Truly Unseen) ===")

# BERTScore
bertscore_clean = bertscore_metric.compute(
    predictions=clean_test_pred_texts,
    references=clean_test_label_texts,
    lang="am",
    device=device
)
print(f"BERTScore F1: {np.mean(bertscore_clean['f1']):.4f}")

# SARI
sari_scores_clean = []
for source, pred, ref in zip(clean_test_source_sentences, clean_test_pred_texts, clean_test_label_texts):
    try:
        sari = sari_metric.compute(
            sources=[source],
            predictions=[pred],
            references=[[ref]]
        )
        sari_scores_clean.append(sari["sari"])
    except Exception as e:
        print(f"SARI computation error: {e}")
        pass

if sari_scores_clean:
    print(f"SARI: {np.mean(sari_scores_clean):.4f}")

# BLEU Score
bleu_result = bleu_metric.compute(
    predictions=clean_test_pred_texts,
    references=[[ref] for ref in clean_test_label_texts]
)
print(f"BLEU: {bleu_result['score']:.4f}")

# Exact Match
exact_matches = sum(1 for pred, ref in zip(clean_test_pred_texts, clean_test_label_texts) if pred == ref)
exact_match_rate = exact_matches / len(clean_test_pred_texts)
print(f"Exact Match: {exact_match_rate:.4f} ({exact_matches}/{len(clean_test_pred_texts)})")

# Length Statistics
source_lengths = [len(s) for s in clean_test_source_sentences]
pred_lengths = [len(p) for p in clean_test_pred_texts]
ref_lengths = [len(r) for r in clean_test_label_texts]

print(f"\n=== Length Statistics ===")
print(f"Source (original):")
print(f"  Mean: {np.mean(source_lengths):.1f} chars, Median: {np.median(source_lengths):.1f} chars")
print(f"  Min: {np.min(source_lengths)} chars, Max: {np.max(source_lengths)} chars")
print(f"\nPrediction (simplified):")
print(f"  Mean: {np.mean(pred_lengths):.1f} chars, Median: {np.median(pred_lengths):.1f} chars")
print(f"  Min: {np.min(pred_lengths)} chars, Max: {np.max(pred_lengths)} chars")
print(f"\nReference (target):")
print(f"  Mean: {np.mean(ref_lengths):.1f} chars, Median: {np.median(ref_lengths):.1f} chars")
print(f"  Min: {np.min(ref_lengths)} chars, Max: {np.max(ref_lengths)} chars")

# Length ratio (prediction vs source)
length_ratios = [p/s if s > 0 else 0 for p, s in zip(pred_lengths, source_lengths)]
print(f"\nLength Ratio (Prediction/Source):")
print(f"  Mean: {np.mean(length_ratios):.3f}, Median: {np.median(length_ratios):.3f}")
print(f"  (Values < 1.0 indicate shortening)")

print(f"\nEvaluated {len(clean_test_pred_texts)} CLEAN test samples (truly unseen)")


Evaluating on CLEAN test set (truly unseen)...
Processed 80/82 samples...
Generated 82 predictions

=== CLEAN Test Set Results (Truly Unseen) ===
BERTScore F1: 0.9698
SARI: 37.9896

Evaluated 82 CLEAN test samples (truly unseen)

=== Comparison ===
Previous (contaminated) results:
  BERTScore F1: 0.9722
  SARI: 38.1762

New (clean) results:
  BERTScore F1: 0.9698
  SARI: 37.9896


In [32]:
# Qualitative Analysis - Check if model is actually simplifying
print("=== Qualitative Analysis: Is the model really simplifying? ===\n")

num_samples = 10
for i in range(min(num_samples, len(clean_test_source_sentences))):
    print(f"Sample {i+1}:")
    print(f"Original Legal ({len(clean_test_source_sentences[i])} chars):")
    print(f"  {clean_test_source_sentences[i]}")
    print(f"\nReference Simplified ({len(clean_test_target_sentences[i])} chars):")
    print(f"  {clean_test_target_sentences[i]}")
    print(f"\nModel Prediction ({len(clean_test_pred_texts[i])} chars):")
    print(f"  {clean_test_pred_texts[i]}")

    # Check similarity
    if clean_test_pred_texts[i] == clean_test_target_sentences[i]:
        print("  ⚠️  EXACT MATCH - Model might be memorizing!")
    elif clean_test_pred_texts[i] == clean_test_source_sentences[i]:
        print("  ⚠️  NO CHANGE - Model not simplifying!")

    print("-" * 80)
    print()

=== Qualitative Analysis: Is the model really simplifying? ===

Sample 1:
Original Legal (150 chars):
  1856. ፲፭. የግዴታ ትምህርት መግቢያ ዕድሜ የቅድመ አንደኛ ደረጃ ትምህርት ከአራት አመት እስከ ስድስት አመት እድሜ ያላቸውን ሕፃናት የሚያጠቃልል ሲሆን የቅድመ አንደኛ ደረጃ የግዴታ ትምህርት የመግቢያ እድሜ አምስት አመት ይሆናል ፡፡

Reference Simplified (89 chars):
  የቅድመ አንደኛ ደረጃ ትምህርት ከ 4 እስከ 6 ዓመት የሆኑ ሕፃናትን የሚመለከት ሲሆን፣ የግዴታ ትምህርት የሚጀምረው በ 5 ዓመት ዕድሜ ነው።

Model Prediction (50 chars):
  ቅድመ አንደኛ ደረጃ የግዴታ ትምህርት የመግቢያ እድሜ አምስት አመት ይሆናል ፡፡
--------------------------------------------------------------------------------

Sample 2:
Original Legal (139 chars):
  1806. የፀድቆ ዲዛይን ስም ማዘዋወርያ ፈቃድ ለማግኘት ቀደም ሲል በህንጻ አዋጅ፣ ደንብ፣ መመሪያ እንዲሁም በኢትዮጵያ የህንጻ ስታንዳርዶች መሰረት ፀድቆ የተዘጋጀውን ዲዛይን ከማመልከቻው ጋር ተያይዞ መቅረብ አለበት ፡፡

Reference Simplified (54 chars):
  የዲዛይን ስም ለማዛወር፣ ቀደም ብሎ የጸደቀው ዲዛይን ከማመልከቻ ጋር መቅረብ አለበት።

Model Prediction (96 chars):
  ፀድቆ ዲዛይን ስም ማዘዋወርያ ፈቃድ ለማግኘት ቀደም ሲል በህንጻ አዋጅ፣ ደንብ፣ መመሪያ እንዲሁም በህንጻ አዋጅ፣ ደንብ፣ መመሪያ እንዲሁም በኢትዮጵያ የ
------------------------------------------------------

In [33]:
# Check for repetition and truncation issues
print("=== Model Quality Issues ===")

repetition_count = 0
truncation_count = 0
no_simplification = 0

for i, (orig, pred, ref) in enumerate(zip(clean_test_source_sentences, clean_test_pred_texts, clean_test_target_sentences)):
    # Check for repetition (same phrase appears 2+ times)
    words = pred.split()
    if len(words) > 3:
        # Check for 3+ word phrases repeating
        for j in range(len(words) - 2):
            phrase = " ".join(words[j:j+3])
            if pred.count(phrase) > 1:
                repetition_count += 1
                if repetition_count <= 3:
                    print(f"Repetition in sample {i+1}: '{phrase}'")
                break

    # Check for truncation (ends abruptly, not with punctuation)
    if pred and pred[-1] not in ['።', '፤', '፥', '፦', '.', '!', '?']:
        truncation_count += 1

    # Check if actually simplified (should be shorter)
    if len(pred) >= len(orig) * 0.9:  # Less than 10% reduction
        no_simplification += 1

print(f"\nRepetition issues: {repetition_count}/{len(clean_test_pred_texts)} ({100*repetition_count/len(clean_test_pred_texts):.1f}%)")
print(f"Truncation issues: {truncation_count}/{len(clean_test_pred_texts)} ({100*truncation_count/len(clean_test_pred_texts):.1f}%)")
print(f"No simplification: {no_simplification}/{len(clean_test_pred_texts)} ({100*no_simplification/len(clean_test_pred_texts):.1f}%)")

=== Model Quality Issues ===
Repetition in sample 2: 'በህንጻ አዋጅ፣ ደንብ፣'
Repetition in sample 6: 'ማህበር መከፋፈል ማለት'
Repetition in sample 9: 'መልሶ የማዋቀር ሥነ'

Repetition issues: 24/82 (29.3%)
Truncation issues: 42/82 (51.2%)
No simplification: 12/82 (14.6%)


## 11. Qualitative Evaluation

Automatic metrics are complemented with qualitative evaluation on held-out legal sentences to verify preservation of legal meaning.


In [15]:
# Show sample predictions for qualitative evaluation
print("=== Sample Predictions (Qualitative Evaluation) ===\n")

num_samples = 5
for i in range(min(num_samples, len(test_data))):
    print(f"Sample {i+1}:")
    print(f"Original Legal: {test_source_sentences[i]}")
    print(f"Reference Simplified: {test_target_sentences[i]}")
    print(f"Model Prediction: {test_pred_texts[i]}")
    print("-" * 80)
    print()


=== Sample Predictions (Qualitative Evaluation) ===

Sample 1:
Original Legal: 1856. ፲፭. የግዴታ ትምህርት መግቢያ ዕድሜ የቅድመ አንደኛ ደረጃ ትምህርት ከአራት አመት እስከ ስድስት አመት እድሜ ያላቸውን ሕፃናት የሚያጠቃልል ሲሆን የቅድመ አንደኛ ደረጃ የግዴታ ትምህርት የመግቢያ እድሜ አምስት አመት ይሆናል ፡፡
Reference Simplified: የቅድመ አንደኛ ደረጃ ትምህርት ከ 4 እስከ 6 ዓመት የሆኑ ሕፃናትን የሚመለከት ሲሆን፣ የግዴታ ትምህርት የሚጀምረው በ 5 ዓመት ዕድሜ ነው።


NameError: name 'test_pred_texts' is not defined

## 12. Inference Function

Function to use the fine-tuned model for inference.


In [None]:
def simplify_legal_text(legal_sentence, model, tokenizer, max_input_length=512, max_output_length=256):
    """
    Simplify a legal sentence using the fine-tuned model.

    Args:
        legal_sentence: Input legal text in Amharic
        model: Fine-tuned model
        tokenizer: Tokenizer
        max_input_length: Maximum input length in bytes
        max_output_length: Maximum output length in bytes

    Returns:
        Simplified sentence
    """
    model.eval()

    # Tokenize input
    inputs = tokenizer(
        legal_sentence,
        max_length=max_input_length,
        truncation=True,
        padding=True,
        return_tensors="pt"
    ).to(device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_output_length,
            num_beams=4,
            early_stopping=True
        )

    # Decode
    simplified = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return simplified

# Test inference
test_sentence = test_source_sentences[0]
simplified = simplify_legal_text(test_sentence, model, tokenizer)
print(f"Original: {test_sentence}")
print(f"Simplified: {simplified}")


## Summary

This notebook fine-tuned AfriByT5-Base for Amharic legal text simplification with:

- **Model**: AfriByT5-Base (byte-level, African-language priors)
- **LoRA**: Regularization mechanism (not freezing, end-to-end adaptation)
- **Data Split**: 1,700 train / 200 validation / 100 test
- **Early Stopping**: Patience 2 epochs, monitoring validation SARI
- **Label Smoothing**: 0.1
- **Sequence Lengths**: 512 bytes input / 256 bytes output
- **Metrics**: SARI (primary), BERTScore multilingual (secondary)
- **Qualitative Evaluation**: Manual review of predictions to verify legal meaning preservation

The model learns simplification behavior, while legal knowledge is supplied at inference time via RAG.
