In [1]:
%pip install -q transformers datasets peft accelerate bitsandbytes evaluate bert-score sacrebleu torch
%pip install -q sentencepiece sacremoses nltk


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset as HFDataset
import numpy as np
from evaluate import load
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
# Download NLTK data (required for SARI metric)
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab', quiet=True)
print("NLTK data ready")

NLTK data ready


## 1. Load and Prepare Dataset


In [4]:
# Load pre-split datasets (created using create_fixed_splits.py)
# This ensures consistency with Model 3 which uses the same splits
# Note: simplification_type will be used during training to condition the model
with open('/content/regen_train.json', 'r', encoding='utf-8') as f:
    train_data = json.load(f)

with open('/content/regen_val.json', 'r', encoding='utf-8') as f:
    val_data = json.load(f)

with open('/content/regen_test.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)

print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")
print(f"Test: {len(test_data)} samples")
print(f"\nSample entry:")
print(json.dumps(train_data[0], ensure_ascii=False, indent=2))
print("\nUsing fixed splits for consistency with Model 3")
print("Note: simplification_type will be used during training to condition the model on the type of simplification needed")


Train: 1700 samples
Validation: 200 samples
Test: 100 samples

Sample entry:
{
  "legal_sentence": "1596. እንደአግባቡ ማመልከቻውን ያቀረበው ሰው ወይም የከሰረው ሰው ንብረት ጠባቂ ተቆጣጣሪ ዳኛው በሰጠው ውሳኔ ላይ ውሳኔው በተሰጠ በአስር ቀናት ውስጥ ለፍርድ ቤት ይግባኝ ማቅረብ ይችላል ።",
  "simplified_sentence": "ባለቤቱ ወይም የንብረት ጠባቂው ዳኛው በሰጠው ውሳኔ ካልተስማሙ በ10 ቀናት ውስጥ ለፍርድ ቤት ይግባኝ ማለት ይችላሉ።",
  "simplification_type": "structure_reordering"
}

Using fixed splits for consistency with Model 3
Note: simplification_type will be used during training to condition the model on the type of simplification needed


## 2. Dataset Split


In [5]:
# Data already loaded from pre-split files in Cell 4
# No random splitting needed - using fixed splits for consistency
print(f"Train: {len(train_data)}")
print(f"Validation: {len(val_data)}")
print(f"Test: {len(test_data)}")
print("\nUsing fixed splits (no random splitting)")


Train: 1700
Validation: 200
Test: 100

Using fixed splits (no random splitting)


## 3. Load Model and Tokenizer

**Model**: `masakhane/afribyt5-base`  
**Sequence Lengths**:
- Max input length: 512 bytes
- Max output length: 256 bytes


In [6]:
model_name = "masakhane/afri-byt5-base"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Sequence length settings for byte-level ByT5
max_input_length = 512  # tokens
max_output_length = 384  # tokens (increased from 256 to reduce truncation)

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Max input length: {max_input_length} tokens")
print(f"Max output length: {max_output_length} tokens")


tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Tokenizer vocab size: 256
Max input length: 512 tokens
Max output length: 384 tokens


In [7]:
# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print(f"Model loaded: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Model loaded: masakhane/afri-byt5-base
Model parameters: 581,653,248


## 4. Configure LoRA

LoRA is applied to encoder and decoder linear layers as a regularization mechanism to reduce overfitting on a small dataset, while still allowing end-to-end adaptation.


In [8]:
# LoRA configuration
# Increased capacity: r=32, alpha=64 to address underfitting risk
# Expanded target_modules to include all attention components for better learning
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=32,  # Increased from 16
    lora_alpha=64,  # Increased from 32 (proportional to r)
    lora_dropout=0.1,
    target_modules=["q", "v", "k", "o"],  # Expanded from ["q", "v"] to include all attention components
    bias="none",
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient requirements for PEFT models
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()

# Move model to device
model = model.to(device)
print(f"Model moved to {device}")


trainable params: 8,847,360 || all params: 590,500,608 || trainable%: 1.4983
Model moved to cuda


## 5. Preprocess Data


In [9]:
def preprocess_function(examples):
    """Tokenize and prepare inputs/targets with simplification instruction and type

    The model learns to condition on simplification_type during training.
    This helps the model apply the appropriate simplification strategy.
    """
    inputs = examples["legal_sentence"] if isinstance(examples["legal_sentence"], list) else [examples["legal_sentence"]]
    targets = examples["simplified_sentence"] if isinstance(examples["simplified_sentence"], list) else [examples["simplified_sentence"]]
    sim_types = examples.get("simplification_type", ["unknown"] * len(inputs))
    if not isinstance(sim_types, list):
        sim_types = [sim_types]

    # Add instruction prompt with simplification_type to guide the model
    # This teaches the model to condition on the type of simplification needed
    base_instruction = "የሕግ ቃላትን ለግለሰቦች ለመረዳት ቀላል አማርኛ ውስጥ አቅርብ: "  # "Simplify legal text to plain Amharic: "

    # Map simplification_type to Amharic labels for the prompt
    type_map = {
        "vocabulary_simplification": "[የቃላት ማቃለል]",
        "sentence_splitting": "[የዓረፍተ ነገር መከፋፈል]",
        "deletion": "[መሻር]",
        "unknown": "[አጠቃላይ]"
    }

    # Create inputs with type information
    inputs = [base_instruction + type_map.get(sim_type, f"[{sim_type}]") + " " + inp
              for inp, sim_type in zip(inputs, sim_types)]

    # Tokenize inputs (no padding - let data collator handle it)
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        truncation=True
    )

    # Tokenize targets (no padding - let data collator handle it)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_output_length,
            truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Convert to HuggingFace datasets
