In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
device_eval = 'cuda:1'
# Model for evaluating the correctness of the prediction compared to the label
model_id_eval = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tok_eval = AutoTokenizer.from_pretrained(model_id_eval)
model_eval = AutoModelForCausalLM.from_pretrained(model_id_eval, torch_dtype='auto').to(device_eval)

# Generate triplets and questions

In [1]:
from process_data import *



2024-09-17 10:35:33,137 SequenceTagger predicts: Dictionary with 20 tags: <unk>, O, S-ORG, S-MISC, B-PER, E-PER, S-LOC, B-ORG, E-ORG, I-PER, S-PER, B-MISC, I-MISC, E-MISC, I-ORG, B-LOC, E-LOC, I-LOC, <START>, <STOP>


## Topic selection

In [None]:
def get_topic_size(topic_name):
    topic = json.loads(f'{{"instance of": "{topic_name}"}}')
    query_part1 = "SELECT ?subjectLabel ?relation ?objectLabel WHERE {"
    query_part2 = ""
    relation_object_pairs = convert_topic_to_symbol(topic)
    for pair in relation_object_pairs:
        query_part2 += f"\n?subject wdt:{pair[0]} wd:{pair[1]} ."
    query_part3 = """
        ?subject  ?relation  ?object.
        ?subject wikibase:identifiers ?subject_identifierCount.
        ?object wikibase:identifiers ?object_identifierCount.
        """
    query_part5 = """ 
        FILTER (?subject_identifierCount >= 8 && ?object_identifierCount >= 5) .  
        SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
    }
    LIMIT 8000
    """
    query = query_part1 + query_part2 + query_part3 + query_part5
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()
    count = len(results['results']['bindings'])
    print(f"Topic {topic} size: {count}")
    return count

for t in []:
    get_topic_size(t)

In [None]:
# len = len(results['results']['bindings'])
# len = 0: ["invention", "animal species", "mineral", "Olympic Games", "train", "mathematics", "neuroscience", "robotics", 
#           "internet", "mobile phone", "3D printing", "bird", "academy awards", "movies", "movie", "grammy award", 'netflix series', 
#           'beverage', "climate", "astronomy",]
# len < 100: ["climate", "physics", "biology", "insect", "fish", "computer hardware", "plant", "sports team", "ecosystem", "reef", "wetland", "grassland",
#             'vehicle', 'airplane', 'bicycle', "animal", "chemical compound", "astronomical object", 'fruit', 'vegetable', 'cuisine', "planet", "physics", 
#             "chemistry", "mathematics", "biology", "geology", "ecology", "genetics", "space mission", "spacecraft", "particle", "species", "ecosystem", 
#             "hypothesis",]
# Error code 500: ["video game", "river", "protein", 'ship', "film", "human", "film", "human", "mountain","scientific journal", "gene", "album",
#                  "star", 'art_literary',  "painting", 'art_painting']

## Triplet Generation

In [15]:
def get_triplets(topic_name):
    with open(f'../data/topic/{topic_name}.json', 'r', encoding='utf-8') as topics_file:
        topics = topics_file.readlines()
    data = []
    # topic_name = 'country'
    # for topic in [f'{{"instance of": "{topic_name}"}}\n']:
    for topic in topics:
        if topic:
            topic = json.loads(topic)
            print(topic)
            query_part1 = "SELECT ?subjectLabel ?relation ?objectLabel WHERE {"
            query_part2 = ""
            relation_object_pairs = convert_topic_to_symbol(topic)
            for pair in relation_object_pairs:
                query_part2 += f"\n?subject wdt:{pair[0]} wd:{pair[1]} ."

            query_part3 = """
                ?subject  ?relation  ?object.
                ?subject wikibase:identifiers ?subject_identifierCount.
                ?object wikibase:identifiers ?object_identifierCount.
                """
            query_part5 = """ 
                FILTER (?subject_identifierCount >= 8 && ?object_identifierCount >= 5) .  
                SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
            }
            LIMIT 5000
            """
            query = query_part1 + query_part2 + query_part3 + query_part5
            sparql.setQuery(query)
            sparql.setReturnFormat(JSON)
            results = sparql.query().convert()
            time.sleep(1)
            # print(f"results: {results}")
            if "results" in results:
                # Create a list to store the data
                with ThreadPoolExecutor() as executor:
                    futures = [executor.submit(process_result, result) for result in results["results"]["bindings"]]
                    # Use tqdm to show the progress bar while waiting for all tasks to complete
                    for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
                        result_data = future.result()
                        if result_data:
                            data.append(result_data)

    fact_triplets = pd.DataFrame(data)  # data/triplet/raw/ store raw triplets before removing duplicate (subject, relation) pairs
    fact_triplets.to_csv(f'../data/triplet/raw/{topic_name}.csv', index=False)

