In [116]:
from datasets import load_dataset, Dataset, DatasetDict
import sentencepiece as spm
from pathlib import Path
from typing import Dict

def extract_long_translations(
        min_length: int = 200,
        max_length: int = 2000,
        sp_model_path: str = "../attention_is_all_you_need/BPE/en-hi.model",
        max_samples: int | None = None
    ) -> Dict[str, str]:
    # ── 0.  SentencePiece initialisation ───────────────────────────────
    sp = spm.SentencePieceProcessor()
    if not Path(sp_model_path).exists():
        raise FileNotFoundError(
            f"SentencePiece model not found at {sp_model_path}"
        )
    sp.load(sp_model_path)                              # turn0search6

    def sp_len(txt: str) -> int:
        return len(sp.encode_as_ids(txt))

    # Helper to add (en,hi) pairs if they meet length and quota
    results: dict[str, str] = {}

    def maybe_add(en: str, hi: str) -> None:
        if  min_length <= sp_len(en) <= max_length and min_length <= sp_len(hi) <= max_length:
            if max_samples is None or len(results) < max_samples:
                results[en] = hi

    # Helper that walks through *any* HF dataset object
    def walk_dataset(ds, nested_translation: bool) -> None:
        if isinstance(ds, DatasetDict):
            splits = ds.values()
        else:                           # bare Dataset
            splits = [ds]
        for split in splits:
            for ex in split:
                if nested_translation:                  # translation column
                    maybe_add(ex["translation"]["en"],
                              ex["translation"]["hi"])
                else:                                   # src / tgt style
                    maybe_add(ex["src"], ex["tgt"])

    # ── 1.  OPUS-100 (≈ 55 M pairs, en-hi subset) ──────────────────────
    walk_dataset(
        load_dataset("opus100", "en-hi"),                # turn0search0
        nested_translation=True,
    )

    # ── 2.  IITB EN-HI corpus ──────────────────────────────────────────
    walk_dataset(
        load_dataset("cfilt/iitb-english-hindi"),        # turn0search1
        nested_translation=True,
    )

    # ── 3.  Samanantar (AI4Bharat) ─────────────────────────────────────
    walk_dataset(
        load_dataset("ai4bharat/samanantar", "hi"),      # turn0search2
        nested_translation=False,                       # uses src / tgt
    )

    # ── 4.  PMIndiaSum: align the *two* directions ─────────────────────
    #        (Hindi article + EN summary)  ↔  (English article + HI summary)
    ds_hi = load_dataset("PMIndiaData/PMIndiaSum", data_dir="hindi-english")
    ds_en = load_dataset("PMIndiaData/PMIndiaSum", data_dir="english-hindi")

    # 4-A.  Build a single lookup table from *every* English-side split
    en_by_url = {}
    for split_name, split in ds_en.items():             # train / validation / test
        for ex in split:
            en_by_url[ex["source_url"]] = ex            # keep last seen, fine for 1-1 mapping

    # 4-B.  Iterate over every Hindi-side split and align
    for split_name, split in ds_hi.items():
        for ex_hi in split:
            eng_rec = en_by_url.get(ex_hi["target_url"])
            if eng_rec:
                maybe_add(eng_rec["text"], ex_hi["text"])
                
    return results


In [117]:
sentences = extract_long_translations(50, 2000)

In [118]:
english_encoded = sp.Encode(list(sentences.keys()))
hindi_encoded = sp.Encode(list(sentences.values()))

In [119]:
def analyze_encoded_lengths(english_encoded, hindi_encoded):
    
    # Calculate lengths
    en_lengths = [len(seq) for seq in english_encoded]
    hi_lengths = [len(seq) for seq in hindi_encoded]
    combined_lengths = [en for en, hi in zip(en_lengths, hi_lengths)]
    
    # Print summary
    print(f"Total pairs: {len(combined_lengths):,}")
    print(f"Combined length - Min: {min(combined_lengths)}, Max: {max(combined_lengths)}")
    print(f"Average combined length: {sum(combined_lengths)/len(combined_lengths):.1f} tokens")
    
    # Count by length buckets
    buckets = [(0, 100), (100, 200), (200, 300), (300, 500), (500, 750), (750, 1000), (1000, 1500),  (2000, float('inf'))]
    for min_len, max_len in buckets:
        count = sum(1 for l in combined_lengths if min_len <= l < max_len)
        label = f"{min_len}-{max_len}" if max_len != float('inf') else f"{min_len}+"
        print(f"{label} tokens: {count:,} pairs ({count/len(combined_lengths)*100:.1f}%)")
    
