In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.chdir("../")

In [None]:
import ast
import yaml
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from util import read_files
from util import count_edge_crossings
from util import bfs

# Analysis of answers for `transpose_prompts`


## Load and clean the answers

In [None]:
answer_dir = os.path.join("answers", "transpose_prompts3")

In [None]:
has_copy_code = []
no_copy_code = []

for answer_file in sorted(os.listdir(answer_dir)):
    answer_file_path = os.path.join(answer_dir, answer_file)

    answer_str = open(answer_file_path, "r").read().strip()

    if "Copy code" in answer_str:
        has_copy_code.append(answer_file)
    else:
        no_copy_code.append(answer_file)

print("Has copy code: ", len(has_copy_code))
print("No copy code: ", len(no_copy_code))

The files that contain the stirng "Copy code" contain the answer formatted in a code block. These are likely to contain a potentially appropriate answer to the query.
Files containing only unstructured text are less likely to contain valid answers and will be analyzed separately.

In [None]:
answers = {}
missed = {}

for answer_file in sorted(has_copy_code):
    answer_file_path = os.path.join(answer_dir, answer_file)
    answer_str = open(answer_file_path, "r").read().strip()

    # print(answer_file)
    # print(answer_str)

    answer_str = answer_str.split("Copy code")[1].strip()

    if "ranks =" in answer_str:
        answer_str = answer_str.split("ranks =")[1].strip()
        answer_str = answer_str.split("}")[0].strip()
        answer_str = answer_str + "}"

    try:
        answer_dict = yaml.safe_load(answer_str)
    except:
        try:
            answer_dict = json.loads(answer_str)
        except:
            try:
                answer_dict = ast.literal_eval(answer_str)
            except:
                missed[answer_file] = answer_str
                continue
    
    # print(answer_dict)
    answers[answer_file] = answer_dict
    # print()
    # break

print("Missed: ", len(missed))
print("Answers: ", len(answers))

Let's now look at the other files.

In [None]:
additional_answers = {}
additional_missed = {}

for answer_file in sorted(no_copy_code):
    answer_file_path = os.path.join(answer_dir, answer_file)
    answer_str = open(answer_file_path, "r").read().strip()
    
    if "ranks =" in answer_str:
        answer_str = answer_str.split("ranks =")[1].strip()
        answer_str = answer_str.split("}")[0].strip()
        answer_str = answer_str + "}"

    try:
        answer_dict = yaml.safe_load(answer_str)
    except:
        try:
            answer_dict = json.loads(answer_str)
        except:
            try:
                answer_dict = ast.literal_eval(answer_str)
            except:
                additional_missed[answer_file] = answer_str
                continue
    
    # print(answer_dict)
    additional_answers[answer_file] = answer_dict

print("Additional missed: ", len(additional_missed))
print("Additional answers: ", len(additional_answers))

We can now merge the parsed answers together.

In [None]:
answers.update(additional_answers)
missed.update(additional_missed)

# Safety check
non_dict = []
for k,v in answers.items():
    if type(v) != dict:
        print("{} does not have a dict : {}".format(k, v))
        non_dict.append((k,v))

for (k,v) in non_dict:
    del answers[k]
    missed[k] = v

print("Total missed: ", len(missed))
print("Total answers: ", len(answers))

## Load and clean the prompts

In [None]:
queries_dir = os.path.join("queries", "transpose_prompts3")

In [None]:
queries = {}

for query_file in sorted(os.listdir(queries_dir)):
    queries[query_file] = {}
    query_file_path = os.path.join(queries_dir, query_file)

    query_str = open(query_file_path, "r").read().strip()

    query_edges = query_str.split("edges = ")[1]
    query_edges = query_edges.split("\n")[0].strip()
    query_edges = ast.literal_eval(query_edges)

    query_ranks = query_str.split("ranks = ")[1]
    query_ranks = query_ranks.split("\n\n")[0].strip()
    query_ranks = query_ranks.split("\n")
    # From each substring remove "Layer " at the 
    # beginning and add "," at the end
    query_ranks = [r[6:].strip() + "," for r in query_ranks]
    query_ranks = "".join(query_ranks)
    query_ranks = "{" + query_ranks[:-1] + "}"
    query_ranks = ast.literal_eval(query_ranks)
    
    queries[query_file]["edges"] = query_edges
    queries[query_file]["ranks"] = query_ranks

print("Number of queries: ", len(queries))

## Safety check: all nodes are there

In [None]:
all_nodes_queries = {}
for query_file, query_dict in queries.items():
    all_nodes_queries[query_file] = set()
    for edge in query_dict["edges"]:
        all_nodes_queries[query_file].add(edge[0])
        all_nodes_queries[query_file].add(edge[1])

# Print the number of nodes in the graphs of each query
all_nodes_orig_number = [len(all_nodes_queries[query_file]) for query_file in all_nodes_queries]
print("\nUnique number of nodes in the graphs of each query:")
print(np.unique(all_nodes_orig_number, return_counts=True))

all_nodes_answers = {}
for answer_file, answer_dict in answers.items():
    all_nodes_answers[answer_file] = set()
    for rank, nodes in answer_dict.items():
        for node in nodes:
            all_nodes_answers[answer_file].add(node)