In [None]:
# get_triplets('places_city') # 4:37:42  # get_triplets('places_landmark') # 4:37:42  # get_triplets('entertainment_music_genre') # 28:21

for topic in ['business_corporation', 'business_brand', 'business_industry']:
    get_triplets(topic)

To avoid question with multiple answers, check and remove (node, relation) with duplicates as shown below:\
{"subject": "Thailand", "relation": "diplomatic relation", "object": "Russia"}\
{"subject": "Thailand", "relation": "diplomatic relation", "object": "Brunei"}

In [32]:
directory = '../data/triplet/raw'
for filename in sorted(os.listdir(directory)):
    if filename.endswith('.csv'):
        fact_triplets = pd.read_csv(f'{directory}/{filename}')
        # print(fact_triplets.shape)
        fact_triplets.dropna(inplace=True)
        condition1 = fact_triplets.apply(lambda row: any(val.startswith('Q') and val[1:].isdigit() for val in row.values), axis=1)
        condition2 = fact_triplets.apply(lambda row: any(val.startswith('http') for val in row.values), axis=1)
        fact_triplets = fact_triplets[~(condition1 | condition2)]  

        # remove_pairs, all_pairs = set(), []
        # for i in fact_triplets.index:
        #     subject, relation = fact_triplets.loc[i, 'subjectLabel'], fact_triplets.loc[i, 'relation']
        #     if (subject, relation) in all_pairs:
        #         remove_pairs.add((subject, relation))
        #     else:
        #         all_pairs.append((subject, relation))
        # fact_triplets_new = fact_triplets[~fact_triplets.apply(lambda row: (row['subjectLabel'], row['relation']) in remove_pairs, axis=1)]
        fact_triplets_new = fact_triplets.drop_duplicates(subset=['subjectLabel', 'relation'], keep=False)
        print(f"{filename:<30} fact_triplets.shape: {fact_triplets.shape}, fact_triplets_new.shape: {fact_triplets_new.shape}", end=' ')
        print(f"Number of rows with subject==object: {fact_triplets_new[fact_triplets_new.subjectLabel == fact_triplets_new.objectLabel].shape[0]}")
        fact_triplets_new = fact_triplets_new[fact_triplets_new.subjectLabel != fact_triplets_new.objectLabel]
        fact_triplets_new.to_csv(f'../data/triplet/{filename}', index=False)

art_sculpture.csv              fact_triplets.shape: (1088, 3), fact_triplets_new.shape: (554, 3) Number of rows with subject==object: 4
business_brand.csv             fact_triplets.shape: (3577, 3), fact_triplets_new.shape: (1430, 3) Number of rows with subject==object: 13
business_corporation.csv       fact_triplets.shape: (851, 3), fact_triplets_new.shape: (340, 3) Number of rows with subject==object: 1
business_industry.csv          fact_triplets.shape: (1508, 3), fact_triplets_new.shape: (605, 3) Number of rows with subject==object: 8
entertainment_anime.csv        fact_triplets.shape: (627, 3), fact_triplets_new.shape: (360, 3) Number of rows with subject==object: 1
entertainment_music_genre.csv  fact_triplets.shape: (3007, 3), fact_triplets_new.shape: (1652, 3) Number of rows with subject==object: 19
entertainment_song.csv         fact_triplets.shape: (5227, 3), fact_triplets_new.shape: (2654, 3) Number of rows with subject==object: 25
event_film.csv                 fact_triplets

