<img src="../assets/CoLLIE_blue.png" alt="GoLLIE" width="200"/>

# Named Entity Recognition with GoLLIE

This notebook is an example of how to run Named Entity Recognition with GoLLIE.  
In this example, we will use the CoNLL03 guidelines: https://www.clips.uantwerpen.be/conll2003/ner/

You can modify the script to run any Named Entity Recognition task you want

### Import requeriments

See the requeriments.txt file in the main directory to install the required dependencies

In [1]:
import sys
sys.path.append("../") # Add the GoLLIE base directory to sys path

In [178]:
import logging
from src.model.load_model import load_model
import black
import inspect
from jinja2 import Template
import tempfile
from src.tasks.utils_typing import AnnotationList
logging.basicConfig(level=logging.INFO)
from typing import Dict, List, Type

## Load GoLLIE

We will load GoLLIE-7B from the hugginface hub

- Set force_auto_device_map="auto" if you want to use the GPU
- Set quantization=4 if the model doesn't fit in your GPU memory

In [80]:
model, tokenizer = load_model(
    inference=True,
    model_weights_name_or_path="codellama/CodeLlama-7b-hf",
    quantization=None,
    use_lora=True,
    lora_weights_name_or_path="/ikerlariak/osainz006/models/collie/CoLLIE+-7b_CodeLLaMA",
    force_auto_device_map=False,
    use_flash_attention=False,
)

INFO:root:Loading model model from codellama/CodeLlama-7b-hf
INFO:root:We will load the model using the following device map: None and max_memory: None
INFO:root:Loading model with dtype: None


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

INFO:root:Model dtype: torch.float32
INFO:root:Total model memory footprint: 27491.065856 MB
INFO:root:Loading pretrained LORA weights from /ikerlariak/osainz006/models/collie/CoLLIE+-7b_CodeLLaMA
INFO:root:
LoRA config:
{'default': LoraConfig(peft_type='LORA', auto_mapping=None, base_model_name_or_path='/gaueko1/hizkuntza-ereduak/Code-LLaMA/huggingface/7b', revision=None, task_type='CAUSAL_LM', inference_mode=True, r=8, target_modules=['gate_proj', 'o_proj', 'down_proj', 'k_proj', 'up_proj', 'v_proj', 'q_proj'], lora_alpha=16, lora_dropout=0.05, fan_in_fan_out=False, bias='none', modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None)}

INFO:root:Merging LoRA layers into the model for faster inference.


In [179]:
print(model.device)

cpu


## Define the guideles

First, we will define the Labels and guidelines for the task. We define them as Python Classes


In [33]:
from typing import List

from src.tasks.utils_typing import Entity, dataclass

"""Entity definitions

The entity definitions are derived from the official ConLL2003 guidelines:
https://www.clips.uantwerpen.be/conll2003/ner/
Based on: Nancy Chinchor, Erica Brown, Lisa Ferro, Patty Robinson,
           "1999 Named Entity Task Definition". MITRE and SAIC, 1999.
"""


@dataclass
class Person(Entity):
    """first, middle and last names of people, animals and fictional characters aliases."""

    span: str  # Such as: "Clinton", "Dole", "Arafat", "Yeltsin", "Lebed"


@dataclass
class Organization(Entity):
    """Companies (press agencies, studios, banks, stock markets, manufacturers, cooperatives) subdivisions of
    companies (newsrooms) brands political movements (political parties, terrorist organisations) government bodies
    (ministries, councils, courts, political unions of countries (e.g. the {\it U.N.})) publications (magazines, newspapers,
    journals) musical companies (bands, choirs, opera companies, orchestras public organisations (schools, universities,
    charities other collections of people (sports clubs, sports teams, associations, theaters companies, religious orders,
    youth organisations."""

    span: str  # Such as: "Reuters", "U.N.", "NEW YORK", "CHICAGO", "PUK"


@dataclass
class Location(Entity):
    """Roads (streets, motorways) trajectories regions (villages, towns, cities, provinces, countries, continents,
    dioceses, parishes) structures (bridges, ports, dams) natural locations (mountains, mountain ranges, woods, rivers,
    wells, fields, valleys, gardens, nature reserves, allotments, beaches, national parks) public places (squares, opera
    houses, museums, schools, markets, airports, stations, swimming pools, hospitals, sports facilities, youth centers,
    parks, town halls, theaters, cinemas, galleries, camping grounds, NASA launch pads, club houses, universities,
    libraries, churches, medical centers, parking lots, playgrounds, cemeteries) commercial places (chemists, pubs,
    restaurants, depots, hostels, hotels, industrial parks, nightclubs, music venues) assorted buildings (houses, monasteries,
    creches, mills, army barracks, castles, retirement homes, towers, halls, rooms, vicarages, courtyards) abstract
    ``places'' (e.g. {\it the free world})"""

    span: str  # Such as: "U.S.", "Germany", "Britain", "Australia", "England"


@dataclass
class Miscellaneous(Entity):
    """Words of which one part is a location, organisation, miscellaneous, or person adjectives and other words derived
    from a word which is location, organisation, miscellaneous, or person religions political ideologies nationalities
    languages programs events (conferences, festivals, sports competitions, forums, parties, concerts) wars sports related
    names (league tables, leagues, cups titles (books, songs, films, stories, albums, musicals, TV programs) slogans eras
    in time types (not brands) of objects (car types, planes, motorbikes)"""

    span: str  # Such as: "Russian", "German", "British", "French", "Dutch"


ENTITY_DEFINITIONS: List[Entity] = [
    Person,
    Organization,
    Location,
    Miscellaneous,
]
    
if __name__ == "__main__":
    cell_txt = In[-1]

### Print the guidelines to guidelines.py

Due to IPython limitations, we need to print the content of the previous cell into a file and import the context of the file. 

