# Efficient Zero Shot Token Classification with GLiNER

In this notebook we will show how to use the [GLiNER library](https://github.com/Knowledgator/GLiNER) to perform zero shot token classification. 

GLiNER is a Named Entity Recognition (NER) model capable of identifying any entity type using a bidirectional transformer encoder (BERT-like). It provides a practical alternative to traditional NER models, which are limited to predefined entities, and Large Language Models (LLMs) that, despite their flexibility, are costly and large for resource-constrained scenarios.

We will use the `GLiNER` class to classify spans within a given text according to a set of labels. 

## Getting started

### Deploy the Argilla server¶

If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).

### Install dependencies

In [2]:
!pip install -qqq argilla gliner transformers torch

## Initializing our models

In this step, we will intialize our model [knowledgator/gliner-bi-small-v1.0](https://huggingface.co/knowledgator/gliner-bi-small-v1.0). This is the smallest and most efficient model available. For higher accuracy you might concider using a slightly larger model. For more efficient model, you might consider [ONNX conversion](https://github.com/Knowledgator/gliner?tab=readme-ov-file#onnx-convertion) of the current models. 

In [1]:
from gliner import GLiNER

# Initialize GLiNER with the base model
gliner_model = GLiNER.from_pretrained("knowledgator/gliner-bi-small-v1.0")

# Sample text for entity prediction
text = """
Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
"""

# Labels for entity prediction
labels = ["Person", "Award", "Date", "Competitions", "Teams"] # for v2.1 use capital case for better performance

# Perform entity prediction
entities = gliner_model.predict_entities(text, labels, threshold=0.5)

# Display predicted entities and their labels
for entity in entities:
    print(entity["text"], "=>", entity["label"])

  from .autonotebook import tqdm as notebook_tqdm
Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]Error while downloading from https://cdn-lfs-us-1.hf.co/repos/3e/6d/3e6d93322409ad0c881312119bc2f73fde24cdebae4d31b158d697a2e1090e5b/35dc537205997fdcf5d9ced0ce976a62faacb84c1f3d6dec492e3fe0be333f75?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pytorch_model.bin%3B+filename%3D%22pytorch_model.bin%22%3B&response-content-type=application%2Foctet-stream&Expires=1728643338&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyODY0MzMzOH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzNlLzZkLzNlNmQ5MzMyMjQwOWFkMGM4ODEzMTIxMTliYzJmNzNmZGUyNGNkZWJhZTRkMzFiMTU4ZDY5N2EyZTEwOTBlNWIvMzVkYzUzNzIwNTk5N2ZkY2Y1ZDljZWQwY2U5NzZhNjJmYWFjYjg0YzFmM2Q2ZGVjNDkyZTNmZTBiZTMzM2Y3NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=oTIpb%7E7MJbY2IgcsRgQfPz44oUEcB9iZW%7EnQUf0dBu2GT%7EoUQPeOjsr2yl-MUMpBLb

ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'cdn-lfs-us-1.hf.co\', port=443): Max retries exceeded with url: /repos/3e/6d/3e6d93322409ad0c881312119bc2f73fde24cdebae4d31b158d697a2e1090e5b/35dc537205997fdcf5d9ced0ce976a62faacb84c1f3d6dec492e3fe0be333f75?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pytorch_model.bin%3B+filename%3D%22pytorch_model.bin%22%3B&response-content-type=application%2Foctet-stream&Expires=1728643338&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyODY0MzMzOH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzNlLzZkLzNlNmQ5MzMyMjQwOWFkMGM4ODEzMTIxMTliYzJmNzNmZGUyNGNkZWJhZTRkMzFiMTU4ZDY5N2EyZTEwOTBlNWIvMzVkYzUzNzIwNTk5N2ZkY2Y1ZDljZWQwY2U5NzZhNjJmYWFjYjg0YzFmM2Q2ZGVjNDkyZTNmZTBiZTMzM2Y3NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=oTIpb~7MJbY2IgcsRgQfPz44oUEcB9iZW~nQUf0dBu2GT~oUQPeOjsr2yl-MUMpBLbg0B47Q7etT7Ko0t8WK94LHa7vv9cxxz1UEW9jQ4Vb4c-vdFrBT~YqQdv8SzWax6pGpYkvWCczcl40ag81SCCzXFdFJK97pEP8nEwKvzLaxJNI51mXG1icE4GDiiIU5nZnp9h5qKDqVBxGSEt9KB4MVGgp8ZC~bEGg~5e8udZ1MLMuUwQSJEOvaFO6n76DZKihaBAVy02awVSY3gay3shTUkovKics56bl8kbZgSznPifnY1iolhQPG4XRbRnDebFirz-NJERqjH0sOMNFzOw__&Key-Pair-Id=K24J24Z295AEI9 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x17763d900>: Failed to resolve \'cdn-lfs-us-1.hf.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 99cf52e3-a202-4c05-ae7e-2220d5e0bba1)')

