# Few-shot NER

In [1408]:
colors = ["#7aecec", "#bfeeb7", "#feca74", "#ff9561", "#aa9cfc", "#c887fb", "#9cc9cc", "#ffeb80",
          "#ff8197", "#ff8197", "#f0d0ff", "#bfe1d9", "#bfe1d9", "#e4e7d2", "#e4e7d2", "#e4e7d2",
          "#e4e7d2", "#e4e7d2"]

In [1409]:
from transformers import pipeline
import re
from spacy.matcher import DependencyMatcher 
import spacy
from spacy import displacy
from spacy.pipeline import EntityRuler

In [1410]:
nlp = spacy.load("en_core_web_sm", disable=["ner"])
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})

zero_shot_classifier = pipeline("zero-shot-classification",
                                model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")



In [1564]:
text = "I taken the lunch, yesterday in Apple company, it was a mashed potato with smoked beef"

In [1565]:
org_entities = {"fruite"     :    ["apple", 'avocado', 'banana', 'orange'],
                "vegetable"  :    ["carrot", "cabbage", 'broccoli'],
                "meat"       :    ["beef", "lamb", 'chicken'],
                "fish"       :    ["salamon", "shrimp", "tuna"],
                "spices"     :    ["angelica", "allspice", "cumin"],
                "ingredient" :    ['oil', 'salt', 'sugar'],
                "vehicle"    :    ['bus', 'truck'],
                "other"      :    []}

org_entities_map = {v:k for k in org_entities.keys() for v in org_entities[k]}
for k in org_entities.keys():
    org_entities_map[k] = k

entities = sum(org_entities.values(), []) + list(org_entities.keys())

CONF_TH = 0.01

# Preprocessing

In [1566]:
def clean(txt):
    '''
    Elminate any non word character from the input text, and remove any additionl spaces.
    
    Args:
      - txt (string) -> the unclean text.
    '''
#     txt = re.sub(r'\W', ' ', txt)
#     txt = re.sub(r' \w ', ' ', txt)
#     txt = re.sub(r' +', ' ', txt)
    return txt.strip().lower()

In [1567]:
clean_text = clean(text)

# POS & Dependency Tree

In [1568]:
doc = nlp(clean_text)

In [1569]:
# displacy.render(doc, style='dep')

In [1570]:
patterns = [
    # ADJ -> NOUN
    [{"RIGHT_ID": "adj", "RIGHT_ATTRS": {"POS": "ADJ"}},
     {"LEFT_ID": "adj", "REL_OP": "<", "RIGHT_ID": "subject", "RIGHT_ATTRS": { "POS": "NOUN"}}],
            
    # NOUN
    [{"RIGHT_ID": "noun", "RIGHT_ATTRS": {"POS": "NOUN"}}],
    
    # NOUN . NOUN
    [{"RIGHT_ID": "noun", "RIGHT_ATTRS": {"POS": "NOUN"}},
     {"LEFT_ID": "noun", "REL_OP": ".", "RIGHT_ID": "subject", "RIGHT_ATTRS": { "POS": "NOUN"}}],
    
    # PROPN
    [{"RIGHT_ID": "pnoun","RIGHT_ATTRS": {"POS": "PROPN"}}],
    
    # PROPN << NOUN
    [{"RIGHT_ID": "pnoun","RIGHT_ATTRS": {"POS": "PROPN"}},
     {"LEFT_ID": "pnoun", "REL_OP": "<<", "RIGHT_ID": "subject", "RIGHT_ATTRS": { "POS": "NOUN"}}],
    
    # PROPN << PROPN
    [{"RIGHT_ID": "pnoun", "RIGHT_ATTRS": {"POS": "PROPN"}},
     {"LEFT_ID": "pnoun", "REL_OP": "<<", "RIGHT_ID": "subject", "RIGHT_ATTRS": { "POS": "PROPN"}}],
    
    # VERB < NOUN
    [{"RIGHT_ID": "verb","RIGHT_ATTRS": {"POS": "VERB"}},
     {"LEFT_ID": "verb", "REL_OP": "<", "RIGHT_ID": "subject", "RIGHT_ATTRS": { "POS": "NOUN"}}],
]

matcher = DependencyMatcher(nlp.vocab)
matcher.add("entities", patterns)

In [1571]:
matches = matcher(doc)
spans = []
for match_id, start in matches:
    if len(start) > 1:
        start, end = start[0], start[1]+1
    else:
        start, end = start[0], start[0]+1
    spans.append(doc[start:end])
filterd_spans = list(spacy.util.filter_spans(spans))

In [1572]:
filterd_spans

[lunch, yesterday, apple company, mashed potato, smoked beef]

In [1573]:
results = zero_shot_classifier([str(s) for s in filterd_spans], list(entities))

In [1574]:
def get_highest_score_label(res):
    scores = {}
    for sc, label in zip(res['scores'], res['labels']):
        org_key = org_entities_map[label]
        scores[org_key] = scores.get(org_key, 0) + sc
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)[0]

new_results = []
for res in results:
    label, score = get_highest_score_label(res)
    new_results.append({"sequence": res['sequence'], "labels": [label], "scores": [score]})
results = new_results

In [1575]:
results_map = {i['sequence']:i for i in results}

In [1576]:
options = {"colors": {k:v for k,v in zip(list(org_entities.keys()), colors)},
           "ents": list(org_entities.keys()),
           'distance': 90}

In [1577]:
nlp.remove_pipe("entity_ruler")
ruler = nlp.add_pipe("entity_ruler", config={"overwrite_ents": True})

for res in results:
    if res['scores'][0]> CONF_TH:
        ruler.add_patterns( [{"label": res['labels'][0], "pattern": [{"TEXT": t} for t in res['sequence'].split()]}] )
    

In [1578]:
doc = nlp(clean_text)

In [1579]:
ents = doc.ents
for ent in ents:
    scr = results_map[str(ent)]['scores'][0]
    new_label = f"{ent.label_} ({float(scr):.0%})"
    options["colors"][new_label] = options["colors"].get(ent.label_.lower(), None)
    options["ents"].append(new_label)
    ent.label_ = new_label
doc.ents = ents

In [1580]:
displacy.render(doc, style="ent", options=options)