In [134]:
with open("guidelines.py","w",encoding="utf8") as python_guidelines:
    print(cell_txt,file=python_guidelines)

from guidelines import *

We use inspect.getsource to get the guidelines as a string

In [135]:
guidelines = [inspect.getsource(definition) for definition in ENTITY_DEFINITIONS]

## Define input sentence

Here we define the input sentence and the gold labels.

You can define and empy list as gold labels if you don't have gold annotations.

In [136]:
text = "Japan began the defence of their Asian Cup title with a lucky 2-1 win against Syria in a Group C championship match on Friday ."
gold = [Location(span="Japan"),Miscellaneous(span="Asian Cup"),Location(span="Syria")]

## Fill the template

For NER we will use the following prompt template

```Python
# The following lines describe the task definition
{%- for definition in guidelines %}
{{ definition }}
{%- endfor %}

# This is the text to analyze
text = {{ text.__repr__() }}

# The annotation instances that take place in the text above are listed here
result = [
{%- for ann in annotations %}
    {{ ann }},
{%- endfor %}
]

```

This template is stored in `templates/prompt.txt`

In [181]:
with open("../templates/prompt.txt", "rt") as f:
    template = Template(f.read())
text = template.render(guidelines=guidelines, text=text, annotations=gold, gold=gold)

### Black Code Formatter

We use the Black Code Formatter to automatically unify all the prompts to the same format. 

https://github.com/psf/black

In [182]:
black_mode = black.Mode()
text = black.format_str(text, mode=black_mode)

### Print the filled and formatted template

In [184]:
print(text)

# The following lines describe the task definition
@dataclass
class Person(Entity):
    """first, middle and last names of people, animals and fictional characters aliases."""

    span: str  # Such as: "Clinton", "Dole", "Arafat", "Yeltsin", "Lebed"


@dataclass
class Organization(Entity):
    """Companies (press agencies, studios, banks, stock markets, manufacturers, cooperatives) subdivisions of
    companies (newsrooms) brands political movements (political parties, terrorist organisations) government bodies
    (ministries, councils, courts, political unions of countries (e.g. the {\it U.N.})) publications (magazines, newspapers,
    journals) musical companies (bands, choirs, opera companies, orchestras public organisations (schools, universities,
    charities other collections of people (sports clubs, sports teams, associations, theaters companies, religious orders,
    youth organisations."""

    span: str  # Such as: "Reuters", "U.N.", "NEW YORK", "CHICAGO", "PUK"


@datacla

## Prepare model inputs

We remove everything after `result =` to run inference with the model.

In [107]:
prompt, _ = text.split("result =")
prompt = prompt + "result ="

In [108]:
model_input = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
model_input["input_ids"] = model_input["input_ids"][:, :-1]

## Run GoLLIE

We generate the predictions using GoLLIE. 

We use `num_beams=1` and `do_sample=False` in our exmperiments. But feel free to experiment 😊

- If you are running the model in CPU, this cell will take ~1 minute
- If you are running the model in CUDA, this cell will take <10 seconds

In [109]:
%%time

model_ouput = model.generate(
    input_ids=model_input.input_ids.to(model.device),
    max_new_tokens=128,
    do_sample=False,
    min_new_tokens=0,
    num_beams=1,
    num_return_sequences=1,
)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


CPU times: user 15min 22s, sys: 15.9 s, total: 15min 37s
Wall time: 47.5 s


### Print the results

In [128]:
for y, x in enumerate(model_ouput):
    print(f"Answer {y}")
    print(tokenizer.decode(x,skip_special_tokens=True).split("result = ")[-1])

Answer 0
[
    Location(span="Japan"),
    Miscellaneous(span="Asian Cup"),
    Location(span="Syria"),
]



## Parse the output

The output is a Python module, we can execute it to parse the results and get a list with the annotations 🤯

We define the AnnotationList class to parse the output with a single line of code. The `AnnotationList.from_output` function filters any label that we did not define (hallucinations) to prevent getting an `undefined class` error. 

In [140]:
result = AnnotationList.from_output(tokenizer.decode(model_ouput[0],skip_special_tokens=True).split("result = ")[-1],task_module="guidelines")
print(result)

[Location(span='Japan'), Miscellaneous(span='Asian Cup'), Location(span='Syria')]


Labels are an instance of the defined classes:

In [162]:
type(result[0])

guidelines.Location

In [163]:
result[0].span

'Japan'

# Evaluate the result

Finally, we will evaluate the outputs from the model.

First, we define an Scorer, for Named Entity Recognition, we will use the `SpanScorer` class.

We need to define the valid_types for the scorer, which will be the labels that we have defined. 

In [175]:
from src.tasks.utils_scorer import SpanScorer

class CoNLL03EntityScorer(SpanScorer):
    """CoNLL03 Entity identification and classification scorer."""

    valid_types: List[Type] = ENTITY_DEFINITIONS

    def __call__(self, reference: List[Entity], predictions: List[Entity]) -> Dict[str, Dict[str, float]]:
        output = super().__call__(reference, predictions)
        return {"entities": output["spans"]}


### Instanciate the scorer

In [None]:
scorer = CoNLL03EntityScorer()

### Compute F1 

In [186]:
scorer(reference=[gold],predictions=[result])

{'entities': {'precision': 1.0,
  'recall': 1.0,
  'f1-score': 1.0,
  'class_scores': {'Location': {'tp': 2,
    'total_pos': 2,
    'total_pre': 2,
    'precision': 1.0,
    'recall': 1.0,
    'f1-score': 1.0},
   'Miscellaneous': {'tp': 1,
    'total_pos': 1,
    'total_pre': 1,
    'precision': 1.0,
    'recall': 1.0,
    'f1-score': 1.0}}}}