In [None]:
# coqa dataset 

import json

path_dataset = "datasets/coqa/coqa-dev-v1.0.json"
dataset = json.load(open(path_dataset))

# for story in tqdm(dataset['data']):
target_data = dataset['data'][0]

In [None]:
story_id = target_data['id']
story_content = target_data['story']
questions = [q_dict["input_text"] for q_dict in target_data['questions']]
rounds = len(questions)
raw_answers = [target_data['answers']]

try:
    raw_answers += target_data['additional_answers'].values()
except:
    pass

answers = {} # 1:["ans1","ans2"]
for i_round in range(rounds):
    answers[i_round+1]=[]

for raw_answer_seq in raw_answers:
    for raw_answer in raw_answer_seq:
        this_round = raw_answer["turn_id"]

        this_ori_ans = raw_answer["input_text"].strip(" ")
        if this_ori_ans not in answers[this_round]:
            answers[this_round].append(this_ori_ans)

        this_ori_story = raw_answer["span_text"].strip(" ")
        if this_ori_story not in answers[this_round]:
            answers[this_round].append(this_ori_story)

In [2]:
# 1.1.a Run single dialogue declerative sentence generation
from ie_utils import ie_vanilla

combined = ""
    
for i in range(len(questions)):
    this_ques = questions[i]
    this_ans = answers[i+1][0]
    combined = combined+"Round {round_num} - {sentence_q}?\nRound {round_num} - {sentence_a}.\n".format(round_num=i+1,sentence_q=this_ques,sentence_a=this_ans)

combined_content = ie_vanilla.extract_declerative_information(combined)
target_data["combined"] = combined_content

In [None]:
# 1.1.b full question extraction

combined_json = {}
for i in range(len(questions)):
    this_ques = questions[i]
    this_ans = answers[i+1][0]
    round_str = "Round {round_num}".format(round_num=i+1)
    combined_json[round_str] = {"Question": this_ques, "Answer":this_ans}
    
combined_json_content = ie_vanilla.question_resolution(combined_json)

target_data["original_qa"] = combined_json
target_data["full_qa"] = combined_json_content

target_data

In [None]:
# 1.2 Type, entities, and relation extraction

# 1.2.1 Topic extraction
topic_content = ie_vanilla.extract_topic(combined_json_content)['topic']
target_data["topic"] = topic_content

In [None]:
# 1.2.2 Entity Types
entity_types_content = ie_vanilla.entity_types(input_topic = topic_content, input_dialogue = combined_content)
target_data["entity_types"] = entity_types_content

In [None]:
# 1.2.3 All entities, all relations for dev datasets
all_entities_content = ie_vanilla.entity_relations(entity_types = entity_types_content, input_text = combined_content)
target_data["full_entity_relation"] = all_entities_content

In [None]:
# 1.2.4.a Round_entity_relation list
round_subgraphs_content = ie_vanilla.round_subgraph(entity_list = all_entities_content["entities"], relation_list = all_entities_content["relations"], dialogue_content=combined_content)
target_data["round_subgraph"] = round_subgraphs_content

In [None]:
# 1.2.4.b Handle unseen round entities and update entity list and relation list

all_entity_types = target_data["entity_types"]
all_entities = target_data["full_entity_relation"]["entities"]
all_relations = target_data["full_entity_relation"]["relations"]

try:
    erstools = ie_utils.ie_tools.entity_relation_set_tools(all_entity_types, all_entities, all_relations)
    round_keys = list(target_data["label"].keys())
    for round_key in round_keys:
        # print(round_key)
        this_round_question_entities = target_data["round_subgraph"][round_key]["Question"]["entities"]
        this_round_answer_entities = target_data["round_subgraph"][round_key]["Answer"]["entities"]
        this_round_question_relations = target_data["round_subgraph"][round_key]["Question"]["relations"]
        this_round_answer_relations= target_data["round_subgraph"][round_key]["Answer"]["relations"]
        target_data["round_subgraph"][round_key]["Question"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_question_entities])
        target_data["round_subgraph"][round_key]["Answer"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_answer_entities])
        target_data["round_subgraph"][round_key]["Question"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_question_relations])
        target_data["round_subgraph"][round_key]["Answer"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_answer_relations])

    all_entities = erstools.all_entities
    all_relations = erstools.all_relations
    target_data["entity_group_member_mapping"] = erstools.entity_group_member_mapping_cache
    target_data["DATA_PROCESS_FLAG"]=0 # normal

