# LLM prompting for entity labeling
This notebook contains starter code for prompting an LLM API for the task of entity recognition.

In [None]:
!pip install ipytest
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install seqeval
!pip install ratelimit
!pip install cohere

In [None]:
# This code block just contains standard setup code for running in Python
import json
import string
import re
import time
from tqdm.auto import tqdm
import random

# PyTorch imports
import torch
from torch.utils.data import DataLoader
import numpy as np
from transformers import AutoTokenizer, BertModel, DefaultDataCollator

from datasets import load_dataset

import evaluate
from ratelimit import limits
import cohere

In [24]:
# Fix the random seed(s) for reproducability
random_seed = 8942764
torch.random.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)

In [None]:
# Just a helper function for efficiently removing punctuation from a string
def strip_punct(s):  return s.translate(string.punctuation)

In [None]:
# Set up LLM backend. They are mostly the same, but some slight differences.

# Initialize Cohere LLM client with your API key.
# You can register for an account here: https://dashboard.cohere.ai/welcome/register
# Then you can find your API key here: https://dashboard.cohere.com/api-keys
co = cohere.Client('API-KEY')
USER_STR = "USER"
SYSTEM_STR = "SYSTEM"
MSG_STR = "message"

In [None]:
# Here is how you can use the API to prompt the Cohere model. All the LLM APIs have pretty much the same format.
# Docs: https://docs.cohere.com/reference/chat

# We're providing one prompt format here, which we'll use as the "baseline" format.

# Here's an example of one way you can provide a prompt and demonstrations to the model, through the chat history.
# Here, we provide the initial prompt using the SYSTEM role, then provide each example (here, just one) as a USER, SYSTEM interaction.
chat_history = [
    {'role': SYSTEM_STR, MSG_STR:
     """You will be given input text containing different types of entities that you will label.
     This is the list of entity types to label: Deity, Mythological_king, Cretaceous_dinosaur, Aquatic_mammal, Aquatic_animal, Goddess.
     Label the enities by surrounding them with tags like '<Cretaceous_dinosaur> Beipiaognathus </Cretaceous_dinosaur>'."""
     }, 
     {'role': USER_STR, MSG_STR: """Text: Once paired in later myths with her Titan brother Hyperion as her husband, mild-eyed Euryphaessa, the far-shining one of the Homeric Hymn to Helios, was said to be the mother of Helios (the Sun), Selene (the Moon), and Eos (the Dawn)."""},
     {'role': SYSTEM_STR, MSG_STR: """Labels: Once paired in later myths with her Titan brother <Deity> Hyperion </Deity> as her husband, mild-eyed Euryphaessa, the far-shining one of the Homeric Hymn to Helios, was said to be the mother of Helios (the Sun), <Goddess> Selene </Goddess> (the Moon), and <Goddess> Eos </Goddess> (the Dawn)."""
}]

# This is where you provide the final prompt that we want the model to complete to give us the answer.
message = f"""Text: From her ideological conception, Taweret was closely grouped with (and is often indistinguishable from) several other protective hippopotamus goddesses: Ipet, Reret, and Hedjet.
Labels: """

response = co.chat(
    model="command-r-plus",
    temperature=0.0,
    chat_history=chat_history,
    message=message
)

print(response.text)

In [None]:
# OpenAI backend
!pip install openai
from openai import OpenAI

# Use the API key that we 
client = OpenAI(api_key='OPENAI-KEY', base_url="https://cmu.litellm.ai")
USER_STR = "user"
SYSTEM_STR = "system"
MSG_STR = "content"

