In [1]:
%pip install -q transformers datasets peft accelerate bitsandbytes evaluate bert-score sacrebleu torch
%pip install -q sentencepiece sacremoses nltk


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset as HFDataset
import numpy as np
from evaluate import load
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
# Download NLTK data (required for SARI metric)
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab', quiet=True)
print("NLTK data ready")

NLTK data ready


## 1. Load and Prepare Dataset


In [4]:
# Load pre-split datasets (created using create_fixed_splits.py)
# This ensures consistency with Model 4 which uses the same splits
# Note: simplification_type field is ignored - only using legal_sentence and simplified_sentence
with open('/kaggle/input/legal-sentence-simplifier/regen_train.json', 'r', encoding='utf-8') as f:
    train_data_raw = json.load(f)

with open('/kaggle/input/legal-sentence-simplifier/regen_val.json', 'r', encoding='utf-8') as f:
    val_data_raw = json.load(f)

with open('/kaggle/input/legal-sentence-simplifier/regen_test.json', 'r', encoding='utf-8') as f:
    test_data_raw = json.load(f)

# Extract only legal_sentence and simplified_sentence (ignore simplification_type)
train_data = [
    {
        "legal_sentence": item["legal_sentence"],
        "simplified_sentence": item["simplified_sentence"]
    }
    for item in train_data_raw
]

val_data = [
    {
        "legal_sentence": item["legal_sentence"],
        "simplified_sentence": item["simplified_sentence"]
    }
    for item in val_data_raw
]

test_data = [
    {
        "legal_sentence": item["legal_sentence"],
        "simplified_sentence": item["simplified_sentence"]
    }
    for item in test_data_raw
]

print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")
print(f"Test: {len(test_data)} samples")
print(f"\nSample entry:")
print(json.dumps(train_data[0], ensure_ascii=False, indent=2))
print("\nUsing fixed splits for consistency with Model 4")
print("Note: simplification_type field is ignored - only using legal_sentence and simplified_sentence for training")


Train: 1700 samples
Validation: 200 samples
Test: 100 samples

Sample entry:
{
  "legal_sentence": "1596. እንደአግባቡ ማመልከቻውን ያቀረበው ሰው ወይም የከሰረው ሰው ንብረት ጠባቂ ተቆጣጣሪ ዳኛው በሰጠው ውሳኔ ላይ ውሳኔው በተሰጠ በአስር ቀናት ውስጥ ለፍርድ ቤት ይግባኝ ማቅረብ ይችላል ።",
  "simplified_sentence": "ባለቤቱ ወይም የንብረት ጠባቂው ዳኛው በሰጠው ውሳኔ ካልተስማሙ በ10 ቀናት ውስጥ ለፍርድ ቤት ይግባኝ ማለት ይችላሉ።"
}

Using fixed splits for consistency with Model 4
Note: simplification_type field is ignored - only using legal_sentence and simplified_sentence for training


## 2. Dataset Split


In [5]:
# Data already loaded from pre-split files in Cell 4
# No random splitting needed - using fixed splits for consistency
print(f"Train: {len(train_data)}")
print(f"Validation: {len(val_data)}")
print(f"Test: {len(test_data)}")
print("\nUsing fixed splits (no random splitting)")


Train: 1700
Validation: 200
Test: 100

Using fixed splits (no random splitting)


## 3. Load Model and Tokenizer

**Model**: `masakhane/afribyt5-base`  
**Sequence Lengths**:
- Max input length: 512 bytes
- Max output length: 256 bytes


In [6]:
model_name = "masakhane/afri-byt5-base"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Sequence length settings for byte-level ByT5
max_input_length = 512  # tokens
max_output_length = 384  # tokens (increased from 256 to reduce truncation)

print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Max input length: {max_input_length} tokens")
print(f"Max output length: {max_output_length} tokens")


tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

Tokenizer vocab size: 256
Max input length: 512 tokens
Max output length: 384 tokens


In [7]:
# Load model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
print(f"Model loaded: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")


config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.33G [00:00<?, ?B/s]

Model loaded: masakhane/afri-byt5-base
Model parameters: 581,653,248


## 4. Configure LoRA

LoRA is applied to encoder and decoder linear layers as a regularization mechanism to reduce overfitting on a small dataset, while still allowing end-to-end adaptation.


In [8]:
# LoRA configuration
# Increased capacity: r=32, alpha=64 to address underfitting risk
# Expanded target_modules to include all attention components for better learning
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=32,  # Increased from 16
    lora_alpha=64,  # Increased from 32 (proportional to r)
    lora_dropout=0.1,
    target_modules=["q", "v", "k", "o"],  # Expanded from ["q", "v"] to include all attention components
    bias="none",
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Enable gradient requirements for PEFT models
if hasattr(model, 'enable_input_require_grads'):
    model.enable_input_require_grads()

# Move model to device
model = model.to(device)
print(f"Model moved to {device}")


trainable params: 8,847,360 || all params: 590,500,608 || trainable%: 1.4983
Model moved to cuda


## 5. Preprocess Data