train_dataset = HFDataset.from_list(train_data)
val_dataset = HFDataset.from_list(val_data)
test_dataset = HFDataset.from_list(test_data)

# Tokenize datasets
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)
val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names
)

print("Datasets tokenized and ready")


Map:   0%|          | 0/1700 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Datasets tokenized and ready


## 6. Setup Metrics

**Primary Metric**: SARI (System output Against References and Inputs)  
**Secondary Metric**: BERTScore (multilingual)

Automatic metrics are complemented with qualitative evaluation on held-out legal sentences to verify preservation of legal meaning.


In [10]:
# Load metrics
sari_metric = load("sari")
bertscore_metric = load("bertscore")
bleu_metric = load("sacrebleu")

def compute_metrics(eval_pred):
    """Compute SARI and BERTScore"""
    predictions, labels = eval_pred

    # Handle nested arrays and convert to numpy if needed
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    if isinstance(labels, tuple):
        labels = labels[0]

    # Convert to numpy arrays if they're not already
    predictions = np.array(predictions)
    labels = np.array(labels)

    # If predictions are logits (shape has extra dimension), take argmax
    if len(predictions.shape) > 1 and predictions.shape[-1] > 1:
        predictions = np.argmax(predictions, axis=-1)

    # Replace -100 in labels with pad_token_id for decoding (data collator uses -100 for ignored tokens)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute BERTScore
    bertscore_result = bertscore_metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        lang="am",  # Amharic
        device=device
    )

    return {
        "bertscore_f1": np.mean(bertscore_result["f1"])
    }

# Note: SARI requires source sentences, which we'll compute separately during evaluation

# Note: SARI requires source sentences, which we'll compute separately during evaluation


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

## 7. Training Configuration

**Key Settings**:
- Label smoothing: 0.1
- Early stopping: patience 2 epochs, monitor validation SARI
- Learning rate: 2e-4
- Batch size: 8
- Gradient accumulation: 8 (effective batch size: 64)
- Epochs: 6 (increased from 4)
- LR scheduler: Cosine with warmup_ratio=0.1


In [11]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./afribyt5-legal-simplification",
    overwrite_output_dir=True,
    num_train_epochs=6,  # Increased from 4 for better learning
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=8,  # Increased from 4 (effective batch size: 64)
    learning_rate=2e-4,  # Keep current (don't lower unless instability appears)
    lr_scheduler_type="cosine",  # Added: cosine decay for better convergence
    warmup_ratio=0.1,  # Added: warmup ratio (replaces warmup_steps)
    logging_steps=50,
    eval_steps=200,  # More frequent evaluation (from 300)
    save_steps=200,  # More frequent saving (from 300)
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_sari",
    greater_is_better=True,
    label_smoothing_factor=0.1,
    bf16=True,
    gradient_checkpointing=True,
    report_to="none",
)

print("Training arguments configured")


Training arguments configured


## 8. Custom Trainer with SARI-based Early Stopping

Early stopping patience: 2 epochs, monitoring validation SARI


