In [17]:
import warnings
import regex as re
from pathlib import Path
import spacy
from spacy import displacy
import pandas as pd
import numpy as np
import coreferee
from sentence_transformers import SentenceTransformer, util
from nltk.translate.bleu_score import SmoothingFunction


def resolve_coreference(text):
    doc = nlp(text)
    doc_list = list(doc)
    # doc._.coref_chains.print()
    resolving_indecies = []
    for _,item in enumerate(doc._.coref_chains):
        resolving_indecies.extend(item)
        
    for word in resolving_indecies:
        new_word = ""
        for index in word:
            if doc[index]._.coref_chains.resolve(doc[index]) is not None:
                temp = []
                for item in doc._.coref_chains.resolve(doc[index]):
                    temp.append(str(item))
                new_word = ", ".join(temp)
            
                doc_list[index] = new_word

    final_doc = []
    for item in doc_list:
        final_doc.append(str(item))
    return " ".join(final_doc)

def extract_subjects(sentence):
    subjects = {}
    verbIdx = 0
    for token in sentence:
        if token.pos_ == "VERB" or token.pos_ == "AUX":
            verbIdx += 1
            subjectFlag = False
            verb = token
            for child in token.children:
                if child.dep_ in ("nsubj", "csubj"):
                    subtree_tokens = [str(t) for t in child.subtree]
                    subjects[token] = (" ".join(subtree_tokens), verbIdx)
                    subjectFlag = True
                elif child.dep_ == "nsubjpass":
                    for child in verb.children:
                        if child.dep_ == "agent" and len(list(child.children)) > 0:
                            subject = [str(t) for t in list(child.children)[0].subtree]
                            subject = " ".join(subject)
                            break
                        else:
                            subject = "Unknown"
                    subjects[verb] = (subject, verbIdx)
                    subjectFlag = True
            if not subjectFlag:  # didn't find a normal subject
                if token.dep_ == "relcl":
                    subject = str(token.head)
                    subjects[token] = (subject, verbIdx)  # should get the subtree of the subject
                elif token.dep_ in ("advcl", "conj"):
                    verb = token.head
                    
                    if verb in subjects:
                        subjects[token] = (subjects[verb][0], verbIdx)
                    else:
                        subjects[token] = ("Unknown", verbIdx)  # replace "Unknown" with a suitable default
                elif token.dep_ == "xcomp":
                    verb = token.head
                    if verb in subjects:
                        subjects[token] = (subjects[verb][0], verbIdx)
                    else:
                        subjects[token] = ("Unknown", verbIdx)
                    for child in verb.subtree:
                        if child.dep_ in ("dobj", "dative", "pobj"):
                            subtree_tokens = [str(t) for t in child.subtree]
                            subjects[token] = (" ".join(subtree_tokens), verbIdx)
                            break
                else:
                    subjects[token] = ("Unknown", verbIdx)
                                        
    # (subject, verbIdx, verb)
    return [(v[0], k, v[1]) for k, v in subjects.items()]             
         
                            
def extract_objects(sentence):
    objects = []
    verbIdx = 0
    for token in sentence:
        if token.pos_ == "VERB" or token.pos_ == "AUX":
            verbIdx += 1
            for child in token.children:
                if child.dep_ in ("dobj", "dative", "attr", "oprd", "acomp","ccomp", "xcomp", "nsubjpass"):
                    subtree_tokens = [str(t) for t in child.subtree]
                    objects.append((" ".join(subtree_tokens), token, verbIdx))
                    
    return objects

def extract_state(sentence):
    states = []
    verbIdx = 0
    for token in sentence:
        if token.pos_ =="VERB" or token.pos_ == "AUX":
            verbIdx += 1
            for child in token.children:
                if child.dep_ == "prep":
                    subtree_tokens = [str(t) for t in child.subtree if t != child]
                    states.append(((" ".join(subtree_tokens), token, verbIdx)))
    return states

def extract_time(sentence):
    times = {}
    verbIdx = 0
    for token in sentence:
        if token.pos_ == "VERB" or token.pos_ == "AUX":
            verbIdx += 1
            for child in token.subtree:
                if child.ent_type_ == "DATE" or child.ent_type_ == "TIME":
                    times[child.text] = (token, verbIdx)
                    
    return [(k, v[0], v[1]) for k, v in times.items()]

def extract_location(sentence):
    locations = {}
    verbIdx = 0
    for token in sentence:
        if token.pos_ == "VERB" or token.pos_ == "AUX":
            verbIdx += 1
            for child in token.subtree:
                if child.ent_type_ in ("GPE", "LOC", "FAC"):
                    locations[child.text] = (token, verbIdx)
                    
    return [(k, v[0], v[1]) for k, v in locations.items()]
                    

