In [1]:
import re
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from thefuzz import fuzz, process
import pickle

# Load the pre-trained BERT NER model from Hugging Face
bert_base_NER = "dslim/bert-base-NER"
bert_base_NER_tokenizer = AutoTokenizer.from_pretrained(bert_base_NER)
bert_base_NER_model = AutoModelForTokenClassification.from_pretrained(bert_base_NER)
bert_base_NER_pipeline = pipeline(
                    "ner",
                    model=bert_base_NER_model,
                    tokenizer=bert_base_NER_tokenizer,
                    aggregation_strategy="simple",
                    device="cuda"
                )   

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [58]:
def get_best_match_person(user_query: str) -> list:
    '''
    Extract the person name from user_query using bert_base_NER_pipeline.
    Concatenate the result, format the result to be the exact person name,
    and exclude unwanted characters.
    '''
    # Use the NER pipeline to get entities from the user query
    ner_results = bert_base_NER_pipeline(user_query)
    person_names = []
    current_name = ''
    for entity in ner_results:
        if entity['entity_group'] == 'PER':
            # Use the 'word' attribute to get the entity text
            word = entity['word']
            # Remove any leading/trailing punctuation and whitespace
            word = word.strip('.,!? ')
            # Check if the word starts with '##', indicating a continuation
            if word.startswith('##'):
                # Remove '##' and concatenate without space
                word = word[2:]
                current_name += word
            else:
                # If there's an existing name, append it to the list
                if current_name:
                    person_names.append(current_name.strip())
                # Start a new name
                current_name = word
        else:
            # If we reach a non-PER entity, append the current name if it exists
            if current_name:
                person_names.append(current_name.strip())
                current_name = ''
    # Append any remaining name after the loop
    if current_name:
        person_names.append(current_name.strip())
    # Replace multiple spaces with a single space in each name
    person_names = [re.sub(r'\s+', ' ', name) for name in person_names]
    # Remove duplicates and return the list
    person_names = list(set(person_names))
    return person_names

user_query = "What does Denzel Washington look like?"
get_best_match_person(user_query)

['Denzel Washington']

In [59]:
user_query = "Show me a picture of Halle Berry."
print(get_best_match_person(user_query))
user_query = "What does Denzel Washington look like?"
print(get_best_match_person(user_query))
user_query = "Let me know what Sandra Bullock looks like."
print(get_best_match_person(user_query))


['Halle Berry']
['Denzel Washington']
['Sandra Bullock']


In [6]:
def get_best_match_MISC(user_query: str) -> list:
    '''
    Extract the person name from user_query using bert_base_NER_pipeline.
    Concatenate the result, format the result to be the exact person name,
    and exclude unwanted characters.
    '''
    # Use the NER pipeline to get entities from the user query
    ner_results = bert_base_NER_pipeline(user_query)
    person_names = []
    current_name = ''
    for entity in ner_results:
        if entity['entity_group'] == 'MISC' or entity['entity_group'] == 'ORG':
            # Use the 'word' attribute to get the entity text
            word = entity['word']
            # Remove any leading/trailing punctuation and whitespace
            word = word.strip('.,!? ')
            # Check if the word starts with '##', indicating a continuation
            if word.startswith('##'):
                # Remove '##' and concatenate without space
                word = word[2:]
                current_name += word
            else:
                # If there's an existing name, append it to the list
                if current_name:
                    person_names.append(current_name.strip())
                # Start a new name
                current_name = word
        else:
            # If we reach a non-PER entity, append the current name if it exists
            if current_name:
                person_names.append(current_name.strip())
                current_name = ''
    # Append any remaining name after the loop
    if current_name:
        person_names.append(current_name.strip())
    # Replace multiple spaces with a single space in each name
    person_names = [re.sub(r'\s+', ' ', name) for name in person_names]
    # Remove duplicates and return the list
    person_names = list(set(person_names))
    return person_names

user_query = "Show me Interstellar"
get_best_match_MISC(user_query)

['Interstellar']