In [None]:
# Here is how you can use the API to prompt the OpenAI model using the same prompt as we used above for Cohere. 
# Docs: https://platform.openai.com/docs/api-reference
messages = [
    {'role': SYSTEM_STR, MSG_STR:
     """You will be given input text containing different types of entities that you will label.
     This is the list of entity types to label: Deity, Mythological_king, Cretaceous_dinosaur, Aquatic_mammal, Aquatic_animal, Goddess.
     Label the enities by surrounding them with tags like '<Cretaceous_dinosaur> Beipiaognathus </Cretaceous_dinosaur>'."""
     }, 
     {'role': USER_STR, MSG_STR: """Text: Once paired in later myths with her Titan brother Hyperion as her husband, mild-eyed Euryphaessa, the far-shining one of the Homeric Hymn to Helios, was said to be the mother of Helios (the Sun), Selene (the Moon), and Eos (the Dawn)."""},
     {'role': SYSTEM_STR, MSG_STR: """Labels: Once paired in later myths with her Titan brother <Deity> Hyperion </Deity> as her husband, mild-eyed Euryphaessa, the far-shining one of the Homeric Hymn to Helios, was said to be the mother of Helios (the Sun), <Goddess> Selene </Goddess> (the Moon), and <Goddess> Eos </Goddess> (the Dawn)."""},
     {'role': USER_STR, MSG_STR: """Text: From her ideological conception, Taweret was closely grouped with (and is often indistinguishable from) several other protective hippopotamus goddesses: Ipet, Reret, and Hedjet.\nLabels: """}
]
print(messages)

response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    temperature=0.0,
    seed=random_seed,
    messages=messages
)

print(response.choices[0].message.content)

# You can also print out the usage, in number of tokens. 
# Pricing is per input/output token, listed here: https://openai.com/pricing
print(f"Usage: {response.usage.prompt_tokens} input, {response.usage.completion_tokens} output, {response.usage.total_tokens} total tokens")

In [None]:
# Load the dataset
from datasets import Dataset, ClassLabel, Sequence

data_splits = load_dataset('json', data_files={'train': 'dinos_and_deities_train_bio.jsonl', 'dev': 'dinos_and_deities_dev_bio_sm.jsonl', 'test': 'dinos_and_deities_test_bio_nolabels.jsonl'})

# Load dicts for mapping int labels to strings, and vice versa
label_names_fname = "dinos_and_deities_train_bio.jsonl.labels"
labels_int2str = []
with open(label_names_fname) as f:
    labels_int2str = f.read().split()
print(f"Labels: {labels_int2str}")
labels_str2int = {l: i for i, l in enumerate(labels_int2str)}

# Also create a set containing the original labels, without B- and I- tags
orig_labels = set()
for label in labels_str2int.keys():
    orig_label = label[2:]
    if orig_label:
        orig_labels.add(orig_label)
print(f"Orig labels: {orig_labels}")

# data_splits.cast_column("ner_tags", Sequence(ClassLabel(names=labels_int2str)))
print(data_splits)

In [None]:
# Let's inspect a single example
dev_example = data_splits['dev'][5]

print(json.dumps(dev_example, indent=4))

In [99]:
# Ok, now let's make the prompting a bit more programmatic. First, implement a function that takes an example from
# the dataset, and converts it into a message for the model using the format we specified above. 

def get_message(example):
    content = example['content']
    message = f'Text: {content}\nLabels: '
    return message

In [101]:
# Next we're going to implement a function to return the chat_history, but in order to do that we first need
# to be able to convert labeled examples from the dataset into a format that makes more sense for the model,
# in this case the HTML-style format we specified in the example. That's the task for this function: take
# an example from the dataset as input, and return a string that has tagged the text with labels in the given
# HTML-style format.
# 
def convert_bio_to_prompt(example):
    prohibited_occurences = ['</Goddess> <Goddess>', '<Aquatic_mammal> </Aquatic_mammal>', '<Mythological_king> </Mythological_king>', '<Aquatic_animal> </Aquatic_animal>', '<Deity> </Deity>', '<Cretaceous_dinosaur> </Cretaceous_dinosaur>']
    omit = 'O'
    full_stop = ['.', ':', ',']
    message = []
    for i,j in zip(example['tokens'], example['ner_strings']):
        if j is not omit:
            if i[-1] in full_stop:
                message.append("<"+j[2:]+">")
                message.append(i[:-1])
                message.append("</"+j[2:]+">"+ i[-1])
            else:
                message.append("<"+j[2:]+">")
                message.append(i)
                message.append("</"+j[2:]+">")
        else:
            message.append(i)
    output = " ".join(message)
    big_regex = re.compile('|'.join(map(re.escape, prohibited_occurences)))
    output = big_regex.sub("", output)
    output = output.replace("  ", " ")
    return output

