This approach uses onnx runtime for optimization and it delivers a slight enhanced performance of the inference time. However it is not significant, particularly when the model is cached.

In [3]:
import torch
import onnxruntime as ort
from gramformer import Gramformer
from transformers import T5ForConditionalGeneration, T5Tokenizer
import time
import os

class SuperOptimizedGrammarCorrector:
    def __init__(self, use_onnx=True, use_quantization=True):
        self.model_name = "prithivida/grammar_error_correcter_v1"
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        
        # Path for ONNX model
        self.onnx_path = "grammar_corrector.onnx"
        
        if use_onnx and os.path.exists(self.onnx_path):
            # Use existing ONNX model
            print("Loading existing ONNX model")
            self.onnx_session = ort.InferenceSession(self.onnx_path)
            self.use_onnx = True
        elif use_onnx:
            # Create and export ONNX model
            print("Creating new ONNX model")
            model = T5ForConditionalGeneration.from_pretrained(self.model_name)
            model.eval()
            
            # Create dummy input for export
            dummy_input = self.tokenizer("gec: This is a test.", return_tensors="pt")
            
            # Export to ONNX
            try:
                torch.onnx.export(
                    model,
                    (dummy_input.input_ids,),
                    self.onnx_path,
                    export_params=True,
                    opset_version=12,
                    input_names=['input_ids'],
                    output_names=['output'],
                    dynamic_axes={
                        'input_ids': {0: 'batch', 1: 'sequence'},
                        'output': {0: 'batch', 1: 'sequence'}
                    }
                )
                self.onnx_session = ort.InferenceSession(self.onnx_path)
                self.use_onnx = True
                print("Successfully created ONNX model")
            except Exception as e:
                print(f"Failed to create ONNX model: {e}")
                self.use_onnx = False
                self.model = model
        else:
            # Use PyTorch model with or without quantization
            model = T5ForConditionalGeneration.from_pretrained(self.model_name)
            model.eval()
            
            if use_quantization:
                try:
                    # Apply dynamic quantization
                    self.model = torch.quantization.quantize_dynamic(
                        model, {torch.nn.Linear}, dtype=torch.qint8
                    )
                    print("Successfully applied quantization")
                except Exception as e:
                    print(f"Failed to apply quantization: {e}")
                    self.model = model
            else:
                self.model = model
            
            self.use_onnx = False
    
    def correct(self, sentence, max_length=128):
        # Preprocess
        input_text = f"gec: {sentence}"
        
        if self.use_onnx:
            # Use ONNX runtime for inference
            encoded = self.tokenizer(input_text, return_tensors="pt")
            ort_inputs = {
                'input_ids': encoded.input_ids.numpy()
            }
            ort_outputs = self.onnx_session.run(None, ort_inputs)
            # Process ONNX outputs and decode
            return self.tokenizer.decode(ort_outputs[0][0], skip_special_tokens=True)
        else:
            # Use PyTorch model
            with torch.inference_mode():
                input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
                outputs = self.model.generate(
                    input_ids=input_ids,
                    max_length=max_length,
                    num_beams=5,
                    early_stopping=True
                )
                return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test function
def test_super_optimized():
    incorrect = "He have been working on this project for three year."
    
    # Test regular Gramformer
    print("Testing regular Gramformer...")
    start_time = time.time()
    gf = Gramformer(models=1, use_gpu=False)
    corrected = list(gf.correct(incorrect, max_candidates=1))[0]
    regular_time = time.time() - start_time
    print(f"Regular time: {regular_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Test super optimized implementation
    print("\nTesting super optimized implementation...")
    start_time = time.time()
    corrector = SuperOptimizedGrammarCorrector(use_onnx=True, use_quantization=True)
    corrected = corrector.correct(incorrect)
    optimized_time = time.time() - start_time
    print(f"Optimized time: {optimized_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Calculate speedup
    speedup = regular_time / optimized_time
    print(f"\nSpeedup: {speedup:.2f}x")

# Run the test
test_super_optimized()

Testing regular Gramformer...
[Gramformer] Grammar error correct/highlight model loaded..
Regular time: 4.9096 seconds
Corrected: He has been working on this project for three years.

Testing super optimized implementation...
Creating new ONNX model
Failed to create ONNX model: You have to specify either decoder_input_ids or decoder_inputs_embeds
Optimized time: 4.3665 seconds
Corrected: He has been working on this project for three years.

Speedup: 1.12x
