In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import ast
import yaml
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Analysis of answers for `transpose_prompts`


## Load and clean the answers

In [4]:
answer_dir = os.path.join("answers", "transpose_prompts")

In [5]:
has_copy_code = []
no_copy_code = []

for answer_file in sorted(os.listdir(answer_dir)):
    answer_file_path = os.path.join(answer_dir, answer_file)

    answer_str = open(answer_file_path, "r").read().strip()

    if "Copy code" in answer_str:
        has_copy_code.append(answer_file)
    else:
        no_copy_code.append(answer_file)

print("Has copy code: ", len(has_copy_code))
print("No copy code: ", len(no_copy_code))

Has copy code:  192
No copy code:  15


The files that contain the stirng "Copy code" contain the answer formatted in a code block. These are likely to contain a potentially appropriate answer to the query.
Files containing only unstructured text are less likely to contain valid answers and will be analyzed separately.

In [6]:
answers = {}
missed = {}

for answer_file in sorted(has_copy_code):
    answer_file_path = os.path.join(answer_dir, answer_file)
    answer_str = open(answer_file_path, "r").read().strip()

    # print(answer_file)
    # print(answer_str)

    answer_str = answer_str.split("Copy code")[1].strip()

    if "ranks =" in answer_str:
        answer_str = answer_str.split("ranks =")[1].strip()
        answer_str = answer_str.split("}")[0].strip()
        answer_str = answer_str + "}"

    try:
        answer_dict = yaml.safe_load(answer_str)
    except:
        try:
            answer_dict = json.loads(answer_str)
        except:
            try:
                answer_dict = ast.literal_eval(answer_str)
            except:
                missed[answer_file] = answer_str
                continue
    
    # print(answer_dict)
    answers[answer_file] = answer_dict
    # print()
    # break

print("Missed: ", len(missed))
print("Answers: ", len(answers))

Missed:  7
Answers:  185


Let's now look at the other files.

In [7]:
additional_answers = {}
additional_missed = {}

for answer_file in sorted(no_copy_code):
    answer_file_path = os.path.join(answer_dir, answer_file)
    answer_str = open(answer_file_path, "r").read().strip()
    
    if "ranks =" in answer_str:
        answer_str = answer_str.split("ranks =")[1].strip()
        answer_str = answer_str.split("}")[0].strip()
        answer_str = answer_str + "}"

    try:
        answer_dict = yaml.safe_load(answer_str)
    except:
        try:
            answer_dict = json.loads(answer_str)
        except:
            try:
                answer_dict = ast.literal_eval(answer_str)
            except:
                additional_missed[answer_file] = answer_str
                continue
    
    # print(answer_dict)
    additional_answers[answer_file] = answer_dict

print("Additional missed: ", len(additional_missed))
print("Additional answers: ", len(additional_answers))

Additional missed:  0
Additional answers:  15


We can now merge the parsed answers together.

In [10]:
answers.update(additional_answers)
missed.update(additional_missed)

# Safety check
non_dict = []
for k,v in answers.items():
    if type(v) != dict:
        print("{} does not have a dict : {}".format(k, v))
        non_dict.append((k,v))

for (k,v) in non_dict:
    del answers[k]
    missed[k] = v

print("Total missed: ", len(missed))
print("Total answers: ", len(answers))

grafo8177.88.txt does not have a dict : len(ranks) sorted_}
Total missed:  8
Total answers:  199


## Load and clean the prompts

In [11]:
queries_dir = os.path.join("queries", "transpose_prompts")

In [12]:
queries = {}

for query_file in sorted(os.listdir(queries_dir)):
    queries[query_file] = {}
    query_file_path = os.path.join(queries_dir, query_file)

    query_str = open(query_file_path, "r").read().strip()

    query_edges = query_str.split("edges = ")[1]
    query_edges = query_edges.split("\n")[0].strip()
    query_edges = ast.literal_eval(query_edges)

    query_ranks = query_str.split("ranks = ")[1]
    query_ranks = query_ranks.split("\n")[0].strip()
    query_ranks = ast.literal_eval(query_ranks)
    
    queries[query_file]["edges"] = query_edges
    queries[query_file]["ranks"] = query_ranks

print("Number of queries: ", len(queries))

Number of queries:  207


## Safety check: all nodes are there

In [13]:
all_nodes_queries = {}
for query_file, query_dict in queries.items():
    all_nodes_queries[query_file] = set()
    for edge in query_dict["edges"]:
        all_nodes_queries[query_file].add(edge[0])
        all_nodes_queries[query_file].add(edge[1])

# Print the number of nodes in the graphs of each query
all_nodes_orig_number = [len(all_nodes_queries[query_file]) for query_file in all_nodes_queries]
print("\nUnique number of nodes in the graphs of each query:")
print(np.unique(all_nodes_orig_number, return_counts=True))

all_nodes_answers = {}
for answer_file, answer_dict in answers.items():
    all_nodes_answers[answer_file] = set()
    for rank, nodes in answer_dict.items():
        for node in nodes:
            all_nodes_answers[answer_file].add(node)

# Print the number of nodes in the graphs of each answer
all_nodes_answer_number = [len(all_nodes_answers[answer_file]) for answer_file in all_nodes_answers]
print("\nUnique number of nodes in the graphs of each answer:")
print(np.unique(all_nodes_answer_number, return_counts=True))

correct_nodes_answers = {}
incorrect_nodes_answers = {}
for answer_file, answer_dict in answers.items():
    answer_nodes = all_nodes_answers[answer_file]
    query_nodes = all_nodes_queries[answer_file]

    if answer_nodes == query_nodes:
        correct_nodes_answers[answer_file] = answer_dict
    else:
        incorrect_nodes_answers[answer_file] = answer_dict

print("\nCorrect nodes answers: ", len(correct_nodes_answers))
print("Incorrect nodes answers: ", len(incorrect_nodes_answers))


Unique number of nodes in the graphs of each query:
(array([10, 11]), array([ 79, 128]))

Unique number of nodes in the graphs of each answer:
(array([10, 11]), array([ 81, 118]))

Correct nodes answers:  195
Incorrect nodes answers:  4


We can now take a look at the different alterations in those graphs whose number of nodes is different than it should be.

In [14]:
increased_nodes_answers = {}
decreased_nodes_answers = {}

for answer_file, answer_dict in incorrect_nodes_answers.items():
    answer_nodes = all_nodes_answers[answer_file]
    query_nodes = all_nodes_queries[answer_file]

    if len(answer_nodes) > len(query_nodes):
        increased_nodes_answers[answer_file] = answer_dict
    else:
        decreased_nodes_answers[answer_file] = answer_dict

print("\nIncreased nodes answers: ", len(increased_nodes_answers))
print("Decreased nodes answers: ", len(decreased_nodes_answers))


Increased nodes answers:  0
Decreased nodes answers:  4
