Utilized torch.compile optimizagtion and the inference time was significantly optimized! It also uses a batch processing approach for multiple sentences.

In [2]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch
import time

class OptimizedT5Corrector:
    def __init__(self):
        # Load model and tokenizer directly
        self.model_name = "prithivida/grammar_error_correcter_v1"
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
        
        # Optimize model for inference
        self.model.eval()
        
        # Use torch.compile for PyTorch 2.0+ (significant speedup)
        if hasattr(torch, 'compile'):
            try:
                self.model = torch.compile(self.model)
                print("Successfully applied torch.compile optimization")
            except Exception as e:
                print(f"Could not apply torch.compile: {e}")
        
        # Optimize memory usage
        self.model.config.use_cache = True
        
    def correct(self, sentence, max_length=128):
        # Apply inference optimizations
        with torch.inference_mode():
            # Prepare input - the "gec:" prefix is important for the model
            input_text = f"gec: {sentence}"
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
            
            # Optimize generation parameters for speed
            outputs = self.model.generate(
                input_ids=input_ids,
                max_length=max_length,
                num_beams=2,  #2 # Reduced from 5 for speed
                early_stopping=True,
                use_cache=True  # Enable KV caching for faster generation
            )
            
            # Decode output
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

# Performance testing function
def benchmark_speed():
    incorrect = "He have been working on this project for three year."
    
    # Test with Gramformer
    print("Testing original Gramformer...")
    from gramformer import Gramformer
    start = time.time()
    gf = Gramformer(models=1, use_gpu=False)
    corrected = list(gf.correct(incorrect, max_candidates=1))[0]
    original_time = time.time() - start
    print(f"Original time: {original_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Test with optimized implementation
    print("\nTesting optimized implementation...")
    start = time.time()
    corrector = OptimizedT5Corrector()
    corrected = corrector.correct(incorrect)
    optimized_time = time.time() - start
    print(f"Optimized time: {optimized_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Calculate speedup
    if optimized_time > 0:
        speedup = original_time / optimized_time
        print(f"\nSpeedup: {speedup:.2f}x")
    
    # Test multiple sentences for throughput comparison
    test_sentences = [
        "He have been working on this project for three year.",
        "She dont want to go to the movie.",
        "They was walking to the store yesterday.",
        "The cats is playing with the yarn.",
        "We has completed our assignment."
    ]
    
    print("\nBenchmarking multiple sentence throughput...")
    
    # Original Gramformer
    start = time.time()
    gf = Gramformer(models=1, use_gpu=False)
    for sent in test_sentences:
        corrected = list(gf.correct(sent, max_candidates=1))[0]
    original_batch_time = time.time() - start
    print(f"Original batch time: {original_batch_time:.4f} seconds")
    
    # Optimized implementation
    start = time.time()
    corrector = OptimizedT5Corrector()
    for sent in test_sentences:
        corrected = corrector.correct(sent)
    optimized_batch_time = time.time() - start
    print(f"Optimized batch time: {optimized_batch_time:.4f} seconds")
    
    # Calculate batch speedup
    if optimized_batch_time > 0:
        batch_speedup = original_batch_time / optimized_batch_time
        print(f"Batch speedup: {batch_speedup:.2f}x")

# Run the benchmark
benchmark_speed()

Testing original Gramformer...
[Gramformer] Grammar error correct/highlight model loaded..
Original time: 5.0279 seconds
Corrected: He has been working on this project for three years.

Testing optimized implementation...
Successfully applied torch.compile optimization
Optimized time: 2.8769 seconds
Corrected: He has been working on this project for three years.

Speedup: 1.75x

Benchmarking multiple sentence throughput...
[Gramformer] Grammar error correct/highlight model loaded..
Original batch time: 13.1276 seconds
Successfully applied torch.compile optimization
Optimized batch time: 7.1054 seconds
Batch speedup: 1.85x
