In [1]:
%uv pip install datasets transformers huggingface_hub


[2mUsing Python 3.12.6 environment at: /usr/local[0m
[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠋[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mResolving dependencies...                                                     [0m[2K[37m⠙[0m [2mdatasets==4.4.1                                                               [0m[2K[37m⠙[0m [2mtransformers==4.56.0                                                          [0m[2K[37m⠙[0m [2mhuggingface-hub==0.34.4                                                       [0m[2K[37m⠹[0m [2mhuggingface-hub==0.34.4                                                       [0m[2K[37m⠹[0m [2mfilelock==3.13.1                                                              [0m[2K[37m⠹[0m [2mnumpy==2.1.2                                                                  [0m[2K[37m⠹[0m [2mpyarrow==22

In [3]:
import json
import os
import difflib
import re
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Optional
from multiprocessing import cpu_count

import datasets
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets
from transformers import AutoTokenizer, RobertaTokenizerFast
from huggingface_hub import HfApi, create_repo, login


In [2]:
import gc

In [5]:
# HF_TOKEN = ""

In [6]:
MODEL_NAME = "IRIIS-RESEARCH/RoBERTa_Nepali_125M"
MAX_SEQUENCE_LENGTH = 128
HF_USERNAME = "DipeshChaudhary"
RAW_DATASET_NAME = "sumitaryal/nepali_grammatical_error_correction"
FINAL_DATASET_NAME = "nepali-gector-mlm-guesser-dataset" 
REPO_ID = f"{HF_USERNAME}/{FINAL_DATASET_NAME}"
NUM_WORKERS = max(1, cpu_count() - 2)


In [7]:
login(token=HF_TOKEN)
print("✅ Hugging Face login successful.")

✅ Hugging Face login successful.


In [8]:
print(f"--- MLM DATASET PROCESSOR ---")
print(f"Model: {MODEL_NAME}")
print(f"Max Seq Length: {MAX_SEQUENCE_LENGTH}")
print(f"Workers: {NUM_WORKERS}")
print(f"Output Repo: {REPO_ID}")
print("-" * 70 + "\n")


--- MLM DATASET PROCESSOR ---
Model: IRIIS-RESEARCH/RoBERTa_Nepali_125M
Max Seq Length: 128
Workers: 106
Output Repo: DipeshChaudhary/nepali-gector-mlm-guesser-dataset
----------------------------------------------------------------------



In [9]:

# Initialize Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
MASK_TOKEN_ID = tokenizer.mask_token_id
CLS_TOKEN_ID = tokenizer.cls_token_id
SEP_TOKEN_ID = tokenizer.sep_token_id
PAD_TOKEN_ID = tokenizer.pad_token_id

# --- Enhanced GEC Vocabulary (From 01-gec-token-gector-tag-dataset-processor.ipynb) ---
# Defines the mapping between GEC tag names and their integer IDs.
class EnhancedNepaliGECVocabulary:
    def __init__(self):
        self.KEEP_ID = 0
        self.DELETE_ID = 1
        self.REPLACE_ID = 2
        self.APPEND_ID = 3
        self.SWAP_NEXT_ID = 4
        self.SWAP_PREV_ID = 5
        self.MERGE_NEXT_ID = 6
        self.MERGE_PREV_ID = 7
        self.SPLIT_ID = 8
        self.UNKNOWN_ID = 9
        
        self.tag_to_id = {
            "$KEEP": self.KEEP_ID, "$DELETE": self.DELETE_ID,
            "$REPLACE": self.REPLACE_ID, "$APPEND": self.APPEND_ID,
            "$SWAP_NEXT": self.SWAP_NEXT_ID, "$SWAP_PREV": self.SWAP_PREV_ID,
            "$MERGE_NEXT": self.MERGE_NEXT_ID, "$MERGE_PREV": self.MERGE_PREV_ID,
            "$SPLIT": self.SPLIT_ID, "$UNKNOWN": self.UNKNOWN_ID
        }
        self.id_to_tag = {v: k for k, v in self.tag_to_id.items()}
    
VOCAB = EnhancedNepaliGECVocabulary()

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/968 [00:00<?, ?B/s]

In [10]:
def calculate_levenshtein_opcodes(incorrect_words: List[str], correct_words: List[str]) -> List[Tuple[str, int, int, int, int]]:
    """Calculates Levenshtein opcodes (tag, i1, i2, j1, j2) using difflib."""
    s = difflib.SequenceMatcher(None, incorrect_words, correct_words, autojunk=False)
    return s.get_opcodes()

def generate_word_level_tags(incorrect_words: List[str], 
                             correct_words: List[str],
                             vocabulary: EnhancedNepaliGECVocabulary) -> Tuple[List[int], List[Tuple[str, int, int, int, int]]]:
    """
    Generates robust, word-level GEC tags using a strict multi-pass priority system.
    This function is slightly simplified from the original for clarity, focusing on
    getting the core REPLACE and APPEND information.
    """
    
    opcodes = calculate_levenshtein_opcodes(incorrect_words, correct_words)
    tags = [vocabulary.KEEP_ID] * len(incorrect_words)
    

    j_next = 0 # Next available index in the correct sentence
    for op, i_start, i_end, j_start, j_end in opcodes:
        inc_span = incorrect_words[i_start:i_end]
        cor_span = correct_words[j_start:j_end]
        
        if any(tags[idx] != vocabulary.KEEP_ID for idx in range(i_start, i_end)):
            j_next = j_end
            continue
            
        if op == 'equal':
            for idx in range(i_start, i_end):
                tags[idx] = vocabulary.KEEP_ID
        elif op == 'replace':
            # This is a key target for masking/replacement
            for idx in range(i_start, i_end):
                tags[idx] = vocabulary.REPLACE_ID
        elif op == 'delete':
            for idx in range(i_start, i_end):
                tags[idx] = vocabulary.DELETE_ID
        elif op == 'insert':

            pass 
        
        j_next = j_end

    return tags, opcodes

In [16]:
def process_for_mlm(example: Dict) -> Dict:
    """
    Maps a (Corrupt, Correct) sentence pair to MLM-ready input/labels.
    Only creates masks for REPLACE and APPEND operations.
    """
    
    # 1. Prepare raw sentences and token lists
    inc_sent = example['incorrect_sentence']
    cor_sent = example['correct_sentence']
    
    # Simple whitespace tokenization for word-level alignment (crucial step from original)
    incorrect_words = inc_sent.split()
    correct_words = cor_sent.split()
    
    # 2. Get Word-Level GEC Tags and Opcodes
    word_tags, opcodes = generate_word_level_tags(incorrect_words, correct_words, VOCAB)

    # 3. Tokenize both sentences (needed to get correct token IDs)
    inc_encoding = tokenizer(
        inc_sent, 
        truncation=True, 
        max_length=MAX_SEQUENCE_LENGTH, 
        return_offsets_mapping=True, # Need this to map tokens back to words/spans
        add_special_tokens=True
    )
    
    cor_encoding = tokenizer(
        cor_sent, 
        truncation=True, 
        max_length=MAX_SEQUENCE_LENGTH, 
        add_special_tokens=True
    )

    inc_input_ids = list(inc_encoding['input_ids'])
    
    # Initialize MLM labels to -100 (ignore by default)
    mlm_labels = [-100] * len(inc_input_ids)

    # If sentences are too long after truncation, skip to avoid alignment issues
    if len(inc_input_ids) > MAX_SEQUENCE_LENGTH or len(cor_encoding['input_ids']) > MAX_SEQUENCE_LENGTH:
        return {
            'mlm_input_ids': inc_input_ids,
            'mlm_labels': mlm_labels,
            'mlm_tag_count': 0
        }

    # 4. Map Word-Level Tags to Token-Level MLM Input/Labels
    
    # The tokenizer's `word_ids()` connects subword tokens to their original word index (0, 1, 2, ...).
    word_ids = inc_encoding.word_ids()
    mlm_tag_count = 0
    
    # This loop tracks the index in the CORRECT sentence tokens (cor_encoding['input_ids'])
    # to find the correct replacement/append token IDs.
    cor_token_idx = 1 # Start after [CLS]
    
    # Iterate over the word indices of the incorrect sentence
    for word_idx in range(len(incorrect_words)):
        tag = word_tags[word_idx]
        
        # Find the span of tokens corresponding to the current incorrect word
        token_indices = [i for i, w_id in enumerate(word_ids) if w_id == word_idx]
        
        if not token_indices:
            continue # Skip words with no tokens (e.g., if a word was fully truncated)
            
        first_token_idx = token_indices[0]
        
        if tag == VOCAB.REPLACE_ID:
            # 📌 CASE: $REPLACE
            # We need to find the correct token(s) to replace the incorrect word.
            
            # Find the corresponding correct word index(es) from the opcodes
            cor_word_index_start = -1
            cor_word_index_end = -1
            
            # Search opcodes for the current incorrect word's index
            for op, i_start, i_end, j_start, j_end in opcodes:
                if op == 'replace' and i_start <= word_idx < i_end:
                    cor_word_index_start = j_start
                    cor_word_index_end = j_end
                    break

            if cor_word_index_start != -1 and cor_word_index_end > cor_word_index_start:
                # The MLM target is the first subword of the *correct* replacement word.
                correct_replacement_word = correct_words[cor_word_index_start]
                
                # Re-tokenize the *correct* replacement word to find its first subword ID
                replacement_tokens = tokenizer.encode(correct_replacement_word, add_special_tokens=False)
                
                if replacement_tokens:
                    # 1. Mask the incorrect word tokens in the input
                    inc_input_ids[first_token_idx] = MASK_TOKEN_ID
                    for other_idx in token_indices[1:]:
                        inc_input_ids[other_idx] = MASK_TOKEN_ID # Mask all sub-tokens

                    # 2. Set the label to the first token of the correct word
                    mlm_labels[first_token_idx] = replacement_tokens[0]
                    
                    mlm_tag_count += 1

        elif tag == VOCAB.APPEND_ID:
            # 📌 CASE: $APPEND
            # In GECTOR, $APPEND is assigned to a word *before* the insertion point.
            
            # Find the corresponding inserted word(s) from the opcodes
            inserted_words = []
            for op, i_start, i_end, j_start, j_end in opcodes:
                 # Check for an 'insert' operation directly following the current word index
                if op == 'insert' and i_end == word_idx + 1:
                    inserted_words.extend(correct_words[j_start:j_end])
            
            if inserted_words:
                first_inserted_word = inserted_words[0]
                
                # Re-tokenize the inserted word to find its first subword ID
                inserted_tokens = tokenizer.encode(first_inserted_word, add_special_tokens=False)

                if inserted_tokens:
                    # 1. Inject a [MASK] token into the input IDs *after* the current word's tokens
                    insertion_point = token_indices[-1] + 1
                    inc_input_ids.insert(insertion_point, MASK_TOKEN_ID)
                    
                    # 2. Adjust mlm_labels (insert -100 at the new position)
                    mlm_labels.insert(insertion_point, -100)
                    
                    # 3. Set the label at the newly inserted [MASK] token
                    mlm_labels[insertion_point] = inserted_tokens[0]
                    
                    # Note: Subsequent token indices in inc_input_ids and mlm_labels are shifted by 1.
                    # We must continue iterating through the word_ids list and the mlm_labels list, 
                    # which is now misaligned with the original word_ids.
                    # A robust implementation would rebuild the word_ids after insertion, 
                    # but for simplicity, we focus on the single append for the current word_idx.
                    
                    mlm_tag_count += 1
        
    return {
        'mlm_input_ids': inc_input_ids,
        'mlm_labels': mlm_labels,
        'mlm_tag_count': mlm_tag_count,
        # Keep original text for verification
        'incorrect_sentence': inc_sent,
        'correct_sentence': cor_sent
    }

In [17]:
print("=" * 70)
print(f"STEP 1: LOADING RAW DATASET: {RAW_DATASET_NAME}")
print("=" * 70)

# Load the raw dataset splits
raw_dataset_splits = load_dataset(RAW_DATASET_NAME)
raw_dataset = concatenate_datasets([raw_dataset_splits['train'], raw_dataset_splits['valid']])

def filter_correct_sentences(example):
    return example['incorrect_sentence'].strip() != example['correct_sentence'].strip()

print(f"Total rows before filtering: {len(raw_dataset)}")
# Filter the dataset to only include sentences where an error exists.
error_dataset = raw_dataset.filter(
    filter_correct_sentences, 
    num_proc=NUM_WORKERS
)
print(f"Total rows after filtering (only errors): {len(error_dataset)}")

print("=" * 70)
print("STEP 2: GENERATING MLM INPUTS AND LABELS")
print(f"Using {NUM_WORKERS} workers.")
print("=" * 70)

# Apply the custom MLM processing function
mlm_dataset = error_dataset.map(
    process_for_mlm,
    remove_columns=error_dataset.column_names,
    num_proc=NUM_WORKERS,
    # The APPEND logic dynamically changes the length, so we must disable batching
    batched=False,
)

# Filter out examples where no REPLACE/APPEND tag was found (e.g., only DELETE/SWAP/MERGE errors)
mlm_dataset = mlm_dataset.filter(lambda x: x['mlm_tag_count'] > 0, num_proc=NUM_WORKERS)
print(f"\nTotal rows after MLM mask generation and filtering (REPLACE/APPEND targets): {len(mlm_dataset)}")

# --- SPLITTING ---
print("=" * 70)
print("STEP 3: CREATING TRAIN/VALID SPLITS")
print("=" * 70)

# Split the dataset into train and validation sets (90/10 split)
final_split = mlm_dataset.train_test_split(test_size=0.1, seed=42)
final_dataset = DatasetDict({
    'train': final_split['train'],
    'valid': final_split['test'],
})

print("Final dataset splits:")
print(final_dataset)

# --- UPLOAD TO HUGGING FACE ---
print("=" * 70)
print("STEP 4: UPLOADING DATASET TO HUGGING FACE HUB")
print("=" * 70)

try:
    if HF_TOKEN:
        final_dataset.push_to_hub(
            repo_id=REPO_ID, 
            commit_message="Initial MLM dataset for GEC Guesser/Suggestor (REPLACE and APPEND focus)",
            private=True,
            token=HF_TOKEN
        )
        print(f"\n🎉 Successfully uploaded dataset to: {REPO_ID}")
    else:
        print(f"⚠️ HF_TOKEN not set. Skipping upload. Dataset is ready locally.")

except Exception as e:
    print(f"\n❌ UPLOAD FAILED! Please check your HF_TOKEN and permissions.")
    print(e)
    
print("\n" + "=" * 70)
print("ALL STEPS COMPLETE.")
print(f"Next Step: Fine-tune {MODEL_NAME} for MLM on this dataset.")
print("=" * 70)

STEP 1: LOADING RAW DATASET: sumitaryal/nepali_grammatical_error_correction
Total rows before filtering: 8130496
Total rows after filtering (only errors): 8130496
STEP 2: GENERATING MLM INPUTS AND LABELS
Using 106 workers.


Map (num_proc=106):   0%|          | 0/8130496 [00:00<?, ? examples/s]

Filter (num_proc=106):   0%|          | 0/8130496 [00:00<?, ? examples/s]


Total rows after MLM mask generation and filtering (REPLACE/APPEND targets): 4557509
STEP 3: CREATING TRAIN/VALID SPLITS
Final dataset splits:
DatasetDict({
    train: Dataset({
        features: ['incorrect_sentence', 'correct_sentence', 'mlm_input_ids', 'mlm_labels', 'mlm_tag_count'],
        num_rows: 4101758
    })
    valid: Dataset({
        features: ['incorrect_sentence', 'correct_sentence', 'mlm_input_ids', 'mlm_labels', 'mlm_tag_count'],
        num_rows: 455751
    })
})
STEP 4: UPLOADING DATASET TO HUGGING FACE HUB


Uploading the dataset shards:   0%|          | 0/6 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   0%|          |  524kB /  157MB            

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   2%|2         | 3.67MB /  158MB            

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   2%|2         | 3.67MB /  158MB            

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   2%|2         | 3.67MB /  157MB            

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   2%|2         | 3.67MB /  157MB            

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   2%|2         | 3.67MB /  158MB            

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ? shards/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

                                        :   7%|6         | 7.35MB /  105MB            


🎉 Successfully uploaded dataset to: DipeshChaudhary/nepali-gector-mlm-guesser-dataset

ALL STEPS COMPLETE.
Next Step: Fine-tune IRIIS-RESEARCH/RoBERTa_Nepali_125M for MLM on this dataset.


In [18]:
final_dataset['train'][0]

{'incorrect_sentence': 'पछिल्लो समय सहज रुपमा भारतीय सरकाले पेन्सन दिन आनाकानी गर्न थालेपछि उनिहरूले विभिन्न निकायमा धाउप बाध्यता बनेको हो ।',
 'correct_sentence': 'पछिल्लो समय सहज रुपमा भारतीय सरकाले पेन्सन दिन आनाकानी गर्न थालेपछि उनिहरूले विभिन्न निकायमा धाउनुपर्ने बाध्यता बनेको हो ।',
 'mlm_input_ids': [1676,
  867,
  2057,
  1120,
  1800,
  35914,
  15240,
  837,
  21380,
  636,
  4131,
  10315,
  1198,
  1149,
  8240,
  6,
  6,
  6,
  5432,
  1911,
  586,
  488],
 'mlm_labels': [-100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  33713,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100],
 'mlm_tag_count': 1}