# Custom Inference Handler for Gector on Inference Endpoints

## Setup

In [16]:
%load_ext autoreload
%autoreload 2

In [2]:
!git lfs install
!git clone https://huggingface.co/andrewrreed/gector-deberta-large-5k

Git LFS initialized.
Cloning into 'gector-deberta-large-5k'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 12 (delta 0), reused 1 (delta 0), pack-reused 11[K
Unpacking objects: 100% (12/12), 1.15 MiB | 7.28 MiB/s, done.


In [None]:
# setup cli with token
!huggingface-cli login
!git config --global credential.helper store

In [3]:
!python -m venv .venv
!source .venv/bin/activate

In [None]:
# download Gector lirary
!git clone https://github.com/gotutiyan/gector.git

# copy in necessary files
%cp gector/requirements.txt gector-deberta-large-5k/.
%cp -r gector/gector gector-deberta-large-5k/.

# remove the old gector repo
%rm -rf gector

# change dir
%cd gector-deberta-large-5k

# remove nvidia/triton dependencies
!grep -viE "^(nvidia|triton)" requirements.txt > temp_requirements.txt && mv temp_requirements.txt requirements.txt

# install reqs
!pip install -r requirements.txt

# download verb form vocab
!mkdir data
!cd data && curl -LO https://github.com/grammarly/gector/raw/master/data/verb-form-vocab.txt

## Create custom handler for Inference Endpoints

In [8]:
%%writefile handler.py
import os
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer
from gector import GECToR, predict, load_verb_dict


class EndpointHandler:
    def __init__(self, path=""):
        self.model = GECToR.from_pretrained(path)
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.encode, self.decode = load_verb_dict(
            os.path.join(path, "data/verb-form-vocab.txt")
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process the input data and return the predicted results.

        Args:
            data (Dict[str, Any]): The input data dictionary containing the following keys:
                - "inputs" (List[str]): A list of input strings to be processed.
                - "n_iterations" (int, optional): The number of iterations for prediction. Defaults to 5.
                - "batch_size" (int, optional): The batch size for prediction. Defaults to 2.
                - "keep_confidence" (float, optional): The confidence threshold for keeping predictions. Defaults to 0.0.
                - "min_error_prob" (float, optional): The minimum error probability for keeping predictions. Defaults to 0.0.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries containing the predicted results for each input string.
        """
        srcs = data["inputs"]

        # Extract optional parameters from data, with defaults
        n_iterations = data.get("n_iterations", 5)
        batch_size = data.get("batch_size", 2)
        keep_confidence = data.get("keep_confidence", 0.0)
        min_error_prob = data.get("min_error_prob", 0.0)

        return predict(
            model=self.model,
            tokenizer=self.tokenizer,
            srcs=srcs,
            encode=self.encode,
            decode=self.decode,
            keep_confidence=keep_confidence,
            min_error_prob=min_error_prob,
            n_iteration=n_iterations,
            batch_size=batch_size,
        )


Writing handler.py


## Test Handler

In [24]:
from handler import EndpointHandler

my_handler = EndpointHandler(path=".")

# prepare sample payload
payload = {
    "inputs": ["This is a correct sentence.", "this is a INcorrect sentence"],
    "n_iterations": 25,
}
# test
out = my_handler(payload)
print(out)

Iteratoin 0. the number of to_be_processed: 2


100%|██████████| 1/1 [00:00<00:00,  3.80it/s]


Iteratoin 1. the number of to_be_processed: 1


100%|██████████| 1/1 [00:00<00:00,  5.88it/s]

['This is a correct sentence.', 'This is an incorrect sentence .']





## Push changes to Hub

In [None]:
!git add *
!git commit -m "add handler"
!git push