def extract_facts(sentence):
    sentence = nlp(sentence)
    states = extract_state(sentence)
    subjects = extract_subjects(sentence)
    objects = extract_objects(sentence)
    times = extract_time(sentence)
    locations = extract_location(sentence)
    
    facts = pd.DataFrame(columns=["Subject", "Relation", "verbIdx", "Objects", "States", "Times", "Locations"])
    
    for subject in subjects: #(Aly, is, 1), (Ziad,is, 2) 
        currentSubject = subject[0]
        verb = subject[1].lemma_
        verbIdx = subject[2]
        mask = (facts['Subject'] != currentSubject) | (facts['Relation'] != verb)
        if mask.all():
            new_row = pd.DataFrame([{"Subject": currentSubject, "Relation": verb, "verbIdx": verbIdx, "Objects": [], "States": [], "Times": [], "Locations": []}])
            facts = pd.concat([facts, new_row], ignore_index=True)

    for obj in objects: #(happy, is, 1), (good, is, 2)
        currentObj = obj[0]
        verb = obj[1].lemma_
        verbIdx = obj[2]
        mask = (facts['Relation'] == verb) & (facts['verbIdx'] == verbIdx)
        if mask.any():
            oldObjects = list(facts.loc[mask, "Objects"].values[0])
            oldObjects.append(currentObj)
            for idx in facts.loc[mask].index:
                facts.at[idx, "Objects"] = oldObjects
            
    for state in states:
        currentState = state[0]
        verb = state[1].lemma_
        verbIdx = state[2]
        mask = (facts['Relation'] == verb) & (facts['verbIdx'] == verbIdx)
        if mask.any():
            oldStates = list(facts.loc[mask, "States"].values[0])
            oldStates.append(currentState)
            for idx in facts.loc[mask].index:
                facts.at[idx, "States"] = oldStates
            
    for time in times:
        currentTime = time[0]
        verb = time[1].lemma_
        verbIdx = time[2]
        mask = (facts['Relation'] == verb) & (facts['verbIdx'] == verbIdx)
        if mask.any():
            oldTimes = list(facts.loc[mask, "Times"].values[0])
            oldTimes.append(currentTime)
            for idx in facts.loc[mask].index:
                facts.at[idx, "Times"] = oldTimes
            
    for location in locations:
        currentLocation = location[0]
        verb = location[1].lemma_
        verbIdx = location[2]
        mask = (facts['Relation'] == verb) & (facts['verbIdx'] == verbIdx)
        if mask.any():
            oldLocations = list(facts.loc[mask, "Locations"].values[0])
            oldLocations.append(currentLocation)
            for idx in facts.loc[mask].index:
                facts.at[idx, "Locations"] = oldLocations
            
    facts = facts.drop(columns=["verbIdx"])
    return facts
        
def preprocess_context(doc):
    text = doc.strip()
    text.replace(".", ",")
    resolved_text = resolve_coreference(text)
    resolved_text = resolved_text.strip()
    resolved_text = resolved_text.replace("  ", " ").replace(" ,", ",").replace(" .", ".").replace("\n", "")
    return resolved_text

def join_sentences_facts(sentences):
    all_facts = pd.DataFrame(columns=["Subject", "Relation", "Objects", "States", "Times", "Locations"])
    for sentence in sentences:
        facts = extract_facts(sentence)
        all_facts = pd.concat([all_facts, facts])
    all_facts = all_facts.groupby(["Subject", "Relation"], as_index=False).agg({
        "Objects": lambda x: [item for sublist in x for item in sublist],
        "States": lambda x: [item for sublist in x for item in sublist],
        "Times": lambda x: [item for sublist in x for item in sublist],
        "Locations": lambda x: [item for sublist in x for item in sublist]
    })
    return all_facts

def change_subject_relation(factsDF):
    for index, row in factsDF.iterrows():
        factsDF.loc[index, "Subject"] = [row['Subject']]
        factsDF.loc[index, "Relation"] = [row['Relation']]
    return factsDF

def similarity(factRow, questionRow, column):
    if len(factRow[column]) == 0 or len(questionRow[column]) == 0 or factRow[column] == ["Unknown"] or questionRow[column] == ["Unknown"]:
        return 0
    columnString = " ".join(factRow[column])
    questionString = " ".join(questionRow[column])
    embeddingFact = model.encode(columnString)
    embeddingQuestion = model.encode(questionString)
    return util.cos_sim(embeddingFact, embeddingQuestion)

        
def cost_function(factsDf, questionFact, excludeColumns=[]):
    cost = 0
    maxFactIdx = 0
    columnNames = ["Subject","Relation", "Objects", "States", "Times", "Locations"]
    for column in excludeColumns:
        columnNames.remove(column)
    for factIdx, factRow in factsDf.iterrows():
        currCost = 0
        for _, questionRow in questionFact.iterrows():
            if len(factRow[excludeColumns[0]]) == 0:
                continue
            for column in columnNames:
                currCost += similarity(factRow, questionRow, column)
        if currCost > cost:
            cost = currCost
            maxFactIdx = factIdx
    return maxFactIdx, cost

