The optimization approach here was quantization. It reduces the model size but it significantly increased the inference time for some unknown reason.

In [3]:
import torch
from gramformer import Gramformer
import time

def create_optimized_gramformer():
    # Initialize regular Gramformer
    gf = Gramformer(models=1, use_gpu=False)
    
    # Get the underlying model
    model = gf.correction_model
    
    # Optimize the model
    # 1. Set to eval mode
    model.eval()
    
    # 2. Apply torch.jit optimization
    try:
        # Create dummy input for tracing
        dummy_text = "This is a sample sentence with error."
        dummy_input = gf.correction_tokenizer(f"gec: {dummy_text}", 
                                           return_tensors="pt").input_ids
        
        # Trace the model
        with torch.inference_mode():
            # Use script instead of trace for better optimization
            traced_model = torch.jit.script(model)
            
        # Replace the original model
        gf.correction_model = traced_model
        
        print("Successfully optimized the model with TorchScript")
    except Exception as e:
        print(f"Could not apply TorchScript optimization: {e}")
    
    # 3. Apply quantization
    try:
        # Quantize the model to int8
        quantized_model = torch.quantization.quantize_dynamic(
            model, 
            {torch.nn.Linear}, 
            dtype=torch.qint8
        )
        
        # Replace the original model
        gf.correction_model = quantized_model
        
        print("Successfully applied quantization")
    except Exception as e:
        print(f"Could not apply quantization: {e}")
    
    return gf

# Test function
def test_optimized_gramformer():
    incorrect = "He have been working on this project for three year."
    
    # Test regular Gramformer
    print("Testing regular Gramformer...")
    start_time = time.time()
    regular_gf = Gramformer(models=1, use_gpu=False)
    corrected = list(regular_gf.correct(incorrect, max_candidates=1))[0]
    regular_time = time.time() - start_time
    print(f"Regular time: {regular_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Test optimized Gramformer
    print("\nTesting optimized Gramformer...")
    start_time = time.time()
    optimized_gf = create_optimized_gramformer()
    corrected = list(optimized_gf.correct(incorrect, max_candidates=1))[0]
    optimized_time = time.time() - start_time
    print(f"Optimized time: {optimized_time:.4f} seconds")
    print(f"Corrected: {corrected}")
    
    # Show speedup
    speedup = regular_time / optimized_time
    print(f"\nSpeedup: {speedup:.2f}x")

# Run the test
test_optimized_gramformer()

Testing regular Gramformer...
[Gramformer] Grammar error correct/highlight model loaded..
Regular time: 5.2717 seconds
Corrected: He has been working on this project for three years.

Testing optimized Gramformer...
[Gramformer] Grammar error correct/highlight model loaded..
Could not apply TorchScript optimization: Comprehension ifs are not supported yet:
  File "c:\Users\andre\miniconda3\envs\miaa\lib\site-packages\transformers\models\t5\modeling_t5.py", line 1179
    
        if not return_dict:
            return tuple(
                v
                for v in [

Successfully applied quantization
Optimized time: 15.8964 seconds
Corrected: He has been working on this project for three years.

Speedup: 0.33x