# Print the number of nodes in the graphs of each answer
all_nodes_answer_number = [len(all_nodes_answers[answer_file]) for answer_file in all_nodes_answers]
print("\nUnique number of nodes in the graphs of each answer:")
print(np.unique(all_nodes_answer_number, return_counts=True))

correct_nodes_answers = {}
incorrect_nodes_answers = {}
for answer_file, answer_dict in answers.items():
    answer_nodes = all_nodes_answers[answer_file]
    query_nodes = all_nodes_queries[answer_file]

    if answer_nodes == query_nodes:
        correct_nodes_answers[answer_file] = answer_dict
    else:
        incorrect_nodes_answers[answer_file] = answer_dict

print("\nCorrect nodes answers: ", len(correct_nodes_answers))
print("Incorrect nodes answers: ", len(incorrect_nodes_answers))

We can now take a look at the different alterations in those graphs whose number of nodes is different than it should be.

In [None]:
increased_nodes_answers = {}
decreased_nodes_answers = {}

for answer_file, answer_dict in incorrect_nodes_answers.items():
    answer_nodes = all_nodes_answers[answer_file]
    query_nodes = all_nodes_queries[answer_file]

    if len(answer_nodes) > len(query_nodes):
        increased_nodes_answers[answer_file] = answer_dict
    else:
        decreased_nodes_answers[answer_file] = answer_dict

print("\nIncreased nodes answers: ", len(increased_nodes_answers))
print("Decreased nodes answers: ", len(decreased_nodes_answers))

In [None]:
WRONG_FORMAT = -1
MISSING_NODES = -2
EXTRA_NODES = -3

In [None]:
results = []
for i, query in enumerate(queries):


    c = count_edge_crossings(queries[query]["edges"], queries[query]["ranks"])
    if query in correct_nodes_answers:
        c2 = count_edge_crossings(queries[query]["edges"], answers[query])
    elif query in increased_nodes_answers:
        c2 = EXTRA_NODES
    elif query in decreased_nodes_answers:
        c2 = MISSING_NODES
    else:
        c2 = WRONG_FORMAT
    # print("original crossings: ", c, "new crossings: ", c2)
    results.append({
        "query": query,
        "original_crossings": c,
        "new_crossings": c2,
    })


    # break

results_df = pd.DataFrame.from_dict(results)

In [None]:
results_df

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Generate some sample data as a dictionary
data = {}

# Count the number of entries for which new_crossings is positive and new_crossings < original_crossings
improved_answers = list(filter(lambda x: x["new_crossings"] >= 0 and x["new_crossings"] < x["original_crossings"], results))
data["Improved answers"] = len(improved_answers)

# Count the number of entries for which new_crossings is positive and new_crossings > original_crossings
data["Worsened answers"] = len(list(filter(lambda x: x["new_crossings"] >= 0 and x["new_crossings"] > x["original_crossings"], results)))

# Count the number of entries for which new_crossings is positive and new_crossings == original_crossings
data["Equivalent answers"] = len(list(filter(lambda x: x["new_crossings"] >= 0 and x["new_crossings"] == x["original_crossings"], results)))

# Count the number of malformed answers
data["Malformed answers"] = len(list(filter(lambda x: x["new_crossings"] == WRONG_FORMAT, results)))

# Count the number of answers with missing nodes
data["Missing nodes"] = len(list(filter(lambda x: x["new_crossings"] == MISSING_NODES, results)))

# Count the number of answers with extra nodes
data["Extra nodes"] = len(list(filter(lambda x: x["new_crossings"] == EXTRA_NODES, results)))

# Convert the dictionary to a Pandas DataFrame for plotting
df = pd.DataFrame.from_dict(data, orient='index', columns=['Value'])

sns.set(rc={'figure.figsize':(15,3)})

# Use Seaborn to create a bar chart
sns.barplot(x=df.index, y='Value', data=df, color='skyblue')

# Set the labels for the x and y axes
# plt.xSabel('Label')
plt.ylabel('Count')

# Show the plot
plt.show()

In [None]:
for ans in improved_answers:
    q = ans["query"]
    print("original ranks: ", queries[q]["ranks"])
    print("original edges: ", queries[q]["edges"])
    print("response ranks:", answers[q])
    print("\n\n")
    # break
    

In [None]:
def visualize_graph(ranks:dict, edges:list):
    import networkx as nx

    edges_out_of_node = {}
    for edge in edges:
        if edge[0] not in edges_out_of_node:
            edges_out_of_node[edge[0]] = []
        edges_out_of_node[edge[0]].append(edge[1])
        if edge[1] not in edges_out_of_node:
            edges_out_of_node[edge[1]] = []
        edges_out_of_node[edge[1]].append(edge[0])

    G = nx.Graph()
    for (layer_n, layer) in ranks.items():
        G.add_nodes_from(layer, layer=layer_n)
        for node in layer:
            if node in edges_out_of_node:
                for edge_target in edges_out_of_node[node]:
                    G.add_edge(node, edge_target)
    pos = nx.multipartite_layout(G, subset_key="layer")
    nx.draw(G, pos, with_labels=True)
    plt.show()
        

In [None]:
i = 0
for k, q in queries.items():
    print(k, q)
    visualize_graph(q["ranks"], q["edges"])
    print("\n\n")
    i += 1
    if i == 10:
        break