In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
import ast
import numpy as np
import pandas as pd

from util import bfs

# In context learning prompts

In this notebook we will create a set of prompts to test the in context learning capabilities of ChatGPT when applied to graph visualization problems.

In [4]:
seed = 42
np.random.seed(seed)

## Rank assignment prompts

We will start with the problem of rank assignment. We will generate the new prompts by randomly sampling 5 graphs from different files, and prepending their correct solutions to the prompt asking for the solution of the current graph.

In [5]:
prompt_base = """
Perform a rank assignment on the graph. Use node 0 as a source for the graph. Each node must be assigned to a rank that is equal to the shortest path between that node and the source. Thus, node 0 will be assigned to rank 0, and the neighbors of node 0 will be assigned to rank 1. 
Write no explanations, only respond with the id of each node and the rank it has been assigned to in a format <id> - <rank>.\n
"""

In [6]:
rank_prompt_query_dir = "queries/rank_prompts"
all_rank_prompt_files = set(os.listdir(rank_prompt_query_dir))

In [7]:
def read_prompt_edge_list(query: str) -> list:
    rank_prompt_edge_list = (
        query.split("edge connections:")[1]
        .split("Perform a rank assignment")[0]
        .strip()
    )
    rank_prompt_edge_list = list(ast.literal_eval(rank_prompt_edge_list))
    return rank_prompt_edge_list

def rank_assignment_to_formatted_str(rank_assignment: dict) -> str:
    rank_assignment_str = ""
    for rank, nodes in rank_assignment.items():
        for node in nodes:
            rank_assignment_str += f"{node} - {rank}\n"
    return rank_assignment_str

def minimize_prompt(query: str) -> str:
    query = query.split("\n")
    query = query[:-2]
    query = "\n".join(query)
    return query

In [8]:
k_samples = 3
sample_size = 100

# Sample `sample_size` queries from the all_rank_prompt_files
sampled_rank_prompt_files = np.random.choice(
    list(all_rank_prompt_files), sample_size, replace=False
)
print(f"Sampled {len(sampled_rank_prompt_files)} queries")

Sampled 100 queries


In [9]:
rank_prompts_icl = {}

for rank_prompt_file in sampled_rank_prompt_files:
    # Sample 5 other prompt files different from the current one
    other_rank_prompt_files = np.random.choice(
        list(all_rank_prompt_files - {rank_prompt_file}), k_samples, replace=False
    )
    assert rank_prompt_file not in other_rank_prompt_files

    # Read the current prompt and the other prompts
    rank_prompt = open(os.path.join(rank_prompt_query_dir, rank_prompt_file)).read()
    other_rank_prompts = [
        open(os.path.join(rank_prompt_query_dir, other_rank_prompt_file)).read()
        for other_rank_prompt_file in other_rank_prompt_files
    ]

    # Extract the edge list for both the current prompt and the other prompts
    rank_prompt_edge_list = read_prompt_edge_list(rank_prompt)
    other_rank_prompt_edge_lists = [
        read_prompt_edge_list(other_rank_prompt)
        for other_rank_prompt in other_rank_prompts
    ]

    # Compute the correct rank assignment for the other prompts
    other_rank_prompt_rank_assignments = [
        bfs(other_rank_prompt_edge_list, 0)
        for other_rank_prompt_edge_list in other_rank_prompt_edge_lists
    ]

    # Convert the rank assignments in the expected format, i.e. <id> - <rank> one per row
    other_rank_prompt_rank_assignments_str = [
        rank_assignment_to_formatted_str(other_rank_prompt_rank_assignment)
        for other_rank_prompt_rank_assignment in other_rank_prompt_rank_assignments
    ]

    # Build the prompt
    prompt = [
        "Input:\n{}\nAnswer:\n{}\n".format(
            minimize_prompt(other_rank_prompts[i]),
            other_rank_prompt_rank_assignments_str[i],
        )
        for i in range(k_samples)
    ]
    prompt = prompt_base + "".join(prompt)
    prompt += "Input:\n{}\nAnswer:\n".format(minimize_prompt(rank_prompt))

    # expected_answer = bfs(rank_prompt_edge_list, 0)
    # expected_answer_str = rank_assignment_to_formatted_str(expected_answer)

    rank_prompts_icl[rank_prompt_file] = prompt.strip()

    del rank_prompt, other_rank_prompts, prompt


In [10]:
rank_prompts_icl_query_dir = "queries/rank_prompts_icl"
os.makedirs(rank_prompts_icl_query_dir, exist_ok=True)

In [11]:
# Write the prompts to disk
for rank_prompt_file, prompt in rank_prompts_icl.items():
    with open(os.path.join(rank_prompts_icl_query_dir, rank_prompt_file), "w") as f:
        f.write(prompt)