In [1]:
import pickle 
from transformers import pipeline
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
from datasets import Dataset
import accelerate
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the pre-trained NER pipeline
ner_pipeline = pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple")

Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [3]:
def extract_ner(sentence):
    # Perform Named Entity Recognition
    ner_results = ner_pipeline(sentence)

    # Extract and print the entity names
    entities = [entity['word'] for entity in ner_results if entity['entity_group'] in ['PER', 'LOC', 'ORG', 'MISC']]
    return entities

extract_ner("Steven Spielberg is a famous director who was born in Cincinnati, Ohio.")

['Steven Spielberg', 'Cincinnati', 'Ohio']

In [4]:
triplets, uri_to_label, label_to_uri, label_list = pickle.load(open("formatted_data.pkl", "rb"))

In [5]:
# read te id2label json file
with open('id2tag.json') as json_file:
    id2label = json.load(json_file)

# read the pytorch model "movie_tag_model" from the file
model = BertForSequenceClassification.from_pretrained("movie_tag_model")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def classify_query(query, model, tokenizer, id2label):
    encoded_dataset = tokenizer(query, padding=True, truncation=True, return_tensors="pt")

    # Perform the classification
    model.eval()
    with torch.no_grad():
        output = model(**encoded_dataset)
        predicted_label = torch.argmax(output.logits, dim=1).item()

    # Map the predicted label to the corresponding tag
    tag = id2label[str(predicted_label)]

    return tag

In [6]:
classify_query("when was abc released", model, tokenizer, id2label)

'release_date'

In [7]:
wikidata_attributes = {
    "revenue": "box office",
    "director": "director",
    "actor": "cast member",
    "release_date": "publication date",
    "genre": "genre",
    "rating": "IMDb ID",
    #"budget": "budget",
    "producer": "producer",
    "screenwriter": "screenwriter",
    #"birth_date": "date of birth",
    "birth_place": "place of birth"
}

In [8]:
tag2uri = {
    tag:label_to_uri[label] for tag, label in wikidata_attributes.items()
}

In [9]:
tag2uri

{'revenue': 'http://www.wikidata.org/prop/direct/P2142',
 'director': 'http://www.wikidata.org/prop/direct/P57',
 'actor': 'http://www.wikidata.org/prop/direct/P161',
 'release_date': 'http://www.wikidata.org/prop/direct/P577',
 'genre': 'http://www.wikidata.org/prop/direct/P136',
 'rating': 'http://www.wikidata.org/prop/direct/P345',
 'producer': 'http://www.wikidata.org/entity/Q47541952',
 'screenwriter': 'http://www.wikidata.org/prop/direct/P58',
 'birth_place': 'http://www.wikidata.org/prop/direct/P19'}

In [50]:
def answer_query(query):
    # Extract named entities from the query
    entities = extract_ner(query)

    # Classify the query
    tag = classify_query(query, model, tokenizer, id2label)

    # Retrieve the corresponding Wikidata attribute
    uri = tag2uri[tag]

    if(len(entities) == 0):
        return "I'm sorry, I couldn't understand the query."
    else:
        en_uri = label_to_uri[entities[0]]
        # print(en_uri, uri)
        # print(triplets[(en_uri, uri)])
        try:
            if(tag in ["rating", "revenue", "budget", "release_date"] ):
                print(triplets[(en_uri, uri)])
            else:
                ans = triplets[(en_uri, uri)]
                ans_labels = [uri_to_label[label] for label in ans]
                print(", ".join(ans_labels))
        except:
            return "I'm sorry, I couldn't find the answer to your question."
        
# answer_query("Who starred in Inception?")
# answer_query("What is the genre of Inception?")
# answer_query("What is the genre of Kung Fu Panda?")
# answer_query("Who directed Inception?")
# answer_query("When was Inception released?")
# answer_query("When was Kung Fu Panda released?")
answer_query("When was Iron Man released?")

['2008-04-30']
