## Import libraries

Import Gramformer, T5 and torch for optimization

In [1]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
from gramformer import Gramformer
import torch
import time

  from .autonotebook import tqdm as notebook_tqdm


## Ignore warnings

In [2]:
import warnings
warnings.filterwarnings("ignore")

## Optimize T5 Transformer (Gramformer Base model) using torch.compile

In [3]:
class OptimizedT5Corrector:
    def __init__(self):
        # Load model and tokenizer directly
        self.model_name = "prithivida/grammar_error_correcter_v1"
        self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
        
        # Optimize model for inference
        self.model.eval()
        
        # Use torch.compile for PyTorch 2.0+ (significant speedup)
        if hasattr(torch, 'compile'):
            try:
                self.model = torch.compile(self.model)
                print("Successfully applied torch.compile optimization")
            except Exception as e:
                print(f"Could not apply torch.compile: {e}")
        
        # Optimize memory usage
        self.model.config.use_cache = True
        
    def correct(self, sentence, max_length=128):
        # Apply inference optimizations
        with torch.inference_mode():
            # Prepare input - the "gec:" prefix is important for the model
            input_text = f"gec: {sentence}"
            input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids
            
            # Optimize generation parameters for speed
            outputs = self.model.generate(
                input_ids=input_ids,
                max_length=max_length,
                num_beams=2,  #2 # Reduced from 5 for speed
                early_stopping=True,
                use_cache=True  # Enable KV caching for faster generation
            )
            
            # Decode output
            return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

In [7]:
incorrect = "He have been working on this project for three year."
incorrect = "Me don tthink that this new helicopter is going to be helping the fire crew combact the fire"
incorrect = "The coffe was spilled ar over de console of my brand new Porche. This car is very special to me and this shall not be hapening again"
    
# Test with Gramformer
print("Testing original Gramformer...")
start = time.time()
gf = Gramformer(models=1, use_gpu=False)
corrected = list(gf.correct(incorrect, max_candidates=1))[0]
original_time = time.time() - start
print(f"Original time: {original_time:.4f} seconds")
print(f"Corrected: {corrected}")

# Test with optimized implementation
print("\nTesting optimized implementation...")
start = time.time()
corrector = OptimizedT5Corrector()
corrected = corrector.correct(incorrect)
optimized_time = time.time() - start
print(f"Optimized time: {optimized_time:.4f} seconds")
print(f"Corrected: {corrected}")

# Calculate speedup
if optimized_time > 0:
    speedup = original_time / optimized_time
    print(f"\nSpeedup: {speedup:.2f}x")

Testing original Gramformer...
[Gramformer] Grammar error correct/highlight model loaded..
Original time: 9.1652 seconds
Corrected: The coffee was spilled over de console of my brand new Porche. This car is very special to me and this shall not be happening again.

Testing optimized implementation...
Successfully applied torch.compile optimization
Optimized time: 5.1128 seconds
Corrected: The coffee was spilled over de console of my brand new Porche. This car is very special to me and this shall not be happening again.

Speedup: 1.79x