## Generate Questions from knowledge graph triplets

In [5]:
import os
import pandas as pd
model_ls = ['mistralai/Mistral-7B-Instruct-v0.3', 'meta-llama/Meta-Llama-3-8B-Instruct']  # 3.1, 'lmsys/vicuna-7b-v1.5'
model_id_format_ls = [e.split('/')[-1].replace('-', '_').lower() for e in model_ls]
model_id_format = model_id_format_ls[0]
print(f'model_id: {model_id_format}')

folder_unfiltered = f"../data/questions/unfiltered/{model_id_format}"
folder_hallu = f"../data/questions/hallucination_all/{model_id_format}"
folder_hallu_100 = f"../data/questions/hallucination/{model_id_format}_100"

model_id: mistral_7b_instruct_v0.3


In [3]:
folder_unfiltered = f"../data/questions/unfiltered/{model_id_format}"
print(f'folder_unfiltered: {folder_unfiltered}')
for filename in os.listdir('../data/triplet'):
    if filename.endswith('.csv'):
        fact_triplets_new = pd.read_csv(f'../data/triplet/{filename}')
        # if os.path.exists(f"{folder_unfiltered}/{filename}"):
        #     continue
        if not os.path.exists(folder_unfiltered):
            os.makedirs(folder_unfiltered)
        
        with open(os.path.join("../data/topic", filename.replace('.csv', '.json')), 'r', encoding='utf-8') as topics_file:
            topics = topics_file.readlines()
        topic = json.loads(topics[0])
        first_pair = next(iter(topic.items()))
        # print(f'filename: {filename}, topic: {first_pair[1]}')

        question_ls = []
        for i in tqdm(fact_triplets_new.index):
            subject, relation, object = fact_triplets_new.loc[i, 'subjectLabel'], fact_triplets_new.loc[i, 'relation'], fact_triplets_new.loc[i, 'objectLabel']
            question = generate_question(subject, relation, object, first_pair[1])
            if question:
                question_ls.append(question)
        df = pd.DataFrame(question_ls)
        df.rename(columns={'question': 'question_rule_based'}, inplace=True)
        df.to_csv(f"{folder_unfiltered}/{filename}", index=False) # [['subject', 'relation', 'object', 'question']]
        print(f"Topic: {first_pair[1]}, fact_triplets_new.shape: {fact_triplets_new.shape}, df_question.shape: {df.shape}", end=' ')
        print(f"Number of rows with label != object: {df[df.object != df.label].shape[0]}")

folder_unfiltered: ../data/questions/unfiltered/meta_llama_3.1_8b_instruct


100%|██████████| 202/202 [00:02<00:00, 83.47it/s]


Topic: recurring sporting event, fact_triplets_new.shape: (202, 3), df_question.shape: (115, 5) Number of rows with label != object: 0


100%|██████████| 523/523 [00:06<00:00, 81.26it/s]


Topic: forest, fact_triplets_new.shape: (523, 3), df_question.shape: (261, 5) Number of rows with label != object: 0


100%|██████████| 339/339 [00:03<00:00, 91.66it/s]


Topic: corporation, fact_triplets_new.shape: (339, 3), df_question.shape: (191, 5) Number of rows with label != object: 0


100%|██████████| 18513/18513 [03:10<00:00, 97.00it/s] 


Topic: city, fact_triplets_new.shape: (18513, 3), df_question.shape: (6805, 5) Number of rows with label != object: 0


100%|██████████| 1633/1633 [00:17<00:00, 91.08it/s]