In [None]:
label_example = convert_bio_to_prompt(dev_example)
print(label_example)

In [None]:
# sample selection, balance the number of examples for each label
# sample 10 examples for each label
dataset = data_splits['train']
print(dataset['ner_strings'][0])
plain_tags = [[i[2:] if i not in ['O'] else i for i in ner_strings] for ner_strings in dataset['ner_strings']]
print(plain_tags)

In [None]:
# rank the tag_type_count in decreasing order of diversity but also keep the original index
from collections import Counter
tag_type_count = [(i, tag_count) for i, tag_count in enumerate(tag_type_count)]
tag_type_count.sort(key=lambda x: len(x[1]), reverse=True)
print(tag_type_count)

In [None]:
sorted_indices = [i for i, _ in tag_type_count]
print(sorted_indices)

In [200]:
# Now we can write a function that takes the number of shots, dataset, list of entity types, and 
# convert_bio_to_prompt function, and returns the chat_history (a list of maps) structured as in 
# the example.
#
def get_chat_history(shots, dataset, entity_types_list, convert_bio_to_prompt_fn):
    samples = random.sample(range(len(dataset)), shots)
    message = [{'role': SYSTEM_STR, MSG_STR:
     """You will be given input text containing different types of entities that you will label.
     This is the list of entity types to label: Deity, Mythological_king, Cretaceous_dinosaur, Aquatic_mammal, Aquatic_animal, Goddess. 
     Label the enities by surrounding them with tags like '<Cretaceous_dinosaur> Beipiaognathus </Cretaceous_dinosaur>'.
     """
     }]
    for i in samples:
        user   = dataset[i]['content']
        system = convert_bio_to_prompt_fn(dataset[i])
        message.append({'role': f'{USER_STR}',  f'{MSG_STR}': f"""Text: {user}"""})
        message.append({'role': f'{SYSTEM_STR}', f'{MSG_STR}': f"""Labels: {system}"""})
    return message

In [None]:
# Now we can put all of those together to prompt the model more automagically!

# For Cohere:
num_shots = 20
response = co.chat(
    model="command-r-plus",
    temperature=0.0,
    chat_history=get_chat_history(num_shots, data_splits['train'], orig_labels, convert_bio_to_prompt),
    message=get_message(dev_example)
)
print(response.text)

In [None]:
# For OpenAI:
num_shots = 0

chat_history = get_chat_history(num_shots, data_splits['train'], orig_labels, convert_bio_to_prompt)
message = {'role': USER_STR, MSG_STR: get_message(dev_example)}
chat_history.append(message)
print(chat_history)

response = client.chat.completions.create(
    model="gpt-3.5-turbo",
    temperature=0.0,
    seed=random_seed,
    messages=chat_history
)

print(response.choices[0].message.content)
print(response.choices[0].message.content[:7])

In [204]:
# Now let's wrap that call in a function that takes shots and an example, calls the API and returns the response.

# Cohere:
def call_api_cohere(shots, example):
    success = False
    while not success:
        try:
            response = co.chat(
                model="command-r-plus",
                temperature=0.0,
                chat_history=get_chat_history(shots, data_splits['train'], orig_labels, convert_bio_to_prompt),
                message=get_message(example)
            )    
            success = 1
        except Exception as err:
            tqdm.write(f"Caught exception: {err}")
    return response.text

In [214]:
# OpenAI:
def call_api_openai(shots, example):
    success = False
    while not success:
        try:
            chat_history = get_chat_history(shots, data_splits['train'], orig_labels, convert_bio_to_prompt)
            message = {'role': USER_STR, 'content': get_message(example)}
            chat_history.append(message)
            response = client.chat.completions.create(
                model="gpt-3.5-turbo",
                temperature=0.0,
                messages=chat_history
            )
            success = 1
        except Exception as err:
            tqdm.write(f"Caught exception: {err}")
    return response.choices[0].message.content 

