In [1]:
# Install required dependencies
%pip install -q sentence-transformers torch scikit-learn numpy



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import json
import random
import pickle
import os
from collections import defaultdict
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

import warnings
warnings.filterwarnings('ignore')

# Set device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Fixed simplification types
LABELS = [
    "sentence_splitting",
    "vocabulary_simplification",
    "structure_reordering",
    "deletion",
    "definition_or_expansion"
]

print(f"Simplification types: {LABELS}")


  from .autonotebook import tqdm as notebook_tqdm


Using device: mps
Simplification types: ['sentence_splitting', 'vocabulary_simplification', 'structure_reordering', 'deletion', 'definition_or_expansion']


## 1. Load and Prepare Dataset


In [3]:
# Load dataset from regen.json
DATA_PATH = "Dataset/final_dataset/regen.json"

with open(DATA_PATH, "r", encoding="utf-8") as f:
    data = json.load(f)

# Extract (sentence, simplification_type) pairs
examples = []
for item in data:
    if "legal_sentence" in item and "simplification_type" in item:
        examples.append((item["legal_sentence"], item["simplification_type"]))

print(f"Total examples: {len(examples)}")

# Check distribution of simplification types
type_counts = defaultdict(int)
for _, label in examples:
    type_counts[label] += 1

print("\nSimplification type distribution:")
for label in LABELS:
    count = type_counts.get(label, 0)
    print(f"  {label}: {count} ({count/len(examples)*100:.1f}%)")

# Show sample
print(f"\nSample example:")
print(f"  Sentence: {examples[0][0][:100]}...")
print(f"  Type: {examples[0][1]}")


Total examples: 2005

Simplification type distribution:
  sentence_splitting: 123 (6.1%)
  vocabulary_simplification: 1453 (72.5%)
  structure_reordering: 257 (12.8%)
  deletion: 131 (6.5%)
  definition_or_expansion: 41 (2.0%)

Sample example:
  Sentence: 1. በዚህ አንቀጽ ንዑስ አንቀጽ መሰረት የሚቀርበው ክስ ገንዘብ ጠያቂው የማህበሩን መፍረስ ካወቀበት ጊዜ ጀምሮ በአምስት ዓመት ውስጥ ካልቀረበ በይርጋ ይታገዳ...
  Type: vocabulary_simplification


## 2. Initialize Sentence Encoder


In [4]:
# Load multilingual sentence encoder
# This model is for strategy selection, not generation
encoder = SentenceTransformer(
    "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)

encoder = encoder.to(device)
print(f"Encoder loaded: paraphrase-multilingual-MiniLM-L12-v2")
print(f"Encoder device: {device}")

# Test encoding
test_emb = encoder.encode(["Test sentence"], convert_to_numpy=True)
print(f"Embedding dimension: {test_emb.shape[1]}")


Encoder loaded: paraphrase-multilingual-MiniLM-L12-v2
Encoder device: mps
Embedding dimension: 384


## 3. Build Contrastive Training Pairs


In [5]:
def build_pairs(examples, max_pairs=8000):
    """
    Build positive and negative pairs for contrastive learning.
    - Positive pair: same simplification_type
    - Negative pair: different simplification_type
    """
    by_label = defaultdict(list)
    for sent, label in examples:
        by_label[label].append(sent)
    
    pairs = []
    labels = list(by_label.keys())
    
    # Ensure we have at least 2 examples per label for positive pairs
    valid_labels = [l for l in labels if len(by_label[l]) >= 2]
    
    if not valid_labels:
        raise ValueError("Need at least 2 examples per label for contrastive learning")
    
    print(f"Building pairs from {len(valid_labels)} labels...")
    
    while len(pairs) < max_pairs:
        # Positive pair
        label = random.choice(valid_labels)
        if len(by_label[label]) >= 2:
            s1, s2 = random.sample(by_label[label], 2)
            pairs.append((s1, s2, 1))
        
        # Negative pair
        if len(valid_labels) > 1:
            neg_label = random.choice([l for l in valid_labels if l != label])
            if len(by_label[label]) > 0 and len(by_label[neg_label]) > 0:
                s1 = random.choice(by_label[label])
                s3 = random.choice(by_label[neg_label])
                pairs.append((s1, s3, 0))
    
    return pairs[:max_pairs]

# Build training pairs
pairs = build_pairs(examples, max_pairs=8000)
print(f"\nTotal contrastive pairs: {len(pairs)}")
print(f"Positive pairs: {sum(1 for _, _, label in pairs if label == 1)}")
print(f"Negative pairs: {sum(1 for _, _, label in pairs if label == 0)}")


Building pairs from 5 labels...

Total contrastive pairs: 8000
Positive pairs: 4000
Negative pairs: 4000


## 4. Contrastive Dataset and Loss


In [15]:
class ContrastiveDataset(Dataset):
    """Dataset for contrastive learning pairs"""
    def __init__(self, pairs):
        self.pairs = pairs
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        return self.pairs[idx]

def contrastive_collate_fn(batch):
    """
    Custom collate function to properly handle tuple batches.
    Batch is a list of tuples: [(s1, s2, label), ...]
    Returns: (s1_list, s2_list, labels_list)
    """
    s1_list = [item[0] for item in batch]
    s2_list = [item[1] for item in batch]
    labels_list = [item[2] for item in batch]
    return (s1_list, s2_list, labels_list)

def contrastive_loss(e1, e2, label, margin=0.5):
    """
    Contrastive loss with cosine similarity and margin.
    - Positive pairs (label=1): minimize distance
    - Negative pairs (label=0): maximize distance beyond margin
    """
    cosine = F.cosine_similarity(e1, e2)
    pos_loss = (1 - cosine) * label
    neg_loss = torch.clamp(cosine - margin, min=0.0) * (1 - label)
    return (pos_loss + neg_loss).mean()

# Create dataset and dataloader with custom collate function
dataset = ContrastiveDataset(pairs)
loader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=contrastive_collate_fn)
print(f"Dataset size: {len(dataset)}")
print(f"Batches per epoch: {len(loader)}")