Topic: music genre, fact_triplets_new.shape: (1633, 3), df_question.shape: (300, 5) Number of rows with label != object: 0


100%|██████████| 2629/2629 [00:28<00:00, 93.63it/s] 


Topic: song, fact_triplets_new.shape: (2629, 3), df_question.shape: (1935, 5) Number of rows with label != object: 0


100%|██████████| 423/423 [00:04<00:00, 89.13it/s]


Topic: programming language, fact_triplets_new.shape: (423, 3), df_question.shape: (311, 5) Number of rows with label != object: 0


100%|██████████| 2392/2392 [00:24<00:00, 98.92it/s] 


Topic: human, fact_triplets_new.shape: (2392, 3), df_question.shape: (1313, 5) Number of rows with label != object: 0


100%|██████████| 359/359 [00:03<00:00, 93.60it/s]


Topic: anime, fact_triplets_new.shape: (359, 3), df_question.shape: (299, 5) Number of rows with label != object: 0


100%|██████████| 3184/3184 [00:34<00:00, 91.11it/s] 


Topic: tourist attraction, fact_triplets_new.shape: (3184, 3), df_question.shape: (1832, 5) Number of rows with label != object: 0


100%|██████████| 597/597 [00:07<00:00, 84.68it/s]


Topic: industry, fact_triplets_new.shape: (597, 3), df_question.shape: (170, 5) Number of rows with label != object: 0


100%|██████████| 1788/1788 [00:19<00:00, 92.64it/s]


Topic: human, fact_triplets_new.shape: (1788, 3), df_question.shape: (1052, 5) Number of rows with label != object: 0


100%|██████████| 338/338 [00:03<00:00, 100.79it/s]


Topic: film festival, fact_triplets_new.shape: (338, 3), df_question.shape: (166, 5) Number of rows with label != object: 0


100%|██████████| 440/440 [00:04<00:00, 92.24it/s]


Topic: revolution, fact_triplets_new.shape: (440, 3), df_question.shape: (156, 5) Number of rows with label != object: 0


100%|██████████| 550/550 [00:06<00:00, 90.97it/s]


Topic: sculpture, fact_triplets_new.shape: (550, 3), df_question.shape: (377, 5) Number of rows with label != object: 0


100%|██████████| 830/830 [00:09<00:00, 85.58it/s]


Topic: disease, fact_triplets_new.shape: (830, 3), df_question.shape: (505, 5) Number of rows with label != object: 0


100%|██████████| 766/766 [00:08<00:00, 93.41it/s]


Topic: volcano, fact_triplets_new.shape: (766, 3), df_question.shape: (429, 5) Number of rows with label != object: 0


100%|██████████| 141/141 [00:01<00:00, 89.58it/s]


Topic: symptom, fact_triplets_new.shape: (141, 3), df_question.shape: (60, 5) Number of rows with label != object: 0


100%|██████████| 219/219 [00:02<00:00, 88.53it/s]


Topic: database, fact_triplets_new.shape: (219, 3), df_question.shape: (138, 5) Number of rows with label != object: 0


100%|██████████| 1417/1417 [00:14<00:00, 95.22it/s] 


Topic: brand, fact_triplets_new.shape: (1417, 3), df_question.shape: (787, 5) Number of rows with label != object: 0


100%|██████████| 4339/4339 [00:48<00:00, 89.30it/s]


Topic: human, fact_triplets_new.shape: (4339, 3), df_question.shape: (2687, 5) Number of rows with label != object: 0


100%|██████████| 822/822 [00:09<00:00, 90.73it/s]


Topic: software, fact_triplets_new.shape: (822, 3), df_question.shape: (633, 5) Number of rows with label != object: 0


100%|██████████| 236/236 [00:02<00:00, 84.86it/s]


Topic: medication, fact_triplets_new.shape: (236, 3), df_question.shape: (53, 5) Number of rows with label != object: 0


100%|██████████| 2761/2761 [00:31<00:00, 88.69it/s]