In [215]:
# Now we want to be able to evaluate the model, in order to compare it to e.g. the fine-tuned BERT model.
# In order to do this, we need to write the reverse of the convert_bio_to_prompt function, so that we can
# convert in the other direction, from the generated response in prompt format, back to bio for evaluation
# using seqeval.

# The input to this function is the string response from the model, and the output should be a list of 
# text BIO labels corresponding to the labeling implied by the tagged output produced by the model, as 
# well as the list of tokens (since the generative model could return something different than we gave it,
# and we need to handle that somehow in the eval).

def convert_response_to_bio(response):
    text = []
    labels = []
    current_label = None
    
    if response[:7] == 'Labels:':
        response = response[7:]
    elif response[:5] == 'Text:':
        response = response[5:]
    
    # Split the response into tokens
    tokens = re.split(r'(\s+|<[A-Za-z_]+>|</[A-Za-z_]+>)', response)
    tokens = [token for token in tokens if token.strip()]


    if not tokens:
        return [], []

    for i in range(len(tokens) - 1):
        if text != [] and tokens[i] in string.punctuation:
            text[-1] += tokens[i]
        elif tokens[i].startswith('<'):
            if tokens[i].startswith("</"):
                current_label = None
            else:
                current_label = 'B-' + tokens[i][1:-1]
        elif i > 0 and (tokens[i - 1].startswith('<') or tokens[i + 1].startswith("</")):
            if current_label:
                labels.append(current_label)
                text.append(tokens[i])
                current_label = None
            elif labels != [] and labels[-1] != 'O' and not tokens[i-1].startswith('</'):
                labels.append('I-' + labels[-1].split('-')[1])
                text.append(tokens[i])
            else:
                labels.append('O')
                text.append(tokens[i])
        elif labels != [] and labels[-1] != 'O' and not tokens[i-1].startswith('</') and tokens[i-1] not in string.punctuation: # previous label strt
            labels.append('I-' + labels[-1].split('-')[1])
            text.append(tokens[i])
        else:
            labels.append('O')
            text.append(tokens[i])

    if tokens[-1] in string.punctuation:
        text[-1] += tokens[-1]
    elif not tokens[-1].startswith('<'):
        labels.append('O')
        text.append(tokens[-1])
    return labels, text