In [9]:
def preprocess_function(examples):
    """Tokenize and prepare inputs/targets with simplification instruction"""
    inputs = examples["legal_sentence"] if isinstance(examples["legal_sentence"], list) else [examples["legal_sentence"]]
    targets = examples["simplified_sentence"] if isinstance(examples["simplified_sentence"], list) else [examples["simplified_sentence"]]

    # Add instruction prompt to encourage simplification
    # This teaches the model that simplification is the expected task
    instruction = "የሕግ ቃላትን ለግለሰቦች ለመረዳት ቀላል አማርኛ ውስጥ አቅርብ: "  # "Simplify legal text to plain Amharic: "
    inputs = [instruction + inp for inp in inputs]

    # Tokenize inputs (no padding - let data collator handle it)
    model_inputs = tokenizer(
        inputs,
        max_length=max_input_length,
        truncation=True
    )

    # Tokenize targets (no padding - let data collator handle it)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=max_output_length,
            truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Convert to HuggingFace datasets
train_dataset = HFDataset.from_list(train_data)
val_dataset = HFDataset.from_list(val_data)
test_dataset = HFDataset.from_list(test_data)

# Tokenize datasets
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)
val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names
)

print("Datasets tokenized and ready")


Map:   0%|          | 0/1700 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Datasets tokenized and ready


## 6. Setup Metrics

**Primary Metric**: SARI (System output Against References and Inputs)  
**Secondary Metric**: BERTScore (multilingual)

Automatic metrics are complemented with qualitative evaluation on held-out legal sentences to verify preservation of legal meaning.


In [10]:
def compute_metrics(eval_pred):
    """Compute metrics - currently disabled"""
    # Return empty dict - no metrics computed during training
    # This speeds up validation significantly
    return {}



Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

## 7. Training Configuration

**Key Settings**:
- Label smoothing: 0.1
- Early stopping: patience 2 epochs, monitor validation SARI
- Learning rate: 2e-4
- Batch size: 8
- Gradient accumulation: 8 (effective batch size: 64)
- Epochs: 6 (increased from 4)
- LR scheduler: Cosine with warmup_ratio=0.1


In [11]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

# Training arguments
training_args = TrainingArguments(
    output_dir="/kaggle/working/afribyt5-legal-simplification",
    overwrite_output_dir=True,
    num_train_epochs=6,  # Increased from 4 for better learning
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=8,  # Increased from 4 (effective batch size: 64)
    learning_rate=2e-4,  # Keep current (don't lower unless instability appears)
    lr_scheduler_type="cosine",  # Added: cosine decay for better convergence
    warmup_ratio=0.1,  # Added: warmup ratio (replaces warmup_steps)
    logging_steps=50,
    eval_steps=40,  # More frequent evaluation (from 300)
    save_steps=40,  # More frequent saving (from 300)
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # Changed from eval_sari
    greater_is_better=True,
    label_smoothing_factor=0.1,
    bf16=True,
    gradient_checkpointing=True,
    report_to="none",
)

print("Training arguments configured")


Training arguments configured


## 8. Custom Trainer with SARI-based Early Stopping

Early stopping patience: 2 epochs, monitoring validation SARI


Trainer initialized with SARI-based early stopping


## 9. Train Model


In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

print(f"Trainer initialized with {len(train_dataset)} training samples and {len(val_dataset)} validation samples")


In [13]:
import torch
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model.gradient_checkpointing_enable()
model.config.use_cache = False

# Train
print("Starting training...")
train_result = trainer.train()
print("Training completed!")

# Save final model
trainer.save_model("/kaggle/working/afribyt5-legal-simplification-final3")
tokenizer.save_pretrained("/kaggle/working/afribyt5-legal-simplification-final3")


Starting training...


Step,Training Loss,Validation Loss


Training completed!


('./afribyt5-legal-simplification-final3/tokenizer_config.json',
 './afribyt5-legal-simplification-final3/special_tokens_map.json',
 './afribyt5-legal-simplification-final3/added_tokens.json')

## 10. Training Visualizations

Visualize training progress, validation metrics, and model performance.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Extract training history from trainer state
train_history = trainer.state.log_history

# Separate training and evaluation logs
train_logs = [log for log in train_history if 'loss' in log and 'eval_loss' not in log]
eval_logs = [log for log in train_history if 'eval_loss' in log]

# Extract data
train_steps = [log['step'] for log in train_logs]
train_losses = [log['loss'] for log in train_logs]

eval_steps = [log['step'] for log in eval_logs]
eval_losses = [log.get('eval_loss', 0) for log in eval_logs]
eval_sari = [log.get('eval_sari', 0) for log in eval_logs]

# Create comprehensive figure
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

# Plot 1: Training Loss (top left, spans 2 columns)
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(train_steps, train_losses, 'b-', linewidth=2.5, label='Training Loss', alpha=0.8)
ax1.set_xlabel('Training Steps', fontsize=13, fontweight='bold')
ax1.set_ylabel('Loss', fontsize=13, fontweight='bold')
ax1.set_title('Training Loss Over Time', fontsize=15, fontweight='bold', pad=15)
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.legend(fontsize=12, loc='best')
ax1.set_facecolor('#f8f9fa')