lengths = analyze_encoded_lengths(english_encoded, hindi_encoded)
lengths

Total pairs: 816,270
Combined length - Min: 50, Max: 1997
Average combined length: 76.6 tokens
0-100 tokens: 715,935 pairs (87.7%)
100-200 tokens: 89,813 pairs (11.0%)
200-300 tokens: 6,615 pairs (0.8%)
300-500 tokens: 2,126 pairs (0.3%)
500-750 tokens: 760 pairs (0.1%)
750-1000 tokens: 412 pairs (0.1%)
1000-1500 tokens: 435 pairs (0.1%)
2000+ tokens: 0 pairs (0.0%)


In [123]:
import random
import numpy as np

def sample_by_length_buckets(english_encoded, hindi_encoded, target_samples=6000, random_seed=42):
    """
    Sample data to limit the number of pairs in specific length buckets
    
    Args:
        english_encoded: List of encoded English sequences
        hindi_encoded: List of encoded Hindi sequences  
        target_samples: Target number of samples for 0-100 and 100-200 buckets
        random_seed: Random seed for reproducibility
    
    Returns:
        Tuple of (sampled_english, sampled_hindi, sampling_info)
    """
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    # Calculate lengths for each pair
    en_lengths = [len(seq) for seq in english_encoded]
    hi_lengths = [len(seq) for seq in hindi_encoded]
    combined_lengths = [en for en, hi in zip(en_lengths, hi_lengths)]
    
    # Create buckets with indices
    buckets = {
        '0-100': [],
        '100-200': [], 
        '200-300': [],
        '300-500': [],
        '500-750': [],
        '750-1000': [],
        '1000-1500': [],
        '1500-2000': [],
        '2000+': []
    }
    
    # Assign each pair to appropriate bucket
    for i, length in enumerate(combined_lengths):
        if 0 <= length < 100:
            buckets['0-100'].append(i)
        elif 100 <= length < 200:
            buckets['100-200'].append(i)
        elif 200 <= length < 300:
            buckets['200-300'].append(i)
        elif 300 <= length < 500:
            buckets['300-500'].append(i)
        elif 500 <= length < 750:
            buckets['500-750'].append(i)
        elif 750 <= length < 1000:
            buckets['750-1000'].append(i)
        elif 1000 <= length < 1500:
            buckets['1000-1500'].append(i)
        elif 1500 <= length < 2000:
            buckets['1500-2000'].append(i)
        else:
            buckets['2000+'].append(i)
    
    # Sample indices to keep
    selected_indices = []
    sampling_info = {}
    
    for bucket_name, indices in buckets.items():
        original_count = len(indices)
        
        if bucket_name in ['0-100', '100-200']:
            # Sample target_samples from these buckets
            if original_count <= target_samples:
                sampled_indices = indices
                sampled_count = original_count
            else:
                sampled_indices = random.sample(indices, target_samples)
                sampled_count = target_samples
        else:
            # Keep all indices from other buckets
            sampled_indices = indices
            sampled_count = original_count
        
        selected_indices.extend(sampled_indices)
        sampling_info[bucket_name] = {
            'original': original_count,
            'sampled': sampled_count,
            'percentage_kept': (sampled_count / original_count * 100) if original_count > 0 else 0
        }
    
    # Sort indices to maintain some order
    selected_indices.sort()
    
    # Create sampled datasets
    sampled_english = [english_encoded[i] for i in selected_indices]
    sampled_hindi = [hindi_encoded[i] for i in selected_indices]
    
    return sampled_english, sampled_hindi, sampling_info

