In [1]:
import logging
import warnings

In [2]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
warnings.filterwarnings('ignore')

In [3]:
logger.info("start")

INFO:__main__:start


In [4]:
from datasets import load_dataset

INFO:datasets:PyTorch version 2.7.1 available.


In [5]:
dataset = load_dataset('jfleg')

In [6]:
dataset["validation"]["sentence"][0]

'So I think we can not live if old people could not find siences and tecnologies and they did not developped . '

In [7]:
dataset["validation"]["corrections"][0]

['So I think we would not be alive if our ancestors did not develop sciences and technologies . ',
 'So I think we could not live if older people did not develop science and technologies . ',
 'So I think we can not live if old people could not find science and technologies and they did not develop . ',
 'So I think we can not live if old people can not find the science and technology that has not been developed . ']

In [8]:
len(dataset["validation"])

755

In [9]:
dataset

DatasetDict({
    validation: Dataset({
        features: ['sentence', 'corrections'],
        num_rows: 755
    })
    test: Dataset({
        features: ['sentence', 'corrections'],
        num_rows: 748
    })
})

In [10]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
# MODELNAME = "vennify/t5-base-grammar-correction"
MODELNAME = "t5-small"
PREFIX = "grammar: "
tokenizer = AutoTokenizer.from_pretrained(MODELNAME)
model = T5ForConditionalGeneration.from_pretrained(MODELNAME)

In [11]:
# Define the input text with a task prefix
input_text = PREFIX + "he go to school yesterday."

# Tokenize the input
input_ids = tokenizer.encode(
    input_text, return_tensors="pt", max_length=128, truncation=True)

# Generate output (corrected text)
output = model.generate(
    input_ids,
    max_length=128,
    num_beams=5,  # Beam search for better quality
    early_stopping=True,
    repetition_penalty=2.5  # Penalize repetitive output
)

# Decode the generated text
corrected_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(f"Original: he go to school yesterday.")
print(f"Corrected: {corrected_text}")

Original: he go to school yesterday.
Corrected: grammar: he go to school yesterday.


In [12]:
tokenizer.decode(output[0], skip_special_tokens=True)

'grammar: he go to school yesterday.'

In [13]:
ds = load_dataset("dim/grammarly_coedit")

In [14]:
input_ids

tensor([[19519,    10,     3,    88,   281,    12,   496,  4981,     5,     1]])

In [15]:
type(dataset["validation"]["sentence"])

list

In [16]:
from transformers import T5Config

In [17]:
print(T5Config())

T5Config {
  "classifier_dropout": 0.0,
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 512,
  "dense_act_fn": "relu",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": false,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "num_decoder_layers": 6,
  "num_heads": 8,
  "num_layers": 6,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.52.4",
  "use_cache": true,
  "vocab_size": 32128
}



In [18]:
from torch.utils.data import DataLoader, Dataset
import torch

In [19]:
class GrammarCorrectionDataset(Dataset):
    """A PyTorch Dataset for grammar correction tasks.

    This dataset takes input text with grammatical errors and their corrected target text,
    tokenizes them using the provided tokenizer, and prepares them for training a grammar
    correction model.

    Args:
                                    tokenizer: The tokenizer to use for encoding the text
                                    input_text (list): List of input texts containing grammatical errors
                                    target_text (list): List of corresponding corrected texts
                                    max_length (int, optional): Maximum sequence length for tokenization. Defaults to 256.

    Returns:
                                    dict: Dictionary containing:
                                                                    - input_ids: Tokenized and padded input text
                                                                    - attention_mask: Attention mask for input text
                                                                    - labels: Target labels for training (-100 for padding tokens)
    """

    def __init__(self, tokenizer, input_text, target_text, max_length=256):
        self.tokenizer = tokenizer
        self.input_text = input_text
        self.target_text = target_text
        self.max_length = max_length

    def __len__(self):
        return len(self.input_text)

    def __getitem__(self, id):
        source = str("grammar: " + self.input_text[id])
        target = str(self.target_text[id])

        source_tokens = self.tokenizer(
            source, return_tensors="pt", max_length=self.max_length,
            padding="max_length", truncation=True, return_attention_mask=True)

        target_tokens = self.tokenizer(
            target, return_tensors="pt", max_length=self.max_length,
            padding="max_length", truncation=True, return_attention_mask=True)

        labels = target_tokens["input_ids"].clone()
        labels[labels == tokenizer.pad_token_id] = -100

        return {
            "input_ids": source_tokens["input_ids"].flatten(),
            "attention_mask": source_tokens["attention_mask"].flatten(),
            "labels": labels.flatten(),
        }

In [20]:
dataset

DatasetDict({
    validation: Dataset({
        features: ['sentence', 'corrections'],
        num_rows: 755
    })
    test: Dataset({
        features: ['sentence', 'corrections'],
        num_rows: 748
    })
})

In [21]:
train_source = dataset["validation"]["sentence"]
train_labels = dataset["validation"]["corrections"]

val_source = dataset["test"]["sentence"]
val_labels = dataset["test"]["corrections"]

In [22]:
train_dataset = GrammarCorrectionDataset(tokenizer, train_source, train_labels)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [23]:
val_dataset = GrammarCorrectionDataset(tokenizer, val_source, val_labels)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=True)

In [24]:
MODELNAME

't5-small'

In [25]:
from transformers import TrainingArguments, Trainer

In [26]:

train_loader

<torch.utils.data.dataloader.DataLoader at 0x1e528926900>

In [27]:
for i, data in enumerate(train_loader):
    print(data)
    break

