In this script we try to optimize the T5 transformer which is directly behind the gramformer. Results seem to be worse.

In [3]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import time
from gramformer import Gramformer

class OptimizedGrammarCorrector:
    def __init__(self):
        # Load model and tokenizer directly
        self.model_name = "prithivida/grammar_error_correcter_v1"
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
        
        # Optimize model for inference
        self.model.eval()
        
    def correct(self, sentence, max_length=128):
        # Directly process the input
        with torch.inference_mode():
            input_ids = self.tokenizer(f"gec: {sentence}", 
                                     return_tensors="pt", 
                                     padding=True).input_ids
            
            # Generate output ids
            outputs = self.model.generate(
                input_ids=input_ids,
                max_length=max_length,
                num_beams=5,
                early_stopping=True
            )
            
            # Decode and return
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test function
def test_custom_corrector():
    incorrect = "He have been working on this project for three year."
    
    # Initialize and time our custom corrector
    start_time = time.time()
    corrector = OptimizedGrammarCorrector()
    corrected = corrector.correct(incorrect)
    custom_time = time.time() - start_time
    
    print(f"Custom implementation: {custom_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Compare with original Gramformer
    start_time = time.time()
    gf = Gramformer(models=1, use_gpu=False)
    original_corrected = list(gf.correct(incorrect, max_candidates=1))[0]
    gramformer_time = time.time() - start_time
    
    print(f"\nOriginal Gramformer: {gramformer_time:.4f} seconds")
    print(f"Corrected: {original_corrected}")
    
    # Calculate speedup
    if custom_time > 0:
        speedup = gramformer_time / custom_time
        print(f"\nSpeedup: {speedup:.2f}x")

# Run the test
test_custom_corrector()

Custom implementation: 3.9616 seconds
Corrected: He has been working on this project for three years.
[Gramformer] Grammar error correct/highlight model loaded..

Original Gramformer: 4.8940 seconds
Corrected: He has been working on this project for three years.

Speedup: 1.24x