Topic: country, fact_triplets_new.shape: (2761, 3), df_question.shape: (1936, 5) Number of rows with label != object: 0


100%|██████████| 2592/2592 [00:27<00:00, 94.57it/s] 


Topic: human, fact_triplets_new.shape: (2592, 3), df_question.shape: (1546, 5) Number of rows with label != object: 0


100%|██████████| 335/335 [00:03<00:00, 96.45it/s]


Topic: glacier, fact_triplets_new.shape: (335, 3), df_question.shape: (107, 5) Number of rows with label != object: 0
folder_unfiltered: ../data/questions/unfiltered/mistral_7b_instruct_v0.3


100%|██████████| 202/202 [00:02<00:00, 89.12it/s]


Topic: recurring sporting event, fact_triplets_new.shape: (202, 3), df_question.shape: (115, 5) Number of rows with label != object: 0


100%|██████████| 523/523 [00:06<00:00, 81.28it/s]


Topic: forest, fact_triplets_new.shape: (523, 3), df_question.shape: (261, 5) Number of rows with label != object: 0


100%|██████████| 339/339 [00:03<00:00, 91.57it/s]


Topic: corporation, fact_triplets_new.shape: (339, 3), df_question.shape: (191, 5) Number of rows with label != object: 0


100%|██████████| 18513/18513 [03:11<00:00, 96.57it/s] 


Topic: city, fact_triplets_new.shape: (18513, 3), df_question.shape: (6805, 5) Number of rows with label != object: 0


100%|██████████| 1633/1633 [00:18<00:00, 90.34it/s]


Topic: music genre, fact_triplets_new.shape: (1633, 3), df_question.shape: (300, 5) Number of rows with label != object: 0


100%|██████████| 2629/2629 [00:28<00:00, 93.30it/s]


Topic: song, fact_triplets_new.shape: (2629, 3), df_question.shape: (1935, 5) Number of rows with label != object: 0


100%|██████████| 423/423 [00:04<00:00, 88.85it/s]


Topic: programming language, fact_triplets_new.shape: (423, 3), df_question.shape: (311, 5) Number of rows with label != object: 0


100%|██████████| 2392/2392 [00:24<00:00, 98.76it/s] 


Topic: human, fact_triplets_new.shape: (2392, 3), df_question.shape: (1313, 5) Number of rows with label != object: 0


100%|██████████| 359/359 [00:03<00:00, 93.29it/s]


Topic: anime, fact_triplets_new.shape: (359, 3), df_question.shape: (299, 5) Number of rows with label != object: 0


100%|██████████| 3184/3184 [00:35<00:00, 90.77it/s] 


Topic: tourist attraction, fact_triplets_new.shape: (3184, 3), df_question.shape: (1832, 5) Number of rows with label != object: 0


100%|██████████| 597/597 [00:06<00:00, 87.29it/s]


Topic: industry, fact_triplets_new.shape: (597, 3), df_question.shape: (170, 5) Number of rows with label != object: 0


100%|██████████| 1788/1788 [00:19<00:00, 91.34it/s]


Topic: human, fact_triplets_new.shape: (1788, 3), df_question.shape: (1052, 5) Number of rows with label != object: 0


100%|██████████| 338/338 [00:03<00:00, 100.17it/s]


Topic: film festival, fact_triplets_new.shape: (338, 3), df_question.shape: (166, 5) Number of rows with label != object: 0


100%|██████████| 440/440 [00:04<00:00, 91.89it/s]


Topic: revolution, fact_triplets_new.shape: (440, 3), df_question.shape: (156, 5) Number of rows with label != object: 0


100%|██████████| 550/550 [00:06<00:00, 90.52it/s]


Topic: sculpture, fact_triplets_new.shape: (550, 3), df_question.shape: (377, 5) Number of rows with label != object: 0


100%|██████████| 830/830 [00:09<00:00, 85.33it/s]


Topic: disease, fact_triplets_new.shape: (830, 3), df_question.shape: (505, 5) Number of rows with label != object: 0