{'input_ids': tensor([[19519,    10,    86,    82,  3474,     3,     6,    27,   317, 20356,
            33,   143,   494,  1727,   231,   394,   145,    79,   310,    33,
             3,     5,     1,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [28]:
dataset["validation"]["sentence"][0]  # type: ignore

'So I think we can not live if old people could not find siences and tecnologies and they did not developped . '

In [29]:
tokenizer(
    # type: ignore
    str("grammar: " + dataset["validation"]["sentence"][0]),
    max_length=128, padding="max_length", truncation=True, return_tensors="pt",
    return_attention_mask=True)

{'input_ids': tensor([[19519,    10,   264,    27,   317,    62,    54,    59,   619,     3,
            99,   625,   151,   228,    59,   253,   108,  1433,     7,    11,
             3,  5822,    29,  4137,     7,    11,    79,   410,    59,  1344,
          3138,     3,     5,     1,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,  

In [30]:
text = "ArithmeticError"

In [31]:
# Define the input text with a task prefix
input_text = PREFIX + "he go to school yesterday."

# Tokenize the input
input_ids = tokenizer.encode(
    input_text, return_tensors="pt", max_length=128, truncation=True)

# Generate output (corrected text)
output = model.generate(
    input_ids,
    max_length=128,
    num_beams=5,  # Beam search for better quality
    early_stopping=True,
    repetition_penalty=2.5  # Penalize repetitive output
)

# Decode the generated text
corrected_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(f"Original: he go to school yesterday.")
print(f"Corrected: {corrected_text}")

Original: he go to school yesterday.
Corrected: grammar: he go to school yesterday.


In [32]:
wi_dataset = load_dataset("wi_locness", "wi", trust_remote_code=True)

In [33]:
wi_dataset["train"]

Dataset({
    features: ['id', 'userid', 'cefr', 'text', 'edits'],
    num_rows: 3000
})

In [34]:
wi_dataset["train"]["text"][0]

'My town is a medium size city with eighty thousand inhabitants. It has a high density population because its small territory. Despite of it is an industrial city, there are many shops and department stores.  I recommend visiting the artificial lake in the certer of the city which is surrounded by a park. Pasteries are very common and most of them offer the special dessert from the city. There are a comercial zone along the widest street of the city where you can find all kind of establishments: banks, bars, chemists, cinemas, pet shops, restaurants, fast food restaurants, groceries, travel agencies, supermarkets and others. Most of the shops have sales and offers at least three months of the year: January, June and August. The quality of the products and services are quite good, because there are a huge competition, however I suggest you taking care about some fakes or cheats.'

In [35]:
wi_dataset["train"]["edits"][0].keys()

dict_keys(['start', 'end', 'text'])

In [36]:
starts = wi_dataset["train"]["edits"][0]["start"]

In [37]:
ends = wi_dataset["train"]["edits"][0]["end"]

In [38]:
wi_dataset["train"]["text"][0][starts[0]:ends[0]]

'medium size'

In [39]:
wi_dataset["train"]["edits"][0]["text"][0]

'medium-sized'

In [40]:
text = wi_dataset["train"]["text"][0]
edits = wi_dataset["train"]["edits"][0]

In [41]:
edits

{'start': [13,
  77,
  104,
  126,
  134,
  256,
  306,
  375,
  396,
  402,
  476,
  484,
  579,
  671,
  774,
  804,
  808,
  826,
  838,
  850,
  857,
  862,
  868],
 'end': [24,
  78,
  104,
  133,
  136,
  262,
  315,
  379,
  399,
  411,
  480,
  498,
  588,
  671,
  777,
  807,
  810,
  835,
  845,
  856,
  861,
  867,
  873],
 'text': ['medium-sized',
  '-',
  ' of',
  'Although',
  '',
  'center',
  None,
  'of',
  'is',
  'commercial',
  'kinds',
  'businesses',
  'grocers',
  ' in',
  'is',
  'is',
  '',
  '. However,',
  'recommend',
  'be',
  'careful',
  'of',
  '']}

In [42]:
edits_list = list(zip(edits["start"], edits["end"], edits["text"]))
edits_list.sort(key=lambda x: x[0], reverse=True)

str = text
for start, end, replacement in edits_list:
    print(replacement)
    if replacement == None:
        replacement = ""
    str = str[:start] + replacement + str[end:]


of
careful
be
recommend
. However,

is
is
 in
grocers
businesses
kinds
commercial
is
of
None
center

Although
 of
-
medium-sized


In [43]:
str

'My town is a medium-sized city with eighty thousand inhabitants. It has a high-density population because of its small territory. Although  it is an industrial city, there are many shops and department stores.  I recommend visiting the artificial lake in the center of the city which is surrounded by a park.  are very common and most of them offer the special dessert of the city. There is a commercial zone along the widest street of the city where you can find all kinds of businesses: banks, bars, chemists, cinemas, pet shops, restaurants, fast food restaurants, grocers, travel agencies, supermarkets and others. Most of the shops have sales and offers in at least three months of the year: January, June and August. The quality of the products and services is quite good, because there is huge competition. However, I recommend you be careful of fakes or cheats.'

In [44]:
edits_list.sort(key=lambda x: x[0], reverse=True)
edits_list

[(868, 873, ''),
 (862, 867, 'of'),
 (857, 861, 'careful'),
 (850, 856, 'be'),
 (838, 845, 'recommend'),
 (826, 835, '. However,'),
 (808, 810, ''),
 (804, 807, 'is'),
 (774, 777, 'is'),
 (671, 671, ' in'),
 (579, 588, 'grocers'),
 (484, 498, 'businesses'),
 (476, 480, 'kinds'),
 (402, 411, 'commercial'),
 (396, 399, 'is'),
 (375, 379, 'of'),
 (306, 315, None),
 (256, 262, 'center'),
 (134, 136, ''),
 (126, 133, 'Although'),
 (104, 104, ' of'),
 (77, 78, '-'),
 (13, 24, 'medium-sized')]

In [45]:
wi_dataset.keys()

dict_keys(['train', 'validation'])

In [46]:
wi_ds = {}
for sets in wi_dataset.keys():
    for features in wi_dataset[sets]:
        incorrect_text = ""
        correct_text = ""
        if "text" in features and "edits" in features:
            incorrect_text = features["text"]  # type: ignore
            # print(incorrect_text)
            correct_text = features["text"]

            correct_list = features["edits"]

            edits_list = list(zip(correct_list["start"],
                                  correct_list["end"], correct_list["text"]))
            edits_list.sort(key=lambda x: x[0], reverse=True)

            for start, end, replacement in edits_list:
                if replacement == None:
                    replacement = ""
                correct_text = correct_text[:start] + \
                    replacement + correct_text[end:]

        if sets not in wi_ds:
            wi_ds[sets] = {"incorrect_text": [], "correct_text": []}

        if incorrect_text != correct_text:
            wi_ds[sets]["incorrect_text"].append(incorrect_text)
            wi_ds[sets]["correct_text"].append(correct_text)        # break

In [47]:
wi_ds["train"]["correct_text"][0]

'My town is a medium-sized city with eighty thousand inhabitants. It has a high-density population because of its small territory. Although  it is an industrial city, there are many shops and department stores.  I recommend visiting the artificial lake in the center of the city which is surrounded by a park.  are very common and most of them offer the special dessert of the city. There is a commercial zone along the widest street of the city where you can find all kinds of businesses: banks, bars, chemists, cinemas, pet shops, restaurants, fast food restaurants, grocers, travel agencies, supermarkets and others. Most of the shops have sales and offers in at least three months of the year: January, June and August. The quality of the products and services is quite good, because there is huge competition. However, I recommend you be careful of fakes or cheats.'

In [48]:
import numpy as np

In [49]:
np.max([len(wi_ds["train"]["incorrect_text"][i].split())
        for i in range(len(wi_ds["train"]["incorrect_text"]))])

np.int64(1551)

In [50]:
wi_ds.keys()

dict_keys(['train', 'validation'])

In [51]:
paws = load_dataset("paws", "labeled_final")

In [52]:
paraphrases = paws.filter(lambda x: x['label'] == 1)

In [53]:
paraphrases

DatasetDict({
    train: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 21829
    })
    test: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 3536
    })
    validation: Dataset({
        features: ['id', 'sentence1', 'sentence2', 'label'],
        num_rows: 3539
    })
})

In [54]:
# Setup comprehensive logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

In [55]:
def verify_environment():
    """Verify computational environment and requirements"""
    logger.info("🔍 ENVIRONMENT VERIFICATION")
    logger.info("=" * 50)
    logger.info(f"   PyTorch: {torch.__version__}")
    logger.info(f"   CUDA Available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        gpu_props = torch.cuda.get_device_properties(0)
        logger.info(f"   GPU: {torch.cuda.get_device_name()}")
        logger.info(f"   GPU Memory: {gpu_props.total_memory / 1e9:.1f} GB")
        logger.info(f"   CUDA Version: {torch.version.cuda}")
    else:
        logger.warning("   ⚠️  No GPU detected - training will be slower")

    # Check available memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        memory_allocated = torch.cuda.memory_allocated() / 1e9
        memory_reserved = torch.cuda.memory_reserved() / 1e9
        logger.info(f"   GPU Memory Used: {memory_allocated:.1f} GB")
        logger.info(f"   GPU Memory Reserved: {memory_reserved:.1f} GB")

    logger.info("✅ Environment verification complete!")
    logger.info("=" * 50)


verify_environment()

INFO:__main__:🔍 ENVIRONMENT VERIFICATION
INFO:__main__:   PyTorch: 2.7.1+cpu
INFO:__main__:   CUDA Available: False
INFO:__main__:✅ Environment verification complete!


In [56]:
train_source = ["grammar: " +
                sentence for sentence in dataset["validation"]["sentence"]]
train_target = [correction[0]
                for correction in dataset["validation"]["corrections"]]

In [57]:
source_tokens = tokenizer(train_source, max_length=256,
                          truncation=True, padding=False)
target_tokens = tokenizer(train_target, max_length=256,
                          truncation=True, padding=False)

In [58]:
token_dataset = source_tokens
token_dataset["labels"] = target_tokens["input_ids"]

In [59]:
token_dataset.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [60]:
dataset.column_names

{'validation': ['sentence', 'corrections'],
 'test': ['sentence', 'corrections']}

In [61]:
def preprocess(sources, target, max_length=256):
    dataset = {}
    source = [f"grammar: {sentence}" for sentence in sources]
    targets = [correction[0]
               for correction in target]

    source_tokens = tokenizer(source, max_length=max_length,
                              truncation=True, padding=False,
                              return_tensors=None)
    target_tokens = tokenizer(targets, max_length=max_length,
                              truncation=True, padding=False,
                              return_tensors=None)
    # print(target_tokens)

    dataset["input_ids"] = source_tokens["input_ids"]
    dataset["attention_mask"] = source_tokens["attention_mask"]
    dataset["labels"] = target_tokens["input_ids"]

    return dataset

In [62]:
train_ds = dataset["validation"].map(lambda ds: preprocess(dataset["validation"]["sentence"],
                                                           dataset["validation"]["corrections"]),
                                     batched=True,
                                     remove_columns=dataset["validation"].column_names)

Map:   0%|          | 0/755 [00:00<?, ? examples/s]

In [63]:
val_ds = dataset["test"].map(lambda ds: preprocess(dataset["test"]["sentence"],
                                                   dataset["test"]["corrections"]),
                             batched=True,
                             remove_columns=dataset["test"].column_names)

Map:   0%|          | 0/748 [00:00<?, ? examples/s]

In [64]:
val_ds

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 748
})

In [65]:
training_arg = TrainingArguments(
    output_dir=r"D:\MScDataScience\9.Research_Methods\Assignment\Assignment2\Checkpoints",
    # Basic setup
    overwrite_output_dir=False,          # Overwrite output directory if exists
    do_train=True,                       # Whether to run training
    do_eval=False,                       # Whether to run evaluation
    do_predict=False,                    # Whether to run predictions

    # Training hyperparameters
    num_train_epochs=1.0,                # Number of training epochs
    # Max training steps (overrides epochs if set)
    max_steps=-1,
    per_device_train_batch_size=8,       # Batch size per device during training
    per_device_eval_batch_size=8,        # Batch size per device during evaluation
    gradient_accumulation_steps=1,        # Steps to accumulate gradients

    # Learning rate and optimization
    learning_rate=5e-5,                  # Initial learning rate
    weight_decay=0.0,                    # Weight decay coefficient
    adam_beta1=0.9,                      # Beta1 for Adam optimizer
    adam_beta2=0.999,                    # Beta2 for Adam optimizer
    adam_epsilon=1e-8,                   # Epsilon for Adam optimizer
    max_grad_norm=1.0,                   # Max gradient norm for clipping

    # Learning rate scheduling
    lr_scheduler_type="linear",          # Type of LR scheduler
    warmup_ratio=0.0,                    # Ratio of warmup steps
    warmup_steps=0,
)

In [66]:
from transformers import DataCollatorForSeq2Seq

In [67]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

In [68]:
trainer = Trainer(
    model=model,                        # The model to train/evaluate/predict
    args=training_arg,                         # TrainingArguments instance
    data_collator=data_collator,                # Function to collate batch data
    train_dataset=train_ds,                # Training dataset
    eval_dataset=val_ds,                 # Evaluation dataset
    tokenizer=tokenizer,                    # Tokenizer for the model
)

In [69]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss


TrainOutput(global_step=95, training_loss=1.1725575497275904, metrics={'train_runtime': 146.0693, 'train_samples_per_second': 5.169, 'train_steps_per_second': 0.65, 'total_flos': 9238399647744.0, 'train_loss': 1.1725575497275904, 'epoch': 1.0})

In [70]:
dataset["validation"][0]

{'sentence': 'So I think we can not live if old people could not find siences and tecnologies and they did not developped . ',
 'corrections': ['So I think we would not be alive if our ancestors did not develop sciences and technologies . ',
  'So I think we could not live if older people did not develop science and technologies . ',
  'So I think we can not live if old people could not find science and technologies and they did not develop . ',
  'So I think we can not live if old people can not find the science and technology that has not been developed . ']}

In [71]:
import re

In [72]:
augmented_data = []
punc = re.compile(r'\s+([.!?,:;])')
for items in dataset["validation"]:
    print(type(items))
    source = punc.sub(r"\1", items["sentence"])
    print(source)
    targets = items["corrections"]
    for correction in targets:
        print(re.sub(punc, r"\1", correction))
        if correction.strip():
            augmented_data.append({
                "sentence": f"grammar: {source}",
                "correction": correction,
                "original_sentence": source,
                "all_corrections": targets
            })
    break

<class 'dict'>
So I think we can not live if old people could not find siences and tecnologies and they did not developped. 
So I think we would not be alive if our ancestors did not develop sciences and technologies. 
So I think we could not live if older people did not develop science and technologies. 
So I think we can not live if old people could not find science and technologies and they did not develop. 
So I think we can not live if old people can not find the science and technology that has not been developed. 


In [73]:
from typing import Dict, List, Tuple, Optional
from sklearn.model_selection import train_test_split
import re

In [74]:
split_number = r'\b(\d+(?:\s+\d+)+)\b'

In [75]:
class JFLEGDataset:
    """
    A comprehensive dataset processor for JFLEG (JHU FLuency-Extended GUG) grammar correction data.

    This class handles the complete pipeline for preparing JFLEG data for T5-based grammar correction
    training, including data loading, preprocessing, augmentation, tokenization, and train/validation/test
    splitting. The JFLEG dataset contains 1,511 examples with 4 human-written corrections each, focusing
    on fluency improvements rather than minimal edits.

    Key Features:
            - Comprehensive text preprocessing to handle formatting issues
            - Data augmentation using all 4 JFLEG corrections per sentence
            - Proper tokenization for T5 sequence-to-sequence training
            - Train/validation/test splitting with preserved evaluation metadata
            - Temperature-scaled mixing support for multi-task learning

    Dataset Sources:
            - Training: JFLEG validation split with 4x augmentation (~6,044 examples)
            - Validation/Test: JFLEG test split without augmentation, then split 90%/10%

    Attributes:
            tokenizer (T5Tokenizer): T5 tokenizer for text processing
            max_length (int): Maximum sequence length for tokenization
            test_split_ratio (float): Proportion of validation data to use for testing
            train_data (Dataset): JFLEG validation split used for training
            validation_data (Dataset): JFLEG test split used for validation/testing

    Example:
            >>> from transformers import T5Tokenizer
            >>> tokenizer = T5Tokenizer.from_pretrained("t5-base")
            >>> dataset = JFLEGDataset(tokenizer, max_length=256, test_split_ratio=0.10)
            >>> train_data, val_data, test_data = dataset.create_train_val_test_datasets()

    References:
            - JFLEG Paper: Napoles et al. (2017) "JFLEG: A Fluency Corpus and Benchmark 
              for Grammatical Error Correction"
            - Dataset: https://huggingface.co/datasets/jhu-clsp/jfleg
    """

    def __init__(self, tokenizer, max_length=256, test_split_ratio=0.10):
        """
        Initialize the JFLEG dataset processor with specified configuration.

        Sets up the dataset processor with the provided tokenizer and configuration
        parameters, then loads the raw JFLEG datasets for subsequent processing.

        Args:
                tokenizer (T5Tokenizer): HuggingFace T5 tokenizer instance for text processing.
                        Must be a properly initialized T5 tokenizer (e.g., from t5-base).
                max_length (int, optional): Maximum sequence length for tokenization. 
                        Sequences longer than this will be truncated. Defaults to 256.
                        Recommended range: 128-512 depending on GPU memory constraints.
                test_split_ratio (float, optional): Proportion of validation data to reserve 
                        for final testing. Must be between 0.0 and 1.0. Defaults to 0.10 (10%).
                        The remaining validation data will be used for model validation during training.

        Raises:
                ValueError: If test_split_ratio is not between 0.0 and 1.0
                TypeError: If tokenizer is not a valid T5Tokenizer instance

        Note:
                The JFLEG dataset splits are used as follows:
                - JFLEG 'validation' split → Training data (with augmentation)
                - JFLEG 'test' split → Validation and test data (split according to test_split_ratio)

                This approach follows standard practice since JFLEG's validation split is larger
                and more suitable for training, while the test split is reserved for evaluation.
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.test_split_ratio = test_split_ratio

        # Validate test_split_ratio
        if not 0.0 <= test_split_ratio <= 1.0:
            raise ValueError(
                f"test_split_ratio must be between 0.0 and 1.0, got {test_split_ratio}")

        # Load the JFLEG datasets
        print(f"[INFO] Initializing JFLEG Dataset Processor...")
        print(
            f"[INFO] Max length: {max_length}, Test split ratio: {test_split_ratio:.1%}")

        self.train_data = load_dataset("jfleg", split="validation")
        self.validation_data = load_dataset("jfleg", split="test")

        print(
            f"[INFO] Loaded JFLEG validation split: {len(self.train_data)} examples")
        print(
            f"[INFO] Loaded JFLEG test split: {len(self.validation_data)} examples")

    def _preprocess(self, text):
        """
        Preprocess and normalize text by fixing common formatting issues.

        This method performs comprehensive text cleaning to handle poorly formatted
        text, such as OCR output or text with inconsistent spacing. It fixes issues
        with numbers, punctuation, quotes, and whitespace normalization.

        Args:
                        text (str): The input text to preprocess. Can be None or empty string.

        Returns:
                        str: The preprocessed and normalized text, or the original input if
                                        it's not a valid string.

        Transformations performed:
                        - Removes multiple consecutive dashes (-- → "")
                        - Fixes decimal formatting (0 . 1 → 0.1)
                        - Fixes fraction formatting (1 / 2 → 1/2)
                        - Removes leading zeros in decimals (00.5 → 0.5)
                        - Joins split numbers (1 2 3 4 → 1234)
                        - Fixes punctuation spacing (word , → word,)
                        - Normalizes quote spacing (" word " → "word")
                        - Collapses multiple spaces to single spaces
                        - Strips leading and trailing whitespace
        """

        # if not text or not isinstance(text, str):
        #     return text

        # Step 1: Remove unwanted characters (double dashes, etc.)
        text = re.sub(r"-{2,}", "", text)

        # Step 2: Fix decimal numbers (0 . 1 → 0.1)
        text = re.sub(r"(\d+)\s+\.\s+(\d+)", r"\1.\2", text)

        # Step 3: Fix fractions (1 / 2 → 1/2)
        text = re.sub(r"(\d+)\s+/\s+(\d+)", r"\1/\2", text)

        # Step 4: Fix leading zeros in decimals (00 . 5 → 0.5)
        text = re.sub(r"\b0+(\d+)\.(\d+)", r"\1.\2", text)

        # Step 5: Split number handling (any length)
        text = re.sub(r"\b(\d+(?:\s+\d+)+)\b",
                      lambda m: m.group(1).replace(" ", ""), text)

        # Step 6: Fix punctuation spacing (, . ! ? : ;)
        text = re.sub(r"\s+([,.!?:;])", r"\1", text)

        # Step 7: Fix double quote spacing
        text = re.sub(r'\s+"', '"', text)  # Remove space before quote
        text = re.sub(r'"\s+', '"', text)  # Remove space after quote

        # Step 8: Normalize multiple spaces to single space
        text = re.sub(r"\s{2,}", " ", text)

        # Step 9: Remove leading/trailing spaces
        text = text.strip()

        return text

    def _apply_augmentation(self, data, augment=True):
        """
        Apply data augmentation to JFLEG dataset using all available corrections.

        This function processes JFLEG examples to create augmented data by utilizing
        all 4 human-written corrections per sentence. Each original sentence is paired with
        each of its corrections to create multiple training examples, significantly increasing
        the dataset size and providing the model with diverse correction targets.

        Args:
                        data (List[Dict]): List of JFLEG dataset examples, where each example contains:
                                        - 'sentence' (str): Original grammatically incorrect sentence
                                        - 'corrections' (List[str]): List of 4 human-written corrections
                        augment (bool, optional): Whether to use all corrections for augmentation.
                                        - If True: Creates 4 examples per input (uses all corrections)
                                        - If False: Creates 1 example per input (uses only first correction)
                                        Default is True.

        Returns:
                        List[Dict]: Augmented dataset where each dictionary contains:
                                        - 'input' (str): Preprocessed input with "grammar: " prefix
                                        - 'target' (str): Preprocessed target correction
                                        - 'processed_sentence' (str): Preprocessed original sentence
                                        - 'processed_corrections' (List[str]): All 4 preprocessed corrections for evaluation
                                        - 'raw_original' (str): Unprocessed original sentence (for debugging)
                                        - 'raw_corrections' (List[str]): Unprocessed corrections (for debugging)
        """
        # storage for augmented data
        augmented_data = []
        for items in data:
            # getting original sentence -- incorrect
            original_sentence = items["sentence"]
            # formatting the incorrect sentence
            processed_sentence = self._preprocess(original_sentence)

            # getting all the original corrected sentences
            corrections = items["corrections"]

            # formatting all the corrected sentences -- evaluation
            processed_corrections = []
            # looping over all 4 corrections
            for correction in corrections:
                if correction.strip():  # Skip empty corrections
                    # storing all the processed corrections
                    processed_corrections.append(self._preprocess(correction))

            # looping over processed corrections
            for processed_correction in processed_corrections:
                # creating a dataset
                augmented_data.append({
                    "input": f"grammar: {processed_sentence}",
                    "target": processed_correction,
                    "processed_sentence": processed_sentence,
                    "processed_corrections": processed_corrections,
                    "raw_original": original_sentence,
                    "raw_corrections": corrections
                })
                # checking if to augment or not
                if not augment:
                    break
        # displaying the length of data
        print("[INFO] Length of Dataset is: ", len(augmented_data))
        return augmented_data

    def _apply_tokenization(self, data):
        """
        Apply tokenization to preprocessed JFLEG dataset examples for T5 model training.

        This function converts text data (input sentences and target corrections) into 
        tokenized format suitable for T5 model training. It processes both the input 
        grammar correction task and the target correction, creating the necessary 
        input_ids, attention_mask, and labels required by the HuggingFace Trainer.

        Args:
                        data (Dict): A single preprocessed example containing:
                        - 'input' (str): Preprocessed input text with "grammar: " prefix
                        - 'target' (str): Preprocessed target correction text
                        - 'processed_sentence' (str): Preprocessed original sentence (preserved but not tokenized)
                        - 'processed_corrections' (List[str]): All preprocessed corrections (preserved but not tokenized)
                        - 'raw_original' (str): Raw original sentence (preserved but not tokenized)
                        - 'raw_corrections' (List[str]): Raw corrections (preserved but not tokenized)


        Returns:
                        Dict: Tokenized example ready for model training containing:
                                        - 'input_ids' (List[int]): Token IDs for the input sequence
                                        - 'attention_mask' (List[int]): Attention mask for input (1 for real tokens, 0 for padding)
                                        - 'labels' (List[int]): Token IDs for the target sequence (used for loss computation)

        Tokenization Settings:
                        - max_length (int): Maximum sequence length (defined by self.max_length)
                        - truncation (bool): True - truncates sequences longer than max_length
                        - padding (bool): False - no padding applied (Trainer handles dynamic padding)
                        - return_tensors: None - returns Python lists instead of PyTorch tensors
        """
        # tokenizing the input of the dataset
        input_encodings = self.tokenizer(data["input"],
                                         max_length=self.max_length,
                                         truncation=True,
                                         padding=False,  # trainer handles the dynamic padding
                                         return_tensors=None)  # returns lists not tensor
        # tokenizing the target of the dataset
        target_encodings = self.tokenizer(data["target"],
                                          max_length=self.max_length,
                                          truncation=True,
                                          padding=False,  # trainer handles the dynamic padding
                                          return_tensors=None)  # returns lists not tensor

        return {
            "input_ids": input_encodings["input_ids"],
            "attention_mask": input_encodings["attention_mask"],
            "labels": target_encodings["input_ids"]
        }

    def create_train_val_test_datasets(self):
        """
        Create training, validation, and test datasets with proper augmentation and tokenization.

        This function orchestrates the complete data processing pipeline for JFLEG grammar 
        correction training. It applies data augmentation, converts to HuggingFace datasets,
        applies tokenization, and splits the data into appropriate train/validation/test sets
        while preserving essential evaluation metadata.

        Processing Pipeline:
                        1. Apply augmentation to training data (4x expansion using all corrections)
                        2. Apply augmentation to validation data (no expansion, uses first correction only)
                        3. Convert Python lists to HuggingFace Datasets
                        4. Apply tokenization using .map() for efficiency
                        5. Split validation data into validation and test sets (90%/10%)
                        6. Preserve evaluation metadata for proper GLEU scoring

        Data Sources:
                        - Training: JFLEG validation split with 4x augmentation (~6,044 examples)
                        - Validation/Test: JFLEG test split without augmentation, then split 90%/10%

        Returns:
                        Tuple[Dataset, List[Dict], List[Dict]]: A tuple containing:
                                        - train_dataset (Dataset): HuggingFace Dataset with tokenized training examples
                                        - val_data (List[Dict]): List of tokenized validation examples with metadata
                                        - test_data (List[Dict]): List of tokenized test examples with metadata

        Data Augmentation Strategy:
                        - Training: augment=True (uses all 4 JFLEG corrections per sentence)
                        - Validation: augment=False (uses only first correction per sentence)
        """

        from datasets import Dataset
        from sklearn.model_selection import train_test_split

        print("[INFO] Creating datasets with augmentation and tokenization...")

        # Step 1: Apply augmentation (returns Python lists)
        print("[INFO] Applying augmentation to training data...")
        train_augmented_list = self._apply_augmentation(
            self.train_data, augment=True)

        print("[INFO] Applying augmentation to validation data...")
        val_augmented_list = self._apply_augmentation(
            self.validation_data, augment=False)

        # Step 2: Convert Python lists to HuggingFace Datasets
        train_augmented_data = Dataset.from_list(train_augmented_list)
        val_augmented_data = Dataset.from_list(val_augmented_list)

        # Step 3: Apply tokenization using map
        print("\n[INFO] Tokenizing training data...")
        train_augmented_map_data = train_augmented_data.map(
            lambda example: self._apply_tokenization(example),
            batched=False,
            remove_columns=["input", "target"],
            desc="Tokenizing Training Data"
        )

        print("[INFO] Tokenizing validation data...")
        val_augmented_map_data = val_augmented_data.map(
            lambda example: self._apply_tokenization(example),
            batched=False,
            remove_columns=["input", "target"],
            desc="Tokenizing Validation Data"
        )

        # Step 4: Split validation dataset into validation and test sets
        print(
            f"\n[INFO] Splitting Validation Data ({100-self.test_split_ratio*100:.0f}%/{self.test_split_ratio*100:.0f}%)...")
        val_data, test_data = train_test_split(
            list(val_augmented_map_data),
            test_size=self.test_split_ratio,
            random_state=42
        )

        # Convert Python validation and test lists to HuggingFace Datasets
        val_data = Dataset.from_list(val_data)
        test_data = Dataset.from_list(test_data)

        # Summary
        print(f"\nDataset Creation Complete:")
        print(f"\t[INFO] Training Dataset:   {len(train_augmented_map_data)}")
        print(f"\t[INFO] Validation Dataset: {len(val_data)}")
        print(f"\t[INFO] Test Dataset:       {len(test_data)}")

        return train_augmented_map_data, val_data, test_data

In [76]:
dataset = JFLEGDataset(tokenizer)

[INFO] Initializing JFLEG Dataset Processor...
[INFO] Max length: 256, Test split ratio: 10.0%


[INFO] Loaded JFLEG validation split: 755 examples
[INFO] Loaded JFLEG test split: 748 examples


In [77]:
train, val, test = dataset.create_train_val_test_datasets()

[INFO] Creating datasets with augmentation and tokenization...
[INFO] Applying augmentation to training data...
[INFO] Length of Dataset is:  3016
[INFO] Applying augmentation to validation data...
[INFO] Length of Dataset is:  747

[INFO] Tokenizing training data...


Tokenizing Training Data:   0%|          | 0/3016 [00:00<?, ? examples/s]

[INFO] Tokenizing validation data...


Tokenizing Validation Data:   0%|          | 0/747 [00:00<?, ? examples/s]


[INFO] Splitting Validation Data (90%/10%)...

Dataset Creation Complete:
	[INFO] Training Dataset:   3016
	[INFO] Validation Dataset: 672
	[INFO] Test Dataset:       75


In [78]:
train

Dataset({
    features: ['processed_sentence', 'processed_corrections', 'raw_original', 'raw_corrections', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 3016
})

In [79]:
val

Dataset({
    features: ['processed_sentence', 'processed_corrections', 'raw_original', 'raw_corrections', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 672
})

In [80]:
test

Dataset({
    features: ['processed_sentence', 'processed_corrections', 'raw_original', 'raw_corrections', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 75
})

In [81]:
logger.info("this is a test for logger")

INFO:__main__:this is a test for logger


In [82]:
from evaluate import load
import nltk
from nltk.translate.gleu_score import sentence_gleu, corpus_gleu

In [83]:
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

In [99]:
from collections import Counter

In [109]:
class GrammarEvaluation:
    """
    Comprehensive Grammar Correction Evaluation Framework.

    This class provides a complete evaluation suite for grammar correction systems,
    implementing industry-standard metrics specifically designed for assessing
    grammatical error correction quality. It combines fluency assessment (GLEU),
    semantic preservation (BERTScore), linguistic quality (METEOR), and comprehensive
    text statistics.

    The evaluation framework is designed for transformer-based models like T5, BERT,
    and other sequence-to-sequence architectures fine-tuned on datasets such as JFLEG,
    BEA-2019, or custom grammar correction corpora.

    Attributes:
        bertscore: HuggingFace BERTScore evaluator for semantic similarity
        meteor: HuggingFace METEOR evaluator for linguistic quality
        metrics_result (Dict): Storage for all computed evaluation metrics

    Performance Benchmarks:
        - GLEU: >0.50 acceptable, >0.55 good, >0.60 excellent
        - BERTScore F1: >0.75 acceptable, >0.80 good, >0.85 excellent
        - METEOR: >0.40 acceptable, >0.45 good, >0.55 excellent

    References:
        - GLEU: Napoles et al. (2017) "JFLEG: A fluency corpus and benchmark"
        - BERTScore: Zhang et al. (2020) "BERTScore: Evaluating Text Generation with BERT"
        - METEOR: Banerjee & Lavie (2005) "METEOR: An automatic metric for MT evaluation"
    """

    def __init__(self):
        """
        Initialize the Grammar Evaluation framework.

        Sets up evaluation metrics and initializes the results storage structure.
        Loads pre-trained models for BERTScore and METEOR evaluation from HuggingFace.

        Initializes:
            - BERTScore evaluator with microsoft/deberta-xlarge-mnli model
            - METEOR evaluator with default configuration
            - Results dictionary with nested structure for all metrics

        Raises:
            ImportError: If required packages (pandas) are not installed

        Note:
            First initialization may take time to download evaluation models.
            Internet connection required for downloading pre-trained models.
        """
        # Load HuggingFace evaluation metrics
        self.bertscore = load("bertscore")  # Semantic similarity evaluation
        self.meteor = load('meteor')        # Linguistic quality evaluation

        # Initialize results storage with hierarchical structure
        self.metrics_result = {
            "bertscore": {              # Semantic preservation metrics
                "precision": 0.0,       # BERTScore precision
                "recall": 0.0,          # BERTScore recall
                # BERTScore F1 (primary semantic metric)
                "f1": 0.0
            },
            "meteor": 0.0,              # Linguistic quality score
            "gleu": 0.0,                # Primary grammar correction metric
            "stats": {}                 # Comprehensive text statistics
        }

    def compute_gleu(self, predictions, references):
        """
        Compute GLEU (Generalized Language Evaluation Understanding) scores for grammar correction.

        GLEU is specifically designed for grammar correction evaluation and handles multiple
        reference corrections better than traditional BLEU. It measures fluency improvement
        while accounting for acceptable variation in correction approaches.

        Args:
            predictions (List[str]): List of model-generated grammar corrections.
            references (List[List[str]]): List of reference correction lists. Each inner list
                contains multiple valid corrections for the same source sentence (e.g., JFLEG
                provides 4 human corrections per sentence).

        Returns:
            None: Stores the average GLEU score in self.metrics_result["gleu"].

        Notes:
            - Uses NLTK's sentence_gleu function for computation
            - Applies lowercase normalization and word tokenization
            - Handles empty or invalid references gracefully with exception handling
            - Scores range from 0.0 (no match) to 1.0 (perfect match)
            - For grammar correction: >0.50 acceptable, >0.55 good, >0.60 excellent

        Raises:
            Exception: Catches and handles any GLEU computation errors by assigning 0.0 score.
        """

        # storage for gleu score
        gleu_scores = []

        # looping over all predictions and references
        for preds, refs in zip(predictions, references):
            # converting predictions to tokens
            preds_tokens = nltk.word_tokenize(preds.lower())
            # converting all reference to tokens
            refs_tokens = [nltk.word_tokenize(ref.lower())
                           for ref in refs if ref.strip()]

            try:
                # computing the score for the gleu
                score = sentence_gleu(refs_tokens, preds_tokens)
                # updating the storage
                gleu_scores.append(score)
            except Exception:
                gleu_scores.append(0.0)

        # computing the average gleu
        self.metrics_result["gleu"] = np.mean(gleu_scores)

    def compute_bertscore(self, predictions, references):
        """
        Compute BERTScore metrics (precision, recall, F1) for grammar correction evaluation.

        BERTScore measures semantic similarity using contextual embeddings, making it ideal
        for evaluating whether grammar corrections preserve semantic meaning while improving
        fluency. For each prediction, scores are computed against all available references
        and then averaged.

        Args:
            predictions (List[str]): List of model-generated grammar corrections.
            references (List[List[str]]): List of reference correction lists. Each inner 
                list contains multiple valid corrections for the same source sentence 
                (e.g., JFLEG provides 4 human corrections per sentence).

        Returns:
            None: Currently stores results in local variables but doesn't persist them.

        Process:
            1. For each prediction, compute BERTScore against each of its references
            2. Average precision, recall, F1 across references for that prediction
            3. Collect averaged scores across all predictions

        Notes:
            - Uses microsoft/deberta-xlarge-mnli for optimal semantic similarity detection
            - Skips empty references automatically
            - For grammar correction: F1 > 0.85 indicates excellent semantic preservation
            - Each BERTScore call processes one prediction-reference pair
        """

        # storage for bertscore metrics
        precisions = []
        recalls = []
        f1s = []

        # looping over all predictions and references
        for preds, refs in zip(predictions, references):
            # storage for per prediction against its 4 references
            precision = []
            recall = []
            f1 = []

            for ref in refs:
                if ref.strip():
                    # computing bertscore
                    score = self.bertscore.compute(predictions=[preds],
                                                   references=[ref],
                                                   lang="en",
                                                   model_type="microsoft/deberta-xlarge-mnli")
                    # updating the local storage
                    precision.append(score["precision"][0])
                    recall.append(score["recall"][0])
                    f1.append(score["f1"][0])

            # updating bertscore mertics with average
            precisions.append(np.mean(precision))
            recalls.append(np.mean(recall))
            f1s.append(np.mean(f1))

        # computing the average bertscore
        self.metrics_result["bertscore"] = {
            "precision": np.mean(precisions),
            "recall": np.mean(recalls),
            "f1": np.mean(f1s)
        }

    def compute_meteor(self, predictions, references):
        """
        Compute METEOR (Metric for Evaluation of Translation with Explicit ORdering) scores 
        for grammar correction evaluation.

        METEOR is particularly valuable for grammar correction as it incorporates:
        - Exact word matching
        - Stem matching (handles morphological variations like "running" vs "runs")
        - Synonym matching (recognizes semantically equivalent words)
        - Word order penalties

        This makes it superior to BLEU for grammar correction where morphological changes
        and lexical substitutions are common.

        Args:
            predictions (List[str]): List of model-generated grammar corrections.
            references (List[List[str]]): List of reference correction lists. Each inner 
                list contains multiple valid corrections for the same source sentence.

        Returns:
            None: Stores the average METEOR score in self.metrics_result["meteor"].

        Process:
            1. For each prediction, compute METEOR against each of its references
            2. Average METEOR scores across references for that prediction  
            3. Average across all predictions for final score

        Notes:
            - METEOR scores range from 0.0 to 1.0 (higher is better)
            - For grammar correction: >0.40 acceptable, >0.45 good, >0.55 excellent
            - Handles morphological variations better than BLEU
            - Includes recall-oriented evaluation (unlike BLEU's precision focus)
        """

        # storage for meteor metrics
        meteors = []

        # looping over all predictions and references
        for preds, refs in zip(predictions, references):
            # storage for per prediction against its 4 references
            meteor = []
            for ref in refs:
                if ref.strip():
                    # computing meteor
                    meteor.append(self.meteor.compute(predictions=[preds],
                                                      references=[ref])["meteor"])

            # updating meteor mertics with avergae
            meteors.append(np.mean(meteor))

        # computing the average meteor
        self.metrics_result["meteor"] = np.mean(meteors)

    def compute_stats(self, predictions, references):
        """
        Compute comprehensive text statistics for grammar correction evaluation.

        This method analyzes various aspects of model predictions versus reference corrections,
        providing detailed insights into model behavior patterns, text properties, and 
        vocabulary usage. All statistics are stored in self.metrics_result["stats"].

        Args:
            predictions (List[str]): List of model-generated grammar corrections.
            references (List[List[str]]): List of reference correction lists. Each inner 
                list contains multiple valid corrections for the same source sentence 
                (e.g., JFLEG provides 4 human corrections per sentence).

        Returns:
            None: All statistics are stored in self.metrics_result["stats"] dictionary.

        Statistics Computed:

            **Sample Information:**
            - num_samples: Total number of predictions evaluated
            - total_references: Total number of reference corrections across all sentences
            - avg_references_per_sentence: Average references available per sentence

            **Length Statistics (Word-level):**
            - avg/min/max/std_prediction_length: Prediction length statistics in words
            - avg/min/max/std_reference_length: Reference length statistics in words

            **Character-level Statistics:**
            - avg/std_prediction_char_length: Character count statistics for predictions
            - avg/std_reference_char_length: Character count statistics for references

            **Length Change Analysis:**
            - avg/std_length_difference: Difference between prediction and first reference lengths
            - positive_length_changes: Count where prediction > reference length (expansion)
            - negative_length_changes: Count where prediction < reference length (compression)  
            - no_length_changes: Count where prediction == reference length (preserved)

            **Vocabulary Analysis:**
            - unique_words_in_predictions: Unique word count in all predictions
            - unique_words_in_references: Unique word count in all references
            - vocab_overlap: Common words between predictions and references
            - vocab_overlap_ratio: Overlap ratio (intersection/union of vocabularies)

        Notes:
            - Length differences computed against first reference for each sentence
            - Word counting uses lowercase normalization
            - Empty references are skipped in processing
            - Vocabulary analysis helps assess model's lexical diversity
        """

        # storage for computing statistics
        pred_lengths = []                   # predictions length
        pred_char_lengths = []              # predictions char length
        all_ref_lengths = []                # reference length
        all_ref_char_lengths = []           # reference char length
        ref_counts = []                     # np. of reference per sentence
        length_diffs = []                   # word difference in preds and refs
        pred_word_counts = Counter()        # unique word counts in prediction
        ref_word_counts = Counter()         # unique word counts in reference
        # no. of len(prediction) > len(reference)
        positive_changes = 0
        # no. of len(prediction) < len(reference)
        negative_changes = 0
        # no. of len(prediction) == len(reference)
        no_changes = 0

        # looping over all predictions and references
        for pred, refs in zip(predictions, references):
            # prediction statistics
            pred_len = len(pred.split())
            pred_char_len = len(pred)
            pred_lengths.append(pred_len)
            pred_char_lengths.append(pred_char_len)

            # prediction word counts
            pred_word_counts.update(pred.lower().split())

            # reference statistics
            ref_counts.append(len(refs))

            # looping for all references for this prediction
            for ref in refs:
                if ref.strip():
                    ref_len = len(ref.split())
                    ref_char_len = len(ref)
                    all_ref_lengths.append(ref_len)
                    all_ref_char_lengths.append(ref_char_len)
                    ref_word_counts.update(ref.lower().split())

            # length difference analysis (compare with first reference)
            if refs:
                ref_len = len(refs[0].split()) if refs[0].strip() else 0
                length_diff = pred_len - ref_len
                length_diffs.append(length_diff)

                # counting changes
                if length_diff > 0:
                    positive_changes += 1
                elif length_diff < 0:
                    negative_changes += 1
                else:
                    no_changes += 1

        # updating the stats in the metrics_result...

        # Sample information (standardized naming)
        self.metrics_result["stats"]["num_samples"] = len(predictions)
        self.metrics_result["stats"]["total_references"] = sum(ref_counts)
        self.metrics_result["stats"]["avg_references_per_sentence"] = np.mean(
            ref_counts)

        # Prediction statistics
        self.metrics_result["stats"]["avg_prediction_length"] = np.mean(
            pred_lengths)
        self.metrics_result["stats"]["min_prediction_length"] = np.min(
            pred_lengths)
        self.metrics_result["stats"]["max_prediction_length"] = np.max(
            pred_lengths)
        self.metrics_result["stats"]["std_prediction_length"] = np.std(
            pred_lengths)
        self.metrics_result["stats"]["avg_prediction_char_length"] = np.mean(
            pred_char_lengths)
        self.metrics_result["stats"]["std_prediction_char_length"] = np.std(
            pred_char_lengths)

        # Reference statistics
        if all_ref_lengths:  # Handle empty case
            self.metrics_result["stats"]["avg_reference_length"] = np.mean(
                all_ref_lengths)
            self.metrics_result["stats"]["min_reference_length"] = np.min(
                all_ref_lengths)
            self.metrics_result["stats"]["max_reference_length"] = np.max(
                all_ref_lengths)
            self.metrics_result["stats"]["std_reference_length"] = np.std(
                all_ref_lengths)
            self.metrics_result["stats"]["avg_reference_char_length"] = np.mean(
                all_ref_char_lengths)
            self.metrics_result["stats"]["std_reference_char_length"] = np.std(
                all_ref_char_lengths)
        else:
            self.metrics_result["stats"]["avg_reference_length"] = 0.0
            self.metrics_result["stats"]["min_reference_length"] = 0
            self.metrics_result["stats"]["max_reference_length"] = 0
            self.metrics_result["stats"]["std_reference_length"] = 0.0
            self.metrics_result["stats"]["avg_reference_char_length"] = 0.0
            self.metrics_result["stats"]["std_reference_char_length"] = 0.0

        # Length difference statistics
        if length_diffs:
            self.metrics_result["stats"]["avg_length_difference"] = np.mean(
                length_diffs)
            self.metrics_result["stats"]["std_length_difference"] = np.std(
                length_diffs)
            # FIXED: Use pre-calculated variables instead of redundant sum() operations
            self.metrics_result["stats"]["positive_length_changes"] = positive_changes
            self.metrics_result["stats"]["negative_length_changes"] = negative_changes
            self.metrics_result["stats"]["no_length_changes"] = no_changes

        # Vocabulary statistics
        self.metrics_result["stats"]["unique_words_in_predictions"] = len(
            pred_word_counts)
        self.metrics_result["stats"]["unique_words_in_references"] = len(
            ref_word_counts)
        self.metrics_result["stats"]["vocab_overlap"] = len(
            set(pred_word_counts.keys()) & set(ref_word_counts.keys()))

        # Vocabulary overlap ratio
        if len(pred_word_counts) > 0 and len(ref_word_counts) > 0:
            self.metrics_result["stats"]["vocab_overlap_ratio"] = (
                self.metrics_result["stats"]["vocab_overlap"] /
                len(set(pred_word_counts.keys()) | set(ref_word_counts.keys()))
            )
        else:
            self.metrics_result["stats"]["vocab_overlap_ratio"] = 0.0

    def evaluate(self, predictions, references):
        """
        Perform comprehensive evaluation of grammar correction predictions.

        This is the main evaluation method that computes all metrics and statistics
        for grammar correction assessment. It provides a complete analysis including
        fluency (GLEU), semantic preservation (BERTScore), linguistic quality (METEOR),
        and comprehensive text statistics.

        Args:
            predictions (List[str]): List of model-generated grammar corrections.
            references (List[List[str]]): List of reference correction lists. Each inner 
                list contains multiple valid corrections for the same source sentence 
                (e.g., JFLEG provides 4 human corrections per sentence).

        Returns:
            Dict[str, Any]: Complete evaluation results containing:
                - "gleu": GLEU score (float)
                - "meteor": METEOR score (float) 
                - "bertscore": Dict with precision, recall, f1 scores
                - "stats": Dict with comprehensive text statistics

        Evaluation Metrics Computed:

            **Core Grammar Correction Metrics:**
            - GLEU: Primary metric for grammar correction fluency assessment
            - BERTScore: Semantic preservation evaluation (precision, recall, F1)
            - METEOR: Linguistic quality with morphological awareness

            **Comprehensive Statistics:**
            - Sample counts and reference information
            - Length statistics (words and characters)
            - Length change analysis
            - Vocabulary analysis and overlap metrics

        Performance Benchmarks:
            - GLEU: >0.50 acceptable, >0.55 good, >0.60 excellent
            - BERTScore F1: >0.75 acceptable, >0.80 good, >0.85 excellent  
            - METEOR: >0.40 acceptable, >0.45 good, >0.55 excellent

        Notes:
            - Progress information is printed during computation
            - Results are stored in self.metrics_result and returned
            - Statistics are displayed as formatted pandas DataFrame
            - All computations handle edge cases gracefully
        """
        print(f"[INFO] Evaluating {len(predictions)} Predictions...")
        # gleu score -- primary for grammar correction
        print("\t[INFO] Computing GLEU Score...")
        self.compute_gleu(predictions, references)

        # bertscore -- sementic preservation
        print("\t[INFO] Computing BERTScore--Precision Recall & F1...")
        self.compute_bertscore(predictions, references)

        # meteor -- linquistic quality
        print("\t[INFO] Computing METEOR Score...")
        self.compute_meteor(predictions, references)

        # statistics
        print("\t[INFO] Computing Comprehensive Statistics...")
        self.compute_stats(predictions, references)

        print("Evaluation Complete:")
        # printing the metrics...
        print(f"\t[INFO] GLEU: {self.metrics_result['gleu']:.4f}")
        print(f"\t[INFO] METEOR: {self.metrics_result['meteor']:.4f}")
        print("\t[INFO] BERTSCORE:")
        print(
            f"\t\t[INFO] Precision: {self.metrics_result['bertscore']['precision']:.4f}")
        print(
            f"\t\t[INFO] Recall: {self.metrics_result['bertscore']['recall']:.4f}")
        print(f"\t\t[INFO] F1: {self.metrics_result['bertscore']['f1']:.4f}")

        # printing the statistics...
        print("\n\t[INFO] Statistics:")
        try:
            import pandas as pd
            stats_df = pd.DataFrame(
                list(self.metrics_result["stats"].items()),
                columns=["Metric", "Value"]
            )
            print(stats_df.to_string(index=False))
        except ImportError:
            print("\t[WARNING] Pandas not available, printing raw statistics:")
            for key, value in self.metrics_result["stats"].items():
                print(f"\t\t{key}: {value}")

        # Return complete results for further processing
        return self.metrics_result

In [85]:
bertscore = load("bertscore")
meteor = load("meteor")

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\agarw\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\agarw\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:\Users\agarw\AppData\Roaming\nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [89]:
predictions = ["hello there"]
references = [["hello there", "hi there"]]
bertscore.compute(predictions=predictions,
                  references=[references[0][0]],
                  lang="en",
                  model_type="microsoft/deberta-xlarge-mnli")

{'precision': [1.0],
 'recall': [1.0],
 'f1': [1.0],
 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.52.4)'}

In [90]:
meteor.compute(predictions=predictions, references=[
               references[0]])

{'meteor': np.float64(0.9375)}

In [91]:
for preds, refs in zip(predictions, references):
    precision = []
    recall = []
    f1 = []
    for ref in refs:
        if ref.strip():
            score = (bertscore.compute(predictions=[preds],
                                       references=[ref],
                                       lang="en",
                                       model_type="microsoft/deberta-xlarge-mnli"))
            print(score)
            precision.append(score["precision"][0])
            recall.append(score["recall"][0])
            f1.append(score["f1"][0])
    print(np.mean(precision), np.mean(recall), np.mean(f1))

{'precision': [1.0], 'recall': [1.0], 'f1': [1.0], 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.52.4)'}
{'precision': [0.944251537322998], 'recall': [0.944251537322998], 'f1': [0.944251537322998], 'hashcode': 'microsoft/deberta-xlarge-mnli_L40_no-idf_version=0.3.12(hug_trans=4.52.4)'}
0.972125768661499 0.972125768661499 0.972125768661499


In [92]:
list(zip(predictions, references))

[('hello there', ['hello there', 'hi there'])]

In [94]:
predictions[0].split()

['hello', 'there']

In [None]:
# Reference statistics
all_ref_lengths = []
ref_counts = []

for refs in val["processed_corrections"]:
    ref_counts.append(len(refs))
    for ref in refs:
        if ref.strip():
            all_ref_lengths.append(len(ref.split()))

In [None]:
class T5LoRATrainer:
    def __init__(model_name="t5-small", ):