# Challenge related with a GEC system

# Table of contents:

- [Challenge related with a GEC system](#challenge-related-with-a-gec-system)
- [Explanation](#explanation)
- [Just Setting Up](#just-setting-up)
- [Dataset Download and parser](#dataset-downloader-and-parser)
    - [Dataset Exploration and Parsing](#some-exploration-of-the-data-and-the-parsing)
- [Preprocessing for T5 Model](#t5-preprocessor)
- [Training the T5 Model](#trainer-class)
- [Inference Engines](#inference-engines)
    - [Base Inference Engine](#base-inference-engine-and-helper)
    - [T5 Inference Engine](#t5-inference-engine)
    - [Llama 3 Inference Engine](#llama-3-inference-engine)
- [Evaluation](#evaluation)
    - [Evaluation helpers](#evaluator-class)
    - [Evaluation of the FCE test data with exact match and GLEU](#evaluating-fce-test-data)
        - [T5](#evaluating-the-test-dataset-fce-with-t5-using-gleu-and-exact-match)
        - [LLAMA3](#evaluating-test-data-fce-with-llama-3-using-gleu-and-exact-match)
    - [Evaluation of Medical data with exact match and GLEU](#evaluating-our-medical-data)
        - [T5](#evaluating-with-t5-using-exact-match-and-gleu)
        - [LLAMA](#evaluating-with-llama3-using-exact-match-and-gleu)
    - [Evaluation of ]

# Explanation

# Just Setting Up

In [None]:
!pip install datasets transformers tqdm scikit-learn sentencepiece torch

In [3]:
## Imports

import os
# import openai
from datasets import Dataset, DatasetDict
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from tqdm import tqdm
import time
import logging
from typing import List, Dict, Tuple
import os
import tarfile
import sys

In [3]:
# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove all handlers associated with the root logger object (avoid duplicate logs)
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Create handler that outputs to notebook cell
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(funcName)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

In [6]:
# constants, ideally defined in other file and imported if done in a github repo

FCE_URL = "https://www.cl.cam.ac.uk/research/nl/bea2019st/data/fce_v2.1.bea19.tar.gz"
FCE_DEV_NAME = "fce.dev.gold.bea19.m2"
FCE_TRAIN_NAME = "fce.train.gold.bea19.m2"
FCE_TEST_NAME = "fce.test.gold.bea19.m2"
FCE_DOWNLOAD_DATASET_DIR = os.path.join(os.getcwd(), "data")
FCE_DATASET_DIR = os.path.join(os.getcwd(), "data/fce/m2")
SENTENCE_TAG = "S "
ANNOTATION_TAG = "A "
NO_EDIT_TAG = "noop"
SEP = "|||"
## We're gonna use T5 small for this task because we're just correcting spelling and we don't wait to wait hours for training.
MODEL_NAME = "t5-small"
FINETUNED_MODEL_OUTPUT_DIR = "./t5_finetuned"
GENERAL_PROMPT_PATH = os.path.join(os.getcwd(), "config/prompt_general.txt")
LLAMA3_ENDPOINT = "http://127.0.0.1:11434/api/generate"
TEXT_TO_REPLACE_IN_PROMPT = "<text_to_replace>"

# Dataset Downloader and Parser

This class orchrestrates the download of the BEA19 files and parsing them to store them for reusability

In [6]:
### Dataset Loader class handling the dataset creation
import tarfile
from datasets import Dataset, DatasetDict
import requests
from typing import Optional, List, Dict


class M2DatasetLoader:
    """
    Loader for datasets in M2 format. Right now it's only tested for BCE2019

    Args:
        dataset_dir (str): Directory where the dataset is stored. Default is FCE_DATASET_DIR.
        train_file (str): Name of the training file. Default is FCE_TRAIN_NAME.
        dev_file (str): Name of the development file. Default is FCE_DEV_NAME.
        test_file (str): Name of the test file. Default is FCE_TEST_NAME.
        dataset_url (str): URL to download the dataset from. Default is FCE_URL.
        fce_download_dir (str): Directory to download the dataset to. Default is FCE_DOWNLOAD_DATASET_DIR.
    """
    def __init__(self,dataset_dir: str = FCE_DATASET_DIR, train_file: str = FCE_TRAIN_NAME,
                 dev_file: str = FCE_DEV_NAME, test_file: str = FCE_TEST_NAME, dataset_url: str = FCE_URL,
                 fce_download_dir: str = FCE_DOWNLOAD_DATASET_DIR):
        self.dataset_dir = dataset_dir
        self.fce_download_dir = fce_download_dir
        self.train_file = os.path.join(dataset_dir, train_file)
        self.dev_file = os.path.join(dataset_dir, dev_file)
        self.test_file = os.path.join(dataset_dir, test_file)
        self.url = dataset_url
        self.logger = logging.getLogger(self.__class__.__name__)
        self.dataset = None

    def download_and_extract(self) -> None:
        """
        Download the file and extract it into a directory if it does not exist
        """
        if not os.path.exists(self.fce_download_dir):
            os.makedirs(self.fce_download_dir, exist_ok=True)
            self.logger.info(f"Downloading BEA 2019 dataset...")
            tar_path = os.path.join(self.fce_download_dir, "bea19.tar.gz")
            with requests.get(self.url, stream=True) as r:
                with open(tar_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            self.logger.info("Extracting dataset...")
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=self.fce_download_dir)
            os.remove(tar_path)
            self.logger.info("BEA dataset downloaded and extracted successfully.")
        else:
            self.logger.warning("BEA dataset already exists locally.")

    def load_dataset(self) -> DatasetDict:
        """ Create the dataset in the Transformers format """
        dataset = DatasetDict({
            'train': Dataset.from_list(self._parse_m2_file(self.train_file)),
            'validation': Dataset.from_list(self._parse_m2_file(self.dev_file)),
            'test': Dataset.from_list(self._parse_m2_file(self.test_file))
        })
        self.logger.info(f"Loaded BEA dataset: {len(dataset['train'])} train, {len(dataset['validation'])} dev, {len(dataset['test'])} test")
        self.dataset = dataset
        return dataset
    
    def save_dataset(self, output_dir: str = os.path.join(os.getcwd(), FCE_DOWNLOAD_DATASET_DIR)) -> None:
        """
        Parses the M2 files and saves the resulting DatasetDict to disk in HuggingFace format.
        Args:
            output_dir (str): Directory to save the dataset.
        """
        if self.dataset is None:
            self.dataset = self.load_dataset()
        self.dataset.save_to_disk(output_dir + "parsed_fce_dataset")
        self.logger.info(f"Saved parsed dataset to {output_dir}")

    def _parse_m2_file(self, filepath: str) -> List[Dict[str, str]]:
        """
        Parses an M2 file and returns a list of {source, target} dictionaries,
        where each 'target' corresponds to one annotator's corrections.
        """
        data = []
        with open(filepath, 'r', encoding='utf-8') as f:
            lines = f.readlines()

        sentence = ""
        edits_by_annotator = dict()

        for line in lines + ['\n']:  # Add sentinel newline
            line = line.strip()
            if line.startswith(SENTENCE_TAG):
                if sentence:
                    if edits_by_annotator:
                        for annotator_id, edits in edits_by_annotator.items():
                            corrected = self._apply_m2_edits(sentence, edits)
                            data.append({'source': sentence, 'target': corrected})
                    else:
                        data.append({'source': sentence, 'target': sentence})
                sentence = line[2:]
                edits_by_annotator = dict()
            elif line.startswith(ANNOTATION_TAG):
                parts = line[2:].split(SEP)
                span = list(map(int, parts[0].split()))
                error_type = parts[1]
                correction = parts[2]
                annotator_id = int(parts[-1])
                if annotator_id not in edits_by_annotator:
                    edits_by_annotator[annotator_id] = []
                edits_by_annotator[annotator_id].append((span, correction, error_type))
            elif line == "" and sentence:
                if edits_by_annotator:
                    for annotator_id, edits in edits_by_annotator.items():
                        corrected = self._apply_m2_edits(sentence, edits)
                        data.append({'source': sentence, 'target': corrected})
                else:
                    data.append({'source': sentence, 'target': sentence})
                sentence = ""
                edits_by_annotator = dict()
        return data

    def _apply_m2_edits(self, sentence: str, edits: list):
        """
        Applies M2 format edits to the original sentence.
        :param sentence: Original sentence (string)
        :param edits: List of (span, correction, error_type)
        :return: Corrected sentence
        """
        tokens = sentence.strip().split()
        offset = 0
        for (span, correction, error_type) in edits:
            start, end = span
            if error_type == NO_EDIT_TAG or (start == -1 and end == -1):
                continue
            # Adjust indices by current offset
            start_adj = start + offset
            end_adj = end + offset
            correction_tokens = correction.strip().split() if correction.strip() else []
            tokens = tokens[:start_adj] + correction_tokens + tokens[end_adj:]
            offset += len(correction_tokens) - (end - start)
        return ' '.join(tokens)
    
    @staticmethod
    def most_common_edit_types(gold_m2_path: str, n: int = 10) -> List[Tuple[str, int]]:
        """
        Returns the most common edit types in the gold M2 file.
        :param gold_m2_path: Path to the gold M2 file.
        :param n: Number of most common edit types to return.
        :return: List of tuples (edit_type, count).
        """
        from collections import Counter
        edit_types = Counter()
        
        with open(gold_m2_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.startswith(ANNOTATION_TAG):
                    parts = line[2:].split(SEP)
                    error_type = parts[1]
                    edit_types[error_type] += 1
        
        return edit_types.most_common(n)


In [None]:
loader = M2DatasetLoader()

loader.download_and_extract()
dataset = loader.load_dataset()
loader.save_dataset(output_dir=os.getcwd() + "/data/fce/")


## Some exploration of the data and the parsing

The file has one or more line per sentence, and one or more line per annotation, grouped together by a blank line.
If there si a *noop* tag the there is no edit.


We need a set of functions to actually put this corrections into the sentence to build up a dataset with "source" and "target"

In [405]:
# Need yo have the data downloaded to explore it. Will be automatized in the next class
train_data = os.path.join(FCE_DATASET_DIR, FCE_TRAIN_NAME)


with open(train_data, 'r') as file:
    data = file.read()

print(data[:1500])

S Dear Sir or Madam ,
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||0

S I am writing in order to express my disappointment about your musical show " Over the Rainbow " .
A 9 10|||R:PREP|||with|||REQUIRED|||-NONE-|||0

S I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook .
A 1 2|||R:VERB:TENSE|||saw|||REQUIRED|||-NONE-|||0
A 8 9|||R:PREP|||on|||REQUIRED|||-NONE-|||0
A 36 37|||R:NOUN|||reviews|||REQUIRED|||-NONE-|||0
A 37 38|||R:PREP|||of|||REQUIRED|||-NONE-|||0
A 45 46|||R:PREP|||because of|||REQUIRED|||-NONE-|||0

S The problems started in the box office , where we asked for the discounts you announced in the advertisement , and the man who was selling the tickets said that they did n't exist .
A 3 4|||R:PREP|||at|||REQUIRED|||-NONE-|||0

S Moreover , the show was delayed forty -

To see what the parser is doing let's take two examples. One without edit and one with some edits.

Without edits: (First record o the file actually)

```S Dear Sir or Madam ,
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||0
```

In [406]:
print("Wrong")
print(dataset['train'][0].get("source"))
print("Corrected")
print(dataset['train'][0].get("target"))

Wrong
Dear Sir or Madam ,
Corrected
Dear Sir or Madam ,


Many edits, like third record of the file:

"I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends .
I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook ."

```
A 1 2|||R:VERB:TENSE|||saw|||REQUIRED|||-NONE-|||0
A 8 9|||R:PREP|||on|||REQUIRED|||-NONE-|||0
A 36 37|||R:NOUN|||reviews|||REQUIRED|||-NONE-|||0
A 37 38|||R:PREP|||of|||REQUIRED|||-NONE-|||0
A 45 46|||R:PREP|||because of|||REQUIRED|||-NONE-|||0
```

Explanation:

It bascially changes *saws* by *saw*, then *of* by *on*, *references* by *reviews*, again *about* by *of* and finally *about* by *because of*

In [407]:
print("Wrong")
print(dataset['train'][2].get("source"))
print("Corrected")
print(dataset['train'][2].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook .
Corrected
I saw the show 's advertisement hanging up on a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good reviews of your Company and , above all , because of the main star , Danny Brook .


Lets see another example where we delete something (I looked for it in the file)

*consequently* needs to get deleted

```
S If you do n't agree , I will act consequently .
A 9 10|||U:ADV||||||REQUIRED|||-NONE-|||0
```

In [408]:
print("Wrong")
print(dataset['train'][8].get("source"))
print("Corrected")
print(dataset['train'][8].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
If you do n't agree , I will act consequently .
Corrected
If you do n't agree , I will act .


With punctuation

In this case, we need to add a comma and a quoute

```

S She began to read " Dear Carolin ..
A 4 4|||M:PUNCT|||,|||REQUIRED|||-NONE-|||0
A 8 8|||M:PUNCT|||"|||REQUIRED|||-NONE-|||0

```



In [427]:
print("Wrong")
print(dataset['train'][15].get("source"))
print("Corrected")
print(dataset['train'][15].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
She began to read " Dear Carolin ..
Corrected
She began to read , " Dear Carolin " ..


In [444]:
from datasets import load_from_disk
parsed_dataset = load_from_disk(os.path.join(os.getcwd(), "data/fce/parsed_fce_dataset"))

parsed_dataset['test'][1266]  # Check the first example in the parsed dataset

{'source': 'There are eye - max cinema , museums , garelty .',
 'target': 'There are eye - max cinema , museums , a gallery .'}

Now that we already have a parser and loader of the dataset for our needs using M2 file, we will create a simple preprocessor for the T5 model

# T5 PreProcessor

Postprocessor class that handles tokenization for the data used by T5 model

In [4]:
from transformers import PreTrainedTokenizer
from datasets import DatasetDict
from typing import Dict

class T5Preprocessor:
    """
    T5 Preprocessor for the data. Could be implemented as abstract class to have multiple preprocessors.

    Args:
        tokenizer (PreTrainedTokenizer): The tokenizer to use.
        max_lenght (int): The maximum lenght of the input. Default is 512.
        truncation (bool): Whether to truncate the input. Default is True.
        padding (str): The padding to use. Default is "max_length".

    """


    def __init__(self, tokenizer: PreTrainedTokenizer, truncation: bool = True, padding: str = "max_length"):
        self.tokenizer = tokenizer
        self.truncation = truncation
        self.padding = padding
        self.logger = logging.getLogger(self.__class__.__name__)

    def _preprocess_function(self, examples: Dict[str,str], max_length: int) -> Dict[str, List[List[int]]]:
        """
        Preprocess examples tokenizing them using the instance parameters.
        """
        inputs = ["correct grammar: " + s for s in examples['source']]
        targets = [t for t in examples['target']]
        model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=self.truncation, padding=self.padding)
        labels = self.tokenizer(targets, max_length=max_length, truncation=self.truncation, padding=self.padding)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    @staticmethod
    def _get_max_input_length(dataset: DatasetDict) -> int:
        """
        Get the maximum input length in the dataset.
        """
        max_length = 0
        for split in dataset:
            split_max = max(len(x) for x in dataset[split]['source'])
            #self.logger.info(f"Max input_ids length in {split}: {split_max}")
            max_length = max(max_length, split_max)
            #self.logger.info(f"Overall max input_ids length: {max_length}")
        return max_length
    
    
    def save_tokenized_dataset(self, dataset: DatasetDict, output_dir: str) -> None:
        """
        Save the tokenized dataset to disk.
        """
        self.logger.info(f"Saving tokenized dataset to {output_dir}...")
        dataset.save_to_disk(output_dir)
        self.logger.info("Tokenized dataset saved successfully.")

    def preprocess(self, dataset: DatasetDict, max_length: int) -> DatasetDict:
        """
        Preprocess the dataset.
        """
        self.logger.info("Preprocessing dataset...")
        tokenized_dataset = dataset.map(lambda examples: self._preprocess_function(examples, max_length), batched=True)
        self.logger.info("Dataset preprocessed")
        return tokenized_dataset

In [14]:
from transformers import T5Tokenizer
from datasets import load_from_disk

dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce/parsed_fce_dataset"))
max_length = T5Preprocessor._get_max_input_length(dataset)  # This will give you the max input length in the dataset for saving the preprocessed dataset to train the model

preprocessor = T5Preprocessor(tokenizer=T5Tokenizer.from_pretrained(MODEL_NAME))



In [None]:
preprocessed_dataset = preprocessor.preprocess(dataset, max_length=max_length)
preprocessed_dataset.save_to_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "/fce/preprocessed_fce_dataset"))

Now what about the trainig class?

# Trainer class

Class created to handle training in a repeatable way

In [7]:
from typing import Tuple, Optional
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import DatasetDict
import torch


class T5Trainer:
    def __init__(self, model_name: str = MODEL_NAME, output_dir: str = FINETUNED_MODEL_OUTPUT_DIR,
                 logging_dir: str = "./finetune_logs", save_strategy: str = "epoch", resume_from_dir: Optional[str] = None,
                 batch_size: int = 8, mixed_precision: bool = False, early_stopping: bool = True,
                 eval_strategy: str = "epoch", early_stopping_patience: int = 3, early_stopping_threshold: float = 0.0, 
                 metric_for_best_model: str = "eval_loss", greater_is_better: bool = False):
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.output_dir = output_dir
        self.logging_dir = logging_dir
        self.save_strategy = save_strategy
        self.logger = logging.getLogger(self.__class__.__name__)
        self.resume_from_dir = resume_from_dir
        self.batch_size = batch_size
        self.mixed_precision = mixed_precision
        self.early_stopping = early_stopping
        self.eval_strategy = eval_strategy
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_threshold = early_stopping_threshold
        self.metric_for_best_model = metric_for_best_model
        self.greater_is_better = greater_is_better
        

    def _check_if_gpu_available(self) -> bool:
        """
        Check if GPU is available.
        """
        available = torch.cuda.is_available()

        if available:
            self.logger.info("GPU available")
        else:
            self.logger.warning("GPU not available")

        return available

    def _check_if_model_already_available(self) -> bool:
        """
        Check if model is already available.
        """
        available = os.path.exists(self.output_dir)

        if available:
            self.logger.info("Model already available")
        else:
            self.logger.warning("Model not available")

        return available

    def _create_unique_dir_for_model(self) -> str:

        """
        Create a unique directory for the model based on the current time.
        """
        current_time = time.strftime("%Y%m%d-%H%M%S")
        unique_dir = os.path.join(self.output_dir, current_time)
        os.makedirs(unique_dir, exist_ok=True)
        self.logger.info(f"Created unique directory for model: {unique_dir}")
        return unique_dir


    def train(self, tokenized_dataset:DatasetDict, epochs=3, learning_rate: float = 3e-4, eval_strategy = "epoch") -> Tuple[str, T5ForConditionalGeneration, T5Tokenizer]:
        """
        Train the T5 model.
        Args:
            tokenized_dataset (DatasetDict): The tokenized dataset.
            epochs (int): The number of epochs to train for.
        Returns:
            Tuple[Model_Dir (ID), T5ForConditionalGeneration, T5Tokenizer]: The trained model and tokenizer.
        """
        self.logger.info("Checking if GPU is available...")
        self._check_if_gpu_available()
        #self.logger.info("Checking if model is already available...")
        #self._check_if_model_already_available()
        # Determine the output directory based on whether we're resuming or starting fresh
        if self.resume_from_dir and os.path.exists(self.resume_from_dir):
            self.output_dir = self.resume_from_dir
            self.logger.info(f"Resuming training from directory: {self.output_dir}")
        else:
             self.output_dir = self._create_unique_dir_for_model()
             self.logger.info(f"Starting new training in directory: {self.output_dir}")

        # self.output_dir = self._create_unique_dir_for_model()


        training_args = TrainingArguments(
            output_dir=self.output_dir,
            per_device_train_batch_size=self.batch_size,
            num_train_epochs=epochs,
            eval_strategy = self.eval_strategy,
            save_strategy=self.save_strategy,
            logging_dir=self.logging_dir,
            learning_rate=learning_rate,
            report_to="none", #Needed to avoid wandb api key request.,
            # fp16=self.mixed_precision,  # Enable mixed precision training if specified
        )
        
        if self.early_stopping:
            training_args.load_best_model_at_end = True
            training_args.metric_for_best_model = self.metric_for_best_model
            training_args.greater_is_better = self.greater_is_better
            training_args.save_total_limit = 1
            training_args.eval_strategy = self.eval_strategy
        self.logger.info("Training arguments set")
        
        if self.mixed_precision:
            training_args.fp16 = True
            self.logger.info("Mixed precision training enabled")


        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_dataset['train'],
            eval_dataset=tokenized_dataset['validation']
        )
        
        if self.early_stopping:
            trainer.add_callback(EarlyStoppingCallback(self.early_stopping_patience, self.early_stopping_threshold))
            self.logger.info("Early stopping callback added")

          # Resume training if a checkpoint exists in the output directory
        if self.resume_from_dir and os.path.exists(self.resume_from_dir):
             # The Trainer class automatically handles resuming from a directory if it exists
             # and contains a checkpoint. You don't need to explicitly load the state dicts
             # if you are using the Trainer's resume functionality.
             trainer.train(resume_from_checkpoint=self.output_dir)
        else:
            trainer.train()


        self.logger.info("T5 model trained")
        trainer.save_model(self.output_dir)
        self.model.save_pretrained(self.output_dir)
        self.tokenizer.save_pretrained(self.output_dir)
        return self.output_dir, self.model, self.tokenizer

In [None]:
from datasets import load_from_disk

preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))

In [None]:


#trainer = T5Trainer(resume_from_dir="./t5_finetuned/20250530-172532")
# trainer = T5Trainer(batch_size=16, resume_from_dir="/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/checkpoint-591")  # Enable mixed precision training

trainer = T5Trainer(batch_size=8)

model_dir, model, tokenizer = trainer.train(preprocessed_dataset)

# Inference Engines

## Base Inference Engine and Helper

In [8]:
from abc import ABC, abstractmethod


class BaseInferenceEngine(ABC):
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
    @abstractmethod
    def correct_sentence(self, sentence: str) -> str:
        pass
    @abstractmethod
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        pass

## T5 Inference Engine

Inference engines to orchestrate model loading, and implements methods for single and batch inference just calling a method.

In [9]:
from typing import Dict, Union, Optional


class T5InferenceEngine(BaseInferenceEngine):
    def __init__(self, model_dir: str, max_length: Optional[int] = None):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.t5_model = T5ForConditionalGeneration.from_pretrained(model_dir)
        self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
        self.max_length = max_length if max_length is not None else self.tokenizer.model_max_length
        self.logger.info(f"T5 model loaded from {model_dir} with max length {self.max_length}")
        

    def correct_sentence(self, sentence:str) -> str:
        """
        Corrects a sentence using the T5 model.
        Args:
            sentence (str): The sentence to correct.
        Returns:
            str: The corrected sentence.
        """
        self.logger.info(f"Correcting sentence: {sentence}")
        inputs = self.tokenizer("correct grammar: " + sentence, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length)
        outputs = self.t5_model.generate(**inputs, max_length=self.max_length)
        corrected_sentence = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.logger.info(f"Corrected sentence: {corrected_sentence}")
        return corrected_sentence
    
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        """
        Corrects a batch of sentences using the T5 model.
        Args:
            sentences (list): List of sentences to correct.
            batch_size (int): Number of sentences per batch.
        Returns:
            list: List of corrected sentences.
        """
        corrected = []
        for i in range(0, len(sentences), batch_size):
            self.logger.info(f"Processing batch {i // batch_size + 1} with size {min(batch_size, len(sentences) - i)}")
            batch = sentences[i:i+batch_size]
            inputs = self.tokenizer(
                ["correct grammar: " + s for s in batch],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            )
            outputs = self.t5_model.generate(**inputs, max_length=self.max_length)
            batch_corrected = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
            corrected.extend(batch_corrected)
        self.logger.info(f"Batch correction completed. Total corrected sentences: {len(corrected)}")
        return corrected

In [16]:
model_dir = "/Users/isaac/Developer/GEC-system/models/improved"
t5_inference_engine = T5InferenceEngine(model_dir=model_dir, max_length=650)
t5_inference_engine.max_length

650

In [20]:
wrong_sentence = "You is a apple."
theorycally_corrected_sentence = "I have an apple."


corrected_sentence = t5_inference_engine.correct_sentence(wrong_sentence)

print(f"Wrong sentence: {wrong_sentence}")
print(f"Corrected sentence: {corrected_sentence}")

2025-06-01 14:20:44,073 - INFO - T5InferenceEngine - correct_sentence - Correcting sentence: You is a apple.
2025-06-01 14:20:44,397 - INFO - T5InferenceEngine - correct_sentence - Corrected sentence: You are an apple.
Wrong sentence: You is a apple.
Corrected sentence: You are an apple.


In [163]:
#%pip install nltk
import nltk
nltk.download('punkt')


[nltk_data] Downloading package punkt to /Users/isaac/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## Llama 3 Inference Engine

See docs to run and query llama3 8b locally using ollama

https://ollama.com/library/llama3

In [10]:
from typing import Union, Dict
import json as _json
import asyncio
import aiohttp
import requests

class Llama3InferenceEngine(BaseInferenceEngine):
    model_name: str = "llama3"
    stream: bool = False
    response_format: dict = {
        "type": "object",
        "properties": {
            "original_text": {"type": "string"},
            "corrected_text": {"type": "string"}
        },
        "required": ["original_text", "corrected_text"]
    }
    
    options: Dict[str, Union[str, int, float]] = {
        "temperature": 0.2,
        "top_k": 20,
        "top_p": 0.5,
        "seed": 0
    }
    
    
    def __init__(self, model_endpoint: str, prompt_path: str):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.model_endpoint = model_endpoint
        self.prompt_path = prompt_path
        self.prompt = self._parse_prompt()
        self.logger.info(f"Llama3InferenceEngine initialized with model endpoint: {self.model_endpoint} and prompt path: {self.prompt_path}, options: {self.options}")
        
    def _parse_prompt(self) -> str:
        with open(self.prompt_path, 'r') as file:
            prompt = file.read()
        if not prompt:
            self.logger.error("Prompt is empty. Please check the prompt file.")
            raise ValueError("Prompt is empty. Please check the prompt file.")
        return prompt
    
    def _replace_prompt_variables(self, sentence: str) -> str:
        return self.prompt.replace(f"{TEXT_TO_REPLACE_IN_PROMPT}", sentence)

    def send_correct_request(self, sentence: str) -> Dict[str, Union[str, Dict[str, str]]]:
        prompt = self._replace_prompt_variables(sentence)
        try:
            response = requests.post(
                self.model_endpoint, 
                json={
                    "model": self.model_name,
                    "prompt": str(prompt),
                    "stream": self.stream,
                    "format": self.response_format,
                    "options": self.options
                }
            )
            response_data = response.json()
            if not response_data:
                self.logger.error("No corrected sentence returned from the model.")
                raise ValueError("No corrected sentence returned from the model.")
            if "response" in response_data and isinstance(response_data["response"], str):
                response_data["response"] = _json.loads(response_data["response"])
            return response_data
        except requests.exceptions.RequestException as e:
            self.logger.error(f"Error during model inference: {e}")
            raise RuntimeError(f"Error during model inference: {e}")
        
    def correct_sentence(self, sentence: str) -> str:
        self.logger.info(f"Correcting sentence: {sentence}")
        response = self.send_correct_request(sentence)
        if isinstance(response, dict):
            corrected_sentence = response.get("response", {}).get("corrected_text", "")
            self.logger.info(f"Corrected sentence: {corrected_sentence}")
            return corrected_sentence
        else:
            return ""
        
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        """Not implemented for Llama3InferenceEngine as it does not support full batch inference yet.
        This method is a placeholder to maintain interface consistency with BaseInferenceEngine.

        Args:
            sentences (list): _description_
            batch_size (int, optional): _description_. Defaults to 16.

        Returns:
            list: _description_
        """
        return super().batch_correct(sentences, batch_size)

    # --- ASYNC BATCH INFERENCE ---
    async def async_correct_sentence(self, session, sentence: str) -> str:
        prompt = self._replace_prompt_variables(sentence)
        payload = {
            "model": self.model_name,
            "prompt": str(prompt),
            "stream": self.stream,
            "format": self.response_format,
            "options": self.options
        }
        self.logger.info(f"Sending async request for sentence: {sentence}")
        async with session.post(self.model_endpoint, json=payload) as resp:
            response_data = await resp.json()
            if not response_data:
                return ""
            if "response" in response_data and isinstance(response_data["response"], str):
                import json as _json
                response_data["response"] = _json.loads(response_data["response"])
                corrected_text = response_data.get("response", {}).get("corrected_text", "")
                self.logger.info(f"Async corrected sentence: {corrected_text}")
            
                return corrected_text if corrected_text else ""
            else:
                self.logger.error("Response format is not as expected.")
                return ""

    async def async_batch_correct(self, sentences, max_concurrent=5):
        timeout = aiohttp.ClientTimeout(total=60)
        semaphore = asyncio.Semaphore(max_concurrent)
        async with aiohttp.ClientSession(timeout=timeout) as session:
            async def sem_task(sentence):
                async with semaphore:
                    return await self.async_correct_sentence(session, sentence)
            tasks = [sem_task(s) for s in sentences]
            return await asyncio.gather(*tasks)

In [61]:
llama3_engine = Llama3InferenceEngine(model_endpoint=LLAMA3_ENDPOINT, prompt_path=GENERAL_PROMPT_PATH)

2025-06-01 20:02:56,687 - INFO - Llama3InferenceEngine - __init__ - Llama3InferenceEngine initialized with model endpoint: http://127.0.0.1:11434/api/generate and prompt path: /Users/isaac/Developer/GEC-system/config/prompt_general.txt, options: {'temperature': 0.0, 'seed': 123, 'top_k': 10, 'top_p': 0.5}


In [62]:
llama3_engine.send_correct_request("Helo")

{'model': 'llama3',
 'created_at': '2025-06-02T02:03:03.061195Z',
 'response': {'original_text': 'Helo', 'corrected_text': 'Hello'},
 'done': True,
 'done_reason': 'stop',
 'context': [128006,
  882,
  128007,
  271,
  2675,
  527,
  4689,
  7580,
  69225,
  62172,
  1493,
  27358,
  1887,
  13,
  4718,
  3465,
  374,
  311,
  1193,
  5155,
  1193,
  32528,
  11,
  43529,
  11,
  477,
  62603,
  6103,
  2768,
  264,
  480,
  7650,
  5410,
  13,
  3234,
  539,
  312,
  28810,
  11,
  923,
  11,
  477,
  4148,
  2038,
  13,
  1442,
  279,
  11914,
  374,
  4495,
  11,
  471,
  433,
  35957,
  627,
  13622,
  315,
  279,
  3288,
  762,
  4978,
  1253,
  539,
  1205,
  904,
  4442,
  11,
  779,
  499,
  690,
  471,
  279,
  11914,
  35957,
  304,
  1455,
  5157,
  382,
  791,
  1455,
  4279,
  6103,
  2997,
  512,
  12,
  12362,
  2500,
  3492,
  938,
  7664,
  2555,
  198,
  35970,
  25,
  358,
  1097,
  12703,
  311,
  1373,
  922,
  12131,
  539,
  1694,
  520,
  813,
  1888,
  627,
  2

In [215]:
llama3_engine.correct_sentence("Helo")

2025-05-31 00:01:11,833 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: Helo
2025-05-31 00:01:12,732 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: Hello


'Hello'

# Evaluation

To do the evaluation, we need to choose some metrics to perform the task. For this, we will do it using three baselines:

- Exact Match

Why do we use this metrics to evaluate GEC?
Exact match is pretty strict and shows how often the system produces a fully corrected sentence comparing with the test data. 100% deterministic so it is good for systems where only perfect corrections are allowed.
- Gleu

Why do we use it?
It is a sentence level variant of BLEU, it is less strict than exact match and reflecs better incremental improvements, not a 100% deterministic correction.


- ERRANT score
Why do we use it?
Because it is a standard metric for GEC, evaluates how well the system identifies and corrects specific errors. It provides a detailed edit level assement.
Basically it compare edits between the system and a reference m2 files (we had to create) 



Note:
*The first two evaluations are done with the evaluator class defined below, using the inference data included in the repository (fce_predicted.csv for t5, fce_predicted_llama.csv for llama3). The data was generated using a script included in the repository (/scripts/predictions_builder), instructions about how to use the script are there.*





## Evaluator Class

In [42]:
from sklearn.metrics import accuracy_score
from nltk.translate.gleu_score import sentence_gleu
from typing import Optional
import re

class Evaluator:
    """ Evaluator class for evaluating the performance of the inference engine on a dataset.
    Args:
        inference_engine (BaseInferenceEngine): The inference engine to use for evaluation.
        n_samples (Optional[int]): Number of samples to evaluate. If None, evaluates the entire test set. Default is None.
        predicted_dataset (Optional[DatasetDict]): A precomputed dataset with predictions. If provided, it will use this dataset instead of running inference. Default is None.
        dataset (DatasetDict): The dataset to evaluate. It also runs inference.
    """
    def __init__(self, inference_engine: BaseInferenceEngine, n_samples: Optional[int] = None, predicted_dataset: Optional[DatasetDict] = None, dataset: Optional[DatasetDict] = None):
        # self.dataset = dataset
        self.engine = inference_engine
        self.n_samples = n_samples
        self.test_data, self.references, self.predictions = None, None, None
        if dataset is None and predicted_dataset is None:
            raise ValueError("Either dataset or predicted_dataset must be provided.")
        if predicted_dataset is not None and isinstance(predicted_dataset, DatasetDict):
            self.test_data = predicted_dataset['test']['source']
            self.references = predicted_dataset['test']['target']
            self.predictions = predicted_dataset['test']['prediction']
        elif dataset is not None and isinstance(dataset, DatasetDict):
            self.dataset = dataset
        self.logger = logging.getLogger(self.__class__.__name__)

    @staticmethod
    def normalize_text(text):
        # Remove spaces before punctuation and ensure single space after
        text = re.sub(r'\s+([.,!?;:"])', r'\1', text)
        text = re.sub(r'([.,!?;:"])([^\s])', r'\1 \2', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def _get_samples(self) -> Tuple[List[str], List[str], List[str]]:
        """ Fetch samples from the test set."""
        test_data = self.dataset['test'] if self.test_data is None else self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set.")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        predictions = self.engine.batch_correct(sample_sentences)
        return sample_sentences, references, predictions
    
    def  _get_samples_if_not_available(self):
        """
        Check if samples are available.
        """
        if self.test_data is None or self.references is None or self.predictions is None:
            self.logger.warning("Samples not available. Fetching samples...")
            self.test_data, self.references, self.predictions = self._get_samples()
        else:
            self.logger.info("Samples already available. Using cached samples.")

    def evaluate_accuracy(self) -> float:
        """ 
        Evaluate the accuracy of the predictions against the references.
        It computes the exact match accuracy.
        Returns:
            float: The exact match accuracy.
        """
        self._get_samples_if_not_available()
        norm_refs = [self.normalize_text(str(ref)) for ref in self.references]
        norm_preds = [self.normalize_text(str(pred)) for pred in self.predictions]
        accuracy = accuracy_score(norm_refs, norm_preds)
        self.logger.info(f"Exact match accuracy on test set: {accuracy:.4f}")
        return accuracy

    def evaluate_gleu(self):
        """
        Evaluate the GLEU score of the predictions against the references.
        It uses the nltk library to compute the GLEU score.
        Returns:
            float: The average GLEU score across all samples.
            
        """
        self._get_samples_if_not_available()
        gleu_scores = []
        for pred, ref in zip(self.predictions, self.references):
            ref_tokens = self.normalize_text(str(ref)).split()
            pred_tokens = self.normalize_text(str(pred)).split()
            gleu = sentence_gleu([ref_tokens], pred_tokens)
            gleu_scores.append(gleu)
        avg_gleu = sum(gleu_scores) / len(gleu_scores) if gleu_scores else 0.0
        self.logger.info(f"Average GLEU score on test set: {avg_gleu:.4f}")
        return avg_gleu

    # --- ASYNC BATCH INFERENCE FOR LLAMA3 ---
    async def evaluate_accuracy_async(self)-> Optional[float]:
        """ Evaluate the accuracy of the predictions against the references using async batch inference.
        It computes the exact match accuracy.
        This method is specifically designed for engines that support async batch inference.

        Returns:
            Optional[float]: The exact match accuracy if async batch inference is available, otherwise None.
        """
        # Fetch samples (sentences and references)
        test_data = self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set (async).")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        # Use async batch correct if available
        if hasattr(self.engine, "async_batch_correct"):
            predictions = await self.engine.async_batch_correct(sample_sentences)
            self.test_data, self.references, self.predictions = sample_sentences, references, predictions
            norm_refs = [self.normalize_text(ref) for ref in self.references]
            norm_preds = [self.normalize_text(pred) for pred in self.predictions]
            accuracy = accuracy_score(norm_refs, norm_preds)
            self.logger.info(f"Exact match accuracy on test set: {accuracy:.4f}")
            return accuracy
        else:
            self.logger.warning("Async batch inference not available for this engine.")
            return None

    async def evaluate_gleu_async(self)-> Optional[float]:
        """ Evaluate the GLEU score of the predictions against the references using async batch inference.
        It uses the nltk library to compute the GLEU score.
        This method is specifically designed for engines that support async batch inference.

        Returns:
            Optional[float]: The average GLEU score across all samples if async batch inference is available, otherwise None.
        """
        # Fetch samples (sentences and references)
        test_data = self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set (async).")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        # Use async batch correct if available
        if hasattr(self.engine, "async_batch_correct"):
            predictions = await self.engine.async_batch_correct(sample_sentences)
            self.test_data, self.references, self.predictions = sample_sentences, references, predictions
            gleu_scores = []
            for pred, ref in zip(self.predictions, self.references):
                ref_tokens = self.normalize_text(ref).split()
                pred_tokens = self.normalize_text(pred).split()
                gleu = sentence_gleu([ref_tokens], pred_tokens)
                gleu_scores.append(gleu)
            avg_gleu = sum(gleu_scores) / len(gleu_scores) if gleu_scores else 0.0
            self.logger.info(f"Average GLEU score on test set: {avg_gleu:.4f}")
            return avg_gleu
        else:
            self.logger.warning("Async batch inference not available for this engine.")
            return None
    
    @staticmethod
    def evaluate_single(pred, ref):
        norm_pred = Evaluator.normalize_text(pred)
        norm_ref = Evaluator.normalize_text(ref)
        exact = int(norm_pred == norm_ref)
        gleu = sentence_gleu([norm_ref.split()], norm_pred.split())
        return {"exact_match": exact, "gleu": gleu}

## Evaluating FCE test data

### Evaluating the test dataset (FCE with T5) using GLEU and exact match

In [17]:
## This dataset was generated using the following command:
# python3 -m scripts.predictions_builder medical t5 fce /Users/isaac/Developer/GEC-system/data/fce_predicted.csv

import pandas as pd
predicted_fce_df = pd.read_csv(os.path.join("/Users/isaac/Developer/GEC-system/data/fce_predicted.csv"))
predicted_fce_dataset = DatasetDict({
    'test': Dataset.from_pandas(predicted_fce_df)
})


# If you don't want to use it, comment the previous lines and uncomment the following 
# preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))
# from datasets import load_from_disk
# preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))
# t5_inference_engine = T5InferenceEngine(model_dir=model_dir, max_length=650)

# t5_evaluator= Evaluator(
#     inference_engine=t5_inference_engine,
#     dataset=preprocessed_dataset
# )
# accuracy = t5_evaluator.evaluate_accuracy()
# print("Accuracy:", accuracy)
# gleu = t5_evaluator.evaluate_gleu()
# print("GLEU:", gleu)


In [18]:

t5_fce_evaluator = Evaluator(t5_inference_engine, predicted_dataset=predicted_fce_dataset)

In [52]:
t5_fce_evaluator.evaluate_accuracy()

2025-06-01 19:29:45,350 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 19:29:45,416 - INFO - Evaluator - evaluate_accuracy - Exact match accuracy on test set: 0.3833


0.38330241187384045

In [53]:
t5_fce_evaluator.evaluate_gleu()

2025-06-01 19:29:46,749 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 19:29:46,898 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.7843


0.7842619115222491

This are kinda good results when using a simple model. Strict matches are good.

In [32]:
import random

sample_indices = random.sample(range(len(predicted_fce_dataset['test'])), 10)
for idx in sample_indices:
    src = Evaluator.normalize_text(predicted_fce_dataset['test'][idx]['source'])
    tgt = Evaluator.normalize_text(predicted_fce_dataset['test'][idx]['target'])
    pred = Evaluator.normalize_text(predicted_fce_dataset['test'][idx]['prediction'])
    print(f"Source: {src}\nTarget: {tgt}\nPrediction: {pred}\n")
    result = t5_fce_evaluator.evaluate_single(pred, tgt)
    print(f"Exact Match: {result['exact_match']}, GLEU: {result['gleu']:.4f}\n{'-'*60}\n")

Source: I will give you now some useful information.
Target: I will now give you some useful information.
Prediction: I will give you now some useful information.

Exact Match: 0, GLEU: 0.5000
------------------------------------------------------------

Source: But also, the museum have a lot of gardens with several statues.
Target: But also the museum has a lot of gardens with several statues.
Prediction: But also, the museum has a lot of gardens with several statues.

Exact Match: 0, GLEU: 0.8333
------------------------------------------------------------

Source: The hotel is just three busstops away from our college.
Target: The hotel is just three bus stops away from our college.
Prediction: The hotel is just three bus stops away from our college.

Exact Match: 1, GLEU: 1.0000
------------------------------------------------------------

Source: I hope to give you, all the information need and please, if you want more information or somenthing is not clear, please do n't esitate

### Evaluating test data (FCE with Llama 3) using GLEU and exact match

In [44]:
predicted_fce_llama_df = pd.read_csv("/Users/isaac/Developer/GEC-system/data/fce_predicted_llama.csv")
llama3_engine = Llama3InferenceEngine(model_endpoint=LLAMA3_ENDPOINT, prompt_path=GENERAL_PROMPT_PATH)

llama3_fce_evaluator = Evaluator(llama3_engine, predicted_dataset=DatasetDict({'test': Dataset.from_pandas(predicted_fce_llama_df)}))
llama3_fce_evaluator.evaluate_accuracy()

2025-06-01 14:35:16,007 - INFO - Llama3InferenceEngine - __init__ - Llama3InferenceEngine initialized with model endpoint: http://127.0.0.1:11434/api/generate and prompt path: /Users/isaac/Developer/GEC-system/config/prompt_general.txt, options: {'temperature': 0.0, 'seed': 123, 'top_k': 10, 'top_p': 0.5}
2025-06-01 14:35:16,021 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 14:35:16,065 - INFO - Evaluator - evaluate_accuracy - Exact match accuracy on test set: 0.1325


0.13246753246753246

In [45]:
llama3_fce_evaluator.evaluate_gleu()

2025-06-01 14:35:18,217 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 14:35:18,367 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.6072


0.6072433371173874

Worse results in exact match due to a prompt bad structured with a lot of edge cases...

# Evaluating our medical data

Finally we want to evaluate with our medical data to see how it behaves...

In [33]:
import pandas as pd
from datasets import Dataset

# Read your medical data CSV (update the path as needed)
medical_df = pd.read_csv("/Users/isaac/Developer/GEC-system/data/data.csv")  # columns should be ['source', 'target']
medical_df.rename(columns={'incorrect_sentence': 'source', 'correct_sentence': 'target'}, inplace=True)

# Convert to HuggingFace Dataset
medical_data = Dataset.from_pandas(medical_df)
medical_dataset_dict = DatasetDict({'test': medical_data})


medical_data

Dataset({
    features: ['source', 'target'],
    num_rows: 204
})

### Evaluating with T5 using exact match and GLEU

In [34]:
## Precomputed predictions for the medical dataset

medical_t5_predictions_df = pd.read_csv("/Users/isaac/Developer/GEC-system/data/medical_predicted.csv")

t5_medical_evaluator = Evaluator(t5_inference_engine, predicted_dataset=DatasetDict({'test': Dataset.from_pandas(medical_t5_predictions_df)}))

In [35]:
t5_medical_evaluator.evaluate_accuracy()

0.1323529411764706

In [36]:
t5_medical_evaluator.evaluate_gleu()

0.7898051106345906

Bad exact match due to lack of medical context...

### Evaluating with Llama3 using exact match and GLEU

In [374]:
# Precomputed predictions for the medical dataset using Llama3

llama3_medical_predictions_df = pd.read_csv("/Users/isaac/Developer/GEC-system/data/medical_predicted_llama.csv")

llama3_medical_evaluator = Evaluator(llama3_engine, predicted_dataset=DatasetDict({'test': Dataset.from_pandas(llama3_medical_predictions_df)}))
llama3_medical_evaluator.evaluate_accuracy()

2025-06-01 13:08:28,092 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 13:08:28,099 - INFO - Evaluator - evaluate_accuracy - Exact match accuracy on test set: 0.5049


0.5049019607843137

In [375]:
llama3_medical_evaluator.evaluate_gleu()

2025-06-01 13:08:31,968 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 13:08:31,988 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.8691


0.869140455496451

Pretty good results due to prompt tunning

## Using ERRANT Scorer on test data

### Generate gold FCE data using our framework 


We need to generate our own gold data using the other framework due to some postprocessing we're doing (like standard punctuation removing spaces between commas or dots), and also i noticed some examples where the gold data annotation of bea19 is 100% not compatible with the m2 files generated but at the end of the day the edits are the same but annotated different.

In [54]:
preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce/preprocessed_fce_dataset"))

In [55]:
# Generate gold FCE data using our framework 

import pandas as pd

data = pd.read_csv(os.path.join(os.getcwd(), "data", "fce_predicted.csv"))["target"].tolist()

with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_gold.txt"), 'w') as f:
    for sentence in data:
        f.write(Evaluator.normalize_text(sentence) + "\n")
    

In [56]:
!errant_parallel -orig ./data/fce_wrong.txt -cor ./data/fce_gold.txt -out ./data/m2/fce_gold.m2

Loading resources...
Processing parallel files...


### T5

The data is already predicted and stored as csv in data/fce_predicted.csv. Now we're gonna convert the predictions as txt as expected

In [57]:
import pandas as pd

# Read your medical data CSV (update the path as needed)
fce_predicted_df_t5 = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_t5.txt"), 'w') as f:
    for sentence in fce_predicted_df_t5['prediction']:
        f.write(Evaluator.normalize_text(sentence) + '\n')

with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_wrong.txt"), 'w') as f:
    for sentence in fce_predicted_df_t5['source']:
        f.write(Evaluator.normalize_text(sentence) + '\n')


Generate system M2

In [None]:
# !pip install --upgrade --force-reinstall numpy h5py

In [58]:
!errant_parallel -orig ./data/fce_wrong.txt -cor ./data/fce_predicted_t5.txt -out ./data/m2/fce_predicted_t5.m2

Loading resources...
Processing parallel files...


Run it as script

In [59]:
!errant_compare -hyp ./data/m2/fce_predicted_t5.m2 -ref ./data/m2/fce_gold.m2


TP	FP	FN	Prec	Rec	F0.5
1452	1411	3312	0.5072	0.3048	0.4477



Again, we see is a solid results for a grammatical error correction (GEC) system, especially if you are using a small or moderately sized model, or if this is your first or baseline system.

Precision (0.5122) is much higher than recall (0.2827), meaning your system is conservative: it makes fewer edits, but they are more likely to be correct.

Recall could be improved (your system is missing some errors), but this is common for many GEC systems.


### LLAMA

### Running analysis over the most common edits in the fce data so we can finetune the prompt

In [46]:
M2DatasetLoader.most_common_edit_types(gold_m2_path="./data/m2/fce_gold.m2")

[('noop', 903),
 ('R:OTHER', 617),
 ('R:NOUN', 576),
 ('R:SPELL', 423),
 ('M:DET', 335),
 ('R:ORTH', 263),
 ('R:PREP', 254),
 ('R:VERB', 233),
 ('R:VERB:TENSE', 165),
 ('R:DET', 134)]

In [44]:
### LLAMA
import pandas as pd

# Read your medical data CSV (update the path as needed)
fce_predicted_df_llama = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_llama.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_llama.txt"), 'w') as f:
    for sentence in fce_predicted_df_llama['prediction']:
        f.write(Evaluator.normalize_text(str(sentence)) + '\n')


In [48]:
!errant_parallel -orig ./data/fce_wrong.txt -cor ./data/fce_predicted_llama.txt -out ./data/m2/fce_predicted_llama.m2

Loading resources...
Processing parallel files...


In [49]:
!errant_compare -hyp ./data/m2/fce_predicted_llama.m2 -ref ./data/m2/fce_gold.m2


TP	FP	FN	Prec	Rec	F0.5
1667	3509	3097	0.3221	0.3499	0.3273



In [52]:
import random

sample_indices = random.sample(range(len(fce_predicted_df_llama)), 10)
for idx in sample_indices:
    src = Evaluator.normalize_text(fce_predicted_df_llama['source'][idx])
    tgt = Evaluator.normalize_text(fce_predicted_df_llama['target'][idx])
    pred = Evaluator.normalize_text(fce_predicted_df_llama['prediction'][idx])
    print(f"Source: {src}\nTarget: {tgt}\nPrediction: {pred}\n")
    result = t5_fce_evaluator.evaluate_single(pred, tgt)
    print(f"Exact Match: {result['exact_match']}, GLEU: {result['gleu']:.4f}\n{'-'*60}\n")

Source: No jeans and tee - shirts.
Target: No jeans or tee shirts.
Prediction: No jeans and T-shirts.

Exact Match: 0, GLEU: 0.2143
------------------------------------------------------------

Source: It is going to be very lovely and enjoyable because I have lots of surprises for students.
Target: It is going to be very lovely and enjoyable because I have lots of surprises for the students.
Prediction: It is going to be very lovely and enjoyable because I have lots of surprises for the students.

Exact Match: 1, GLEU: 1.0000
------------------------------------------------------------

Source: With reference to the information that you had requested, the hotel that had been booked is the Holiday Inn, in New Port and to get to the University of Wales wich is not far from the hotel, you only need to take a divertion where clearly indicate Carleon and once that you are on the main road all you need to do is to follow the country road which take you direct to the place.
Target: With refe

It is making some mistakes due to some lack of context and training. Still baseline (not that bad)

## Errant Score on medical data

### Generate m2 files including gold

In [223]:
medical_data = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_predicted.csv"))


with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_wrong.txt"), 'w') as f:
    for sentence in medical_data['source']:
        f.write(Evaluator.normalize_text(sentence) + '\n')
        
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_gold.txt"), 'w') as f:
    for sentence in medical_data['target']:
        f.write(sentence + '\n')


In [224]:
!errant_parallel -orig ./data/medical_wrong.txt -cor ./data/medical_gold.txt -out ./data/m2/medical_gold.m2

Loading resources...
Processing parallel files...


### T5

In [38]:
import pandas as pd

# Read your medical data CSV (update the path as needed)
medical_predicted_df_t5 = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_predicted.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_predicted_t5.txt"), 'w') as f:
    for sentence in medical_predicted_df_t5['prediction']:
        f.write(Evaluator.normalize_text(sentence) + '\n')



In [228]:
!errant_parallel -orig ./data/medical_wrong.txt -cor ./data/medical_predicted_t5.txt -out ./data/m2/medical_predicted_t5.m2

Loading resources...
Processing parallel files...


In [229]:
!errant_compare -hyp ./data/m2/medical_predicted_t5.m2 -ref ./data/m2/medical_gold.m2


TP	FP	FN	Prec	Rec	F0.5
31	184	208	0.1442	0.1297	0.141



In [43]:
import random

sample_indices = random.sample(range(len(medical_predicted_df_t5)), 10)
for idx in sample_indices:
    src = Evaluator.normalize_text(medical_predicted_df_t5['source'][idx])
    tgt = Evaluator.normalize_text(medical_predicted_df_t5['target'][idx])
    pred = Evaluator.normalize_text(medical_predicted_df_t5['prediction'][idx])
    print(f"Source: {src}\nTarget: {tgt}\nPrediction: {pred}\n")
    result = Evaluator.evaluate_single(pred, tgt)
    print(f"Exact Match: {result['exact_match']}, GLEU: {result['gleu']:.4f}\n{'-'*60}\n")

Source: Patient develop diabetic ketoacidosis with severe dehydration and electrolyte imbalances requiring intensive care management.
Target: The patient developed diabetic ketoacidosis with severe dehydration and electrolyte imbalances requiring intensive care management.
Prediction: Patient develops diabetic ketoacidosis with severe dehydration and electrolyte imbalances requiring intensive care management.

Exact Match: 0, GLEU: 0.7778
------------------------------------------------------------

Source: The radiation oncologist recommend intensity-modulated radiation therapy for locally advanced prostate adenocarcinoma.
Target: The radiation oncologist recommended intensity-modulated radiation therapy for locally advanced prostate adenocarcinoma.
Prediction: The radiation oncologist recommend intensity-modulated radiation therapy for locally advanced prostate adenocarcinoma.

Exact Match: 0, GLEU: 0.7619
------------------------------------------------------------

Source: Slit-lam

Lack of medical data while training is giving a bad score 

### LLAMA

### Running analysis over the most common edits in the medical data so we can finetune the prompt

In [358]:
M2DatasetLoader.most_common_edit_types(gold_m2_path="./data/m2/medical_gold.m2")

[('R:VERB:TENSE', 177),
 ('M:DET', 36),
 ('R:VERB:FORM', 7),
 ('R:VERB:SVA', 6),
 ('R:PREP', 5),
 ('M:OTHER', 1),
 ('R:OTHER', 1),
 ('U:OTHER', 1),
 ('M:VERB:TENSE', 1),
 ('R:ADV', 1)]

In [53]:
import pandas as pd

# Read your medical data CSV (update the path as needed)
medical_predicted_df_llama = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_predicted_llama.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "medical_predicted_llama.txt"), 'w') as f:
    for sentence in medical_predicted_df_llama['prediction']:
        f.write(Evaluator.normalize_text(str(sentence)) + '\n')



In [388]:
!errant_parallel -orig ./data/medical_wrong.txt -cor ./data/medical_predicted_llama.txt -out ./data/m2/medical_predicted_llama.m2

Loading resources...
Processing parallel files...


In [389]:
M2DatasetLoader.most_common_edit_types(gold_m2_path="./data/m2/medical_predicted_llama.m2")

[('R:VERB:TENSE', 155),
 ('M:DET', 60),
 ('R:NOUN:NUM', 38),
 ('R:VERB:SVA', 21),
 ('R:NOUN', 17),
 ('R:VERB:FORM', 10),
 ('R:SPELL', 8),
 ('R:OTHER', 8),
 ('R:PREP', 5),
 ('M:OTHER', 4)]

In [390]:
!errant_compare -hyp ./data/m2/medical_predicted_llama.m2 -ref ./data/m2/medical_gold.m2


TP	FP	FN	Prec	Rec	F0.5
181	156	58	0.5371	0.7573	0.5703



In [54]:
import random

sample_indices = random.sample(range(len(medical_predicted_df_llama)), 10)
for idx in sample_indices:
    src = Evaluator.normalize_text(medical_predicted_df_llama['source'][idx])
    tgt = Evaluator.normalize_text(medical_predicted_df_llama['target'][idx])
    pred = Evaluator.normalize_text(medical_predicted_df_llama['prediction'][idx])
    print(f"Source: {src}\nTarget: {tgt}\nPrediction: {pred}\n")
    result = t5_fce_evaluator.evaluate_single(pred, tgt)
    print(f"Exact Match: {result['exact_match']}, GLEU: {result['gleu']:.4f}\n{'-'*60}\n")

Source: Electrophysiology study show prolonged QT interval and increased risk for torsades de pointes arrhythmia.
Target: Electrophysiology study showed prolonged QT interval and increased risk for torsades de pointes arrhythmia.
Prediction: Electrophysiology studies showed prolonged QT interval and an increased risk for torsades de pointes arrhythmia.

Exact Match: 0, GLEU: 0.6852
------------------------------------------------------------

Source: Spirometry testing show severe obstructive pattern with reduced forced expiratory volume in one second.
Target: Spirometry testing showed severe obstructive pattern with reduced forced expiratory volume in one second.
Prediction: Spirometry testing showed severe obstructive pattern with reduced forced expiratory volume in one second.

Exact Match: 1, GLEU: 1.0000
------------------------------------------------------------

Source: Systolic and diastolic blood pressure is in normal range following administration of antihypertensive medicat

Kinda good results for only prompt engineering

# How can we improve?

## General data

### LLAMA

Using a larger model, improving the prompt passing more examples, or even fine tune it with the bea19 data

### T5

We used T5 small, so we could try choosing T5 base or some other larger model, fine tune it again, leaving it more time (when i trained i only chose 3 epochs due to gpu limitations) of training.

## Medical data

### T5

For improving this model i suggest finetuning the model that has been already trained for correction in general data, using medical to see if we improve

In [69]:
from datasets import DatasetDict
from sklearn.model_selection import train_test_split

# Assuming medical_df is already loaded and columns are ['source', 'target']
import pandas as pd

medical_df = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "data.csv"))
medical_df.rename(columns={'incorrect_sentence': 'source', 'correct_sentence': 'target'}, inplace=True)

# Split into train (80%), validation (10%), test (10%)
train_df, temp_df = train_test_split(medical_df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

medical_split_dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df.reset_index(drop=True)),
})


In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the already trained T5 model and tokenizer, then retrain (finetune) with medical data

# Load the previously trained model (from model_dir)

finetune_model = T5ForConditionalGeneration.from_pretrained(model_dir)
finetune_tokenizer = T5Tokenizer.from_pretrained(model_dir)

# Create a new T5Preprocessor using the loaded tokenizer
medical_preprocessor = T5Preprocessor(tokenizer=finetune_tokenizer)

# Get max input length for medical data
medical_max_length = T5Preprocessor._get_max_input_length(medical_split_dataset)

# Preprocess the medical dataset
medical_tokenized = medical_preprocessor.preprocess(medical_split_dataset, max_length=medical_max_length)

# Define a new output directory for the finetuned model
medical_finetune_output_dir = "./models/t5_finetuned_medical"

# Create a new T5Trainer using the loaded model and tokenizer
class CustomT5Trainer(T5Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = finetune_model
        self.tokenizer = finetune_tokenizer

medical_trainer = CustomT5Trainer(
    output_dir=medical_finetune_output_dir,
    batch_size=8
)

# Train (finetune) the model on the medical data
medical_model_dir, medical_model, medical_tokenizer = medical_trainer.train(medical_tokenized, epochs=10)

In [73]:
# Evaluate the finetuned model on the medical dataset
medical_evaluator = Evaluator(
    inference_engine=T5InferenceEngine(model_dir=medical_model_dir, max_length=650),
    dataset=medical_split_dataset
)
medical_evaluator.evaluate_accuracy()

2025-06-01 22:01:59,929 - INFO - T5InferenceEngine - __init__ - T5 model loaded from ./t5_finetuned_medical/20250601-220037 with max length 650
2025-06-01 22:01:59,935 - INFO - Evaluator - _get_samples - Fetching 21 samples from the test set.
2025-06-01 22:01:59,936 - INFO - T5InferenceEngine - batch_correct - Processing batch 1 with size 16
2025-06-01 22:02:01,000 - INFO - T5InferenceEngine - batch_correct - Processing batch 2 with size 5
2025-06-01 22:02:01,805 - INFO - T5InferenceEngine - batch_correct - Batch correction completed. Total corrected sentences: 21
2025-06-01 22:02:01,807 - INFO - Evaluator - evaluate_accuracy - Exact match accuracy on test set: 0.9048


0.9047619047619048

In [74]:
medical_evaluator.evaluate_gleu()

2025-06-01 22:02:13,434 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-06-01 22:02:13,436 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.9847


0.9847041847041847

This is not way too much informative because there's only 21 samples in the test dataset but we can see improvement (i guess)

### LLAMA

We could use a larger model with more examples when doing prompt engineering, even fine tune it with all the medical data we have available. This won't be done in this notebook.