except:
    try:
        erstools = entity_relation_set_tools(all_entity_types, all_entities[0:1], all_relations[0:1], lm)
        round_keys = list(target_data["label"].keys())
        for round_key in round_keys:
            this_round_question_entities = target_data["round_subgraph"][round_key]["Question"]["entities"]
            this_round_answer_entities = target_data["round_subgraph"][round_key]["Answer"]["entities"]
            this_round_question_relations = target_data["round_subgraph"][round_key]["Question"]["relations"]
            this_round_answer_relations= target_data["round_subgraph"][round_key]["Answer"]["relations"]
            target_data["round_subgraph"][round_key]["Question"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_question_entities])
            target_data["round_subgraph"][round_key]["Answer"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_answer_entities])
            target_data["round_subgraph"][round_key]["Question"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_question_relations])
            target_data["round_subgraph"][round_key]["Answer"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_answer_relations])

        all_entities = erstools.all_entities
        all_relations = erstools.all_relations
        target_data["entity_group_member_mapping"] = erstools.entity_group_member_mapping_cache
        target_data["DATA_PROCESS_FLAG"]=10 # rebuild subgraphs 
    except Exception as e:
        print("Error in entity relation extraction: ", e)



In [None]:
target_data["DATA_PROCESS_FLAG"] = 0 # or 500 for failed extraction

In [None]:
import copy
import json
from ie_utils import ie_vanilla
from ie_utils import ie_utils