## Create our dataset

We will have a look at [the tner/ontonotes5 dataset](https://huggingface.co/datasets/tner/ontonotes5) to understand its structure and the kind of data it contains. We do this by using the embedded Hugging Face Dataset Viewer.

<iframe
  src="https://huggingface.co/datasets/tner/ontonotes5/embed/viewer/ontonotes5/train"
  frameborder="0"
  width="100%"
  height="560px"
></iframe>

In [None]:
import argilla as rg

labels = [
    "CARDINAL",
    "DATE",
    "PERSON",
    "NORP",
    "GPE",
    "LAW",
    "PERCENT",
    "ORDINAL",
    "MONEY",
    "WORK_OF_ART",
    "FAC",
    "TIME",
    "QUANTITY",
    "PRODUCT",
    "LANGUAGE",
    "ORG",
    "LOC",
    "EVENT",
]

settings = rg.Settings(
    guidelines="Classify individual tokens according to the specified categories, ensuring that any overlapping or nested entities are accurately captured.",
    fields=[
        rg.TextField(
            name="text",
            title="Text",
            use_markdown=False,
        ),
    ],
    questions=[
        rg.SpanQuestion(
            name="span_label",
            field="text",
            labels=labels,
            title="Classify the tokens according to the specified categories.",
            allow_overlapping=False,
        )
    ],
)

dataset = rg.Dataset(
    name="token_classification_dataset",
    settings=settings,
)
dataset.create()

Next, we will load our dataset and log the reocrds without predictions it to argilla.

In [None]:
from datasets import load_dataset

hf_dataset = load_dataset("tner/ontonotes5", split="test[:2100]")

records = [rg.Record(fields={"text": " ".join(row["tokens"])}) for row in hf_dataset]

dataset.records.log(records)

## Create record predictions

We will first create a basic prediction function. This function predicts entities and converts them into the correct entity representation for Argilla (start, end, label). These function will then provide intial suggestions for the NER labels for each one of the records.

In [None]:
def predict_gliner(model, text, labels, threshold):
    entities = model.predict_entities(text, labels, threshold)
    return [
        {k: v for k, v in ent.items() if k not in {"score", "text"}} for ent in entities
    ]

predict_gliner(
    model=gliner_model, text="This is a text about Bill Gates and Microsoft.", labels=labels, threshold=0.70
)

Next, we will continuously loop through our records to retrieve the latest records and update their model suggestions. By doing this, we know that our annotators will always have the latest and greatest suggestion available, easign their annotaion journey.

In [None]:

while True:
    pending_records_filter = rg.Filter(("status", "==", "pending"))
    pending_records = list(
        dataset.records(
            query=rg.Query(filter=pending_records_filter),
            limit=1,
        ).to_list(flatten=True)
    )
    updated_data = [
        {
            "span_label": predict_gliner(
                model=gliner_model, text=sample["text"], labels=labels, threshold=0.70
            ),
            "id": sample["id"],
        }
        for sample in pending_records
    ]