# 🔫 Zero-shot Named Entity Recognition with Flair

### **TL;DR**: 

You can use Rubrix for analizing and validating the NER predictions from the new zero-shot model provided by the Flair NLP library. 

This is useful for quickly bootstrapping a training set (using Rubrix [*Annotation Mode*](../reference/rubrix_webapp_reference.rst#annotation-mode)) as well as integrating with weak-supervision workflows.

![wnut zeroshot explore](https://github.com/recognai/rubrix-materials/raw/main/tutorials/zeroshot_ner/zeroshot_ner.gif)

## Install dependencies

In [None]:
%pip install datasets flair -qqq

## Setup Rubrix

**If you are new to Rubrix, visit and ⭐ star Rubrix for more materials like and detailed docs**: [Github repo](https://github.com/recognai/rubrix)

If you have not installed and launched Rubrix, check the [Setup and Installation guide](https://docs.rubrix.ml/en/latest/getting_started/setup%26installation.html).


Once installed, you only need to import Rubrix:

In [None]:
import rubrix as rb

## Load the `wnut_17` dataset

In this example, we'll use a challenging NER dataset, the "WNUT 17: Emerging and Rare entity recognition" dataset, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This dataset is useful for getting a sense of the quality of our zero-shot predictions.

Let's load the test set from the Hugging Face Hub:

In [None]:
from datasets import load_dataset

dataset = load_dataset("wnut_17", split=["test"])

In [None]:
wnut_labels = [tag.split('-')[1] for tag in dataset['train'].features['ner_tags'].feature.names if '-' in tag]

## Configure Flair TARSTagger

Now let's configure our NER model, following [Flair's  documentation](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_10_TRAINING_ZERO_SHOT_MODEL.md#use-case-2-zero-shot-named-entity-recognition-ner-with-tars).

In [None]:
from flair.models import TARSTagger
from flair.data import Sentence

# Load zero-shot NER tagger
tars = TARSTagger.load('tars-ner')

# Define labels for named entities using wnut labels
labels = wnut_labels
tars.add_and_switch_to_new_task('task 1', labels, label_type='ner')

Let's test it with one example!

In [None]:
sentence = Sentence(" ".join(dataset[0][0]['tokens']))

In [None]:
tars.predict(sentence)

# Creating the prediction entity as a list of tuples (entity, start_char, end_char)
prediction = [
    (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
    for entity in sentence.get_spans("ner")
]
prediction

## Predict over `wnut_17` and log into `rubrix`

Now, let's log the predictions in `rubrix`

In [None]:
records = []
for record in dataset[0]:
    input_text = " ".join(record["tokens"])
    
    sentence = Sentence(input_text)
    tars.predict(sentence)
    prediction = [
        (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
        for entity in sentence.get_spans("ner")
    ]
    
    # Building TokenClassificationRecord
    rb_record = rb.TokenClassificationRecord(
        text=input_text,
        tokens=[token.text for token in sentence],
        prediction=prediction,
        prediction_agent="tars-ner",
    )
    
    rb.log(rb_record, name='tars_ner_wnut_17', metadata={"split": "test"})