In [12]:
class CustomTrainer(Trainer):
    """Custom trainer that computes SARI for early stopping"""

    def __init__(self, *args, source_sentences=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.source_sentences = source_sentences

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """Override evaluate to include SARI"""
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        # Get predictions
        predictions = self.predict(eval_dataset)
        pred_texts = tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True)
        label_texts = tokenizer.batch_decode(predictions.label_ids, skip_special_tokens=True)

        # Compute BERTScore
        bertscore_result = bertscore_metric.compute(
            predictions=pred_texts,
            references=label_texts,
            lang="am",
            device=device
        )

        # Compute SARI if we have source sentences
        sari_scores = []
        if self.source_sentences is not None:
            for i, (source, pred, ref) in enumerate(zip(self.source_sentences, pred_texts, label_texts)):
                try:
                    sari = sari_metric.compute(
                        sources=[source],
                        predictions=[pred],
                        references=[[ref]]
                    )
                    sari_scores.append(sari["sari"])
                except:
                    pass

        metrics = {
            f"{metric_key_prefix}_bertscore_f1": np.mean(bertscore_result["f1"]),
        }

        # Always compute SARI if source sentences are available
        if sari_scores:
            metrics[f"{metric_key_prefix}_sari"] = np.mean(sari_scores)
        else:
            # If SARI couldn't be computed, set a default value to avoid errors
            metrics[f"{metric_key_prefix}_sari"] = 0.0

        self.log(metrics)
        return metrics

# Prepare source sentences for validation set (for SARI computation)
val_source_sentences = [item["legal_sentence"] for item in val_data]

# Initialize trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    source_sentences=val_source_sentences,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

print("Trainer initialized with SARI-based early stopping")


Trainer initialized with SARI-based early stopping


## 9. Train Model


In [13]:
import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model.gradient_checkpointing_enable()
model.config.use_cache = False

# Train
print("Starting training...")
train_result = trainer.train()
print("Training completed!")

# Save final model
trainer.save_model("./afribyt5-legal-simplification-final4")
tokenizer.save_pretrained("./afribyt5-legal-simplification-final4")


Starting training...


Step,Training Loss,Validation Loss


Training completed!


('./afribyt5-legal-simplification-final4/tokenizer_config.json',
 './afribyt5-legal-simplification-final4/special_tokens_map.json',
 './afribyt5-legal-simplification-final4/added_tokens.json')

## 10. Training Visualizations

Visualize training progress, validation metrics, and model performance.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Extract training history from trainer state
train_history = trainer.state.log_history

# Separate training and evaluation logs
train_logs = [log for log in train_history if 'loss' in log and 'eval_loss' not in log]
eval_logs = [log for log in train_history if 'eval_loss' in log]

# Extract data
train_steps = [log['step'] for log in train_logs]
train_losses = [log['loss'] for log in train_logs]

eval_steps = [log['step'] for log in eval_logs]
eval_losses = [log.get('eval_loss', 0) for log in eval_logs]
eval_sari = [log.get('eval_sari', 0) for log in eval_logs]
eval_bertscore = [log.get('eval_bertscore_f1', 0) for log in eval_logs]

# Create comprehensive figure
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

# Plot 1: Training Loss (top left, spans 2 columns)
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(train_steps, train_losses, 'b-', linewidth=2.5, label='Training Loss', alpha=0.8)
ax1.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
ax1.set_ylabel('Loss', fontsize=13, fontweight='bold')
ax1.set_title('Training Loss Over Time', fontsize=15, fontweight='bold', pad=15)
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.legend(fontsize=12, loc='best')
ax1.set_facecolor('#f8f9fa')