100%|██████████| 766/766 [00:08<00:00, 93.20it/s]


Topic: volcano, fact_triplets_new.shape: (766, 3), df_question.shape: (429, 5) Number of rows with label != object: 0


100%|██████████| 141/141 [00:01<00:00, 89.26it/s]


Topic: symptom, fact_triplets_new.shape: (141, 3), df_question.shape: (60, 5) Number of rows with label != object: 0


100%|██████████| 219/219 [00:02<00:00, 88.57it/s]


Topic: database, fact_triplets_new.shape: (219, 3), df_question.shape: (138, 5) Number of rows with label != object: 0


100%|██████████| 1417/1417 [00:14<00:00, 95.03it/s] 


Topic: brand, fact_triplets_new.shape: (1417, 3), df_question.shape: (787, 5) Number of rows with label != object: 0


100%|██████████| 4339/4339 [00:48<00:00, 89.37it/s]


Topic: human, fact_triplets_new.shape: (4339, 3), df_question.shape: (2687, 5) Number of rows with label != object: 0


100%|██████████| 822/822 [00:09<00:00, 90.81it/s]


Topic: software, fact_triplets_new.shape: (822, 3), df_question.shape: (633, 5) Number of rows with label != object: 0


100%|██████████| 236/236 [00:02<00:00, 85.07it/s]


Topic: medication, fact_triplets_new.shape: (236, 3), df_question.shape: (53, 5) Number of rows with label != object: 0


100%|██████████| 2761/2761 [00:31<00:00, 88.74it/s]


Topic: country, fact_triplets_new.shape: (2761, 3), df_question.shape: (1936, 5) Number of rows with label != object: 0


100%|██████████| 2592/2592 [00:27<00:00, 95.20it/s] 


Topic: human, fact_triplets_new.shape: (2592, 3), df_question.shape: (1546, 5) Number of rows with label != object: 0


100%|██████████| 335/335 [00:03<00:00, 96.37it/s]

Topic: glacier, fact_triplets_new.shape: (335, 3), df_question.shape: (107, 5) Number of rows with label != object: 0





In [25]:
folder_unfiltered, model_id_format_ls[1:], other_model_id

('../data/questions/unfiltered/mistral_7b_instruct_v0.3',
 ['meta_llama_3_8b_instruct'],
 'meta_llama_3_8b_instruct')

In [24]:
import shutil
for other_model_id in model_id_format_ls[1:]:
    shutil.copytree(folder_unfiltered, f"../data/questions/unfiltered/{other_model_id}")
f"../data/questions/unfiltered/{other_model_id}"

'../data/questions/unfiltered/meta_llama_3_8b_instruct'

In [27]:
# # Remove all the output columns for all the files under "../data/questions/unfiltered/{other_model_id}"
# for filename in os.listdir(f"../data/questions/unfiltered/{other_model_id}"):
#     df = pd.read_csv(f"../data/questions/unfiltered/{model_id_format}/{filename}")
#     df.drop(columns=[f'output_mistral_7b_instruct_v0.3'], inplace=True)
#     df.to_csv(f"../data/questions/unfiltered/{other_model_id}/{filename}", index=False)

In [5]:
shape_data = {}
for model_id in model_ls:
    model_id_format = model_id.split('/')[-1].replace('-', '_').lower()
    folder_unfiltered = f"../data/questions/unfiltered/{model_id_format}"
    for filename in sorted(os.listdir(folder_unfiltered)):
        if filename.endswith('.csv'):
            df = pd.read_csv(f'{folder_unfiltered}/{filename}')
            topic = filename.replace('.csv', '')
            if topic not in shape_data:
                shape_data[topic] = {}
            shape_data[topic][model_id_format] = df.shape

shape_df = pd.DataFrame.from_dict(shape_data, orient='index')
# shape_df.columns = model_columns
shape_df.index.name = 'Topic'
shape_df.reset_index(inplace=True)