Dataset size: 8000
Batches per epoch: 500


In [None]:
# Debug: Check batch structure (with custom collate function)
print("Checking batch structure...")
test_loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=contrastive_collate_fn)
test_batch = next(iter(test_loader))
print(f"Batch type: {type(test_batch)}")
print(f"Batch length: {len(test_batch)}")
print(f"Batch structure: (s1_list, s2_list, labels_list)")
print(f"s1_list length: {len(test_batch[0])}")
print(f"s2_list length: {len(test_batch[1])}")
print(f"labels_list length: {len(test_batch[2])}")
print(f"Labels: {test_batch[2]}")
print(f"Label types: {[type(l) for l in test_batch[2]]}")


Checking batch structure...
Batch type: <class 'tuple'>
Batch length: 3
Batch structure: (s1_list, s2_list, labels_list)
s1_list length: 2
s2_list length: 2
labels_list length: 2
Labels: [1, 0]
Label types: [<class 'int'>, <class 'int'>]


## 5. Train Contrastive Model


In [17]:
# Set up optimizer
optimizer = AdamW(encoder.parameters(), lr=2e-5)

# Training loop
encoder.train()
EPOCHS = 3  # Light fine-tuning, enough for strategy separation

print("Starting contrastive training...")
print(f"Epochs: {EPOCHS}, Batch size: 16, Device: {device}\n")

