# Challenge related with a GEC system

# Table of contents:

- [Challenge related with a GEC system](#challenge-related-with-a-gec-system)
- [Explanation](#explanation)
- [Just Setting Up](#just-setting-up)
- [Dataset Download and Loader](#dataset-downloader-and-loader)
    - [Dataset Exploration and Parsing](#some-exploration-of-the-data-and-the-parsing)
- [Preprocessing for T5 Model](#t5-preprocessor)
- [Training the T5 Model](#trainer-class)
- [Inference Engines](#inference-engines)
    - [Base Inference Engine](#base-inference-engine-and-helper)
    - [T5 Inference Engine](#t5-inference-engine)
    - [Llama 3 Inference Engine](#llama-3-inference-engine)
- [Evaluation](#evaluator-class)
    - [T5 FCE evaluation](#evaluating-the-test-data-fce-with-t5)
    - [Llama3 FCE evaluation](#evaluating-test-data-fce-with-llama-3)

- [Final evaluation on medical data](#evaluating-our-medical-data)

# Explanation

# Just Setting Up

In [1]:
!pip install datasets transformers tqdm scikit-learn sentencepiece torch

You should consider upgrading via the '/Users/isaac/venvs/tensorflow/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
## Imports

import os
# import openai
from datasets import Dataset, DatasetDict
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from tqdm import tqdm
import time
import logging
from typing import List, Dict, Tuple
import os
import tarfile
import sys


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove all handlers associated with the root logger object (avoid duplicate logs)
for handler in logger.handlers[:]:
    logger.removeHandler(handler)

# Create handler that outputs to notebook cell
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(funcName)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

In [3]:
# constants, ideally defined in other file and imported if done in a github repo

FCE_URL = "https://www.cl.cam.ac.uk/research/nl/bea2019st/data/fce_v2.1.bea19.tar.gz"
FCE_DEV_NAME = "fce.dev.gold.bea19.m2"
FCE_TRAIN_NAME = "fce.train.gold.bea19.m2"
FCE_TEST_NAME = "fce.test.gold.bea19.m2"
FCE_DOWNLOAD_DATASET_DIR = os.path.join(os.getcwd(), "data")
FCE_DATASET_DIR = os.path.join(os.getcwd(), "data/fce/m2")
SENTENCE_TAG = "S "
ANNOTATION_TAG = "A "
NO_EDIT_TAG = "noop"
SEP = "|||"
## We're gonna use T5 small for this task because we're just correcting spelling and we don't wait to wait hours for training.
MODEL_NAME = "t5-small"
FINETUNED_MODEL_OUTPUT_DIR = "./t5_finetuned"
PROMPT_PATH = os.path.join(os.getcwd(), "config/prompt.txt")
LLAMA3_ENDPOINT = "http://127.0.0.1:11434/api/generate"
TEXT_TO_REPLACE_IN_PROMPT = "<text_to_replace>"

# Dataset Downloader and Loader

## Class for orchestrating the downloading and Transformers propagation

In [5]:
### Dataset Loader class handling the dataset creation
import tarfile
from datasets import Dataset, DatasetDict
import requests
from typing import Optional, List, Dict


class M2DatasetLoader:
    """
    Loader for datasets in M2 format. Right now it's only tested for BCE2019

    Args:
        dataset_dir (str): Directory where the dataset is stored. Default is FCE_DATASET_DIR.
        train_file (str): Name of the training file. Default is FCE_TRAIN_NAME.
        dev_file (str): Name of the development file. Default is FCE_DEV_NAME.
        test_file (str): Name of the test file. Default is FCE_TEST_NAME.
        dataset_url (str): URL to download the dataset from. Default is FCE_URL.
        fce_download_dir (str): Directory to download the dataset to. Default is FCE_DOWNLOAD_DATASET_DIR.
    """
    def __init__(self,dataset_dir: str = FCE_DATASET_DIR, train_file: str = FCE_TRAIN_NAME,
                 dev_file: str = FCE_DEV_NAME, test_file: str = FCE_TEST_NAME, dataset_url: str = FCE_URL,
                 fce_download_dir: str = FCE_DOWNLOAD_DATASET_DIR):
        self.dataset_dir = dataset_dir
        self.fce_download_dir = fce_download_dir
        self.train_file = os.path.join(dataset_dir, train_file)
        self.dev_file = os.path.join(dataset_dir, dev_file)
        self.test_file = os.path.join(dataset_dir, test_file)
        self.url = dataset_url
        self.logger = logging.getLogger(self.__class__.__name__)
        self.dataset = None

    def download_and_extract(self) -> None:
        """
        Download the file and extract it into a directory if it does not exist
        """
        if not os.path.exists(self.fce_download_dir):
            os.makedirs(self.fce_download_dir, exist_ok=True)
            self.logger.info(f"Downloading BEA 2019 dataset...")
            tar_path = os.path.join(self.fce_download_dir, "bea19.tar.gz")
            with requests.get(self.url, stream=True) as r:
                with open(tar_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            self.logger.info("Extracting dataset...")
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=self.fce_download_dir)
            os.remove(tar_path)
            self.logger.info("BEA dataset downloaded and extracted successfully.")
        else:
            self.logger.warning("BEA dataset already exists locally.")

    def load_dataset(self) -> DatasetDict:
        """ Create the dataset in the Transformers format """
        dataset = DatasetDict({
            'train': Dataset.from_list(self._parse_m2_file(self.train_file)),
            'validation': Dataset.from_list(self._parse_m2_file(self.dev_file)),
            'test': Dataset.from_list(self._parse_m2_file(self.test_file))
        })
        self.logger.info(f"Loaded BEA dataset: {len(dataset['train'])} train, {len(dataset['validation'])} dev, {len(dataset['test'])} test")
        self.dataset = dataset
        return dataset
    
    def save_dataset(self, output_dir: str = os.path.join(os.getcwd(), FCE_DOWNLOAD_DATASET_DIR)) -> None:
        """
        Parses the M2 files and saves the resulting DatasetDict to disk in HuggingFace format.
        Args:
            output_dir (str): Directory to save the dataset.
        """
        if self.dataset is None:
            self.dataset = self.load_dataset()
        self.dataset.save_to_disk(output_dir + "parsed_fce_dataset")
        self.logger.info(f"Saved parsed dataset to {output_dir}")

    def _parse_m2_file(self, filepath: str) -> List[Dict[str, str]]:
      """
      Parses an M2 file and returns a list of {source, target} dictionaries,
      where each 'target' corresponds to one annotator's corrections.
      """
      data = []
      with open(filepath, 'r', encoding='utf-8') as f:
          lines = f.readlines()

      sentence = ""
      edits_by_annotator = dict()
      current_annotator = 0

      for line in lines + ['\n']:  # Add sentinel newline
          line = line.strip()
          if line.startswith(SENTENCE_TAG):
              if sentence and edits_by_annotator:
                  for annotator_id, edits in edits_by_annotator.items():
                      corrected = self._apply_m2_edits(sentence, edits)
                      data.append({'source': sentence, 'target': corrected})
              sentence = line[2:]
              edits_by_annotator = dict()
          elif line.startswith(ANNOTATION_TAG):
              parts = line[2:].split(SEP)
              span = list(map(int, parts[0].split()))
              error_type = parts[1]
              correction = parts[2]
              annotator_id = int(parts[-1])

              # Initialize list of edits for this annotator
              if annotator_id not in edits_by_annotator:
                  edits_by_annotator[annotator_id] = []
              edits_by_annotator[annotator_id].append((span, correction, error_type))
          elif line == "" and sentence:
              # End of current sentence block
              if not edits_by_annotator:
                  # No edits -> copy original sentence as is
                  data.append({'source': sentence, 'target': sentence})
              else:
                  for annotator_id, edits in edits_by_annotator.items():
                      corrected = self._apply_m2_edits(sentence, edits)
                      data.append({'source': sentence, 'target': corrected})
              sentence = ""
              edits_by_annotator = dict()
      return data


    def _apply_m2_edits(self, sentence: str, edits:str):
        """
        Applies M2 format edits to the original sentence.
        :param sentence: Original sentence (string)
        :param edits: List of (span, correction, error_type)
        :return: Corrected sentence
        """
        tokens = sentence.strip().split()
        # Sort edits by start index descending to avoid offset issues
        edits = sorted(edits, key=lambda x: x[0][0], reverse=True)

        for (start, end), correction, error_type in edits:
            if error_type == NO_EDIT_TAG:
                continue
            if start == -1 and end == -1:
                continue  # skip noop or invalid spans
            # Replace tokens[start:end] with correction tokens (split)
            correction_tokens = correction.strip().split()
            tokens = tokens[:start] + correction_tokens + tokens[end:]
        return ' '.join(tokens)


In [6]:
loader = M2DatasetLoader()

loader.download_and_extract()
dataset = loader.load_dataset()
loader.save_dataset(output_dir=os.getcwd() + "/data/")


2025-05-30 21:03:38,857 - INFO - M2DatasetLoader - load_dataset - Loaded BEA dataset: 28350 train, 2191 dev, 2695 test


Saving the dataset (1/1 shards): 100%|██████████| 28350/28350 [00:00<00:00, 3076466.80 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2191/2191 [00:00<00:00, 884137.01 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 2695/2695 [00:00<00:00, 1218853.71 examples/s]

2025-05-30 21:03:38,896 - INFO - M2DatasetLoader - save_dataset - Saved parsed dataset to /Users/isaac/Developer/GEC-system/data/





## Some exploration of the data and the parsing

The file has one or more line per sentence, and one or more line per annotation, grouped together by a blank line.
If there si a *noop* tag the there is no edit.


We need a set of functions to actually put this corrections into the sentence to build up a dataset with "source" and "target"

In [7]:
# Need yo have the data downloaded to explore it. Will be automatized in the next class
train_data = os.path.join(FCE_DATASET_DIR, FCE_TRAIN_NAME)


with open(train_data, 'r') as file:
    data = file.read()

print(data[:1500])

S Dear Sir or Madam ,
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||0

S I am writing in order to express my disappointment about your musical show " Over the Rainbow " .
A 9 10|||R:PREP|||with|||REQUIRED|||-NONE-|||0

S I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook .
A 1 2|||R:VERB:TENSE|||saw|||REQUIRED|||-NONE-|||0
A 8 9|||R:PREP|||on|||REQUIRED|||-NONE-|||0
A 36 37|||R:NOUN|||reviews|||REQUIRED|||-NONE-|||0
A 37 38|||R:PREP|||of|||REQUIRED|||-NONE-|||0
A 45 46|||R:PREP|||because of|||REQUIRED|||-NONE-|||0

S The problems started in the box office , where we asked for the discounts you announced in the advertisement , and the man who was selling the tickets said that they did n't exist .
A 3 4|||R:PREP|||at|||REQUIRED|||-NONE-|||0

S Moreover , the show was delayed forty -

To see what the parser is doing let's take two examples. One without edit and one with some edits.

Without edits: (First record o the file actually)

```S Dear Sir or Madam ,
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||0
```

In [8]:
print("Wrong")
print(dataset['train'][0].get("source"))
print("Corrected")
print(dataset['train'][0].get("target"))

Wrong
Dear Sir or Madam ,
Corrected
Dear Sir or Madam ,


Many edits, like third record of the file:

"I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends .
I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook ."

```
A 1 2|||R:VERB:TENSE|||saw|||REQUIRED|||-NONE-|||0
A 8 9|||R:PREP|||on|||REQUIRED|||-NONE-|||0
A 36 37|||R:NOUN|||reviews|||REQUIRED|||-NONE-|||0
A 37 38|||R:PREP|||of|||REQUIRED|||-NONE-|||0
A 45 46|||R:PREP|||because of|||REQUIRED|||-NONE-|||0
```

Explanation:

It bascially changes *saws* by *saw*, then *of* by *on*, *references* by *reviews*, again *about* by *of* and finally *about* by *because of*

In [9]:
print("Wrong")
print(dataset['train'][2].get("source"))
print("Corrected")
print(dataset['train'][2].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
I saws the show 's advertisement hanging up of a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good references about your Company and , above all , about the main star , Danny Brook .
Corrected
I saw the show 's advertisement hanging up on a wall in London where I was spending my holiday with some friends . I convinced them to go there with me because I had heard good reviews of your Company and , above all , because of the main star , Danny Brook .


Lets see another example where we delete something (I looked for it in the file)

*consequently* needs to get deleted

```
S If you do n't agree , I will act consequently .
A 9 10|||U:ADV||||||REQUIRED|||-NONE-|||0
```

In [10]:
print("Wrong")
print(dataset['train'][8].get("source"))
print("Corrected")
print(dataset['train'][8].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
If you do n't agree , I will act consequently .
Corrected
If you do n't agree , I will act .


With punctuation

In this case, we need to add a comma and a quoute

```

S She began to read " Dear Carolin ..
A 4 4|||M:PUNCT|||,|||REQUIRED|||-NONE-|||0
A 8 8|||M:PUNCT|||"|||REQUIRED|||-NONE-|||0

```



In [11]:
print("Wrong")
print(dataset['train'][15].get("source"))
print("Corrected")
print(dataset['train'][15].get("target"))
#print(dataset['validation'][0])
#print(dataset['test'][0])

Wrong
She began to read " Dear Carolin ..
Corrected
She began to read , " Dear Carolin .. "


Now that we already have a parser and loader of the dataset for our needs using M2 file, we will create a simple preprocessor for the T5 model

# T5 PreProcessor

In [12]:
from transformers import PreTrainedTokenizer
from datasets import DatasetDict
from typing import Dict

class T5Preprocessor:
    """
    T5 Preprocessor for the data. Could be implemented as abstract class to have multiple preprocessors.

    Args:
        tokenizer (PreTrainedTokenizer): The tokenizer to use.
        max_lenght (int): The maximum lenght of the input. Default is 512.
        truncation (bool): Whether to truncate the input. Default is True.
        padding (str): The padding to use. Default is "max_length".

    """


    def __init__(self, tokenizer: PreTrainedTokenizer, truncation: bool = True, padding: str = "max_length"):
        self.tokenizer = tokenizer
        self.truncation = truncation
        self.padding = padding
        self.logger = logging.getLogger(self.__class__.__name__)

    def _preprocess_function(self, examples: Dict[str,str], max_length: int) -> Dict[str, List[List[int]]]:
        """
        Preprocess examples tokenizing them using the instance parameters.
        """
        inputs = ["correct grammar: " + s for s in examples['source']]
        targets = [t for t in examples['target']]
        model_inputs = self.tokenizer(inputs, max_length=max_length, truncation=self.truncation, padding=self.padding)
        labels = self.tokenizer(targets, max_length=max_length, truncation=self.truncation, padding=self.padding)
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    def _get_max_input_length(self, dataset: DatasetDict) -> int:
        """
        Get the maximum input length in the dataset.
        """
        max_length = 0
        for split in dataset:
            split_max = max(len(x) for x in dataset[split]['source'])
            self.logger.info(f"Max input_ids length in {split}: {split_max}")
            max_length = max(max_length, split_max)
            self.logger.info(f"Overall max input_ids length: {max_length}")
        return max_length
    
    
    def save_tokenized_dataset(self, dataset: DatasetDict, output_dir: str) -> None:
        """
        Save the tokenized dataset to disk.
        """
        self.logger.info(f"Saving tokenized dataset to {output_dir}...")
        dataset.save_to_disk(output_dir)
        self.logger.info("Tokenized dataset saved successfully.")

    def preprocess(self, dataset: DatasetDict, max_length: int) -> DatasetDict:
        """
        Preprocess the dataset.
        """
        self.logger.info("Preprocessing dataset...")
        tokenized_dataset = dataset.map(lambda examples: self._preprocess_function(examples, max_length), batched=True)
        self.logger.info("Dataset preprocessed")
        return tokenized_dataset

In [13]:
from transformers import T5Tokenizer
from datasets import load_from_disk

dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "parsed_fce_dataset"))

preprocessor = T5Preprocessor(tokenizer=T5Tokenizer.from_pretrained(MODEL_NAME))

max_length = preprocessor._get_max_input_length(dataset)  # This will give you the max input length in the dataset for saving the preprocessed dataset to train the model


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


2025-05-30 21:04:59,802 - INFO - T5Preprocessor - _get_max_input_length - Max input_ids length in train: 630
2025-05-30 21:04:59,803 - INFO - T5Preprocessor - _get_max_input_length - Overall max input_ids length: 630
2025-05-30 21:04:59,805 - INFO - T5Preprocessor - _get_max_input_length - Max input_ids length in validation: 510
2025-05-30 21:04:59,805 - INFO - T5Preprocessor - _get_max_input_length - Overall max input_ids length: 630
2025-05-30 21:04:59,807 - INFO - T5Preprocessor - _get_max_input_length - Max input_ids length in test: 454
2025-05-30 21:04:59,807 - INFO - T5Preprocessor - _get_max_input_length - Overall max input_ids length: 630


In [14]:
preprocessed_dataset = preprocessor.preprocess(dataset, max_length=max_length)
preprocessed_dataset.save_to_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))

2025-05-30 21:04:59,812 - INFO - T5Preprocessor - preprocess - Preprocessing dataset...


Map: 100%|██████████| 28350/28350 [00:06<00:00, 4132.51 examples/s]
Map: 100%|██████████| 2191/2191 [00:00<00:00, 4314.68 examples/s]
Map: 100%|██████████| 2695/2695 [00:00<00:00, 4332.20 examples/s]

2025-05-30 21:05:07,839 - INFO - T5Preprocessor - preprocess - Dataset preprocessed



Saving the dataset (1/1 shards): 100%|██████████| 28350/28350 [00:00<00:00, 107178.48 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2191/2191 [00:00<00:00, 112135.40 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2695/2695 [00:00<00:00, 115969.36 examples/s]


Now what about the trainig class?

# Trainer class

In [15]:
from typing import Tuple, Optional
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import DatasetDict
import torch


class T5Trainer:
    def __init__(self, model_name: str = MODEL_NAME, output_dir: str = FINETUNED_MODEL_OUTPUT_DIR,
                 logging_dir: str = "./finetune_logs", save_strategy: str = "epoch", resume_from_dir: Optional[str] = None,
                 batch_size: int = 8, mixed_precision: bool = False, early_stopping: bool = True,
                 eval_strategy: str = "epoch", early_stopping_patience: int = 3, early_stopping_threshold: float = 0.0, 
                 metric_for_best_model: str = "eval_loss", greater_is_better: bool = False):
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.output_dir = output_dir
        self.logging_dir = logging_dir
        self.save_strategy = save_strategy
        self.logger = logging.getLogger(self.__class__.__name__)
        self.resume_from_dir = resume_from_dir
        self.batch_size = batch_size
        self.mixed_precision = mixed_precision
        self.early_stopping = early_stopping
        self.eval_strategy = eval_strategy
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_threshold = early_stopping_threshold
        self.metric_for_best_model = metric_for_best_model
        self.greater_is_better = greater_is_better
        

    def _check_if_gpu_available(self) -> bool:
        """
        Check if GPU is available.
        """
        available = torch.cuda.is_available()

        if available:
            self.logger.info("GPU available")
        else:
            self.logger.warning("GPU not available")

        return available

    def _check_if_model_already_available(self) -> bool:
        """
        Check if model is already available.
        """
        available = os.path.exists(self.output_dir)

        if available:
            self.logger.info("Model already available")
        else:
            self.logger.warning("Model not available")

        return available

    def _create_unique_dir_for_model(self) -> str:

        """
        Create a unique directory for the model based on the current time.
        """
        current_time = time.strftime("%Y%m%d-%H%M%S")
        unique_dir = os.path.join(self.output_dir, current_time)
        os.makedirs(unique_dir, exist_ok=True)
        self.logger.info(f"Created unique directory for model: {unique_dir}")
        return unique_dir


    def train(self, tokenized_dataset:DatasetDict, epochs=3, learning_rate: float = 3e-4, eval_strategy = "epoch") -> Tuple[str, T5ForConditionalGeneration, T5Tokenizer]:
        """
        Train the T5 model.
        Args:
            tokenized_dataset (DatasetDict): The tokenized dataset.
            epochs (int): The number of epochs to train for.
        Returns:
            Tuple[Model_Dir (ID), T5ForConditionalGeneration, T5Tokenizer]: The trained model and tokenizer.
        """
        self.logger.info("Checking if GPU is available...")
        self._check_if_gpu_available()
        #self.logger.info("Checking if model is already available...")
        #self._check_if_model_already_available()
        # Determine the output directory based on whether we're resuming or starting fresh
        if self.resume_from_dir and os.path.exists(self.resume_from_dir):
            self.output_dir = self.resume_from_dir
            self.logger.info(f"Resuming training from directory: {self.output_dir}")
        else:
             self.output_dir = self._create_unique_dir_for_model()
             self.logger.info(f"Starting new training in directory: {self.output_dir}")

        # self.output_dir = self._create_unique_dir_for_model()


        training_args = TrainingArguments(
            output_dir=self.output_dir,
            per_device_train_batch_size=self.batch_size,
            num_train_epochs=epochs,
            eval_strategy = self.eval_strategy,
            save_strategy=self.save_strategy,
            logging_dir=self.logging_dir,
            learning_rate=learning_rate,
            report_to="none", #Needed to avoid wandb api key request.,
            # fp16=self.mixed_precision,  # Enable mixed precision training if specified
        )
        
        if self.early_stopping:
            training_args.load_best_model_at_end = True
            training_args.metric_for_best_model = self.metric_for_best_model
            training_args.greater_is_better = self.greater_is_better
            training_args.save_total_limit = 1
            training_args.eval_strategy = self.eval_strategy
        self.logger.info("Training arguments set")
        
        if self.mixed_precision:
            training_args.fp16 = True
            self.logger.info("Mixed precision training enabled")


        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_dataset['train'],
            eval_dataset=tokenized_dataset['validation']
        )
        
        if self.early_stopping:
            trainer.add_callback(EarlyStoppingCallback(self.early_stopping_patience, self.early_stopping_threshold))
            self.logger.info("Early stopping callback added")

          # Resume training if a checkpoint exists in the output directory
        if self.resume_from_dir and os.path.exists(self.resume_from_dir):
             # The Trainer class automatically handles resuming from a directory if it exists
             # and contains a checkpoint. You don't need to explicitly load the state dicts
             # if you are using the Trainer's resume functionality.
             trainer.train(resume_from_checkpoint=self.output_dir)
        else:
            trainer.train()


        self.logger.info("T5 model trained")
        trainer.save_model(self.output_dir)
        self.model.save_pretrained(self.output_dir)
        self.tokenizer.save_pretrained(self.output_dir)
        return self.output_dir, self.model, self.tokenizer

In [16]:
def clean_empty_finetuned_model_dirs(output_dir: str = FINETUNED_MODEL_OUTPUT_DIR) -> None:
    """
    Clean empty directories in the output directory.
    Args:
        output_dir (str): The output directory to clean.
    """
    # Only check immediate children, do not traverse deeper
    for name in os.listdir(output_dir):
        path = os.path.join(output_dir, name)
        if os.path.isdir(path) and not os.listdir(path):
            os.rmdir(path)
            logger.info(f"Removed empty directory: {path}")
        else:
            logger.info(f"Directory {path} is not empty or not a directory, skipping.")

In [None]:
clean_empty_finetuned_model_dirs(os.getcwd() + "/models/t5_finetuned")

In [19]:
from datasets import load_from_disk

preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))


In [20]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()

In [None]:


#trainer = T5Trainer(resume_from_dir="./t5_finetuned/20250530-172532")
# trainer = T5Trainer(batch_size=16, resume_from_dir="/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/checkpoint-591")  # Enable mixed precision training

trainer = T5Trainer(batch_size=8)

model_dir, model, tokenizer = trainer.train(preprocessed_dataset)

In [None]:
model.save_pretrained("/Users/isaac/Developer/sample/t5_finetuned/20250530-121239")
tokenizer.save_pretrained("/Users/isaac/Developer/sample/t5_finetuned/20250530-121239")

('/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/tokenizer_config.json',
 '/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/special_tokens_map.json',
 '/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/spiece.model',
 '/Users/isaac/Developer/sample/t5_finetuned/20250530-121239/added_tokens.json')

# Inference Engines

## Base Inference Engine and Helper

In [4]:
from abc import ABC, abstractmethod


class BaseInferenceEngine(ABC):
    def __init__(self):
        self.logger = logging.getLogger(self.__class__.__name__)
    @abstractmethod
    def correct_sentence(self, sentence: str) -> str:
        pass
    @abstractmethod
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        pass

## T5 Inference Engine

In [None]:
from typing import Dict, Union, Optional


class T5InferenceEngine(BaseInferenceEngine):
    def __init__(self, model_dir: str, max_length: Optional[int] = None):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.t5_model = T5ForConditionalGeneration.from_pretrained(model_dir)
        self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
        self.max_length = max_length if max_length is not None else self.tokenizer.model_max_length
        self.logger.info(f"T5 model loaded from {model_dir} with max length {self.max_length}")
        

    def correct_sentence(self, sentence:str) -> str:
        """
        Corrects a sentence using the T5 model.
        Args:
            sentence (str): The sentence to correct.
        Returns:
            str: The corrected sentence.
        """
        self.logger.info(f"Correcting sentence: {sentence}")
        inputs = self.tokenizer("correct grammar: " + sentence, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length)
        outputs = self.t5_model.generate(**inputs, max_length=self.max_length)
        corrected_sentence = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.logger.info(f"Corrected sentence: {corrected_sentence}")
        return corrected_sentence
    
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        """
        Corrects a batch of sentences using the T5 model.
        Args:
            sentences (list): List of sentences to correct.
            batch_size (int): Number of sentences per batch.
        Returns:
            list: List of corrected sentences.
        """
        corrected = []
        for i in range(0, len(sentences), batch_size):
            self.logger.info(f"Processing batch {i // batch_size + 1} with size {min(batch_size, len(sentences) - i)}")
            batch = sentences[i:i+batch_size]
            inputs = self.tokenizer(
                ["correct grammar: " + s for s in batch],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            )
            outputs = self.t5_model.generate(**inputs, max_length=self.max_length)
            batch_corrected = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
            corrected.extend(batch_corrected)
        self.logger.info(f"Batch correction completed. Total corrected sentences: {len(corrected)}")
        return corrected

In [9]:
model_dir = "/Users/isaac/Developer/GEC-system/models/finished"
t5_inference_engine = T5InferenceEngine(model_dir=model_dir, max_length=650)
t5_inference_engine.max_length

2025-05-31 09:56:54,110 - INFO - T5InferenceEngine - __init__ - T5 model loaded from /Users/isaac/Developer/GEC-system/models/finished with max length 650


650

In [13]:
wrong_sentence = "You is a apple."
theorycally_corrected_sentence = "I have an apple."


corrected_sentence = t5_inference_engine.correct_sentence(wrong_sentence)

print(f"Wrong sentence: {wrong_sentence}")
print(f"Corrected sentence: {corrected_sentence}")

2025-05-31 09:57:14,637 - INFO - T5InferenceEngine - correct_sentence - Correcting sentence: You is a apple.
2025-05-31 09:57:14,715 - INFO - T5InferenceEngine - correct_sentence - Corrected sentence: You are an apple.
Wrong sentence: You is a apple.
Corrected sentence: You are an apple.


In [163]:
#%pip install nltk
import nltk
nltk.download('punkt')


[nltk_data] Downloading package punkt to /Users/isaac/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

## Llama 3 Inference Engine

See docs to run and query llama3 8b locally using ollama

https://ollama.com/library/llama3

In [43]:
from typing import Union
import json as _json
import asyncio
import aiohttp
import requests

class Llama3InferenceEngine(BaseInferenceEngine):
    model_name: str = "llama3"
    stream: bool = False
    response_format: dict = {
        "type": "object",
        "properties": {
            "original_text": {"type": "string"},
            "corrected_text": {"type": "string"}
        },
        "required": ["original_text", "corrected_text"]
    }
    
    def __init__(self, model_endpoint: str, prompt_path: str):
        self.logger = logging.getLogger(self.__class__.__name__)
        self.model_endpoint = model_endpoint
        self.prompt_path = prompt_path
        self.prompt = self._parse_prompt()
        
    def _parse_prompt(self) -> str:
        with open(self.prompt_path, 'r') as file:
            prompt = file.read()
        if not prompt:
            self.logger.error("Prompt is empty. Please check the prompt file.")
            raise ValueError("Prompt is empty. Please check the prompt file.")
        return prompt
    
    def _replace_prompt_variables(self, sentence: str) -> str:
        return self.prompt.replace(f"{TEXT_TO_REPLACE_IN_PROMPT}", sentence)

    def send_correct_request(self, sentence: str) -> Dict[str, Union[str, Dict[str, str]]]:
        prompt = self._replace_prompt_variables(sentence)
        try:
            response = requests.post(
                self.model_endpoint, 
                json={
                    "model": self.model_name,
                    "prompt": str(prompt),
                    "stream": self.stream,
                    "format": self.response_format
                }
            )
            response_data = response.json()
            if not response_data:
                self.logger.error("No corrected sentence returned from the model.")
                raise ValueError("No corrected sentence returned from the model.")
            if "response" in response_data and isinstance(response_data["response"], str):
                response_data["response"] = _json.loads(response_data["response"])
            return response_data
        except requests.exceptions.RequestException as e:
            self.logger.error(f"Error during model inference: {e}")
            raise RuntimeError(f"Error during model inference: {e}")
        
    def correct_sentence(self, sentence: str) -> str:
        self.logger.info(f"Correcting sentence: {sentence}")
        response = self.send_correct_request(sentence)
        if isinstance(response, dict):
            corrected_sentence = response.get("response", {}).get("corrected_text", "")
            self.logger.info(f"Corrected sentence: {corrected_sentence}")
            return corrected_sentence
        else:
            return ""
        
    def batch_correct(self, sentences: list, batch_size: int = 16) -> list:
        """Not implemented for Llama3InferenceEngine as it does not support full batch inference yet.
        This method is a placeholder to maintain interface consistency with BaseInferenceEngine.

        Args:
            sentences (list): _description_
            batch_size (int, optional): _description_. Defaults to 16.

        Returns:
            list: _description_
        """
        return super().batch_correct(sentences, batch_size)

    # --- ASYNC BATCH INFERENCE ---
    async def async_correct_sentence(self, session, sentence: str) -> str:
        prompt = self._replace_prompt_variables(sentence)
        payload = {
            "model": self.model_name,
            "prompt": str(prompt),
            "stream": self.stream,
            "format": self.response_format
        }
        self.logger.info(f"Sending async request for sentence: {sentence}")
        async with session.post(self.model_endpoint, json=payload) as resp:
            response_data = await resp.json()
            if not response_data:
                return ""
            if "response" in response_data and isinstance(response_data["response"], str):
                import json as _json
                response_data["response"] = _json.loads(response_data["response"])
                corrected_text = response_data.get("response", {}).get("corrected_text", "")
                self.logger.info(f"Async corrected sentence: {corrected_text}")
            
                return corrected_text if corrected_text else ""
            else:
                self.logger.error("Response format is not as expected.")
                return ""

    async def async_batch_correct(self, sentences, max_concurrent=5):
        timeout = aiohttp.ClientTimeout(total=60)
        semaphore = asyncio.Semaphore(max_concurrent)
        async with aiohttp.ClientSession(timeout=timeout) as session:
            async def sem_task(sentence):
                async with semaphore:
                    return await self.async_correct_sentence(session, sentence)
            tasks = [sem_task(s) for s in sentences]
            return await asyncio.gather(*tasks)

In [44]:
llama3_engine = Llama3InferenceEngine(model_endpoint=LLAMA3_ENDPOINT, prompt_path=PROMPT_PATH)

In [45]:
llama3_engine.send_correct_request("Helo")

{'model': 'llama3',
 'created_at': '2025-05-31T17:36:30.449223Z',
 'response': {'original_text': 'Helo', 'corrected_text': 'Hello'},
 'done': True,
 'done_reason': 'stop',
 'context': [128006,
  882,
  128007,
  271,
  2675,
  527,
  459,
  6335,
  6498,
  11311,
  11397,
  13,
  4718,
  3465,
  374,
  311,
  4495,
  904,
  69225,
  62172,
  11,
  43529,
  11,
  477,
  62603,
  6103,
  304,
  279,
  2768,
  11914,
  13,
  3234,
  539,
  2349,
  279,
  7438,
  477,
  1742,
  315,
  279,
  11914,
  13,
  8442,
  471,
  279,
  37065,
  11914,
  11,
  449,
  912,
  5217,
  16540,
  382,
  85664,
  25,
  16183,
  78,
  128009,
  128006,
  78191,
  128007,
  271,
  90,
  330,
  10090,
  4424,
  794,
  330,
  39,
  20782,
  498,
  330,
  20523,
  291,
  4424,
  794,
  330,
  9906,
  1,
  335],
 'total_duration': 4884461166,
 'load_duration': 10360125,
 'prompt_eval_count': 63,
 'prompt_eval_duration': 165838875,
 'eval_count': 19,
 'eval_duration': 893736917}

In [215]:
llama3_engine.correct_sentence("Helo")

2025-05-31 00:01:11,833 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: Helo
2025-05-31 00:01:12,732 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: Hello


'Hello'

# Evaluator Class

In [46]:
from sklearn.metrics import accuracy_score
from nltk.translate.gleu_score import sentence_gleu
from typing import Optional
import re

class Evaluator:
    """ Evaluator class for evaluating the performance of the inference engine on a dataset.
    Args:
        inference_engine (BaseInferenceEngine): The inference engine to use for evaluation.
        n_samples (Optional[int]): Number of samples to evaluate. If None, evaluates the entire test set. Default is None.
        predicted_dataset (Optional[DatasetDict]): A precomputed dataset with predictions. If provided, it will use this dataset instead of running inference. Default is None.
        dataset (DatasetDict): The dataset to evaluate. It also runs inference.
    """
    def __init__(self, inference_engine: BaseInferenceEngine, n_samples: Optional[int] = None, predicted_dataset: Optional[DatasetDict] = None, dataset: Optional[DatasetDict] = None):
        # self.dataset = dataset
        self.engine = inference_engine
        self.n_samples = n_samples
        self.test_data, self.references, self.predictions = None, None, None
        if dataset is None and predicted_dataset is None:
            raise ValueError("Either dataset or predicted_dataset must be provided.")
        if predicted_dataset is not None and isinstance(predicted_dataset, DatasetDict):
            self.test_data = predicted_dataset['test']['source']
            self.references = predicted_dataset['test']['target']
            self.predictions = predicted_dataset['test']['prediction']
        elif dataset is not None and isinstance(dataset, DatasetDict):
            self.dataset = dataset
        self.logger = logging.getLogger(self.__class__.__name__)

    @staticmethod
    def normalize_text(text):
        # Remove spaces before punctuation and ensure single space after
        text = re.sub(r'\s+([.,!?;:"])', r'\1', text)
        text = re.sub(r'([.,!?;:"])([^\s])', r'\1 \2', text)
        text = re.sub(r'\s+', ' ', text)
        return text.strip()

    def _get_samples(self) -> Tuple[List[str], List[str], List[str]]:
        """ Fetch samples from the test set."""
        test_data = self.dataset['test'] if self.test_data is None else self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set.")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        predictions = self.engine.batch_correct(sample_sentences)
        return sample_sentences, references, predictions
    
    def  _get_samples_if_not_available(self):
        """
        Check if samples are available.
        """
        if self.test_data is None or self.references is None or self.predictions is None:
            self.logger.warning("Samples not available. Fetching samples...")
            self.test_data, self.references, self.predictions = self._get_samples()
        else:
            self.logger.info("Samples already available. Using cached samples.")

    def evaluate_accuracy(self) -> float:
        """ 
        Evaluate the accuracy of the predictions against the references.
        It computes the exact match accuracy.
        Returns:
            float: The exact match accuracy.
        """
        self._get_samples_if_not_available()
        norm_refs = [self.normalize_text(ref) for ref in self.references]
        norm_preds = [self.normalize_text(pred) for pred in self.predictions]
        accuracy = accuracy_score(norm_refs, norm_preds)
        self.logger.info(f"Exact match accuracy on test set: {accuracy:.4f}")
        return accuracy

    def evaluate_gleu(self):
        """
        Evaluate the GLEU score of the predictions against the references.
        It uses the nltk library to compute the GLEU score.
        Returns:
            float: The average GLEU score across all samples.
            
        """
        self._get_samples_if_not_available()
        gleu_scores = []
        for pred, ref in zip(self.predictions, self.references):
            ref_tokens = self.normalize_text(ref).split()
            pred_tokens = self.normalize_text(pred).split()
            gleu = sentence_gleu([ref_tokens], pred_tokens)
            gleu_scores.append(gleu)
        avg_gleu = sum(gleu_scores) / len(gleu_scores) if gleu_scores else 0.0
        self.logger.info(f"Average GLEU score on test set: {avg_gleu:.4f}")
        return avg_gleu

    # --- ASYNC BATCH INFERENCE FOR LLAMA3 ---
    async def evaluate_accuracy_async(self)-> Optional[float]:
        """ Evaluate the accuracy of the predictions against the references using async batch inference.
        It computes the exact match accuracy.
        This method is specifically designed for engines that support async batch inference.

        Returns:
            Optional[float]: The exact match accuracy if async batch inference is available, otherwise None.
        """
        # Fetch samples (sentences and references)
        test_data = self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set (async).")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        # Use async batch correct if available
        if hasattr(self.engine, "async_batch_correct"):
            predictions = await self.engine.async_batch_correct(sample_sentences)
            self.test_data, self.references, self.predictions = sample_sentences, references, predictions
            norm_refs = [self.normalize_text(ref) for ref in self.references]
            norm_preds = [self.normalize_text(pred) for pred in self.predictions]
            accuracy = accuracy_score(norm_refs, norm_preds)
            self.logger.info(f"Exact match accuracy on test set: {accuracy:.4f}")
            return accuracy
        else:
            self.logger.warning("Async batch inference not available for this engine.")
            return None

    async def evaluate_gleu_async(self)-> Optional[float]:
        """ Evaluate the GLEU score of the predictions against the references using async batch inference.
        It uses the nltk library to compute the GLEU score.
        This method is specifically designed for engines that support async batch inference.

        Returns:
            Optional[float]: The average GLEU score across all samples if async batch inference is available, otherwise None.
        """
        # Fetch samples (sentences and references)
        test_data = self.dataset['test']
        n = min(self.n_samples, len(test_data)) if (self.n_samples is not None and self.n_samples > 0) else len(test_data)
        self.logger.info(f"Fetching {n} samples from the test set (async).")
        sample_sentences = [test_data[i]['source'] for i in range(n)]
        references = [test_data[i]['target'] for i in range(n)]
        # Use async batch correct if available
        if hasattr(self.engine, "async_batch_correct"):
            predictions = await self.engine.async_batch_correct(sample_sentences)
            self.test_data, self.references, self.predictions = sample_sentences, references, predictions
            gleu_scores = []
            for pred, ref in zip(self.predictions, self.references):
                ref_tokens = self.normalize_text(ref).split()
                pred_tokens = self.normalize_text(pred).split()
                gleu = sentence_gleu([ref_tokens], pred_tokens)
                gleu_scores.append(gleu)
            avg_gleu = sum(gleu_scores) / len(gleu_scores) if gleu_scores else 0.0
            self.logger.info(f"Average GLEU score on test set: {avg_gleu:.4f}")
            return avg_gleu
        else:
            self.logger.warning("Async batch inference not available for this engine.")
            return None

## Evaluating the test data (FCE with T5)

In [None]:
## This dataset was generated using the following command:
# python3 -m scripts.predictions_builder medical t5 fce /Users/isaac/Developer/GEC-system/data/fce_predicted.csv

import pandas as pd
predicted_df = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted.csv"))
predicted_dataset = DatasetDict({
    'test': Dataset.from_pandas(predicted_df)
})


# If you don't want to use it, comment the previous lines and uncomment the following 
# preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))
# from datasets import load_from_disk
# preprocessed_dataset = load_from_disk(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "preprocessed_fce_dataset"))
# t5_inference_engine = T5InferenceEngine(model_dir=model_dir, max_length=650)

# t5_evaluator= Evaluator(
#     inference_engine=t5_inference_engine,
#     dataset=preprocessed_dataset
# )
# accuracy = t5_evaluator.evaluate_accuracy()
# print("Accuracy:", accuracy)
# gleu = t5_evaluator.evaluate_gleu()
# print("GLEU:", gleu)


In [None]:

t5_evaluator = Evaluator(t5_inference_engine, predicted_dataset=predicted_dataset)

In [32]:
t5_evaluator.evaluate_accuracy()

2025-05-31 11:03:37,431 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-05-31 11:03:37,489 - INFO - Evaluator - evaluate_accuracy - Exact match accuracy on test set: 0.3800


0.3799628942486085

In [33]:
t5_evaluator.evaluate_gleu()

2025-05-31 11:03:46,076 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-05-31 11:03:46,226 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.7836


0.7835986906022476

## Evaluating test data (FCE) with Llama 3

In [151]:
llama3_evaluator = Evaluator(preprocessed_dataset, llama3_engine, n_samples=50)
llama3_evaluator.evaluate_accuracy()

2025-05-30 22:23:10,370 - INFO - Evaluator - _get_samples - Fetching 50 samples from the test set.
2025-05-30 22:23:10,443 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: Dear Mrs Smith ,
2025-05-30 22:23:14,485 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: Dear Ms. Smith,
2025-05-30 22:23:14,486 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: I am sad to read about Richard not being at his best .
2025-05-30 22:23:15,893 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: I am sad to hear that Richard is not doing well.
2025-05-30 22:23:15,894 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: I hope that he will recover soon and that he will make it to our conference .
2025-05-30 22:23:17,417 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: I hope that he will soon recover and be able to attend our conference.
2025-05-30 22:23:17,418 - INFO - Lla

0.12

In [152]:
llama3_evaluator.evaluate_gleu()

2025-05-30 22:25:14,157 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
Average GLEU score on test set: 0.5892


0.589201589384381

In [148]:
llama3_evaluator.test_data

['Dear Mrs Smith ,',
 'I am sad to read about Richard not being at his best .',
 'I hope that he will recover soon and that he will make it to our conference .',
 'As far as organisation is concerned , adequate number of double rooms with shower or bath have been booked for your group at the Palace hotel located across the main street from the train station .',
 'To come to our college , you can either watch downhill to the lake front ( about 10 min ) or take the new cable car from the train station .',
 'The college is just stairs away from the bottom end .',
 'The end - of - conference party will take place on a boat cruising on the lake .',
 'It will start by a speech from the Director of the conference , followed by a meal .',
 'The recreative part will begin at 10 pm with a short musical comedy performed by some of our students and end with fireworks launched from the roof of the college at 2 am .',
 'As nights may be cold at this time of the year , may I suggest that you take wit

In [147]:
llama3_evaluator.predictions

['Dear Ms. Smith,',
 'I am sad to hear that Richard is not at his best',
 'I hope that he will soon recover and be able to attend our conference.',
 'As far as organization is concerned, an adequate number of double rooms with a shower or bath have been booked for your group at the Palace Hotel, located across the main street from the train station.',
 'To come to our college, you can either walk down to the lake front (about 10 minutes) or take the new cable car from the train station.',
 'The college is just a few steps away from the other end.',
 'The end-of-conference party will take place on a boat that cruises on the lake.',
 'It will start with a speech from the director of the conference, followed by a meal.',
 'The recreational part will begin at 10:00 pm with a short musical comedy performed by some of our students, and it will end with fireworks launched from the roof of the college at 2:00 am.',
 'As nights may be cold at this time of year, may I suggest that you take a pul

In [149]:
llama3_evaluator.references

['Dear Mrs Smith ,',
 'I am sad to read about Richard not being well .',
 'I hope that he will recover soon and that he will make it to our conference .',
 'As far as the organisation is concerned , an adequate number of double rooms with shower or bath have been booked for your group at the Palace Hotel , located across the main street from the train station .',
 'To come to our college , you can either walk downhill to the lakeside ( about 10 mins ) or take the new cable car from the train station .',
 'The college is just steps away from the bottom end .',
 'The end - of - conference party will take place on a boat cruising on the lake .',
 'It will start with a speech from the Director of the conference , followed by a meal .',
 'The recreational part will begin at 10 pm with a short musical comedy performed by some of our students and end with fireworks launched from the roof of the college at 2 am .',
 'As nights may be cold at this time of the year , may I suggest that you take 

# Evaluating our medical data

Finally we want to evaluate with our medical data to see how it behaves...

In [204]:
import pandas as pd
from datasets import Dataset

# Read your medical data CSV (update the path as needed)
medical_df = pd.read_csv("/Users/isaac/Developer/GEC-system/data/data.csv")  # columns should be ['source', 'target']
medical_df.rename(columns={'incorrect_sentence': 'source', 'correct_sentence': 'target'}, inplace=True)

# Convert to HuggingFace Dataset
medical_data = Dataset.from_pandas(medical_df)
medical_dataset_dict = DatasetDict({'test': medical_data})


medical_data

Dataset({
    features: ['source', 'target'],
    num_rows: 204
})

## Evaluating with T5

In [208]:
t5_evaluator = Evaluator(medical_dataset_dict, t5_inference_engine)

In [209]:
t5_evaluator.evaluate_accuracy()

2025-05-30 23:54:37,918 - INFO - Evaluator - _get_samples - Fetching 204 samples from the test set.
2025-05-30 23:54:37,928 - INFO - T5InferenceEngine - correct_sentence - Correcting sentence: The patient report severe dyspnea and bilateral lower extremity edema following administration of intravenous furosemide.
2025-05-30 23:54:38,184 - INFO - T5InferenceEngine - correct_sentence - Corrected sentence: The patient report severe dyspnea and bilateral lower extremity edema following administration of intravenous furosemide.
2025-05-30 23:54:38,185 - INFO - T5InferenceEngine - correct_sentence - Correcting sentence: No signs of cellulitis or necrotizing fasciitis were find during the comprehensive dermatological examination.
2025-05-30 23:54:38,377 - INFO - T5InferenceEngine - correct_sentence - Corrected sentence: No signs of cellulitis or necrotizing fasciitis were found during the comprehensive dermatological examination.
2025-05-30 23:54:38,377 - INFO - T5InferenceEngine - correct_se

0.1323529411764706

In [210]:
t5_evaluator.evaluate_gleu()

2025-05-30 23:55:26,156 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
Average GLEU score on test set: 0.7898


0.7898051106345906

## Evaluating with Llama3

In [217]:
llama3_evaluator = Evaluator(medical_dataset_dict, llama3_engine)
llama3_evaluator.evaluate_accuracy()

2025-05-31 00:01:26,815 - INFO - Evaluator - _get_samples - Fetching 204 samples from the test set.
2025-05-31 00:01:26,832 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: The patient report severe dyspnea and bilateral lower extremity edema following administration of intravenous furosemide.
2025-05-31 00:01:29,151 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: The patient reported severe dyspnea and bilateral lower extremity edema following the administration of intravenous furosemide.
2025-05-31 00:01:29,151 - INFO - Llama3InferenceEngine - correct_sentence - Correcting sentence: No signs of cellulitis or necrotizing fasciitis were find during the comprehensive dermatological examination.
2025-05-31 00:01:31,135 - INFO - Llama3InferenceEngine - correct_sentence - Corrected sentence: No signs of cellulitis or necrotizing fasciitis were found during the comprehensive dermatological examination.
2025-05-31 00:01:31,136 - INFO - Llama3In

0.029411764705882353

In [218]:
llama3_evaluator.evaluate_gleu()

2025-05-31 00:08:24,623 - INFO - Evaluator - _get_samples_if_not_available - Samples already available. Using cached samples.
2025-05-31 00:08:24,640 - INFO - Evaluator - evaluate_gleu - Average GLEU score on test set: 0.6182


0.6182389959988361

## Using ERRANT Scorer on test data

### T5

The data is already predicted and stored as csv in data/fce_predicted.csv. Now we're gonna convert the predictions as txt as expected

In [63]:
import pandas as pd

# Read your medical data CSV (update the path as needed)
fce_predicted_df_t5 = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_t5.txt"), 'w') as f:
    for sentence in fce_predicted_df_t5['prediction']:
        f.write(Evaluator.normalize_text(sentence) + '\n')

with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_wrong.txt"), 'w') as f:
    for sentence in fce_predicted_df_t5['source']:
        f.write(Evaluator.normalize_text(sentence) + '\n')


Generate system M2

In [57]:
!pip install --upgrade --force-reinstall numpy h5py

Collecting numpy
  Using cached numpy-2.0.2-cp39-cp39-macosx_14_0_arm64.whl (5.3 MB)
Collecting h5py
  Downloading h5py-3.13.0-cp39-cp39-macosx_11_0_arm64.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 1.6 MB/s eta 0:00:01
[?25hInstalling collected packages: numpy, h5py
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
  Attempting uninstall: h5py
    Found existing installation: h5py 3.8.0
    Uninstalling h5py-3.8.0:
      Successfully uninstalled h5py-3.8.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-macos 2.12.0 requires numpy<1.24,>=1.22, but you have numpy 2.0.2 which is incompatible.
scipy 1.10.1 requires numpy<1.27.0,>=1.19.5, but you have numpy 2.0.2 which is incompatible.[0m
Successfully installed h5py-3.13.0 numpy-2.

In [65]:
!errant_parallel -orig ./data/fce_wrong.txt -cor ./data/fce_predicted_t5.txt -out ./data/m2/fce_predicted_t5.m2

Loading resources...
Processing parallel files...


Run it as script

In [66]:
!errant_compare -hyp ./data/m2/fce_predicted_t5.m2 -ref ./data/fce/m2/fce.test.gold.bea19.m2


TP	FP	FN	Prec	Rec	F0.5
679	1951	3870	0.2582	0.1493	0.2253



### LLAMA

In [None]:
### LLAMA
import pandas as pd

# Read your medical data CSV (update the path as needed)
fce_predicted_df_llama = pd.read_csv(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_llama.csv"))  # columns should be ['source', 'target']
with open(os.path.join(FCE_DOWNLOAD_DATASET_DIR, "fce_predicted_llama.txt"), 'w') as f:
    for sentence in fce_predicted_df_llama['prediction']:
        f.write(str(sentence) + '\n')


In [72]:
!errant_parallel -orig ./data/fce_wrong.txt -cor ./data/fce_predicted_llama.txt -out ./data/m2/fce_predicted_llama.m2

Loading resources...
Processing parallel files...


In [73]:
!errant_compare -hyp ./data/m2/fce_predicted_llama.m2 -ref ./data/fce/m2/fce.test.gold.bea19.m2


TP	FP	FN	Prec	Rec	F0.5
914	6726	3635	0.1196	0.2009	0.1302