display(shape_df)

Unnamed: 0,Topic,meta_llama_3.1_8b_instruct,mistral_7b_instruct_v0.3,vicuna_7b_v1.5
0,art_sculpture,"(358, 6)","(377, 6)","(377, 6)"
1,business_brand,"(781, 6)","(787, 5)","(787, 6)"
2,business_corporation,"(191, 6)","(191, 5)","(191, 6)"
3,business_industry,"(170, 6)","(170, 5)","(170, 6)"
4,entertainment_anime,"(296, 6)","(299, 6)","(299, 6)"
5,entertainment_music_genre,"(300, 6)","(300, 6)","(300, 6)"
6,entertainment_song,"(1932, 6)","(1935, 6)","(1935, 6)"
7,event_film,"(166, 6)","(166, 6)","(166, 6)"
8,event_history,"(156, 6)","(156, 6)","(156, 6)"
9,event_sport,"(115, 6)","(115, 6)","(115, 6)"


In [9]:
shape_data = {}
for model_id in model_ls:
    model_id_format = model_id.split('/')[-1].replace('-', '_').lower()
    for filename in sorted(os.listdir(folder_hallu_100)):
        if filename.endswith('.csv'):
            df = pd.read_csv(f'{folder_hallu_100}/{filename}')
            topic = filename.replace('.csv', '')
            if model_id_format not in shape_data:
                shape_data[model_id_format] = {}
            shape_data[model_id_format][topic] = df.shape

shape_df = pd.DataFrame.from_dict(shape_data)
shape_df.index.name = 'Topic'
shape_df = shape_df.reset_index()
shape_df.sort_values(by=['Topic'])

Unnamed: 0,Topic,meta_llama_3.1_8b_instruct,mistral_7b_instruct_v0.3
0,art_sculpture,"(105, 8)","(105, 8)"
1,business_brand,"(100, 26)","(100, 26)"
2,business_corporation,"(100, 26)","(100, 26)"
3,business_industry,"(105, 8)","(105, 8)"
4,entertainment_anime,"(100, 26)","(100, 26)"
5,entertainment_music_genre,"(105, 8)","(105, 8)"
6,entertainment_song,"(105, 8)","(105, 8)"
7,event_film,"(47, 8)","(47, 8)"
8,event_history,"(88, 8)","(88, 8)"
9,event_sport,"(32, 8)","(32, 8)"


## Checking data

In [4]:
# Check for duplication and NaN values

remove_relation = ["topic's main category", "topic's main template", "described by source", "Commons category", "on focus list of Wikimedia project"]

for filename in os.listdir(folder_unfiltered):
    df = pd.read_csv(f"{folder_unfiltered}/{filename}")
    df_dup = df[df.duplicated(['subject', 'relation'], keep=False)]
    if len(df_dup) > 0:  # check duplicate (subject, relation) pairs
        print(f"In {filename}, there are {len(df_dup)} questions with duplicate (subject, relation) pairs:")
        
    if len(df[df['subject'] == df['object']]) > 0:  # Check if subject == object
        print(f"In {filename}, there are {len(df[df['subject'] == df['object']])} questions where subject == object")
        df = df[df['subject'] != df['object']]

    # for relation_check in remove_relation:
    #     if relation_check in df['relation'].to_list():
    #         print(f'Check {relation_check} relation for {filename}')

    if df[df.isna().any(axis=1)].shape[0] > 0:
        print(f"In {filename}, there are {df[df.isna().any(axis=1)].shape[0]} NaN values.")
        # df = df.dropna(subset=[f'output_{model_id_format}'])

    if len(df[df['label'] != df['object']]) > 0:
        print(f"In {filename}, there are {len(df[df['label'] != df['object']])} questions where label != object")
    else:
        df = df.drop(columns=['label'])  # label column equals to the object column
        df.to_csv(f"{folder_unfiltered}/{filename}", index=False)

In human_writer.csv, there are 2 questions where subject == object
