In [None]:
import os
from pathlib import Path
import pandas as pd
from collections import defaultdict
from utils.get_processed_dataset import get_processed_dataset
from utils.utils import read_jsonl, get_major_entities
from omegaconf import OmegaConf
from configs.config import PRONOUNS_GROUPS, PLURAL_PRONOUNS, dataset_yaml, selected_keys
from configs.config_gen import NAME_TO_PREFIX
from tqdm.auto import tqdm
import hydra
from utils.qa_utils import write_qa_to_jsonl, get_mention_info, add_copelands_count
import dtale
import jsonlines
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import openai 
from openai import OpenAI
from dotenv import load_dotenv

In [None]:
os.chdir(Path(__file__).resolve().parents[1])

In [None]:
hydra.initialize(config_path="configs/args")
args = hydra.compose(config_name="args_qa.yaml")

# input_format_parts = ["{{", "}} (#This is the marked mention)"]
# output_format_parts = ["- Mention: ", "- Explanation:", "- The mention refers to:"]

In [3]:
def get_entity_gender(dataset_proc_all, major_entities_all, doc_key):
    pron_info = {}
    gender_info = {}
    for ind,entity_id in enumerate(major_entities_all[doc_key]["entity_id"]):
        entity_mentions = dataset_proc_all[doc_key]["clusters_vs_mentions"][entity_id]
        entity_mentions_str = [dataset_proc_all[doc_key]["mentions_vs_mentionstr"][mention].lower() for mention in entity_mentions]
        entity_mentions_pro_grps = [PRONOUNS_GROUPS[mention] for mention in entity_mentions_str if mention in PRONOUNS_GROUPS]
        counter_pro_grps = Counter(entity_mentions_pro_grps)
        ## Convert to a set ordered by frequency of grps
        counter_pro_grps = {k: v for k, v in sorted(counter_pro_grps.items(), key=lambda item: item[1], reverse=True)}
        pron_info[major_entities_all[doc_key]["entity_name"][ind]] = counter_pro_grps
        ## Determine the gender of the entity
        gender_info[major_entities_all[doc_key]["entity_name"][ind]] = "Unknown"
        for pro_grp in counter_pro_grps:
            if pro_grp == 2:
                gender_info[major_entities_all[doc_key]["entity_name"][ind]] = "Male"
                break
            elif pro_grp == 3:
                gender_info[major_entities_all[doc_key]["entity_name"][ind]] = "Female"
                break
            elif pro_grp in [4,7]:
                gender_info[major_entities_all[doc_key]["entity_name"][ind]] = "Neutral"
                break

    return pron_info, gender_info

In [None]:
dataset_configs = OmegaConf.load(dataset_yaml)
mention_info_df_final = pd.DataFrame()
major_entities_all = {}
dataset_proc_all = {}
for dataset_name in ["litbank", "fantasy"]:
    dataset_source = dataset_configs[dataset_name][f"train_file"]
    tsv_addr = dataset_configs[dataset_name][f"tsv"]
    doc_me = dataset_configs[dataset_name][f"train_me"]
    dataset = read_jsonl(dataset_source)
    dataset_proc = get_processed_dataset(dataset, tsv_litbank=tsv_addr)
    dataset_proc_all.update(dataset_proc)
    major_entities = get_major_entities(doc_me)
    major_entities_all.update(major_entities)
    mention_info_df = get_mention_info(dataset_proc, major_entities, tsv_addr=tsv_addr)
    mention_info_df["dataset"] = dataset_name
    mention_info_df_final = pd.concat([mention_info_df_final, mention_info_df])