# Plot 2: Validation Loss
ax2 = fig.add_subplot(gs[1, 0])
if eval_losses and any(l > 0 for l in eval_losses):
    ax2.plot(eval_steps, eval_losses, 'r-', linewidth=2.5, marker='o', markersize=5, label='Validation Loss', alpha=0.8)
    ax2.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax2.set_title('Validation Loss', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3, linestyle='--')
    ax2.legend(fontsize=11)
    ax2.set_facecolor('#f8f9fa')
else:
    ax2.text(0.5, 0.5, 'No validation data', ha='center', va='center', fontsize=12)
    ax2.set_title('Validation Loss', fontsize=14, fontweight='bold')

# Plot 3: SARI Score
ax3 = fig.add_subplot(gs[1, 1])
if eval_sari and any(s > 0 for s in eval_sari):
    ax3.plot(eval_steps, eval_sari, 'g-', linewidth=2.5, marker='s', markersize=5, label='SARI', alpha=0.8)
    ax3.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax3.set_ylabel('SARI Score', fontsize=12, fontweight='bold')
    ax3.set_title('SARI Score (Simplification Quality)', fontsize=14, fontweight='bold')
    ax3.grid(True, alpha=0.3, linestyle='--')
    ax3.legend(fontsize=11)
    ax3.set_ylim(bottom=0)
    ax3.set_facecolor('#f8f9fa')
else:
    ax3.text(0.5, 0.5, 'No SARI data', ha='center', va='center', fontsize=12)
    ax3.set_title('SARI Score', fontsize=14, fontweight='bold')

else:
    ax4.text(0.5, 0.5, 'No BERTScore data', ha='center', va='center', fontsize=12)

# Plot 5: Combined Loss Plot (Training vs Validation)
ax5 = fig.add_subplot(gs[2, 1])
if eval_losses and any(l > 0 for l in eval_losses):
    ax5.plot(train_steps, train_losses, 'b-', linewidth=2, label='Training Loss', alpha=0.7)
    ax5.plot(eval_steps, eval_losses, 'r-', linewidth=2, marker='o', markersize=4, label='Validation Loss', alpha=0.7)
    ax5.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax5.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax5.set_title('Training vs Validation Loss', fontsize=14, fontweight='bold')
    ax5.grid(True, alpha=0.3, linestyle='--')
    ax5.legend(fontsize=11)
    ax5.set_facecolor('#f8f9fa')
else:
    ax5.plot(train_steps, train_losses, 'b-', linewidth=2, label='Training Loss', alpha=0.7)
    ax5.set_xlabel('Training Steps', fontsize=12, fontweight='bold')
    ax5.set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax5.set_title('Training Loss', fontsize=14, fontweight='bold')
    ax5.grid(True, alpha=0.3, linestyle='--')
    ax5.legend(fontsize=11)
    ax5.set_facecolor('#f8f9fa')

# Add overall title
fig.suptitle('Training Progress: AfriByT5 Legal Simplification', fontsize=18, fontweight='bold', y=0.995)

plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.show()

# Save high-resolution plot
plot_path = "/kaggle/working/afribyt5-legal-simplification_training_plots.png"
fig.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"\nTraining plots saved to: {plot_path}")

# Print detailed summary statistics
print("\n" + "="*70)
print("TRAINING SUMMARY: AfriByT5 Legal Simplification")
print("="*70)
if train_losses:
    print(f"\nTraining Loss:")
    print(f"  Initial: {train_losses[0]:.4f}")
    print(f"  Final: {train_losses[-1]:.4f}")
    print(f"  Best: {min(train_losses):.4f} (at step {train_steps[train_losses.index(min(train_losses))]})")
    print(f"  Improvement: {((train_losses[0] - min(train_losses)) / train_losses[0] * 100):.2f}%")
if eval_losses and any(l > 0 for l in eval_losses):
    valid_losses = [l for l in eval_losses if l > 0]
    valid_steps = [eval_steps[i] for i, l in enumerate(eval_losses) if l > 0]
    print(f"\nValidation Loss:")
    print(f"  Initial: {valid_losses[0]:.4f}")
    print(f"  Final: {valid_losses[-1]:.4f}")
    print(f"  Best: {min(valid_losses):.4f} (at step {valid_steps[valid_losses.index(min(valid_losses))]})")
    print(f"  Improvement: {((valid_losses[0] - min(valid_losses)) / valid_losses[0] * 100):.2f}%")
if eval_sari and any(s > 0 for s in eval_sari):
    valid_sari = [s for s in eval_sari if s > 0]
    valid_steps = [eval_steps[i] for i, s in enumerate(eval_sari) if s > 0]
    print(f"\nSARI Score:")
    print(f"  Initial: {valid_sari[0]:.4f}")
    print(f"  Final: {valid_sari[-1]:.4f}")
    print(f"  Best: {max(valid_sari):.4f} (at step {valid_steps[valid_sari.index(max(valid_sari))]})")
    print(f"  Improvement: {((max(valid_sari) - valid_sari[0]) / valid_sari[0] * 100):.2f}%")
print("="*70)