for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    
    for batch_idx, batch in enumerate(loader):
        # Batch is now (s1_list, s2_list, labels_list) thanks to custom collate_fn
        s1_list, s2_list, labels_list = batch
        
        # Convert labels to tensor (should already be numeric: 1 or 0)
        labels = torch.tensor([float(label) for label in labels_list], dtype=torch.float32).to(device)
        
        # Encode sentences using forward pass (for training with gradients)
        # sentence-transformers' encode() uses no_grad, so we use the model's forward directly
        # Get embeddings with gradients enabled
        features1 = encoder.tokenize(s1_list)
        features1 = {k: v.to(device) for k, v in features1.items()}
        e1 = encoder(features1)['sentence_embedding']
        
        features2 = encoder.tokenize(s2_list)
        features2 = {k: v.to(device) for k, v in features2.items()}
        e2 = encoder(features2)['sentence_embedding']
        
        # Compute loss
        loss = contrastive_loss(e1, e2, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        # Progress update every 50 batches
        if (batch_idx + 1) % 50 == 0:
            print(f"  Batch {batch_idx + 1}/{len(loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{EPOCHS} | Average Loss: {avg_loss:.4f}\n")

print("Training complete!")
encoder.eval()


Starting contrastive training...
Epochs: 3, Batch size: 16, Device: mps

  Batch 50/500, Loss: 0.2734
  Batch 100/500, Loss: 0.2473
  Batch 150/500, Loss: 0.1986
  Batch 200/500, Loss: 0.2624
  Batch 250/500, Loss: 0.1853
  Batch 300/500, Loss: 0.1816
  Batch 350/500, Loss: 0.1489
  Batch 400/500, Loss: 0.1053
  Batch 450/500, Loss: 0.1186
  Batch 500/500, Loss: 0.0984
Epoch 1/3 | Average Loss: 0.2082

  Batch 50/500, Loss: 0.1257
  Batch 100/500, Loss: 0.1357
  Batch 150/500, Loss: 0.0933
  Batch 200/500, Loss: 0.1250
  Batch 250/500, Loss: 0.0670
  Batch 300/500, Loss: 0.0672
  Batch 350/500, Loss: 0.1396
  Batch 400/500, Loss: 0.1311
  Batch 450/500, Loss: 0.0874
  Batch 500/500, Loss: 0.1621
Epoch 2/3 | Average Loss: 0.1276

  Batch 50/500, Loss: 0.1121
  Batch 100/500, Loss: 0.0916
  Batch 150/500, Loss: 0.1848
  Batch 200/500, Loss: 0.0625
  Batch 250/500, Loss: 0.0833
  Batch 300/500, Loss: 0.0717
  Batch 350/500, Loss: 0.0955
  Batch 400/500, Loss: 0.0948
  Batch 450/500, Loss:

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False, 'architecture': 'BertModel'})
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)

## 6. Build Class Centroids


In [None]:
# Build centroids for each simplification type
# Centroids represent the "average" embedding for each strategy type
encoder.eval()

centroids = {}
print("Building centroids for each simplification type...")

with torch.no_grad():
    for label in LABELS:
        # Get all sentences for this label
        sents = [s for s, l in examples if l == label]
        
        if len(sents) == 0:
            print(f"  Warning: No examples found for {label}")
            continue
        
        # Encode all sentences for this label
        emb = encoder.encode(sents, convert_to_numpy=True, show_progress_bar=False)
        
        # Compute centroid (mean embedding)
        centroids[label] = emb.mean(axis=0)
        print(f"  {label}: {len(sents)} examples, centroid shape: {centroids[label].shape}")

print(f"\nCentroids built for {len(centroids)} types")


Building centroids for each simplification type...
  sentence_splitting: 123 examples, centroid shape: (384,)
  vocabulary_simplification: 1453 examples, centroid shape: (384,)
  structure_reordering: 257 examples, centroid shape: (384,)
  deletion: 131 examples, centroid shape: (384,)
  definition_or_expansion: 41 examples, centroid shape: (384,)

Centroids built for 5 types


## 7. Save Model and Centroids


In [None]:
# Create save directory
SAVE_DIR = "models/contrastive_strategy_selector"
os.makedirs(SAVE_DIR, exist_ok=True)

# Save the fine-tuned encoder
encoder.save(SAVE_DIR)
print(f"Encoder saved to: {SAVE_DIR}")

# Save centroids
centroids_path = os.path.join(SAVE_DIR, "centroids.pkl")
with open(centroids_path, "wb") as f:
    pickle.dump(centroids, f)
print(f"Centroids saved to: {centroids_path}")

# Save labels for reference
labels_path = os.path.join(SAVE_DIR, "labels.txt")
with open(labels_path, "w", encoding="utf-8") as f:
    for label in LABELS:
        f.write(f"{label}\n")
print(f"Labels saved to: {labels_path}")

print("\n✅ All files saved successfully!")


Encoder saved to: models/contrastive_strategy_selector
Centroids saved to: models/contrastive_strategy_selector/centroids.pkl
Labels saved to: models/contrastive_strategy_selector/labels.txt

✅ All files saved successfully!


## 8. Load Saved Model (for future use)


In [None]:
# Example: How to load the saved model and centroids
# Uncomment and run this cell when you want to load a previously trained model

"""
LOAD_DIR = "models/contrastive_strategy_selector"

# Load encoder
loaded_encoder = SentenceTransformer(LOAD_DIR)
loaded_encoder = loaded_encoder.to(device)

# Load centroids
with open(os.path.join(LOAD_DIR, "centroids.pkl"), "rb") as f:
    loaded_centroids = pickle.load(f)

print("Model and centroids loaded successfully!")
"""