In [None]:
# Here's a test example you can use to validate/debug your code (note that this was constructed to simulate various
# spacing/tokenization scenarios and does not necessarily reflect "correct" labeling wrt the training data):
import ipytest
ipytest.autoconfig()
def test_convert_html_to_bio():
    html_str = 'From <Goddess> her</Goddess> ideological conception, <Goddess> the deity Taweret </Goddess> was closely grouped with (and is often indistinguishable from) several other protective <Aquatic_mammal>hippopotamus</Aquatic_mammal> <Goddess>goddesses </Goddess>: <Goddess> Ipet ("the Nurse")</Goddess>, <Goddess>Reret ("the Sow") </Goddess>, and <Goddess>Hedjet ("the White One")</Goddess>.'
    labels, text = convert_response_to_bio(html_str)
    true_labels = ['O', 'B-Goddess', 'O', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Aquatic_mammal', 'B-Goddess', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'I-Goddess']
    true_text = ['From', 'her', 'ideological', 'conception,', 'the', 'deity', 'Taweret', 'was', 'closely', 'grouped', 'with', '(and', 'is', 'often', 'indistinguishable', 'from)', 'several', 'other', 'protective', 'hippopotamus', 'goddesses:', 'Ipet', '("the', 'Nurse"),', 'Reret', '("the', 'Sow"),', 'and', 'Hedjet', '("the', 'White', 'One").']
    print(labels)
    print(text)
    assert labels == true_labels
    assert text == true_text

def test_convert_html_to_bio_labels():
    html_str = 'Labels: From <Goddess> her</Goddess> ideological conception, <Goddess> the deity Taweret </Goddess> was closely grouped with (and is often indistinguishable from) several other protective <Aquatic_mammal>hippopotamus</Aquatic_mammal> <Goddess>goddesses </Goddess>: <Goddess> Ipet ("the Nurse")</Goddess>, <Goddess>Reret ("the Sow") </Goddess>, and <Goddess>Hedjet ("the White One")</Goddess>.'
    labels, text = convert_response_to_bio(html_str)
    true_labels = ['O', 'B-Goddess', 'O', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Aquatic_mammal', 'B-Goddess', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'O', 'B-Goddess', 'I-Goddess', 'I-Goddess', 'I-Goddess']
    true_text = ['From', 'her', 'ideological', 'conception,', 'the', 'deity', 'Taweret', 'was', 'closely', 'grouped', 'with', '(and', 'is', 'often', 'indistinguishable', 'from)', 'several', 'other', 'protective', 'hippopotamus', 'goddesses:', 'Ipet', '("the', 'Nurse"),', 'Reret', '("the', 'Sow"),', 'and', 'Hedjet', '("the', 'White', 'One").']
    print(labels)
    print(text)
    assert labels == true_labels
    assert text == true_text

ipytest.run('-vv')  # '-vv' for increased verbosity

In [217]:
# Now we can put all of the above together to evaluate!
metric = evaluate.load("seqeval")

def run_eval(dataset, shots, backend):

  for example in tqdm(dataset, total=len(dataset), desc="Evaluating", position=tqdm._get_free_pos()):

      # String list of labels (BIO)
      true_labels = [labels_int2str[l] for l in example['ner_tags']]
      example_tokens = example['tokens']

      response_text = call_api_openai(shots, example) if backend == "openai" else call_api_cohere(shots, example)
      # print(f"Response: {response_text}")

      # String list of predicted labels (BIO)
      predictions, generated_tokens = convert_response_to_bio(response_text)

      # Handle case where the generated text doesn't align with the input text.
      # Basically, we'll eval everything up to where the two strings start to diverge.
      # We relax this slightly by ignoring punctuation (sometimes we lose a paren or something, 
      # but that's not catastrophic for eval/tokenization).
      # Just predict 'O' for anything following mismatch.
      matching_elements = [strip_punct(i) == strip_punct(j) for i, j in zip(example_tokens, generated_tokens)]

      if False in matching_elements:
         last_matching_idx = matching_elements.index(False)
      else:
         last_matching_idx = min(len(generated_tokens), len(example_tokens))

      predictions = predictions[:last_matching_idx] + ['O']*(len(example_tokens)-last_matching_idx)
      metric.add(predictions=predictions, references=true_labels)
  
  return metric.compute()

In [None]:
# Run the eval on the dev set
dev_examples_to_take = 0

dev_set = data_splits['dev']
if dev_examples_to_take > 0:
    dev_set = data_splits['dev'].select(range(dev_examples_to_take))

for num_shots in [0, 1, 5, 10, 20, 40]:
    result = run_eval(dev_set, shots=num_shots, backend='openai')
    print(result)

## Output for Evaluation

In the following cells, run your trained model on the test data, and produce a list of lists of tags, with one list per sentence, e.g. 

```
[
    [
        "B-Aquatic_animal",
        "I-Aquatic_animal",
        "I-Aquatic_animal",
...
        "O",
        "O",
        "B-Aquatic_animal",
        "I-Aquatic_animal"
    ],
    [...]
]
```

Serialize your predictions into a file named `test_predictions_llm_baseline.json` for your initial attempt at an LLM tagger.

In [None]:
# your code here.
test_prediction = []
for i, example in enumerate(tqdm(data_splits['test'], total = len(data_splits['test']), desc="Testing", position=tqdm._get_free_pos())):
    # call the llm api
    response_text = call_api_openai(5, example)

    # convert the response to BIO format
    predictions, generated_tokens = convert_response_to_bio(response_text)

    # Handle case where the generated text doesn't align with the input text.
    if len(predictions) > len(example['tokens']):
        test_prediction.append(predictions[:len(example['tokens'])])
    else:
        mismatch = len(example['tokens']) - len(predictions)
        test_prediction.append(predictions + ['O']*mismatch)

# Save the predictions to a file
with open("test_predictions_llm_baseline.json", "w") as f:
    json.dump(test_prediction, f)