def print_sampling_summary(sampling_info):
    """Print a summary of the sampling results"""
    print("Sampling Summary:")
    print("=" * 50)
    
    total_original = sum(info['original'] for info in sampling_info.values())
    total_sampled = sum(info['sampled'] for info in sampling_info.values())
    
    for bucket_name, info in sampling_info.items():
        original = info['original']
        sampled = info['sampled']
        percentage = info['percentage_kept']
        
        if original > 0:
            print(f"{bucket_name:12}: {original:7,} -> {sampled:6,} ({percentage:5.1f}% kept)")
        else:
            print(f"{bucket_name:12}: {original:7,} -> {sampled:6,} (N/A)")
    
    print("-" * 50)
    print(f"{'Total':12}: {total_original:7,} -> {total_sampled:6,} ({total_sampled/total_original*100:.1f}% kept)")

def analyze_sampled_lengths(english_encoded, hindi_encoded):
    """Analyze the length distribution of the sampled data"""
    # Calculate lengths
    en_lengths = [len(seq) for seq in english_encoded]
    hi_lengths = [len(seq) for seq in hindi_encoded]
    combined_lengths = [en + hi for en, hi in zip(en_lengths, hi_lengths)]
    
    # Print summary
    print(f"\nSampled Data Analysis:")
    print(f"Total pairs: {len(combined_lengths):,}")
    print(f"Combined length - Min: {min(combined_lengths)}, Max: {max(combined_lengths)}")
    print(f"Average combined length: {sum(combined_lengths)/len(combined_lengths):.1f} tokens")
    
    # Count by length buckets
    buckets = [(0, 100), (100, 200), (200, 300), (300, 500), (500, 750), (750, 1000), (1000, 1500), (1500, 2000), (2000, float('inf'))]
    for min_len, max_len in buckets:
        count = sum(1 for l in combined_lengths if min_len <= l < max_len)
        if max_len == float('inf'):
            label = f"{min_len}+"
        else:
            label = f"{min_len}-{max_len}"
        print(f"{label:12} tokens: {count:6,} pairs ({count/len(combined_lengths)*100:.1f}%)")

sampled_english, sampled_hindi, sampling_info = sample_by_length_buckets(
    english_encoded, 
    hindi_encoded, 
    target_samples=6000,
    random_seed=42
)

print_sampling_summary(sampling_info)
analyze_sampled_lengths(sampled_english, sampled_hindi)

# Replace original data with sampled data
english_encoded = sampled_english
hindi_encoded = sampled_hindi

print(f"\nData successfully sampled!")
print(f"Original dataset reduced from 816,270 to {len(english_encoded):,} pairs")

Sampling Summary:
0-100       :       0 ->      0 (N/A)
100-200     : 714,972 ->  6,000 (  0.8% kept)
200-300     :  76,054 -> 76,054 (100.0% kept)
300-500     :  19,929 -> 19,929 (100.0% kept)
500-750     :   2,822 ->  2,822 (100.0% kept)
750-1000    :     801 ->    801 (100.0% kept)
1000-1500   :     723 ->    723 (100.0% kept)
1500-2000   :     394 ->    394 (100.0% kept)
2000+       :     575 ->    575 (100.0% kept)
--------------------------------------------------
Total       : 816,270 -> 107,298 (13.1% kept)

Sampled Data Analysis:
Total pairs: 107,298
Combined length - Min: 100, Max: 3946
Average combined length: 291.7 tokens
0-100        tokens:      0 pairs (0.0%)
100-200      tokens:  6,000 pairs (5.6%)
200-300      tokens: 76,054 pairs (70.9%)
300-500      tokens: 19,929 pairs (18.6%)
500-750      tokens:  2,822 pairs (2.6%)
750-1000     tokens:    801 pairs (0.7%)
1000-1500    tokens:    723 pairs (0.7%)
1500-2000    tokens:    394 pairs (0.4%)
2000+        tokens:    575 