def process_question_context(question, doc):
    splitted_question = question.split(" ")
    question_type = splitted_question[0].lower()
    question_nlp = nlp(question)
    if question_nlp[0].ent_type_ == "DATE":
        question_type = "when"
    resolved_doc = preprocess_context(doc)
    cleaned_doc = nlp(resolved_doc)
    sentences = [one_sentence.text.strip() for one_sentence in cleaned_doc.sents]
    
    questionDF = extract_facts(question)
    factsDF = join_sentences_facts(sentences)
    
    newFactsDF = change_subject_relation(factsDF)
    newQuestionDF = change_subject_relation(questionDF)
    
    return newFactsDF, newQuestionDF, question_type

def get_answer(factsDF, questionDF, question_type):
    correctIdx, _ = cost_function(factsDF, questionDF, excludeColumns=[excludesPerQuestionType[question_type]])
    answer = factsDF.loc[correctIdx, excludesPerQuestionType[question_type]]
    if answer == []:
        answer = factsDF.loc[correctIdx, "States"]    
    return " ".join(answer)
    

# if __name__ == "__main__":
nlp = spacy.load('en_core_web_md')
nlp.add_pipe("merge_entities")
nlp.add_pipe("merge_noun_chunks")
nlp.add_pipe('coreferee')
model = SentenceTransformer("all-MiniLM-L6-v2")

excludesPerQuestionType = {
    "when": "Times",
    "where": "Locations",
    "who": "Subject",
    "what": "Objects",
    "how": "States"
}   
    
# doc = """
# Lionel Andrés "Leo" Messi was born in 24 June 1987 is an Argentine professional footballer plays as a forward for and captains both Major League Soccer club Inter Miami and the Argentina national team.
# He played in Barcelona in 2010.
# Widely regarded as one of the greatest players of all time, Messi has won a record eight Ballon d'Or awards, a record six European Golden Shoes, and was named the world's best player for a record eight times by FIFA.
# Until 2021, he had spent his entire professional career with Barcelona, where he won a club-record 34 trophies, including ten La Liga titles, seven Copa del Rey titles, and the UEFA Champions League four times.
# With his country, he won the 2021 Copa América and the 2022 FIFA World Cup. A prolific goalscorer and creative playmaker, Messi holds the records for most goals, hat-tricks, and assists in La Liga. He has the most international goals by a South American male. Messi has scored over 800 senior career goals for club and country, and the most goals for a single club.
# """
# question = "how did messi play?"
# factsDF, questionDF, question_type = process_question_context(question, doc)
# answer = get_answer(factsDF, questionDF, question_type)

# print("========================================================")
# print("Question: ", question)
# print("Answer: ", answer)




In [18]:
import random
from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad")
# train = dataset['train']
validation = dataset['validation']

random.seed(42)
import getopt
import sys

import pandas as pd
import json 
import regex as re
from tqdm import tqdm
import nltk

In [23]:

def QuestionStartsWith_Accuracy(dataset, startsWith):

    correct = 0
    EM = 0
    BLEU = 0
    total = 0
    errors = []
    corrects = []
    empties = []
    kolo = 0
    for item in tqdm(dataset):
        random_number = random.randint(0, len(dataset))
        kolo += 1
        included_question = False
        try:
            context = item['context']
            context = re.sub(' +', ' ', context)
            
            question = item['question']
            tempQuestion = question.lower()
            question = re.sub(' +', ' ', question)
            
            answer = item['answers']['text'][0]
            title = item['title']
    
            # check if question starts with startsWith
            for start in startsWith:
                if tempQuestion.startswith(start):
                    included_question = True
                    break
            
            if included_question:
    
                total += 1
                if total == 74:
                    pass
                    
                factsDF, questionDF, question_type = process_question_context(question, context)
                outputAnswer = get_answer(factsDF, questionDF, question_type)
                if outputAnswer == "":
                    empties.append(kolo)
                    outputAnswer = "No_Answer_Found"
                
                # print("Question: " , question)
                # print("Answer: ", answer, "-------" , "Our Answer: ", outputAnswer)
                if outputAnswer in answer or answer in outputAnswer:
                    correct += 1
                    corrects.append(kolo)
                # else:
                #     print("Question: " , question)
                #     print("Answer: ", answer, "-------" , "Our Answer: ", outputAnswer)
                if outputAnswer == answer:
                    EM += 1

                n = min(len(outputAnswer.split()), 4)
                if n == 0:
                    BLEUscore = 0
                else:
                    weights = [1.0/n]*n
                    smoothie = SmoothingFunction().method4
                    BLEUscore = nltk.translate.bleu_score.sentence_bleu([answer], outputAnswer, weights=weights, smoothing_function=smoothie)
                    BLEU += BLEUscore
                    

            if kolo % 1000 == 0:
                print(f"Correct: {correct}, out of {total}: {100*correct/total}%")
                print(f"EM: {EM}, out of {total}: {100*EM/total}%")
                print(f"BLEU: {BLEU}, out of {total}: {100*BLEU/total}%")
        
        except Exception as e:
            print("title: ", title)
            print("Error in question number: ", total)
            print("Question: ", question)
            print("Answer: ", answer)
            print("Error: ", e)
            errors.append(total)
            # total -= 1
            print("\n\n")
    
    if total != 0:
        print(f"Correct: {correct}, out of {total}: {100*correct/total}%")
        print(f"EM: {EM}, out of {total}: {100*EM/total}%")
        print(f"BLEU: {BLEU}, out of {total}: {100*BLEU/total}%")
    else:
        print("No Questions found with the given starting word")
    print("Errors: ", errors)
    print("Empties: ", len(empties), empties)
    return correct, total, errors, corrects