'\nLOAD_DIR = "models/contrastive_strategy_selector"\n\n# Load encoder\nloaded_encoder = SentenceTransformer(LOAD_DIR)\nloaded_encoder = loaded_encoder.to(device)\n\n# Load centroids\nwith open(os.path.join(LOAD_DIR, "centroids.pkl"), "rb") as f:\n    loaded_centroids = pickle.load(f)\n\nprint("Model and centroids loaded successfully!")\n'

## 9. Hybrid Inference System (Heuristics + Contrastive)


In [22]:
def heuristic_predict(sentence):
    """
    Heuristic rules for high-confidence strategy selection.
    Returns simplification_type if confident, None otherwise.
    """
    words = sentence.split()
    word_count = len(words)
    
    # Sentence splitting: very long sentences
    if word_count > 40:
        return "sentence_splitting"
    
    # Deletion: presence of known boilerplate phrases
    boilerplate_phrases = [
        "እንደተጠበቀ ሆኖ",
        "በማንኛውም ሁኔታ",
        "ያለ አግባብ",
        "እንደተጠበቀ",
    ]
    if any(phrase in sentence for phrase in boilerplate_phrases):
        return "deletion"
    
    # Structure reordering: multiple conjunctions/clauses
    conjunctions = ["እና", "ወይም", "ቢሆንም", "ነገር ግን"]
    conjunction_count = sum(1 for conj in conjunctions if conj in sentence)
    if conjunction_count >= 3:
        return "structure_reordering"
    
    # No high-confidence heuristic match
    return None

def contrastive_predict(sentence, encoder, centroids):
    """
    Use contrastive embeddings to predict simplification type.
    Returns the type with highest cosine similarity to its centroid.
    """
    # Encode the sentence
    emb = encoder.encode([sentence], convert_to_numpy=True)[0]
    
    # Compute similarity to each centroid
    sims = {}
    for label, centroid in centroids.items():
        similarity = cosine_similarity([emb], [centroid])[0][0]
        sims[label] = similarity
    
    # Return type with highest similarity
    return max(sims, key=sims.get), sims

def select_simplification_type(sentence, encoder, centroids):
    """
    Hybrid selector: heuristics first, contrastive fallback.
    Returns (simplification_type, method_used, confidence_scores)
    """
    # Try heuristics first
    heuristic_result = heuristic_predict(sentence)
    if heuristic_result:
        return heuristic_result, "heuristic", None
    
    # Fallback to contrastive
    predicted_type, sims = contrastive_predict(sentence, encoder, centroids)
    return predicted_type, "contrastive", sims

print("Hybrid inference system ready!")


Hybrid inference system ready!


In [None]:
# Test on a few examples
print("Testing hybrid strategy selector:\n")

test_indices = [0, 10, 50, 100, 200]  # Sample different examples

for idx in test_indices:
    if idx < len(examples):
        sentence, true_label = examples[idx]
        
        # Get prediction
        pred_type, method, sims = select_simplification_type(sentence, encoder, centroids)
        
        # Display results
        print(f"Example {idx + 1}:")
        print(f"  Sentence: {sentence[:80]}...")
        print(f"  True label: {true_label}")
        print(f"  Predicted: {pred_type} (via {method})")
        if sims:
            print(f"  Similarities: {', '.join([f'{k}: {v:.3f}' for k, v in sorted(sims.items(), key=lambda x: x[1], reverse=True)])}")
        print(f"  Match: {'✅' if pred_type == true_label else '❌'}")
        print()


Testing hybrid strategy selector:

Example 1:
  Sentence: 1. በዚህ አንቀጽ ንዑስ አንቀጽ መሰረት የሚቀርበው ክስ ገንዘብ ጠያቂው የማህበሩን መፍረስ ካወቀበት ጊዜ ጀምሮ በአምስት ዓመት...
  True label: vocabulary_simplification
  Predicted: structure_reordering (via contrastive)
  Similarities: structure_reordering: 0.996, vocabulary_simplification: 0.946, deletion: 0.798, sentence_splitting: 0.407, definition_or_expansion: -0.189
  Match: ❌