# Plot 2: Validation Loss
ax2 = fig.add_subplot(gs[1, 0])
if eval_losses and any(l > 0 for l in eval_losses):
    ax2.plot(eval_steps, eval_losses, 'r-', linewidth=2.5, marker='o', markersize=5, label='Validation Loss', alpha=0.8)
    ax2.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax2.set_title('Validation Loss', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.legend(fontsize=11)
    ax2.set_facecolor('#f8f9fa')
else:
    ax2.text(0.5, 0.5, 'No validation data', ha='center', va='center', fontsize=12)
    ax2.set_title('Validation Loss', fontsize=14, fontweight='bold')

# Plot 3: SARI Score
ax3 = fig.add_subplot(gs[1, 1])
if eval_sari and any(s > 0 for s in eval_sari):
    ax3.plot(eval_steps, eval_sari, 'g-', linewidth=2.5, marker='s', markersize=5, label='SARI', alpha=0.8)
    ax3.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax3.set_ylabel('SARI Score', fontsize=12, fontweight='bold')
    ax3.set_title('SARI Score (Simplification Quality)', fontsize=14, fontweight='bold')
    ax3.grid(True, alpha=0.3, linestyle='--')
    ax3.legend(fontsize=11)
    ax3.set_ylim(bottom=0)
    ax3.set_facecolor('#f8f9fa')
else:
    ax3.text(0.5, 0.5, 'No SARI data', ha='center', va='center', fontsize=12)
    ax3.set_title('SARI Score', fontsize=14, fontweight='bold')

# Plot 4: BERTScore F1
ax4 = fig.add_subplot(gs[2, 0])
if eval_bertscore and any(b > 0 for b in eval_bertscore):
    ax4.plot(eval_steps, eval_bertscore, 'm-', linewidth=2.5, marker='^', markersize=5, label='BERTScore F1', alpha=0.8)
    ax4.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax4.set_ylabel('BERTScore F1', fontsize=12, fontweight='bold')
    ax4.set_title('BERTScore F1 (Semantic Similarity)', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3, linestyle='--')
    ax4.legend(fontsize=11)
    ax4.set_ylim(bottom=0, top=1)
    ax4.set_facecolor('#f8f9fa')
else:
    ax4.text(0.5, 0.5, 'No BERTScore data', ha='center', va='center', fontsize=12)
    ax4.set_title('BERTScore F1', fontsize=14, fontweight='bold')

# Plot 5: Combined Loss Plot (Training vs Validation)
ax5 = fig.add_subplot(gs[2, 1])
if eval_losses and any(l > 0 for l in eval_losses):
    ax5.plot(train_steps, train_losses, 'b-', linewidth=2, label='Training Loss', alpha=0.7)
    ax5.plot(eval_steps, eval_losses, 'r-', linewidth=2, marker='o', markersize=4, label='Validation Loss', alpha=0.7)
    ax5.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax5.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax5.set_title('Training vs Validation Loss', fontsize=14, fontweight='bold')
    ax5.grid(True, alpha=0.3, linestyle='--')
    ax5.legend(fontsize=11)
    ax5.set_facecolor('#f8f9fa')
else:
    ax5.plot(train_steps, train_losses, 'b-', linewidth=2, label='Training Loss', alpha=0.7)
    ax5.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax5.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax5.set_title('Training Loss', fontsize=14, fontweight='bold')
    ax5.grid(True, alpha=0.3, linestyle='--')
    ax5.legend(fontsize=11)
    ax5.set_facecolor('#f8f9fa')

# Add overall title
fig.suptitle('Training Progress: AfriByT5 Legal Simplification', fontsize=18, fontweight='bold', y=0.995)

plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

# Save high-resolution plot
plot_path = "./afribyt5-legal-simplification_training_plots.png"
fig.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\nTraining plots saved to: {plot_path}")

# Print detailed summary statistics
print("\n" + "="*70)
print("TRAINING SUMMARY: AfriByT5 Legal Simplification")
print("="*70)
if train_losses:
    print(f"\nTraining Loss:")
    print(f"  Initial: {train_losses[0]:.4f}")
    print(f"  Final: {train_losses[-1]:.4f}")
    print(f"  Best: {min(train_losses):.4f} (at step {train_steps[train_losses.index(min(train_losses))]})")
    print(f"  Improvement: {((train_losses[0] - min(train_losses)) / train_losses[0] * 100):.2f}%")
if eval_losses and any(l > 0 for l in eval_losses):
    valid_losses = [l for l in eval_losses if l > 0]
    valid_steps = [eval_steps[i] for i, l in enumerate(eval_losses) if l > 0]
    print(f"\nValidation Loss:")
    print(f"  Initial: {valid_losses[0]:.4f}")
    print(f"  Final: {valid_losses[-1]:.4f}")
    print(f"  Best: {min(valid_losses):.4f} (at step {valid_steps[valid_losses.index(min(valid_losses))]})")
    print(f"  Improvement: {((valid_losses[0] - min(valid_losses)) / valid_losses[0] * 100):.2f}%")
if eval_sari and any(s > 0 for s in eval_sari):
    valid_sari = [s for s in eval_sari if s > 0]
    valid_steps = [eval_steps[i] for i, s in enumerate(eval_sari) if s > 0]
    print(f"\nSARI Score:")
    print(f"  Initial: {valid_sari[0]:.4f}")
    print(f"  Final: {valid_sari[-1]:.4f}")
    print(f"  Best: {max(valid_sari):.4f} (at step {valid_steps[valid_sari.index(max(valid_sari))]})")
    print(f"  Improvement: {((max(valid_sari) - valid_sari[0]) / valid_sari[0] * 100):.2f}%")
if eval_bertscore and any(b > 0 for b in eval_bertscore):
    valid_bertscore = [b for b in eval_bertscore if b > 0]
    valid_steps = [eval_steps[i] for i, b in enumerate(eval_bertscore) if b > 0]
    print(f"\nBERTScore F1:")
    print(f"  Initial: {valid_bertscore[0]:.4f}")
    print(f"  Final: {valid_bertscore[-1]:.4f}")
    print(f"  Best: {max(valid_bertscore):.4f} (at step {valid_steps[valid_bertscore.index(max(valid_bertscore))]})")
    print(f"  Improvement: {((max(valid_bertscore) - valid_bertscore[0]) / valid_bertscore[0] * 100):.2f}%")
print("="*70)

In [14]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [15]:
!zip -r afribyt5-legal-simplification-final4.zip afribyt5-legal-simplification-final4
!cp afribyt5-legal-simplification-final4.zip /content/drive/MyDrive/

  adding: afribyt5-legal-simplification-final4/ (stored 0%)
  adding: afribyt5-legal-simplification-final4/adapter_model.safetensors (deflated 8%)
  adding: afribyt5-legal-simplification-final4/special_tokens_map.json (deflated 86%)
  adding: afribyt5-legal-simplification-final4/added_tokens.json (deflated 82%)
  adding: afribyt5-legal-simplification-final4/adapter_config.json (deflated 57%)
  adding: afribyt5-legal-simplification-final4/tokenizer_config.json (deflated 95%)
  adding: afribyt5-legal-simplification-final4/training_args.bin (deflated 53%)
  adding: afribyt5-legal-simplification-final4/README.md (deflated 66%)


In [16]:
# Create a truly unseen test set
print("=== Creating Clean Test Set ===")

# Step 1: Get all sentences that were used in training/validation
train_sentences = set([item["legal_sentence"] for item in train_data])
val_sentences = set([item["legal_sentence"] for item in val_data])
all_seen_sentences = train_sentences.union(val_sentences)

print(f"Sentences in train: {len(train_sentences)}")
print(f"Sentences in val: {len(val_sentences)}")
print(f"Total unique sentences seen during training: {len(all_seen_sentences)}")

# Step 2: From the ORIGINAL data list, find items NOT in train or val
unseen_data = []
for item in data:  # Go through ALL original data
    if item["legal_sentence"] not in all_seen_sentences:
        unseen_data.append(item)

print(f"\nOriginal data size: {len(data)}")
print(f"Unseen data available: {len(unseen_data)}")

# Step 3: Take first 100 (or available) from truly unseen data
clean_test_size = min(100, len(unseen_data))
clean_test_data = unseen_data[:clean_test_size]

print(f"Created clean test set with {len(clean_test_data)} samples")
print(f"   (These samples were NEVER in train or val)")

# Step 4: Verify no overlap
clean_test_sentences = set([item["legal_sentence"] for item in clean_test_data])
overlap = all_seen_sentences.intersection(clean_test_sentences)

if overlap:
    print(f"ERROR: Still found {len(overlap)} overlapping sentences!")
    print("This shouldn't happen. Check for duplicate entries in original data.")
else:
    print("Verified: Clean test set has ZERO overlap with train or val!")

=== Creating Clean Test Set ===
Sentences in train: 1667
Sentences in val: 200
Total unique sentences seen during training: 1857


NameError: name 'data' is not defined

## 10. Evaluation on Test Set

Evaluate on the held-out test set (100 samples) that was kept unseen during training.


In [None]:
# Prepare CLEAN test dataset (truly unseen)
clean_test_source_sentences = [item["legal_sentence"] for item in clean_test_data]
clean_test_target_sentences = [item["simplified_sentence"] for item in clean_test_data]
# Keep simplification_type for analysis
clean_test_simplification_types = [item.get("simplification_type", "unknown") for item in clean_test_data]

# Add instruction prompt with simplification_type for inference (same as training)
base_instruction = "የሕግ ቃላትን ለግለሰቦች ለመረዳት ቀላል አማርኛ ውስጥ አቅርብ: "
type_map = {
    "vocabulary_simplification": "[የቃላት ማቃለል]",
    "sentence_splitting": "[የዓረፍተ ነገር መከፋፈል]",
    "deletion": "[መሻር]",
    "unknown": "[አጠቃላይ]"
}
prompted_test_sentences = [
    base_instruction + type_map.get(sim_type, f"[{sim_type}]") + " " + sent
    for sent, sim_type in zip(clean_test_source_sentences, clean_test_simplification_types)
]

# Tokenize inputs for inference
clean_test_inputs = tokenizer(
    prompted_test_sentences,
    max_length=max_input_length,
    truncation=True,
    padding=True,
    return_tensors="pt"
).to(device)

# Set model to eval mode
model.eval()

# Generate predictions directly (bypassing trainer.predict)
print("Evaluating on CLEAN test set (truly unseen)...")
clean_test_pred_texts = []

with torch.no_grad():
    for i in range(0, len(clean_test_source_sentences), 8):  # Process in batches of 8
        batch_input_ids = clean_test_inputs["input_ids"][i:i+8]
        batch_attention_mask = clean_test_inputs["attention_mask"][i:i+8]

        # Generate with improved parameters
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_new_tokens=384,  # Use max_new_tokens instead of max_length (tokens, not bytes)
            num_beams=4,
            early_stopping=True,
            do_sample=False,
            repetition_penalty=1.3,  # Penalize repetition (decoding fix)
            no_repeat_ngram_size=3,  # Prevent 3-gram repetition (decoding fix)
            length_penalty=0.7  # Encourage shorter outputs (reduced from 1.1 to help with truncation)
        )

        # Decode batch
        batch_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        clean_test_pred_texts.extend(batch_preds)

        if (i // 8 + 1) % 10 == 0:
            print(f"Processed {min(i+8, len(clean_test_source_sentences))}/{len(clean_test_source_sentences)} samples...")

print(f"Generated {len(clean_test_pred_texts)} predictions")

# Use original target sentences for metrics
clean_test_label_texts = clean_test_target_sentences

# Compute metrics
print("\n=== CLEAN Test Set Results (Truly Unseen) ===")

# BERTScore
bertscore_clean = bertscore_metric.compute(
    predictions=clean_test_pred_texts,
    references=clean_test_label_texts,
    lang="am",
    device=device
)
print(f"BERTScore F1: {np.mean(bertscore_clean['f1']):.4f}")

# SARI
sari_scores_clean = []
for source, pred, ref in zip(clean_test_source_sentences, clean_test_pred_texts, clean_test_label_texts):
    try:
        sari = sari_metric.compute(
            sources=[source],
            predictions=[pred],
            references=[[ref]]
        )
        sari_scores_clean.append(sari["sari"])
    except Exception as e:
        print(f"SARI computation error: {e}")
        pass

if sari_scores_clean:
    print(f"SARI: {np.mean(sari_scores_clean):.4f}")

# BLEU Score
bleu_result = bleu_metric.compute(
    predictions=clean_test_pred_texts,
    references=[[ref] for ref in clean_test_label_texts]
)
print(f"BLEU: {bleu_result['score']:.4f}")

# Exact Match
exact_matches = sum(1 for pred, ref in zip(clean_test_pred_texts, clean_test_label_texts) if pred == ref)
exact_match_rate = exact_matches / len(clean_test_pred_texts)
print(f"Exact Match: {exact_match_rate:.4f} ({exact_matches}/{len(clean_test_pred_texts)})")

# Length Statistics
source_lengths = [len(s) for s in clean_test_source_sentences]
pred_lengths = [len(p) for p in clean_test_pred_texts]
ref_lengths = [len(r) for r in clean_test_label_texts]

print(f"\n=== Length Statistics ===")
print(f"Source (original):")
print(f"  Mean: {np.mean(source_lengths):.1f} chars, Median: {np.median(source_lengths):.1f} chars")
print(f"  Min: {np.min(source_lengths)} chars, Max: {np.max(source_lengths)} chars")
print(f"\nPrediction (simplified):")
print(f"  Mean: {np.mean(pred_lengths):.1f} chars, Median: {np.median(pred_lengths):.1f} chars")
print(f"  Min: {np.min(pred_lengths)} chars, Max: {np.max(pred_lengths)} chars")
print(f"\nReference (target):")
print(f"  Mean: {np.mean(ref_lengths):.1f} chars, Median: {np.median(ref_lengths):.1f} chars")
print(f"  Min: {np.min(ref_lengths)} chars, Max: {np.max(ref_lengths)} chars")

# Length ratio (prediction vs source)
length_ratios = [p/s if s > 0 else 0 for p, s in zip(pred_lengths, source_lengths)]
print(f"\nLength Ratio (Prediction/Source):")
print(f"  Mean: {np.mean(length_ratios):.3f}, Median: {np.median(length_ratios):.3f}")
print(f"  (Values < 1.0 indicate shortening)")

print(f"\nEvaluated {len(clean_test_pred_texts)} CLEAN test samples (truly unseen)")


In [None]:
# Qualitative Analysis - Check if model is actually simplifying
print("=== Qualitative Analysis: Is the model really simplifying? ===\n")

num_samples = 10
for i in range(min(num_samples, len(clean_test_source_sentences))):
    print(f"Sample {i+1}:")
    print(f"Original Legal ({len(clean_test_source_sentences[i])} chars):")
    print(f"  {clean_test_source_sentences[i]}")
    print(f"\nReference Simplified ({len(clean_test_target_sentences[i])} chars):")
    print(f"  {clean_test_target_sentences[i]}")
    print(f"\nModel Prediction ({len(clean_test_pred_texts[i])} chars):")
    print(f"  {clean_test_pred_texts[i]}")

    # Check similarity
    if clean_test_pred_texts[i] == clean_test_target_sentences[i]:
        print("EXACT MATCH - Model might be memorizing!")
    elif clean_test_pred_texts[i] == clean_test_source_sentences[i]:
        print("NO CHANGE - Model not simplifying!")

    print("-" * 80)
    print()

In [None]:
# Analyze performance by simplification_type
print("=== Performance Analysis by Simplification Type ===\n")

from collections import defaultdict

# Group predictions by simplification_type
type_metrics = defaultdict(lambda: {"sari": [], "bertscore": [], "count": 0, "truncation": 0, "repetition": 0})

for i, (source, pred, ref, sim_type) in enumerate(zip(
    clean_test_source_sentences,
    clean_test_pred_texts,
    clean_test_label_texts,
    clean_test_simplification_types
)):
    sim_type = sim_type if sim_type else "unknown"
    type_metrics[sim_type]["count"] += 1

    # Compute SARI for this sample
    try:
        sari = sari_metric.compute(
            sources=[source],
            predictions=[pred],
            references=[[ref]]
        )
        type_metrics[sim_type]["sari"].append(sari["sari"])
    except:
        pass

    # Compute BERTScore for this sample
    try:
        bert = bertscore_metric.compute(
            predictions=[pred],
            references=[ref],
            lang="am",
            device=device
        )
        type_metrics[sim_type]["bertscore"].append(np.mean(bert["f1"]))
    except:
        pass

    # Check truncation
    if pred and pred[-1] not in ['።', '፤', '፥', '፦', '.', '!', '?']:
        type_metrics[sim_type]["truncation"] += 1

    # Check repetition
    words = pred.split()
    if len(words) > 3:
        for j in range(len(words) - 2):
            phrase = " ".join(words[j:j+3])
            if pred.count(phrase) > 1:
                type_metrics[sim_type]["repetition"] += 1
                break

# Print results by type
print("Performance breakdown by simplification type:\n")
for sim_type in sorted(type_metrics.keys()):
    metrics = type_metrics[sim_type]
    count = metrics["count"]

    print(f"--- {sim_type.upper()} ({count} samples) ---")
    if metrics["sari"]:
        print(f"  Mean SARI: {np.mean(metrics['sari']):.4f}")
    if metrics["bertscore"]:
        print(f"  Mean BERTScore F1: {np.mean(metrics['bertscore']):.4f}")
    print(f"  Truncation: {metrics['truncation']}/{count} ({100*metrics['truncation']/count:.1f}%)")
    print(f"  Repetition: {metrics['repetition']}/{count} ({100*metrics['repetition']/count:.1f}%)")
    print()

print(f"\nTotal types analyzed: {len(type_metrics)}")


In [None]:
# Check for repetition and truncation issues
print("=== Model Quality Issues ===")

repetition_count = 0
truncation_count = 0
no_simplification = 0

for i, (orig, pred, ref) in enumerate(zip(clean_test_source_sentences, clean_test_pred_texts, clean_test_target_sentences)):
    # Check for repetition (same phrase appears 2+ times)
    words = pred.split()
    if len(words) > 3:
        # Check for 3+ word phrases repeating
        for j in range(len(words) - 2):
            phrase = " ".join(words[j:j+3])
            if pred.count(phrase) > 1:
                repetition_count += 1
                if repetition_count <= 3:
                    print(f"Repetition in sample {i+1}: '{phrase}'")
                break

    # Check for truncation (ends abruptly, not with punctuation)
    if pred and pred[-1] not in ['።', '፤', '፥', '፦', '.', '!', '?']:
        truncation_count += 1

    # Check if actually simplified (should be shorter)
    if len(pred) >= len(orig) * 0.9:  # Less than 10% reduction
        no_simplification += 1

print(f"\nRepetition issues: {repetition_count}/{len(clean_test_pred_texts)} ({100*repetition_count/len(clean_test_pred_texts):.1f}%)")
print(f"Truncation issues: {truncation_count}/{len(clean_test_pred_texts)} ({100*truncation_count/len(clean_test_pred_texts):.1f}%)")
print(f"No simplification: {no_simplification}/{len(clean_test_pred_texts)} ({100*no_simplification/len(clean_test_pred_texts):.1f}%)")

## 11. Qualitative Evaluation

Automatic metrics are complemented with qualitative evaluation on held-out legal sentences to verify preservation of legal meaning.


In [None]:
# Show sample predictions for qualitative evaluation
print("=== Sample Predictions (Qualitative Evaluation) ===\n")

num_samples = 5
for i in range(min(num_samples, len(test_data))):
    print(f"Sample {i+1}:")
    print(f"Original Legal: {test_source_sentences[i]}")
    print(f"Reference Simplified: {test_target_sentences[i]}")
    print(f"Model Prediction: {test_pred_texts[i]}")
    print("-" * 80)
    print()


## 12. Inference Function

Function to use the fine-tuned model for inference.


In [None]:
def simplify_legal_text(legal_sentence, model, tokenizer, simplification_type=None, max_input_length=512, max_output_length=384):
    # ...
    base_instruction = "የሕግ ቃላትን ለግለሰቦች ለመረዳት ቀላል አማርኛ ውስጥ አቅርብ: "
    type_map = {
        "vocabulary_simplification": "[የቃላት ማቃለል]",
        "sentence_splitting": "[የዓረፍተ ነገር መከፋፈል]",
        "deletion": "[መሻር]",
        "unknown": "[አጠቃላይ]",
        None: "[አጠቃላይ]"
    }
    type_label = type_map.get(simplification_type, "[አጠቃላይ]")
    prompted_input = base_instruction + type_label + " " + legal_sentence

    # Tokenize input with prompt
    inputs = tokenizer(
        prompted_input,
        max_length=max_input_length,
        truncation=True,
        padding=True,
        return_tensors="pt"
    ).to(device)

    # Generate with improved parameters
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_output_length,  # Use max_new_tokens instead of max_length
            num_beams=4,
            early_stopping=True,
            repetition_penalty=1.3,  # Penalize repetition
            no_repeat_ngram_size=3,  # Prevent 3-gram repetition
            length_penalty=0.7  # Encourage shorter outputs (reduced from 1.1)
        )

    # Decode
    simplified = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return simplified