if __name__ == "__main__":

    startsWith = ["who "]
    x = QuestionStartsWith_Accuracy(validation, startsWith)

  referred_head_lexeme.similarity(referring_head_lexeme)
  9%|▉         | 998/10570 [03:01<06:54, 23.08it/s]

Correct: 48, out of 203: 23.645320197044335%
EM: 25, out of 203: 12.31527093596059%
BLEU: 50.59556901173984, out of 203: 24.923925621546722%


  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
 19%|█▉        | 1990/10570 [04:40<18:47,  7.61it/s]  

Correct: 79, out of 319: 24.764890282131663%
EM: 40, out of 319: 12.539184952978056%
BLEU: 78.71942414422021, out of 319: 24.67693546840759%


  referred_head_lexeme.similarity(referring_head_lexeme)
 28%|██▊       | 2992/10570 [06:00<04:02, 31.21it/s]

Correct: 108, out of 396: 27.272727272727273%
EM: 56, out of 396: 14.141414141414142%
BLEU: 104.7792733956671, out of 396: 26.459412473653305%


 38%|███▊      | 3999/10570 [07:08<13:55,  7.87it/s]  

Correct: 131, out of 470: 27.872340425531913%
EM: 62, out of 470: 13.191489361702128%
BLEU: 119.30639816845809, out of 470: 25.384340035842147%


  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
 47%|████▋     | 4980/10570 [08:09<08:42, 10.70it/s]

Correct: 153, out of 519: 29.479768786127167%
EM: 68, out of 519: 13.102119460500964%
BLEU: 135.03689774926175, out of 519: 26.01867008656296%


 57%|█████▋    | 6000/10570 [10:04<08:40,  8.79it/s]  

Correct: 171, out of 619: 27.62520193861066%
EM: 79, out of 619: 12.762520193861066%
BLEU: 154.8708122870352, out of 619: 25.01951733231586%


  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
 66%|██████▌   | 6997/10570 [12:08<29:56,  1.99it/s]

Correct: 194, out of 712: 27.247191011235955%
EM: 96, out of 712: 13.48314606741573%
BLEU: 182.02166040716153, out of 712: 25.56483994482606%


 68%|██████▊   | 7232/10570 [12:12<01:04, 51.42it/s]

title:  Harvard_University
Error in question number:  718
Question:  Who is the Costa Rican President that went to Harvard?
Answer:  José María Figueres
Error:  0





 76%|███████▌  | 7999/10570 [13:34<09:09,  4.68it/s]

Correct: 221, out of 802: 27.556109725685786%
EM: 110, out of 802: 13.71571072319202%
BLEU: 203.96608911324253, out of 802: 25.432180687436723%


  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
 85%|████████▌ | 8993/10570 [15:06<00:17, 88.93it/s]

Correct: 243, out of 911: 26.67398463227223%
EM: 120, out of 911: 13.172338090010976%
BLEU: 225.508760416421, out of 911: 24.753980287203188%


  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
  referred_head_lexeme.similarity(referring_head_lexeme)
 95%|█████████▍| 9999/10570 [16:47<02:18,  4.12it/s]

Correct: 271, out of 998: 27.15430861723447%
EM: 137, out of 998: 13.72745490981964%
BLEU: 252.76438396360533, out of 998: 25.32709258152358%


100%|██████████| 10570/10570 [17:56<00:00,  9.82it/s]

Correct: 294, out of 1059: 27.762039660056658%
EM: 151, out of 1059: 14.258734655335221%
BLEU: 273.20585499633404, out of 1059: 25.798475448190185%
Errors:  [718]
Empties:  0 []





In [20]:
# [174, 315, 326, 349, 472, 488, 493, 498, 530, 667]
# when: 48%