result_container = copy.deepcopy(target_data)
for key in result_container.keys():
    target_data = result_container[key]

    story_id = target_data['id']
    story_content = target_data['story']
    questions = [q_dict["input_text"] for q_dict in target_data['questions']]
    rounds = len(questions)
    raw_answers = [target_data['answers']]

    try:
        raw_answers += target_data['additional_answers'].values()
    except:
        pass

    answers = {} # 1:["ans1","ans2"]
    for i_round in range(rounds):
        answers[i_round+1]=[]

    for raw_answer_seq in raw_answers:
        for raw_answer in raw_answer_seq:
            this_round = raw_answer["turn_id"]

            this_ori_ans = raw_answer["input_text"].strip(" ")
            if this_ori_ans not in answers[this_round]:
                answers[this_round].append(this_ori_ans)

            this_ori_story = raw_answer["span_text"].strip(" ")
            if this_ori_story not in answers[this_round]:
                answers[this_round].append(this_ori_story)
    
    def p111():
        for i in range(len(questions)):
            this_ques = questions[i]
            this_ans = answers[i+1][0]
            combined = combined+"Round {round_num} - {sentence_q}?\nRound {round_num} - {sentence_a}.\n".format(round_num=i+1,sentence_q=this_ques,sentence_a=this_ans)

        combined_content = ie_vanilla.extract_declerative_information(combined)
        target_data["combined"] = combined_content

    def p111b():
        combined_json = {}
        for i in range(len(questions)):
            this_ques = questions[i]
            this_ans = answers[i+1][0]
            round_str = "Round {round_num}".format(round_num=i+1)
            combined_json[round_str] = {"Question": this_ques, "Answer":this_ans}
        
        combined_json_content = ie_vanilla.question_resolution(combined_json)

        target_data["original_qa"] = combined_json
        target_data["full_qa"] = combined_json_content
    
    def p121():
        topic_content = ie_vanilla.extract_topic(combined_json_content)['topic']
        target_data["topic"] = topic_content

    def p122()
        entity_types_content = ie_vanilla.entity_types(input_topic = topic_content, input_dialogue = combined_content)
        target_data["entity_types"] = entity_types_content

    def p123():
        all_entities_content = ie_vanilla.entity_relations(entity_types = entity_types_content, input_text = combined_content)
        target_data["full_entity_relation"] = all_entities_content
    
    def p124a():
        round_subgraphs_content = ie_vanilla.round_subgraph(entity_list = all_entities_content["entities"], relation_list = all_entities_content["relations"], dialogue_content=combined_content)
        target_data["round_subgraph"] = round_subgraphs_content
    def p124b():    
        all_entity_types = target_data["entity_types"]
        all_entities = target_data["full_entity_relation"]["entities"]
        all_relations = target_data["full_entity_relation"]["relations"]

        try:
            erstools = ie_utils.ie_tools.entity_relation_set_tools(all_entity_types, all_entities, all_relations)
            round_keys = list(target_data["label"].keys())
            for round_key in round_keys:
                # print(round_key)
                this_round_question_entities = target_data["round_subgraph"][round_key]["Question"]["entities"]
                this_round_answer_entities = target_data["round_subgraph"][round_key]["Answer"]["entities"]
                this_round_question_relations = target_data["round_subgraph"][round_key]["Question"]["relations"]
                this_round_answer_relations= target_data["round_subgraph"][round_key]["Answer"]["relations"]
                target_data["round_subgraph"][round_key]["Question"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_question_entities])
                target_data["round_subgraph"][round_key]["Answer"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_answer_entities])
                target_data["round_subgraph"][round_key]["Question"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_question_relations])
                target_data["round_subgraph"][round_key]["Answer"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_answer_relations])

            all_entities = erstools.all_entities
            all_relations = erstools.all_relations
            target_data["entity_group_member_mapping"] = erstools.entity_group_member_mapping_cache
            target_data["DATA_PROCESS_FLAG"]=0 # normal

        except:
            try:
                erstools = entity_relation_set_tools(all_entity_types, all_entities[0:1], all_relations[0:1], lm)
                round_keys = list(target_data["label"].keys())
                for round_key in round_keys:
                    this_round_question_entities = target_data["round_subgraph"][round_key]["Question"]["entities"]
                    this_round_answer_entities = target_data["round_subgraph"][round_key]["Answer"]["entities"]
                    this_round_question_relations = target_data["round_subgraph"][round_key]["Question"]["relations"]
                    this_round_answer_relations= target_data["round_subgraph"][round_key]["Answer"]["relations"]
                    target_data["round_subgraph"][round_key]["Question"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_question_entities])
                    target_data["round_subgraph"][round_key]["Answer"]["entities"] = flatten([erstools.wash_entity(entity) for entity in this_round_answer_entities])
                    target_data["round_subgraph"][round_key]["Question"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_question_relations])
                    target_data["round_subgraph"][round_key]["Answer"]["relations"] = flatten([erstools.wash_relation(relation) for relation in this_round_answer_relations])

                all_entities = erstools.all_entities
                all_relations = erstools.all_relations
                target_data["entity_group_member_mapping"] = erstools.entity_group_member_mapping_cache
                target_data["DATA_PROCESS_FLAG"]=10 # rebuild subgraphs 
            except Exception as e:
                print("Error in entity relation extraction: ", e)

    retry_count = 0
    try: 
        # Run the processing functions
        p111()
        p111b()
        p121()
        p122()
        p123()
        p124a()
        p124b()
        target_data["DATA_PROCESS_FLAG"] = 0 # or 500 for failed extraction
    except Exception as e:
        # retry for 3 times
        retry_count += 1
        if retry_count < 3:
            # Run the processing functions
            p111()
            p111b()
            p121()
            p122()
            p123()
            p124a()
            p124b()
            target_data["DATA_PROCESS_FLAG"] = 0 # or 500 for failed extraction
        else:
            print(f"Failed to process data for story_id {story_id} after 3 retries. Error: {e}")
            target_data["DATA_PROCESS_FLAG"] = 500
    
    result_container[story_id] = target_data

# Save the processed data
output_path = "data/extracted_dev_all_final.json"
with open(output_path, "w") as f:
    json.dump(result_container, f)