# Test inference
test_sentence = test_source_sentences[0]
simplified = simplify_legal_text(test_sentence, model, tokenizer)
print(f"Original: {test_sentence}")
print(f"Simplified: {simplified}")


## Summary

This notebook fine-tuned AfriByT5-Base for Amharic legal text simplification with:

- **Model**: AfriByT5-Base (byte-level, African-language priors)
- **LoRA**: Regularization mechanism (not freezing, end-to-end adaptation)
  - **Capacity**: r=32, alpha=64 (increased from r=16, alpha=32)
  - **Target modules**: ["q", "v", "k", "o"] (expanded from ["q", "v"])
- **Data Split**: 1,700 train / 200 validation / 82 test (truly unseen)
- **Training**:
  - **Epochs**: 6 (increased from 4)
  - **Gradient accumulation**: 8 (effective batch size: 64)
  - **LR scheduler**: Cosine with warmup_ratio=0.1
  - **Early stopping**: Patience 2 epochs, monitoring validation SARI
- **Label Smoothing**: 0.1
- **Sequence Lengths**: 512 tokens input / 384 tokens output (increased from 256)
- **Generation**:
  - **Repetition penalty**: 1.3
  - **No-repeat n-gram size**: 3
  - **Length penalty**: 1.1
- **Metrics**: SARI (primary), BERTScore multilingual (secondary)
- **Qualitative Evaluation**: Manual review of predictions to verify legal meaning preservation

The model learns simplification behavior, while legal knowledge is supplied at inference time via RAG.