In [5]:
def get_gpt_gender_output(entity_name, doc_key):
    load_dotenv()
    prompt = f"""Classify the given phrase based on its gender. If the phrase clearly indicates a male name or title, classify it as 'Male.' If the phrase indicates a female name or title, classify it as 'Female.' If the phrase does not specify a male or female name/title, or it is ambiguous, classify it as 'Neutral.' The phrase may include titles, names, or other identifiers. Your options are: Male, Female, Neutral.

Follow the format below:
Explanation: 
Gender: 

The phrase is: \"\"\"{entity_name}\"\"\""""
    ## Create an api call to gpt-4o-mini
    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
    conversation = [{"role": "user", "content": prompt},]
    output_string = "" 
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=conversation,
            temperature=0.0,
            max_tokens=4095,
            stream=True,
        )

        for chunk in completion:
            if chunk.choices[0].delta.content != None:
                output_string += chunk.choices[0].delta.content
            else:
                finish_reason = chunk.choices[0].finish_reason

    except Exception as e:
        print(f"Error in Document ID: {doc_key} and Entity Name: {entity_name}")
        print("Error: ", e)
        return {}

    if finish_reason == "content_filter":
        print("Content Filter Error in Document ID: {doc_key} and Entity Name: {entity_name}")
        print(f"Content Filter Error: {completion}")
        return {}

    if output_string != "":
        predicted_answer = output_string.split("Gender:")[-1].strip()
        print(output_string)
        return {
            "doc_key": doc_key,
            "entity_name": entity_name,
            "gender": predicted_answer,
            "prompted": True,
            "output_string": output_string
        }
    
    return {}


In [None]:
gender_info_list = []
metadata_dest = "data/metadata/gender_info.csv"
num_count = 0
for document in tqdm(dataset_proc_all):
    pron_info, gender_info = get_entity_gender(dataset_proc_all, major_entities_all, document)
    for name in gender_info:
        if gender_info[name] == "Unknown":
            num_count += 1
            gender_info_list.append(get_gpt_gender_output(name, document))
        else:
            gender_info_list.append(
                {
                    "doc_key": document,
                    "entity_name": name,
                    "gender": gender_info[name],
                    "prompted": False,
                    "output_string": "N/A"
                }
            )

print(f"Number of entities with unknown gender: {num_count}")
## Convert the list of dictionaries to a dataframe
gender_info_df = pd.DataFrame(gender_info_list)
## Save the dataframe to a csv
gender_info_df.to_csv(metadata_dest, index=False)
dtale.show(gender_info_df,subprocess=False,host='localhost', port=40001)

In [4]:
def get_gpt_dialogue_output(text, doc_key, mention_ind):
    load_dotenv()
    prompt = f"""Read the text given below. The text has an entity mention marked within \"\"\" {{mention}} (#This is the marked mention) \"\"\". Extract the mention and the sentence it occurs. Determine if the sentence is a dialogue or part of a dialogue.

Text:
{text}

Follow the below format:
- Mention:
- Sentence:
- Is the sentence a dialogue or part of dialogue: True/False"""

    ## Create an api call to gpt-4o-mini
    client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
    conversation = [{"role": "user", "content": prompt},]
    output_string = "" 
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=conversation,
            temperature=0.0,
            max_tokens=4095,
            stream=True,
        )

        for chunk in completion:
            if chunk.choices[0].delta.content != None:
                output_string += chunk.choices[0].delta.content
            else:
                finish_reason = chunk.choices[0].finish_reason

    except Exception as e:
        print(f"Error in Document ID: {doc_key} and Mention Ind: {mention_ind}")
        print("Error: ", e)
        return {}

    if finish_reason == "content_filter":
        print(f"Content Filter Error in Document ID: {doc_key} and Mention Ind: {mention_ind}")
        print(f"Content Filter Error: {completion}")
        return {}

    if output_string != "":
        predicted_answer = output_string.split("Is the sentence a dialogue or part of dialogue:")[-1].strip()
        print(output_string)
        print(f"Predicted Answer: {predicted_answer == "True"}")
        return {
            "doc_key": doc_key,
            "mention_ind": mention_ind,
            "is_dialogue": predicted_answer == "True",
            "output_string": output_string
        }
    
    return {}

In [None]:
dataset_address = "data/qas/data/qas_test.jsonl"
with jsonlines.open(dataset_address) as reader:
    qa_data = list(reader)

pron_dialogue_info = []
metadata_dest = "data/metadata/pron_dialogue_info.csv"

for obj in qa_data:
    doc_key = obj["doc_key"]
    text = obj["text"]
    category = obj["category"]
    if category == "PRON":
        pron_dialogue_info.append(get_gpt_dialogue_output(text, doc_key, obj["mention_ind"]))

## Convert the list of dictionaries to a dataframe
pron_dialogue_info_df = pd.DataFrame(pron_dialogue_info)
## Save the dataframe to a csv
pron_dialogue_info_df.to_csv(metadata_dest, index=False)

In [None]:
dtale.show(pron_dialogue_info_df,subprocess=False,host='localhost', port=40001)