In [1]:
from gramformer import Gramformer
import torch
import spacy
import random

### Original GrammarModel

In [48]:
class GrammarModel(Gramformer):
    """
    Grammar correction model.
    """
    def __init__(self, models=1, use_gpu=False, seed=1212):
        self.gm = super().__init__(models=1, use_gpu=False)


    def grammar_correction(self,last_user_input):
        """
        Generate a corrected sentence and a message to the user with the correction.
        """
        corrected_sentence = self.correct(last_user_input, max_candidates=1)
        corrected_sentence = list(corrected_sentence)[0]
        message_styles = [
            "I think you meant: ",
            "Oh, you mean: ",
            "This would be better said like this: "
        ]

        if corrected_sentence != last_user_input:
            correction_message = f"{random.choice(message_styles)} \"{corrected_sentence}\" "
        else:
             correction_message = None

        return corrected_sentence, correction_message


    def add_correction_to_chat_history(self, chat_history):
        """
        Append the message to the user to the chat history.
        Return the corrected sentence.
        """
        last_user_input = chat_history[-1].get('text')
        corrected_sentence, correction_message = self.grammar_correction(last_user_input)
        error_types = self.get_edits(last_user_input, corrected_sentence)

        if correction_message:
            chat_history.append(
                {
                    'sender': 'bot',
                    'text': correction_message,
                    'correction': True
                }
            )
        return chat_history       


    def _get_edits(self, input_sentence, corrected_sentence):
        """
        Return the type of the error.
        """
        input_sentence = self.annotator.parse(input_sentence)
        corrected_sentence = self.annotator.parse(corrected_sentence)
        alignment = self.annotator.align(input_sentence, corrected_sentence)
        edits = self.annotator.merge(alignment)

        if len(edits) == 0:  
            return []

        edit_annotations = []
        for e in edits:
            e = self.annotator.classify(e)
            edit_annotations.append(e.type[2:])
                
        if len(edit_annotations) > 0:
            return edit_annotations
        else:    
            return []

In [4]:
gm = GrammarModel(models = 1, use_gpu=False)

[Gramformer] Grammar error correct/highlight model loaded..


### Showcase punctuation and casing errors (GrammarModel)

In [5]:
# Example sentences
ex1= "Hi bot!" # should not be corrected to "Bot"
ex2= "Hello" # should not be corrected to "Hello." or "Hello!"

In [32]:
error_types_ex1 = []
error_types_ex2 = []
corrected_sentences_ex1 = []
corrected_sentences_ex2 = []

In [43]:
correct_sentence, message = gm.grammar_correction(ex2)
corrected_sentences_ex2.append(correct_sentence)

error_types = gm.get_edits(ex2, correct_sentence)
error_types_ex2.extend(error_types)

print(correct_sentence, error_types)

Hello. ['OTHER']


In [45]:
# summary of tracked errors for sentence ex1
print(f"Original: {ex1}\nError Types: {error_types_ex1}\nSuggested Corrections: {corrected_sentences_ex1}")

Original: Hi bot!
Error Types: ['OTHER', 'ORTH', 'OTHER', 'NOUN', 'OTHER', 'OTHER', 'ORTH']
Suggested Corrections: ['Hi Bi-Bo!', 'Hi Bot!', 'Hi!', 'Hello BOTH!', 'Hi Bot!', 'Hi bot!']


In [46]:
# summary of tracked errors for sentence ex2
print(f"Original: {ex2}\nError Types: {error_types_ex2}\nSuggested Corrections: {corrected_sentences_ex2}")

Original: Hello
Error Types: ['OTHER', 'OTHER', 'OTHER', 'OTHER', 'OTHER', 'OTHER']
Suggested Corrections: ['Hello!', 'Hello.', 'Hello.', 'Hello.', 'Hello.', 'Hello.']


### Remove correction for error types ORTH, OTHER (and PUNCT?)

In [327]:
class GrammarModel2(Gramformer):
    """
    Grammar correction model.
    """
    def __init__(self, models=1, use_gpu=False, seed=1212):
        self.gm = super().__init__(models=1, use_gpu=False)
        self.ignore_errors = ['OTHER', 'ORTH']


    def grammar_correction(self,last_user_input):
        """
        Generate a corrected sentence and a message to the user with the correction.
        """
        corrected_sentence = self.correct(last_user_input, max_candidates=1)
        corrected_sentence = list(corrected_sentence)[0]
        message_styles = [
            "I think you meant: ",
            "Oh, you mean: ",
            "This would be better said like this: "
        ]

        if corrected_sentence != last_user_input:
            correction_message = f"{random.choice(message_styles)} \"{corrected_sentence}\" "
        else:
             correction_message = None

        return corrected_sentence, correction_message


    def add_correction_to_chat_history(self, chat_history):
        """
        Append the message to the user to the chat history.
        Return the corrected sentence.
        """
        last_user_input = chat_history[-1].get('text')
        corrected_sentence, correction_message = self.grammar_correction(last_user_input)
        error_types = self.get_edits(last_user_input, corrected_sentence)
        overlap_ignore_errors = any(item in error_types for item in self.ignore_errors)

        if correction_message and (overlap_ignore_errors is False):
            chat_history.append(
                {
                    'sender': 'bot',
                    'text': correction_message,
                    'correction': True,
                    'error_type': error_types
                }
            )
        return chat_history       


    def _get_edits(self, input_sentence, corrected_sentence):
        """
        Return the type of the error.
        """
        input_sentence = self.annotator.parse(input_sentence)
        corrected_sentence = self.annotator.parse(corrected_sentence)
        alignment = self.annotator.align(input_sentence, corrected_sentence)
        edits = self.annotator.merge(alignment)

        if len(edits) == 0:  
            return []

        edit_annotations = []
        for e in edits:
            e = self.annotator.classify(e)
            edit_annotations.append(e.type[2:])
                
        if len(edit_annotations) > 0:
            return edit_annotations
        else:    
            return []

In [328]:
gm2 = GrammarModel2(models=1, use_gpu=False)

[Gramformer] Grammar error correct/highlight model loaded..


In [341]:
chat_history_ex1 = [{'sender': 'User', 'text': 'Hi bot'}]

In [342]:
chat_history = gm2.add_correction_to_chat_history(chat_history_ex1)

In [343]:
chat_history

[{'sender': 'User', 'text': 'Hi bot'},
 {'sender': 'bot',
  'text': 'This would be better said like this:  "Hi booch!" ',
  'correction': True,
  'error_type': ['NOUN']}]