Example 11:
  Sentence: 11. ቦርድ አስፈላጊ መስል በታየው ጊዜ ሁሉ ከአባላቱ መካከል ጉዳዮችን መርምረው የመፍትሔ ሏሳብ የሚያቀርቡ ኮሚቴዎችን ማቋቋም...
  True label: vocabulary_simplification
  Predicted: vocabulary_simplification (via contrastive)
  Similarities: vocabulary_simplification: 0.996, structure_reordering: 0.979, deletion: 0.926, sentence_splitting: 0.147, definition_or_expansion: -0.115
  Match: ✅

Example 51:
  Sentence: 51. ጉባኤው ሥራ ከመጀመሩ በፊት ስብሰባው ላይ ከተገኙት ባለአክሲዮኖች መካከል በጉባኤው ላይ የተደረገው ውይይት እና የተላለፈ...
  True label: sentence_splitting
  Predicted: sentence_splitting (via contrastive)
  Similarities: sentence_splitting: 

## 11. Evaluation: Accuracy by Method


In [None]:
# Evaluate on all examples
print("Evaluating hybrid system on full dataset...\n")

correct = 0
heuristic_count = 0
contrastive_count = 0
heuristic_correct = 0
contrastive_correct = 0

for sentence, true_label in examples:
    pred_type, method, _ = select_simplification_type(sentence, encoder, centroids)
    
    if pred_type == true_label:
        correct += 1
    
    if method == "heuristic":
        heuristic_count += 1
        if pred_type == true_label:
            heuristic_correct += 1
    else:
        contrastive_count += 1
        if pred_type == true_label:
            contrastive_correct += 1

total = len(examples)
accuracy = correct / total * 100
heuristic_accuracy = (heuristic_correct / heuristic_count * 100) if heuristic_count > 0 else 0
contrastive_accuracy = (contrastive_correct / contrastive_count * 100) if contrastive_count > 0 else 0

print(f"Overall Results:")
print(f"  Total examples: {total}")
print(f"  Correct predictions: {correct} ({accuracy:.2f}%)")
print(f"\nBy Method:")
print(f"  Heuristic:")
print(f"    Used: {heuristic_count} ({heuristic_count/total*100:.1f}%)")
print(f"    Accuracy: {heuristic_accuracy:.2f}%")
print(f"  Contrastive:")
print(f"    Used: {contrastive_count} ({contrastive_count/total*100:.1f}%)")
print(f"    Accuracy: {contrastive_accuracy:.2f}%")


Evaluating hybrid system on full dataset...

Overall Results:
  Total examples: 2005
  Correct predictions: 1039 (51.82%)

By Method:
  Heuristic:
    Used: 158 (7.9%)
    Accuracy: 36.08%
  Contrastive:
    Used: 1847 (92.1%)
    Accuracy: 53.17%


## 12. Integration Function for AfriByT5

This function can be imported and used in your AfriByT5 inference pipeline.


In [25]:
def get_strategy_selector(model_dir="models/contrastive_strategy_selector", device=None):
    """
    Load the trained strategy selector (encoder + centroids).
    Returns a function that can be used to select simplification_type.
    
    Usage:
        selector = get_strategy_selector()
        simplification_type = selector("Your legal sentence here")
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    
    # Load encoder
    encoder = SentenceTransformer(model_dir)
    encoder = encoder.to(device)
    
    # Load centroids
    centroids_path = os.path.join(model_dir, "centroids.pkl")
    with open(centroids_path, "rb") as f:
        centroids = pickle.load(f)
    
    # Return a callable function
    def select_type(sentence):
        pred_type, method, _ = select_simplification_type(sentence, encoder, centroids)
        return pred_type
    
    return select_type

# Example usage:
# selector = get_strategy_selector()
# simplification_type = selector("Your legal sentence")
# Then use this type to condition AfriByT5: f"simplify | {simplification_type}: {sentence}"

print("Integration function ready!")
print("\nTo use in your AfriByT5 pipeline:")
print("  from contrastive_strategy_selector import get_strategy_selector")
print("  selector = get_strategy_selector()")
print("  simplification_type = selector(legal_sentence)")
print("  prompt = f'simplify | {simplification_type}: {legal_sentence}'")


Integration function ready!

To use in your AfriByT5 pipeline:
  from contrastive_strategy_selector import get_strategy_selector
  selector = get_strategy_selector()
  simplification_type = selector(legal_sentence)
  prompt = f'simplify | {simplification_type}: {legal_sentence}'
