In [None]:
import random
import json
import numpy as np
from glob import glob
from tqdm import tqdm
from llmclient import LLMClient, get_llm_response
from utils import to_file, compute_acc

# for o3-mini
tokens = 50000
cot_tokens = 50000
max_apo_tokens = 50000
MODEL = "o3-mini"

# for most SLMs
tokens = 3
cot_tokens = 1024
max_apo_tokens = 512
MODEL = "phi-3.5-moe-instruct"

MAX_DATA_TO_EVAL = 1000 # Datasets are 2k each

root_out = MODEL
tkey = "max_tokens" if "o3" not in MODEL else "max_completion_tokens"
params = {tkey: tokens,} 
llm = LLMClient(params, MODEL)

# Shared functions

In [None]:
def get_predictions(dataset, max_calls, problem="parity", max_exemplars=10, use_desc="", desc="", shuffle_exemplars=False,
                    start_at=None, is_cot=False, stream=False, max_token_override=None, fname=None, debug=False):
    """ 
    Retrieve predictions for a given dataset.
    """
    predictions = []
    max_points = min(max_calls, len(dataset))
    if debug: print(f"Calling for max_points = {max_points}")

    if is_cot:
        llm.update_params({tkey: cot_tokens})
    else:
        llm.update_params({tkey: tokens})
    if max_token_override is not None:
        llm.update_params({tkey: max_token_override})

    if start_at is None:
        start_at = 0
    for i in tqdm(range(start_at, max_points), desc=desc):
        point = dataset[i]
        prediction = {k:v for k,v in point.items()}

        if debug: print(f"Getting prompt for {problem}: entry: {point["Entry"]}")

        if problem == "parity":
            prompt = get_prompt_for_parity(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "pattern_matching":
            prompt = get_prompt_for_pattern_matching(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "vending_machine":
            prompt = get_prompt_for_vending_machine(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "vending_machine_with_sum":
            prompt = get_prompt_for_vending_machine_with_sum(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "hamiltonian":
            prompt = get_prompt_for_hamiltonian(point["Entry"], point["Path"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "stack":
            prompt = get_prompt_for_stack(point["Entry"],use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "maze_complete":
            prompt = get_prompt_for_maze_complete(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "maze_solve":
            prompt = get_prompt_for_maze_solve(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        elif problem == "reversal":
            prompt = get_prompt_for_reversal(point["Entry"], use_desc=use_desc, max_exemplars=max_exemplars, shuffle_exemplars=shuffle_exemplars, is_cot=is_cot)
        else:
            raise "Problem \"{}\" not implemented!".format(problem)

        if debug: print(f"Got prompt: {prompt}")
        actual_response = get_llm_response(llm, prompt, debug=debug)
        if debug: print(f"Got response: {actual_response}")
        prediction["Response"] = actual_response
        predictions.append(prediction)
        if stream:
            with open(fname, "a", encoding="utf-8") as f:
                f.write(json.dumps(prediction, ensure_ascii=False) + ",\n")                

    return predictions


def call_and_solve_for(problem, suff, dataset, SHOTS=[2, 5, 10, 20, 50, 100], 
                       zero_shot_prompt = None, few_shot_prompt = None, shuffle_exemplars=False,
                       do_ood=False, out_dir_root="", is_cot=False, stream=True):
    """
    Call the LLM over an entire test suite (problem + ID/OOD + shots + prompt).
    This function supports restarts: simply pass in the shots you want in the form: 
        [(shot, datapoint_it_broke_at), more, shots, here] (for ID), or
        [(shot, delta, datapoint_it_broke_at), more, shots, here] (for OOD)
    """
    for shots in SHOTS:
        start_at = None
        start_delta = None
        if type(shots) == tuple:
            start_at = shots[-1]
            if len(shots) == 3: # OD
                start_delta = shots[1]
            shots = shots[0]

        system_prompt = ""
        if shots == 0 and zero_shot_prompt is not None:
            system_prompt = zero_shot_prompt
        if shots != 0 and few_shot_prompt is not None:
            system_prompt = few_shot_prompt

        if not do_ood:
            preds_raw_id = get_predictions(dataset, MAX_DATA_TO_EVAL, max_exemplars=shots, problem=problem, use_desc=system_prompt,
                                            desc="ID {}".format(shots), is_cot=is_cot, start_at=start_at, shuffle_exemplars=shuffle_exemplars,
                                           stream=stream, fname="{}preds_raw_id_{}_shot_{}.json".format(out_dir_root, shots, suff))
            if not stream:
                to_file(preds_raw_id, "{}preds_raw_id_{}_shot_{}.json".format(out_dir_root, shots, suff), dumpall=False)
        else:
            oods_negs = {}
            for delta in [0.2, 0.45, 0.65, 0.85]: 
                if start_delta is not None and delta < start_delta:
                    continue
                if start_delta != delta:
                    start_at = None
                corpus = dataset[0]["delta_{}".format(str(delta))]
                preds_raw_od = get_predictions(corpus, MAX_DATA_TO_EVAL, max_exemplars=shots, problem=problem, use_desc=system_prompt,
                                               desc="OD {} | {}".format(shots, str(delta)), is_cot=is_cot, start_at = start_at, shuffle_exemplars=shuffle_exemplars,
                                               stream=stream, fname="{}tmp_{}_preds_raw_od_{}_shot_{}.json".format(out_dir_root, str(delta), shots, suff))
                oods_negs["delta_{}".format(str(delta))] = preds_raw_od
            to_file(oods_negs, "{}preds_raw_od_{}_shot_{}.json".format(out_dir_root, shots, suff), dumpall=True)

In [None]:
def apo(initial_prompt, dataset, problem="parity", beam_width=4, search_depth=6, max_token_override=None, debug=False):
    """
    Main loop for the APO algorithm. All are the defaults from the original paper.
    ---
    Params:
    initial_prompt: str: the initial prompt on which to run APO
    dataset: list of entries compatible with `get_predictions_for_version_apo`
    version: prompt version (i.e., one of "01234567")
    llm: the LLMClient object
    beam_width: ...the width of the beam? Defaults to 4
    search_depth: int to determine the depth (loops) at which you will run APO.
    max_token_override: dict of the form {"tkey": tkey, "max_tokens": max_tokens to override}. You need both because some LLMs use different token keys.
    debug: debug.
    """    
    b0 = [(initial_prompt, 0)]
    for i in range(search_depth):
        candidates = []
        # Our train set is 4k -- model will only see 64.
        subset = random.sample(dataset[2000:], k=64)
        if debug: print(f"b0: {b0}")
        for prompts, scores in b0:
            candidates += expand(prompts, subset, problem, max_token_override=max_token_override, debug=debug)
        # Sampling to avoid overrun
        _cands = random.sample(candidates, k=8) if len(candidates) > 8 else candidates
        if debug: print(f"chosen candidates {_cands}")
        if _cands != []:
            b0 += select(_cands, dataset, problem, beam_width)
        if debug: break
    b0.sort(key=lambda x: x[-1], reverse=True)
    return b0


def expand(p_candidate, subset, problem, max_token_override=None, max_errors=4, debug=False):
    """
    Beam search proper.
    ---
    Params:
    p_candidate: a list of candidate prompts
    subset: the list of entries compatible with `get_predictions_for_version_apo` (i.e., the dataset)
    version: the prompt version
    llm: the LLMClient object
    max_errors: the maximum number of mismatched datapoints allowable to run for the gradient version
    max_token_override: dict of the form {"tkey": tkey, "max_tokens": max_tokens to override}. You need both because some LLMs use different token keys.
    debug: gee I wonder what this does.
    """
    if debug: print("Calling for predictions")
    resps = get_predictions(subset, len(subset), problem=problem, max_exemplars=2, use_desc=p_candidate, debug=debug, desc="APO")
    errors = []
    accuracy = compute_acc(resps, is_apo=True)[0]
    if debug: print(f"Called {len(subset)} predictions and got accuracy {accuracy}")
    for entry in resps:
        if entry["Response"][0].isnumeric():
            pred = int(entry["Response"][0])
            if pred != entry["Label"]:
                errors.append(entry)
    errors = random.sample(errors, k=min(len(errors), max_errors))
    # Minor hack (not in the paper): if no errors, we've converged
    if debug: print(f"Selected {len(errors)}")
    if errors == []:
        successors = [p_candidate]
    else:
        successors = gradient_and_edit(p_candidate, errors, max_token_override=max_token_override, debug=debug)
    return successors


def gradient_and_edit(p_candidate, errors, max_token_override=None, num_reasons=4, edits_per_gradient=1, num_mc_samples=2, debug=False):
    """
    Run the pseudo gradient.
    All params and prompts are the defaults from the paper.
    ---
    Params:
    p_candidate: a list of candidate prompts
    errors: a list of entries where the label was wrong compatible with `get_predictions_for_version_apo` (i.e., the dataset)
    version: the prompt version
    llm: the LLMClient object
    max_token_override: dict of the form {"tkey": tkey, "max_tokens": max_tokens to override}. You need both because some LLMs use different token keys.
    num_reasons: number of reasons required for the LLM to output as to why this was wrong
    edits_per_gradient: number of edits done for the prompts that were wrong
    num_mc_samples: the samples for the Monte-Carlo bit
    debug: debug.
    """
    # Defaults from the paper
    def gradient_prompt(p, e, f): 
        resp = f"I'm trying to write a zero-shot classifier prompt.\nMy current prompt is:\n\"{p}\"\nBut this prompt gets the following examples wrong:\n{e}\ngive {f} reasons why the prompt could have gotten these examples wrong.\nWrap each reason with <START> and <END>"
        return resp

    def edition_prompt(p, e, g, n):
        resp = f"I'm trying to write a zero-shot classifier.\n My current prompt is:\n\"{p}\"\nBut it gets the following examples wrong:\n{e}\nBased on these examples the problem with this prompt is that:\n{g}\nBased on the above information, I wrote {n} different improved prompts.\nEach prompt is wrapped with <START> and <END>.\nThe {n} new prompts are:"
        return resp
    
    def mc_prompt(p):
        resp = f"Generate a variation of the following instruction while keeping the semantic meaning.\nInput: {p}\nOutput:"
        return resp

    llm.update_params({tkey: max_apo_tokens if max_token_override is None else max_token_override})
    
    # Direct call to model, postprocess "gradient"
    if "Path" not in errors[0]:
        error_string = "\n - ".join([p["Entry"] + ": " + str(p["Label"]) for p in errors])
    else:
        # Hamiltonian has a sliiiightly different signature
        error_string = "\n - ".join([f"{p['Entry']} | path: {p['Path']} : {p['Label']}" for p in errors])
    prompt = gradient_prompt(p_candidate, error_string, str(num_reasons))
    if debug: print(f"Prompt is \n{prompt}\n for the gradient step")
    response = get_llm_response(llm, prompt)
    if debug: print(f"ot \n{response}\n from the gradient step")
    
    # Edit the prompt -- one edit per gradient.
    edited_prompts = []
    # Note: the original paper doesn't specify parsing. We need to keep it or else this algorithm won't work.
    response_processed = [r.replace("<END>", "").strip() for r in response.split("<START>") if "<END>" in r]
    for g in response_processed:
        prompt = edition_prompt(p_candidate, error_string, g.strip(), str(edits_per_gradient))
        response = get_llm_response(llm, prompt)
        if "4o" in MODEL or "turbo" in MODEL:
            edited_prompts.append(response.strip())
        else:
            # this will inject boilerplate from the model in omni
            edited_prompts.append(response.strip().split("\n")[0].strip())
    # Do MC search
    # Two candidates per instruction:
    if debug: print(f"Here are the edited prompts:\n{edited_prompts}")
    candidates = []
    for c in edited_prompts:
        # I would like to make this a batch call but I don't want to alter the original work.
        for _ in range(num_mc_samples):
            response1 = get_llm_response(llm, mc_prompt(c))
            # Same as before: no parsing in paper
            response_processed = [r.replace("<END>", "").strip() for r in response1.split("<START>") if "<END>" in r]
            if response_processed != []:
                response1 = response_processed[0]
            candidates.append(response1)

    llm.update_params({tkey: tokens})

    return candidates


def select(candidates, dataset, problem, beam_size, B=12):
    """
    Select candidates.
    ---
    Params:
    candidates: a list of prompt candidates
    dataset: a list of entries compatible with `get_predictions_for_version_apo` (i.e., the dataset)
    version: the version for the prompt
    llm: the LLMClient object
    beam_size: the size of the beam
    """
    # Paper states that B in 12-50 keeps it steady.
    # This is a very confusing and not-very-well-written algorithm.
    # We'll implement it verbatim from the paper though.
    S = [(p, 0) for p in candidates]
    old_s = [s for s in S]

    def get_n(i, T):
        # What is T lmao 
        # From the below, our maximum number of iterations (n) cannot be larger than T
        # Likewise, B gets too degenerate when close to T (T < B, otherwise you'll just get zeroes)
        left = 1/(0.5*sum([1/j for j in range(2, T + 1)]))
        right = (B-T)/(T + 1 - i)
        return int(np.ceil(left*right))

    # Says 1... n - 1 for n prompts; index shift python
    m = min(len(candidates) - 2, beam_size + 1)
    for i in range(1, m):
        subset = random.sample(dataset, k=get_n(i, m))
        # Original paper isn't very clear here (i is not defined; We'll evaluate all prompts)
        S_exp = []
        for prompt, _ in S:
            resps = get_predictions(subset, len(subset), problem=problem, max_exemplars=2, use_desc=prompt)
            accuracy = compute_acc(resps, is_apo=True)[0] 
            S_exp.append((prompt, accuracy))
        S_exp.sort(key = lambda x: x[-1], reverse=True) # Decreasing
        S = [p for p in S_exp[:-1]]
    # We also need to address this corner case, not addressed in the paper
    if S == []:
        S = old_s
    return [S[0]]


# PARITY

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/PARITY/corpus_parity_id_4k_new.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/PARITY/corpus_parity_id_4k_new_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/PARITY/corpus_parity_ood_4k_test_new_with_negs.json", "r", encoding="utf-8").readlines()]

problem = "parity"
out_dir_root = f"{root_out}/PARITY/"

def get_prompt_for_parity(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    def build_cot_exemplar(ex, lab):
        count = ex.count("0")
        is_even = lab == "1" # Need this for noise as opposed to count % 2 == 0
        exemplar = "Let's think and solve this step-by-step. " #"{}: \n".format(ex) <- this is for chat
        if type(use_desc) == str:
            exemplar += "There are {} zeros in the string.\n".format(count)
            exemplar += "{} is an {} number.\n".format(count, "even" if is_even else "odd")
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=2))} {count} {' '.join(random.sample(use_desc, k=4))}.\n"
            exemplar += f"{count} {' '.join(random.sample(use_desc, k=2))} {'even' if is_even else 'odd'} {random.sample(use_desc, k=1)}.\n"
        exemplar += "So the answer is {}".format(lab)        
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt

## Modus ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
parity_desc = "This task is called PARITY. The strings in PARITY are generated from a probabilistic automaton.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labeled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the number of zeros in the input string is even, the label is always 1. Else, it is 0.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000:1\nThere are an odd number of zeros in the string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#92

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#92

## CoT

In [None]:
parity_desc = "This task is called PARITY. The strings in PARITY are generated from a probabilistic automaton.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labeled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the number of zeros in the input string is even, the label is always 1. Else, it is 0.\n"

parity_desc_zero_shot = parity_desc + "Give  your reasoning in a new line, and your answer as a single integer.\n"
parity_desc_zero_shot += "For example:\n000:\nThere are an odd number of zeros in the string.\nSo the answer is 1.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"CoT_{problem}"
llm.update_params({tkey: cot_tokens})

In [None]:
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True) #54

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True) #689

## Automata encoded

In [None]:
ecm_str_desc = "def ecm(P, is_neg=False, max_length = None):\n    # P should be either P or Q (test)\n    p00, p01, p0f = P[0]\n    p10, p11, p1f = P[1]\n\n    state = \"S0\"\n    end_state = \"S0\" if is_neg else \"S1\"\n    tape = \"0\" \n    steps = 0\n\n    while True:\n        x = pull()\n        #print(state, end_state, x, tape)\n        if state == end_state:\n            # Assume p0/1f is lower than the other two\n            if state == \"S0\":\n                if x <= p0f:\n                    break\n                else:\n                    if x <= p00:\n                        state = \"S0\"\n                        tape += \"1\"\n                    else:\n                        state = \"S1\"\n                        tape += \"0\"\n\n            else:\n                if x <= p1f:\n                    break\n                else:\n                    if x <= p11:\n                        state = \"S1\"\n                        tape += \"1\"\n                    else:\n                        state = \"S0\"\n                        tape += \"0\"\n        else:\n            # Only worry about probs\n            if state == \"S1\":\n                if x <= p10:\n                    state = \"S0\"\n                    tape += \"0\"\n                else:\n                    state = \"S1\"\n                    tape += \"1\"\n            else:\n                if x <= p00:\n                    state = \"S0\"\n                    tape += \"1\"\n                else:\n                    state = \"S1\"\n                    tape += \"0\"\n        steps += 1\n        if max_length is not None:\n            if steps >= max_length:\n                break\n    return tape"

parity_desc = "This task is called PARITY. The strings in PARITY are generated from a probabilistic automaton shown below\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labeled 0 or 1, and output the correct label. You can execute the automaton to decide what is the right label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the number of zeros in the input string is even, the label is always 1. Else, it is 0.\n"

parity_desc += "Here's the automaton:\n\n"
parity_desc += ecm_str_desc
parity_desc += "\n\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000:1\nThere are an odd number of zeros in the string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"


suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
parity_desc = "This task is called PARITY. The strings in PARITY are generated from a probabilistic automaton.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labeled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the number of zeros in the input string is even, the label is always 1. Else, it is 0.\n"
parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
suffix = f"apo_{problem}"

# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
suffix = f"apo_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

# Pattern Matching Automata

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/Pattern_Matching/corpus_pattern_matching_id_4k_new.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/Pattern_Matching/corpus_pattern_matching_id_4k_new_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/Pattern_Matching/corpus_pattern_matching_ood_4k_new_test.json", "r", encoding="utf-8").readlines()]

problem = "pattern_matching"
out_dir_root = f"{root_out}/Pattern_Matching/"

def get_prompt_for_pattern_matching(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    pattern = "abcabb"

    def get_match(c, tally):
        if tally == "":
            return pattern[0] == c
        elif pattern == tally:
            return True
        elif len(tally) > len(pattern):
            return False
        else:
            return pattern[len(tally)] == c

    def build_cot_exemplar(ex, lab):
        has_pattern = lab == "1" # Need this for noise as opposed to count % 2 == 0
        exemplar = "\n{}: \n".format(ex)
        exemplar += "Let's think and solve this step-by-step. "
        if type(use_desc) == str:
            exemplar += "We read the string character-by-character and keep a tally:\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=11))}:\n"
        tally = ""
        if len(ex) >= len(pattern):
            for i, c in enumerate(ex):
                is_match = get_match(c, tally)
                if type(use_desc) == str:
                    exemplar += "We read \"{}\". It is {}. ".format(c, "not a match" if not is_match else "a match")
                else:
                    match_str = "not a match" if not is_match else "a match"
                    exemplar += f"{' '.join(random.sample(use_desc, k=2))} \"{c}\". {' '.join(random.sample(use_desc, k=2))} {match_str}. "
                if is_match:
                    tally += c
                    if type(use_desc) == str:
                        exemplar += "Our tally is: {}. ".format(tally)
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=3))}: {tally}. "
                    if tally == pattern and lab == "1":
                        if type(use_desc) == str:
                            exemplar += "Our tally matches the pattern. "
                        else:
                            exemplar += f"{' '.join(random.sample(use_desc, k=5))}. "
                        break
                else:
                    if tally != "":
                        if type(use_desc) == str:
                            exemplar += "We clear our tally. "
                        else:
                            exemplar += f"{' '.join(random.sample(use_desc, k=4))}. "
                    tally = ""
                if type(use_desc) == str:
                    exemplar += "Now we move to the next character.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=7))}.\n"
            if lab == "0" or i == len(ex) - 1:
                if type(use_desc) == str:
                    exemplar += "We have reached the end of the string. "
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=8))}. "
        else:
            if type(use_desc) == str:
                exemplar += "The length of the example is too short. "
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=8))}. "

        if type(use_desc) == str:
            exemplar += "The pattern \"{}\" is {}.\n".format(pattern, "in the string" if has_pattern else "is not in the string")
        else:
            pat_in_str = "in the string" if has_pattern else "is not in the string"
            exemplar += f"{' '.join(random.sample(use_desc, k=2))} \"{pattern}\" {random.choice(use_desc)} {pat_in_str}.\n"
        exemplar += "So the answer is {}".format(lab)
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt

## Modus Ponens

In [None]:
suffix = f"modus_ponens_{problem}"
#0
call_and_solve_for(problem=problem, suff=suffix, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suffix, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suffix = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suffix, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suffix, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
parity_desc = "This is a pattern matching task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each string is labelled 0 or 1 depending on whether the pattern \"abcabb\" is (1) or is not (0) in the string.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the pattern is present in the string, the label is always 1. Else, it is 0.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\naabcabbb:1\nThe pattern is present in the string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],  
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],  
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], #0, 2, 5, 10
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], #0, 2, 5, 10
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


## CoT

In [None]:
parity_desc = "This is a pattern matching task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each string is labelled 0 or 1 depending on whether the pattern \"abcabb\" is (1) or is not (0) in the string.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the pattern is present in the string, the label is always 1. Else, it is 0.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\naabcabbb:1\nThe pattern is present in the string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: cot_tokens})

In [None]:
suffix = f"CoT_{problem}"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) 

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100], 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 5, 10, 20, 50, 100
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) #790

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100], #0, 5, 10, 20 #17
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True) #689

## Automata

In [None]:
ecm_str_desc = "def pattern_matching(P, is_neg=False, max_length = 24_000):\n    \"\"\"\n    Automaton generating strings for the pattern matching problem:\n    L = {w \in {a, b, c}* : \exists x, y \in {a, b, c}* (w = x abcabb y)}.\n    \"\"\"\n    # Transition matrix for the automaton\n    # T[i][j] means what you write on the tape when in Si and transitioning to Sj\n    # When there's a comma, random choices are needed.\n    # S0, # S1, # S2, # S3, # S4, # S5, # S6, # Sf\n    T = [[\"b,c\", \"a\",  None, None, None, None, None],   # S0\n         [\"c\",   \"a\",  \"b\",  None, None, None, None],   # S1\n         [\"b\",   \"a\",  None, \"c\",  None, None, None],    # S2\n         [\"b,c\", None, None, None, \"a\",  None, None],   # S3\n         [\"c\",   \"a\",  None, None, None, \"b\",  None],   # S4\n         [None,  \"a\",  None, \"c\",  None, None, \"b\"],    # S5\n         [None,  None, None, None, None, None, \"a,b,c\"]]# S6\n    if is_neg:\n        T = [T[6], T[0], T[1], T[2], T[3], T[4], T[5]]\n\n    state = \"S0\"\n    end_states = [\"S0\"] if is_neg else [\"S6\"]\n    tape = \"\" # Start with the empty tape \n    steps = 0\n\n    # Transitions _to_ the state in question. \n    # - weights are handled by P\n    arr = [0, 1, 2, 3, 4, 5, 6, \"f\"] \n\n    while True:\n        ix = int(state[-1])\n        x = random.choices(arr, weights=P[ix], k=1)[0]\n        if x == \"f\":\n            break\n\n        next_state = \"S\" + str(x)\n        write_to = T[ix][x]\n        if \",\" in write_to:\n            write_to = random.choice(write_to.split(\",\"))\n        tape += write_to\n        state = next_state\n        \n        steps += 1\n        if max_length is not None:\n            if steps >= max_length:\n                return None\n    if tape == \"\":\n        return None\n    return tape"

parity_desc = "This is a pattern matching task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each string is labelled 0 or 1 depending on whether the pattern \"abcabb\" is (1) or is not (0) in the string.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the pattern is present in the string, the label is always 1. Else, it is 0.\n"

parity_desc += "Here's the automaton:\n\n"
parity_desc += ecm_str_desc
parity_desc += "\n\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\naabcabbb:1\nThe pattern is present in the string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: tokens})

suff = f"w_automaton_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## APO

In [None]:
parity_desc = "This is a pattern matching task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each string is labelled 0 or 1 depending on whether the pattern \"abcabb\" is (1) or is not (0) in the string.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"
parity_desc += "In the limit where the automaton is deterministic, if the pattern is present in the string, the label is always 1. Else, it is 0.\n"
parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
suff = f"apo_{problem}"
llm.update_params({tkey: tokens})

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"apo_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Vending Machine

## Shared Functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/Vending_Machine/corpus_vending_machine_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/Vending_Machine/corpus_vending_machine_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/Vending_Machine/corpus_vending_machine_ood_4k_test.json", "r", encoding="utf-8").readlines()]

out_dir_root = f"{root_out}/Vending_Machine/"

def assert_string(string):
    vals = {"biscuit": 20,
            "coffee": 15,
            "soda": 25}
    valid = True
    balances = [0]
    balance = 0
    for e in string.split(","):
        if e == "":
            continue
        if e[0] == "+":
            balance += int(e[1:])
            balances.append(balance)
        else:
            if balance - vals[e] < 0:
                return random.choice(balances), False
            else:
                balance = balance - vals[e]
    return balance, True

def get_prompt_for_vending_machine(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    '''Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    '''
    vals = {"biscuit": 20,
            "coffee": 15,
            "soda": 25}

    def build_cot_exemplar(ex, lab):
        is_positive = lab == "1"
        exemplar = f"{ex}: \n"
        balance = 0
        balances = []

        if type(use_desc) == str:
            exemplar += "Let's think and solve this step-by-step. We start with a balance of 0.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=6))}. {' '.join(random.sample(use_desc, k=7))}.\n"            
        for e in ex.split(",")[:-1]:
            if e == "":
                continue
            if type(use_desc) == str:
                exemplar += f"We read \"{e}\","
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=2))} \"{e}\","
            if e[0] == "+":
                balance += int(e[1:])
                balances.append(balance)
                if type(use_desc) == str:
                    exemplar += f" so we add {e[1:]} to our current balance and we now have {balance}.\n"
                else:
                    exemplar += f" {' '.join(random.sample(use_desc, k=3))} {e[1:]} {' '.join(random.sample(use_desc, k=8))} {balance}.\n"                    
            else:
                if type(use_desc) == str:
                    exemplar += f" so we return a {e} and substract {vals[e]} from our balance and now we have {balance - vals[e]}.\n"
                else:
                    exemplar += f" {' '.join(random.sample(use_desc, k=3))} {e} {' '.join(random.sample(use_desc, k=2))} {vals[e]} {' '.join(random.sample(use_desc, k=7))} {balance - vals[e]}.\n"
                if balance - vals[e] < 0:
                    balance = balance - vals[e]
                else:
                    balance = balance - vals[e]

        final_balance = ex.split(",")[-1]
        if type(use_desc) == str:
            base = f"The machine's balance is {final_balance}"
        else:
            base = f"{' '.join(random.sample(use_desc, k=4))} {final_balance}"
        exemplar += f"Our final balance is {balance}. {base}. the answer is then {lab}"
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})

    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})

    return prompt


def get_prompt_for_vending_machine_with_sum(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    '''
    Turn vending machine into an arithmetic problem.
    Note: we ignore the label on this one and precompute the correct label.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    '''
    vals = {"biscuit": 20,
            "coffee": 15,
            "soda": 25}

    def get_label_for(en):
        balance = 0
        for e in en.split(","):
            if e == "":
                continue
            if e[0] == "+":
                balance += int(e[1:])
            else:
                balance -= vals[e]
        return str(balance)

    def build_cot_exemplar(ex, lab):
        exemplar = f"{ex}: \n"
        balance = 0
        balances = []
        exemplar += "Let's think and solve this step-by-step. We start with a balance of 0.\n"
        for e in ex.split(","):
            if e == "":
                continue
            if type(use_desc) == str:
                exemplar += f"We read \"{e}\","
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=2))} \"{e}\","
            if e[0] == "+":
                balance += int(e[1:])
                balances.append(balance)
                if type(use_desc) == str:
                    exemplar += f" so we add {e[1:]} to our current balance and we now have {balance}.\n"
                else:
                    exemplar += f" {' '.join(random.sample(use_desc, k=3))} {e[1:]} {' '.join(random.sample(use_desc, k=8))} {balance}.\n"
            else:
                if type(use_desc) == str:
                    exemplar += f" so we return a {e} and substract {vals[e]} from our balance and now we have {balance - vals[e]}.\n"
                else:
                    exemplar += f" {' '.join(random.sample(use_desc, k=3))} {e} {' '.join(random.sample(use_desc, k=2))} {vals[e]} {' '.join(random.sample(use_desc, k=6))} {balance - vals[e]}.\n"
                if balance - vals[e] < 0:
                    balance = balance - vals[e]
                else:
                    balance = balance - vals[e]

        exemplar += f"Our final balance is {balance}. The answer is then {balance}"
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        entry = corpus_id[i]["Entry"]
        entry = ",".join(corpus_id[i]["Entry"].split(",")[:-1])
        prompt.append({"role": "user", "content": f"{entry}:"})
        ex = get_label_for(entry)
        if is_cot:
            ex = build_cot_exemplar(entry, ex)
        prompt.append({"role": "assistant", "content": ex})

    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt


## Modus Ponens

In [None]:
problem = "vending_machine"
suff = "modus_ponens_vending_machine"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

suff = "modus_ponens_vending_machine_shuffled"
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

In [None]:
problem = "vending_machine_with_sum"
suff = "modus_ponens_vending_machine_sum"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

problem = "vending_machine_with_sum"
suff = "modus_ponens_vending_machine_sum_shuffled"
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
parity_desc = "This is a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee). The last number is always the remaining balance.\n"
parity_desc += "Your job is to output 1 if the remaining balance is correct given the sequence, or 0 if it is not.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single label (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit,0:1\nThe sequence is correct.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

problem = "vending_machine"
suff = "w_desc_vending_machine"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

problem = "vending_machine"
suff = "w_desc_vending_machine_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
parity_desc = "You are a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee).\n"
parity_desc += "Your job is to output the remaining balance given the sequence.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit:0\nWe had a balance of 20, and biscuits are worth 20, so we are left with zero.\n\n"

parity_desc += "Given the data below, determine what is the most likely balance for the given string and output ONLY the balance.\n"
parity_desc += "Data:\n\n"

problem = "vending_machine_with_sum"
suff = "w_desc_vending_machine_sum"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

problem = "vending_machine_with_sum"
suff = "w_desc_vending_machine_sum_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc
problem = "vending_machine"

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc
problem = "vending_machine_with_sum"
suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## CoT

In [None]:
parity_desc = "This is a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee). The last number is always the remaining balance.\n"
parity_desc += "Your job is to output 1 if the remaining balance is correct given the sequence, or 0 if it is not.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single label (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit,0:1\nThe sequence is correct.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

problem = "vending_machine"

llm.update_params({tkey: cot_tokens})


In [None]:
suff = f"CoT_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suff = f"CoT_{problem}_shuffled" #708

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100], #965
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = "You are a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee).\n"
parity_desc += "Your job is to output the remaining balance given the sequence.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit:0\nWe had a balance of 20, and biscuits are worth 20, so we are left with zero.\n\n"

parity_desc += "Given the data below, determine what is the most likely balance for the given string and output ONLY the balance.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: cot_tokens})

problem = "vending_machine_with_sum"

In [None]:
suff = f"CoT_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) #346

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)
#177

In [None]:
suff = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) #346

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)
#891

## SoT

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc
problem = "vending_machine"

llm.update_params({tkey: cot_tokens})

In [None]:
suff = f"SoT_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20 , 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suff = f"SoT_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

llm.update_params({tkey: cot_tokens})

problem = "vending_machine_with_sum"

In [None]:
suff = f"SoT_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) #346

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)
#891

In [None]:
suff = f"SoT_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True) #346

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)


## Automata

In [None]:
ecm_str_desc = "def vending_machine(P, is_neg=False, max_length = 24_000, debug=False):\n    \"\"\"\n    Automaton generating strings for the vending machine.\n    One coffee costs 15. One soda costs 25. One biscuit costs 20.\n    You have 5 and 10 coins.\n    This ends with a sequence of items and change (press \"change\" to end)\n    This vending machine continues the transactions until someone clicks \"end\".\n    \"\"\"\n    # Transition matrix for the automaton\n    T = [[None, \"+10\", None,        \"+5\", None, None,      \"change\"], # S0 = 0\n         [None, None, \"+10\",        None, \"+5\", None,      \"change\"], # S1 = 10\n         [\"biscuit\", None, \"coffee\", None, None, \"+5\",      \"change\"], # S2 = 20\n         [None, None, None,         None, \"+10\", None,     \"change\"], # S3 = 5\n         [\"coffee\", None,  None,    None, None, \"+10\",     \"change\"], # S4 = 15\n         [\"soda\", \"coffee\", None,   \"biscuit\", None, None,  \"change\"], # S5 = 25\n         [None, None, None,         None, None, None,      \"change\"], # Sf = change\n        ]\n\n    state = \"S0\"\n    end_states = [\"S0\"] if is_neg else [\"S6\"]\n    tape = \"\" # Start with the empty tape \n    steps = 0\n\n    # Transitions _to_ the state in question. \n    # - weights are handled by P\n    arr = [0, 1, 2, 3, 4, 5, \"f\"] \n    terminate = False\n\n    while True:\n        ix = int(state[-1])\n        if not terminate:\n            x = random.choices(arr, weights=P[ix], k=1)[0]\n        else:\n            x = \"f\"\n        if x == \"f\":\n            break\n\n        next_state = \"S\" + str(x)\n        if debug:\n            print(f\"from S{ix} to {next_state}\")\n        write_to = T[ix][x]\n        if \",\" in write_to:\n            write_to = random.choice(write_to.split(\",\"))\n        tape += write_to + \",\"\n        if debug:\n            print(tape)\n        state = next_state\n        \n        steps += 1\n        if max_length is not None:\n            if steps >= max_length:\n                terminate = True\n    if tape == \"\":\n        return None\n    return tape"

parity_desc = "This is a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee). The last number is always the remaining balance.\n"
parity_desc += "Your job is to output 1 if the remaining balance is correct given the sequence, or 0 if it is not.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc += "Here's the code for the vending machine:\n\n"
parity_desc += ecm_str_desc
parity_desc += "\n\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single label (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit,0:1\nThe sequence is correct.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: tokens})
suff = "w_automaton_vending_machine"
problem = "vending_machine"

In [None]:
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
suff = "w_automaton_vending_machine_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
ecm_str_desc = "def vending_machine(P, is_neg=False, max_length = 24_000, debug=False):\n    \"\"\"\n    Automaton generating strings for the vending machine.\n    One coffee costs 15. One soda costs 25. One biscuit costs 20.\n    You have 5 and 10 coins.\n    This ends with a sequence of items and change (press \"change\" to end)\n    This vending machine continues the transactions until someone clicks \"end\".\n    \"\"\"\n    # Transition matrix for the automaton\n    T = [[None, \"+10\", None,        \"+5\", None, None,      \"change\"], # S0 = 0\n         [None, None, \"+10\",        None, \"+5\", None,      \"change\"], # S1 = 10\n         [\"biscuit\", None, \"coffee\", None, None, \"+5\",      \"change\"], # S2 = 20\n         [None, None, None,         None, \"+10\", None,     \"change\"], # S3 = 5\n         [\"coffee\", None,  None,    None, None, \"+10\",     \"change\"], # S4 = 15\n         [\"soda\", \"coffee\", None,   \"biscuit\", None, None,  \"change\"], # S5 = 25\n         [None, None, None,         None, None, None,      \"change\"], # Sf = change\n        ]\n\n    state = \"S0\"\n    end_states = [\"S0\"] if is_neg else [\"S6\"]\n    tape = \"\" # Start with the empty tape \n    steps = 0\n\n    # Transitions _to_ the state in question. \n    # - weights are handled by P\n    arr = [0, 1, 2, 3, 4, 5, \"f\"] \n    terminate = False\n\n    while True:\n        ix = int(state[-1])\n        if not terminate:\n            x = random.choices(arr, weights=P[ix], k=1)[0]\n        else:\n            x = \"f\"\n        if x == \"f\":\n            break\n\n        next_state = \"S\" + str(x)\n        if debug:\n            print(f\"from S{ix} to {next_state}\")\n        write_to = T[ix][x]\n        if \",\" in write_to:\n            write_to = random.choice(write_to.split(\",\"))\n        tape += write_to + \",\"\n        if debug:\n            print(tape)\n        state = next_state\n        \n        steps += 1\n        if max_length is not None:\n            if steps >= max_length:\n                terminate = True\n    if tape == \"\":\n        return None\n    return tape"

parity_desc = "You are a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee).\n"
parity_desc += "Your job is to output the remaining balance given the sequence.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"

parity_desc += "Here's the code for the vending machine:\n\n"
parity_desc += ecm_str_desc
parity_desc += "\n\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n+10,+10,biscuit:0\nWe had a balance of 20, and biscuits are worth 20, so we are left with zero.\n\n"

parity_desc += "Given the data below, determine what is the most likely balance for the given string and output ONLY the balance.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: tokens})
suff = "w_automaton_vending_machine_with_sum"

In [None]:
problem = "vending_machine_with_sum"
suff = "w_automaton_vending_machine_with_sum"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, 
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

In [None]:
problem = "vending_machine_with_sum"
suff = "w_automaton_vending_machine_with_sum_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## APO

In [None]:
parity_desc = "This is a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee). The last number is always the remaining balance.\n"
parity_desc += "Your job is to output 1 if the remaining balance is correct given the sequence, or 0 if it is not.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"
parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

problem = "vending_machine"
suff = f"apo_{problem}"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))


In [None]:
problem = "vending_machine"
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
llm.update_params({tkey: tokens})

# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"apo_{problem}_shuffled"
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
parity_desc = "You are a vending machine. You are given a sequence of additions of balance (+10, +5, etc) or a selection (soda, biscuit, or coffee).\n"
parity_desc += "Your job is to output the remaining balance given the sequence.\n"
parity_desc += "Each soda is worth 25. Each biscuit is 20. Each coffee is 15. When someone selects a soda, biscuit, or coffee, the value of the item is subtracted from the balance.\n"
parity_desc += "Given the data below, determine what is the most likely balance for the given string and output ONLY the balance.\n"
parity_desc += "Data:\n\n"

problem = "vending_machine_with_sum"
suff = f"apo_{problem}"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
llm.update_params({tkey: tokens})

# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"apo_{problem}_shuffled"
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Hamiltonian

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/Hamiltonian/corpus_hamiltonian_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/Hamiltonian/corpus_hamiltonian_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/Hamiltonian/corpus_hamiltonian_ood_4k_test.json", "r", encoding="utf-8").readlines()]

out_dir_root = f"{root_out}/Hamiltonian/"
problem = "hamiltonian"

sample_graph = {'0': ['2', '3'], '1': ['0', '2', '3'], '2': ['0', '3'], '3': ['0']}

parity_desc = "You are a theorem verifier. You are given a directed graph in adjacency matrix form, and a path.\n"
parity_desc += "Your job is to output whether the path given is a valid Hamiltonian path (1) or is not (0) for the given graph.\n"
parity_desc += "The root will always be the vertex labelled as \"0\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single value (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:\n{sample_graph} | path: 0,2,3,1 : 0\nThis path is not a Hamiltonian path because there is no path to vertex 1 from 3.\n\n"

parity_desc += "Given the data below, determine whether the given path is a Hamiltonian path for the directed graph (0 or 1) and output ONLY the label.\n"
parity_desc += "Data:\n\n"


def get_prompt_for_hamiltonian(point, path=None, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """

    def build_cot_exemplar(graph, path, lab):
        
        exemplar = f"{graph} | path: {path} : \n"
        exemplar += "Let's think and solve this step-by-step. We start at vertex 0.\n"
        path = path.split(",")
        _path = set(path)
        _graph = set([k for k in graph.keys()])
        if type(use_desc) == str:
            exemplar += f"There are {len(path)} vertices in the path, and {len(graph)} vertices in the graph.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=2))} {len(path)}"
            exemplar += f" {' '.join(random.sample(use_desc, k=5))} {len(graph)} {' '.join(random.sample(use_desc, k=4))}\n"

        if len(path) != len(graph):
            lenn = "shorter" if len(path) < len(graph) else "longer"
            if type(use_desc) == str:
                exemplar += f"The path is {lenn}, so it can't be a Hamiltonian path.\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=3))} {lenn} {' '.join(random.sample(use_desc, k=7))}."                
            exemplar += f"The answer is then {lab}"
            return exemplar

        if type(use_desc) == str:
            exemplar += "Since the path and graph match, we now traverse it.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=10))}.\n"

        seen = {}
        for p_index in range(1, len(path) + 1):
            vertex = path[p_index - 1]
            if type(use_desc) == str:
                exemplar += f"We are at vertex {vertex} in the path. "
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=4))} {vertex} {' '.join(random.sample(use_desc, k=3))}.\n"
            v = graph[str(vertex)]

            p = path[p_index] if p_index < len(path) else "the last vertex"
            if type(use_desc) == str:
                exemplar += f"This vertex has children \"{v}\", and the next vertex in the path is {p}. "
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=4))} \"{v}\", {' '.join(random.sample(use_desc, k=7))} {p}.\n"

            if p != "the last vertex":
                if p not in graph:
                    if type(use_desc) == str:
                        exemplar += f"The next vertex in the path, {p}, is not in the graph, so this is not a valid path.\nThe answer is then {lab}."
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=6))}, {p}, {' '.join(random.sample(use_desc, k=12))}.\n"
                        exemplar += f"The answer is then {lab}."
                    return exemplar
                if p in v:
                    if type(use_desc) == str:
                        exemplar += f"The vertex is in the children set. "
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=7))}.\n"
                else:
                    if type(use_desc) == str:
                        exemplar += f"The vertex is not in the children set, so this is not a valid path.\nThe answer is then {lab}."
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=8))}, {' '.join(random.sample(use_desc, k=7))}.\n"
                        exemplar += f"The answer is then {lab}."
                    return exemplar
                if p in seen:
                    if type(use_desc) == str:
                        exemplar += f"But we already saw {p}, so this can't be a Hamiltonian path.\nThe answer is then {lab}."
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=4))} {p}, {' '.join(random.sample(use_desc, k=7))}.\n"
                        exemplar += f"The answer is then {lab}."
                    return exemplar
                else:
                    if type(use_desc) == str:
                        exemplar += f"Since we haven't seen it, we continue.\n"
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=5))}, {' '.join(random.sample(use_desc, k=2))}.\n"
                    seen[p] = 1

        if type(use_desc) == str:
            exemplar += f"We traversed all nodes without repetition.\nThe answer is then {lab}."
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=6))}.\n"
            exemplar += f"The answer is then {lab}."
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(max_exemplars):
        prompt.append({"role": "user", "content": f"{corpus_id[i]["Entry"]} | path: {corpus_id[i]["Path"]}:"})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], corpus_id[i]["Path"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point} | path: {path} :"
    prompt.append({"role": "user", "content": user_str})
    return prompt


## Modus Ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root)

suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
suff = f"w_desc_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"w_desc_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False) #866

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## CoT

In [None]:
parity_desc = "You are a theorem verifier. You are given a directed graph in adjacency matrix form, and a path.\n"
parity_desc += "Your job is to output whether the path given is a valid Hamiltonian path (1) or is not (0) for the given graph.\n"
parity_desc += "The root will always be the vertex labelled as \"0\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single value (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:\n{sample_graph} | path: 0,2,3,1 : 0\nThis path is not a Hamiltonian path because there is no path to vertex 1 from 3.\n\n"

parity_desc += "Given the data below, determine whether the given path is a Hamiltonian path for the directed graph (0 or 1) and output ONLY the label.\n"
parity_desc += "Data:\n\n"

llm.update_params({tkey: cot_tokens})

In [None]:
suff = f"CoT_{problem}"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)


In [None]:
suff = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)


## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True) #689

## Automaton Encoded

In [None]:
ecm_str_desc = "def is_acyclic(graph):\n    in_degree = [0 for _ in graph.keys()]\n    q = []\n    for u, edges in graph.items():\n        for v in edges:\n            in_degree[int(v)] += 1\n    for u in graph.keys():\n        if in_degree[int(u)] == 0:\n            q.append(u)\n    \# BFS traversal\n    visited = 0\n    while q:\n        u = q.pop(0)\n        visited += 1\n        for v in graph[u]:\n            in_degree[int(v)] -= 1\n            if in_degree[int(v)] == 0:\n                q.append(v)\n\n    return visited == len(graph.keys())\n\n\ndef random_graph(n_vertices, weights=None, fully_connected=False, acyclic=True):\n    graph = []\n    vertices = {str(i): [] for i in range(n_vertices)}\n    these_entities = random.choices(entities, k=n_vertices)\n\n    for u in vertices.keys():\n        if acyclic:\n            target_sources = [v for v in vertices.keys() if u != v and u not in vertices[v]]\n        else:\n            target_sources = [v for v in vertices.keys() if u != v]            \n        if not fully_connected:\n            if weights is None:\n                target_sources = [v for v in target_sources if random.randint(0, 1) == 1]\n            else:\n                target_sources = [v for v in target_sources if random.uniform(0, 1) < weights[int(v)]]\n        vertices[u] = target_sources\n\n    if not is_acyclic(vertices):\n        if acyclic:\n            return None\n    \n    return vertices\n\n\ndef find_hamiltonian_cycle(graph):\n\n    def is_safe_to_add(v, pos, path): \n        if str(v) not in graph[str(path[pos-1])]: \n            return False\n        return v not in path \# if v in path return false\n\n    def hamiltonian_cycle_util(path, pos): \n        if pos == len(graph): \n            if str(path[0]) in graph[str(path[pos-1])][path[0]]: \n                return True\n            else:\n                return False\n        for v in range(1, len(graph)): \n            if is_safe_to_add(v, pos, path): \n                path[pos] = v \n                if hamiltonian_cycle_util(path, pos+1):\n                    return True\n                path[pos] = -1\n        return False\n\n    path = [-1 for _ in graph.keys()]\n    path[0] = 0\n    if not hamiltonian_cycle_util(path, 1): \n        return None\n    return [{\"Index\": i, \"Vertex\": p} for i, p in enumerate(path)]\n"

parity_desc = "You are a theorem verifier. You are given a directed graph in adjacency matrix form, and a path, both generated by the code below.\n"
parity_desc += "Your job is to output whether the path given is a valid Hamiltonian path (1) or is not (0) for the given graph.\n"
parity_desc += "The root will always be the vertex labelled as \"0\".\n"
parity_desc += "In the limit where the automaton is deterministic, if the number of zeros in the input string is even, the label is always 1. Else, it is 0.\n"

parity_desc += "Here's the code:\n\n"
parity_desc += ecm_str_desc
parity_desc += "\n\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single value (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:\n{sample_graph} | path: 0,2,3,1 : 0\nThis path is not a Hamiltonian path because there is no path to vertex 1 from 3.\n\n"

parity_desc += "Given the data below, determine whether the given path is a Hamiltonian path for the directed graph (0 or 1) and output ONLY the label.\n"
parity_desc += "Data:\n\n"


suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
parity_desc = "You are a theorem verifier. You are given a directed graph in adjacency matrix form, and a path.\n"
parity_desc += "Your job is to output whether the path given is a valid Hamiltonian path (1) or is not (0) for the given graph.\n"
parity_desc += "The root will always be the vertex labelled as \"0\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single value (0 or 1), and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:\n{sample_graph} | path: 0,2,3,1 : 0\nThis path is not a Hamiltonian path because there is no path to vertex 1 from 3.\n\n"

parity_desc += "Given the data below, determine whether the given path is a Hamiltonian path for the directed graph (0 or 1) and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
suff = f"apo_{problem}_shuffled"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Stack

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/Stack/corpus_stack_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/Stack/corpus_stack_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/Stack/corpus_stack_ood_4k_test.json", "r", encoding="utf-8").readlines()]

problem = "stack"
out_dir_root = f"{root_out}/Stack/"


def get_prompt_for_stack(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    def build_cot_exemplar(ex, lab):

        initial, sequence, final = ex.split("\n")
        is_correct = str(lab) == "1" # Need this for noise
        exemplar = "Let's think and solve this step-by-step. "
        this_string = initial[:]
        for state in sequence.split(","):
            _this_string = this_string if this_string != "" else "empty"
            if type(use_desc) == str:
                exemplar += f"Our stack is {_this_string}.\nWe read: \"{state}\". "
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} {_this_string}.\n{' '.join(random.sample(use_desc, k=1))}: \"{state}\". "
            if state == "pop":
                if this_string == "":
                    if type(use_desc) == str:
                        exemplar += f"We have no elements in our string, so we ignore it.\n"
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=1))}\n"
                else:
                    _this_string = this_string[:-1]
                    if this_string[:-1] == "":
                        _this_string = "empty"
                    if type(use_desc) == str:
                        exemplar += f"We pop \"{this_string[-1]}\" and our new stack is {_this_string}.\n"
                    else:
                        exemplar += f"{' '.join(random.sample(use_desc, k=1))} \"{this_string[-1]}\" {' '.join(random.sample(use_desc, k=1))} {_this_string}.\n"
                    this_string = this_string[:-1]
            if state == "push":
                if type(use_desc) == str:
                    exemplar += "We get ready to push to the stack.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))}\n"
                continue
            if state in ["0", "1"]:
                this_string = this_string + state
                if type(use_desc) == str:
                    exemplar += f"We push \"{state}\" to the stack and our new stack is {this_string}.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))} \"{state}\" {' '.join(random.sample(use_desc, k=1))} {this_string}.\n"
            if state == "stop":
                if type(use_desc) == str:
                    exemplar += f"We terminate.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))}\n"
                break

        if this_string == "":
            this_string = "empty"

        if type(use_desc) == str:
            exemplar += f"Our final stack is {this_string} and the solution says {final}.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=1))} {this_string} {' '.join(random.sample(use_desc, k=1))} {final}.\n"
        exemplar += "So the answer is {}".format(lab)        
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt


## Modus ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
parity_desc = "This is a stack simulator. You will be given (in three lines) an initial state, a sequence of operations, and a final state.\n"
parity_desc += "Your job is to determine whether the final state is correct given the initial state and a sequence of operations.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Pop operations on an empty stack are ignored.\n"
parity_desc += "Push is always followed by the symbol that is pushed.\nThe only allowable symbols are 0 and 1, and the only allowable operations are push, pop, and stop.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000\npush,1,pop,stop,\n000: 0\nThe label is correct because pushing and popping the same element returns the original state, which matches the final state.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given initial state, sequence of operations, and a final state; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#92

## CoT

In [None]:
parity_desc = "This is a stack simulator. You will be given (in three lines) an initial state, a sequence of operations, and a final state.\n"
parity_desc += "Your job is to determine whether the final state is correct given the initial state and a sequence of operations.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Pop operations on an empty stack are ignored.\n"
parity_desc += "Push is always followed by the symbol that is pushed.\nThe only allowable symbols are 0 and 1, and the only allowable operations are push, pop, and stop.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000\npush,1,pop,stop,\n000: 0\nThe label is correct because pushing and popping the same element returns the original state, which matches the final state.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given initial state, sequence of operations, and a final state; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"CoT_{problem}"
llm.update_params({tkey: cot_tokens})

In [None]:
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Automata encoded

In [None]:
ecm_str_desc = "def pushpop(k, min_seq=None, this_tape=None, probs = [0.25, 0.25, 0.25, 0.5]):\n    \"\"\"\n    Probs: [ Pr[push], Pr[pop], Pr[stop], Pr[0] if push ]\n    So Pr[0] = 1 - Pr[1].\n    \"\"\"\n    initial_string = \"\".join([str(j) for j in random.choices([0, 1], k = k)])\n    if this_tape is not None:\n        initial_string = this_tape\n    _probs = probs[:3]\n    state_sequences = []\n    is_stop = False\n    count = 0\n    while not is_stop:\n        next_state = random.choices([\"push\", \"pop\", \"stop\"], weights=_probs, k=1)[0]\n        count += 1\n        if next_state != \"stop\":\n            state_sequences.append(next_state)\n            if next_state == \"push\":\n                state_sequences.append(random.choices([\"0\", \"1\"], \n                                                      weights=[probs[-1], 1 - probs[-1]], \n                                                      k=1)[0])\n        else:\n            if min_seq is None or min_seq <= count:\n                is_stop = True\n                state_sequences.append(next_state)\n                break\n\n    this_string = initial_string[:]\n    for i in range(len(state_sequences)):\n        state = state_sequences[i]\n        if state == \"pop\":\n            if this_string == \"\":\n                continue\n            if len(this_string) == 1:\n                this_string = \"\"\n            else:\n                this_string = this_string[:-1]\n        if state == \"push\":\n            continue\n        if state in [\"0\", \"1\"]:\n            this_string = this_string + state\n        if state == \"stop\":\n            break\n\n    if this_string == \"\":\n        this_string = \"empty\"\n    return initial_string, state_sequences, this_string\n"

parity_desc = "This is a stack push-pop simulator. You will be given (in three lines) an initial state, a sequence of operations, and a final state.\n"
parity_desc += "Your job is to determine whether the final state is correct given the initial state and a sequence of operations.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "The stack is simulated by the code shown below.\n"
parity_desc += "Pop operations on an empty stack are ignored.\n"
parity_desc += "Push is always followed by the symbol that is pushed.\nThe only allowable symbols are 0 and 1, and the only allowable operations are push, pop, and stop.\n"

parity_desc += f"Here's the code:\n{ecm_str_desc}\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000\npush,1,pop,stop,\n000: 0\nThe label is correct because pushing and popping the same element returns the original state, which matches the final state.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given initial state, sequence of operations, and a final state; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
parity_desc = "This is a stack push-pop simulator. You will be given (in three lines) an initial state, a sequence of operations, and a final state.\n"
parity_desc += "Your job is to determine whether the final state is correct given the initial state and a sequence of operations.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Pop operations on an empty stack are ignored.\n"
parity_desc += "Push is always followed by the symbol that is pushed.\nThe only allowable symbols are 0 and 1, and the only allowable operations are push, pop, and stop.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\n000\npush,1,pop,stop,\n000: 0\nThe label is correct because pushing and popping the same element returns the original state, which matches the final state.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given initial state, sequence of operations, and a final state; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
suff = f"apo_{problem}_shuffled"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Maze Complete

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/MazeComplete/corpus_maze_complete_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/MazeComplete/corpus_maze_complete_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/MazeComplete/corpus_maze_complete_ood_4k_test.json", "r", encoding="utf-8").readlines()]

problem = "maze_complete"
out_dir_root = f"{root_out}/MazeComplete/"


def get_neighbours(split_string, i, j, c = " "):
    """
    Neighbours in a Moore neighbourhood of i,j matching character c
    """
    neighbours, neighbour_positions = [], []
    for n, (a, b) in [("up", (i - 1, j)), ("down", (i + 1, j)), ("left", (i, j - 1)), ("right", (i, j + 1))]:
        if split_string[a][b] == c:
            neighbours.append((a, b))
            neighbour_positions.append(n)
    return {"Neighbours": neighbours, "Position": neighbour_positions}


def get_prompt_for_maze_complete(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    def build_cot_exemplar(_ex, lab):
        ex = _ex.replace("Solved maze:", "").replace("Missing moves:", "").strip()
        maze_split_string = ex.split("\n")
        answer = maze_split_string[-1]
        maze_split_string = maze_split_string[:-1]
        line_question = [(ix, s) for ix, s in enumerate(maze_split_string) if "?" in s][0]

        exemplar = f"Let's think and solve this step-by-step.\nWe begin at line 0."

        line_in_maze_ix, line_question_ix = None, None
        for j, maze_line in enumerate(maze_split_string):
            contains = "?" in maze_line
            if type(use_desc) == str:
                exemplar += f"This line {'contains' if contains else 'does not contain'} \"?\".\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} {'contains' if contains else 'does not contain'} \"?\".\n"
            if contains:
                line_in_maze_ix = j
                line_question_ix = maze_line.find("?")
                if type(use_desc) == str:
                    exemplar += f"The \"?\" character is at position {line_question_ix} in the line. We will now perform a search on the neighbours to find the path.\n"
                else:
                    exemplar += f"{random.choice(use_desc)} \"?\" {' '.join(random.sample(use_desc, k=1))} {line_in_maze_ix} {' '.join(random.sample(use_desc, k=1))}.\n"
                break
            else:
                if type(use_desc) == str:
                    exemplar += f"We move on then to line {j + 1}.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=6))} {j + 1}.\n"

        # This is a very lazy DFS algorithm but since we only cover three steps, it is easier this way.
        question_mark_neighbours = get_neighbours(maze_split_string, line_in_maze_ix, line_question_ix, c=" ")
        found = False
        buffer = []
        if type(use_desc) == str:
            exemplar += f"This has neighbours: {question_mark_neighbours['Position']} at {question_mark_neighbours['Neighbours']}.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=1))}: {question_mark_neighbours['Position']} {random.choice(use_desc)} {question_mark_neighbours['Neighbours']}.\n"
        for neighbours, positions in zip(question_mark_neighbours['Neighbours'], question_mark_neighbours['Position']):
            if found: break
            ix, iy = neighbours
            next_neighbours = get_neighbours(maze_split_string, ix, iy, c=" ")
            buffer = [positions]
            if type(use_desc) == str:
                exemplar += f"We select the neighbour at {neighbours} (\"{positions}\") and add it to our buffer. Our buffer is: {buffer}.\n"
                exemplar += f"This has neighbours: {next_neighbours['Position']} at {next_neighbours['Neighbours']}.\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} (\"{positions}\") {' '.join(random.sample(use_desc, k=1))}: {buffer}.\n"
                exemplar += f"{' '.join(random.sample(use_desc, k=1))}: {next_neighbours['Position']} {random.choice(use_desc)} {next_neighbours['Neighbours']}.\n"
            for _neighbours, _positions in zip(next_neighbours['Neighbours'], next_neighbours['Position']):
                if found: break
                buffer = [positions, _positions]
                if type(use_desc) == str:
                    exemplar += f"\tWe select the neighbour at {_neighbours} (\"{_positions}\") and add it to our buffer. Our buffer is: {buffer}.\n"
                else:
                    exemplar += f"\t{' '.join(random.sample(use_desc, k=1))} {_neighbours} (\"{_positions}\") {' '.join(random.sample(use_desc, k=2))}: {buffer}.\n"
                jx, jy = _neighbours
                # Check if this one neighbours/connects the path.
                last_neighbours = get_neighbours(maze_split_string, jx, jy, c="+")
                if type(use_desc) == str:
                    exemplar += f"\tThis one has the following available neighbours connecting to the path: {last_neighbours['Position']} at {last_neighbours['Neighbours']}.\n"
                else:
                    exemplar += f"\t{' '.join(random.sample(use_desc, k=1))}: {last_neighbours['Position']} {random.choice(use_desc)} {last_neighbours['Neighbours']}.\n"
                if last_neighbours['Neighbours'] != []:
                    plus_coordinates = last_neighbours["Neighbours"][0]
                    plus_direction = last_neighbours["Position"][0]
                    buffer.append(plus_direction)
                    if type(use_desc) == str:
                        exemplar += f"\t\tThis has a \"+\" neighbour at {plus_coordinates} (\"{plus_direction}\"), so it connects to the path.\n"
                        exemplar += f"\t\tWe add it to our buffer. Our buffer is now {buffer}.\n"
                    else:
                        exemplar += f"\t\t{' '.join(random.sample(use_desc, k=1))} \"+\" {' '.join(random.sample(use_desc, k=1))} {plus_coordinates} (\"{plus_direction}\"), {' '.join(random.sample(use_desc, k=1))}.\n"
                        exemplar += f"\t\t{' '.join(random.sample(use_desc, k=1))} {buffer}.\n"
                    found = True
                    break
                else:
                    if type(use_desc) == str:
                        exemplar += "\t\tIt does not connect to the path, so we remove it from our buffer.\n"
                    else:
                        exemplar += f"\t\t{' '.join(random.sample(use_desc, k=1))}.\n"
        if type(use_desc) == str:
            exemplar += f"We are done!\nOur final set of positions is {','.join(buffer)} and the solution says {answer}.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=1))}!\n{' '.join(random.sample(use_desc, k=1))} {','.join(buffer)} {' '.join(random.sample(use_desc, k=1))} {answer}.\n"
        exemplar += f"So the answer is {lab}"
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt


## Modus ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nMissing moves:\ndown,right,down"

parity_desc = "You are helping me complete a maze. You will be given a maze almost solved, and sequence of moves to finish solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to solving the maze solved.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "The path you must complete is denoted by uninterrupted \"+\", and your completion starts at \"?\". Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"
parity_desc += "The first move you must verify is the one connecting the path to \"?\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe label is 1 because \"?\" is above and to the left to the last \"+\" from the path, so moving down,right,down is the right move to connect the path.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#92

## CoT

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nMissing moves:\ndown,right,down"

parity_desc = "You are helping me complete a maze. You will be given a maze almost solved, and sequence of moves to finish solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to solving the maze solved.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "The path you must complete is denoted by uninterrupted \"+\", and your completion starts at \"?\". Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"
parity_desc += "The first move you must verify is the one connecting the path to \"?\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe label is 1 because \"?\" is above and to the left to the last \"+\" from the path, so moving down,right,down is the right move to connect the path.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"CoT_{problem}"
llm.update_params({tkey: cot_tokens})

In [None]:
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True) #190

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})
# Last one for this problem
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Automata encoded

In [None]:
ecm_str_desc = "def get_neighbours(split_string, i, j):\n    \"\"\"\n    Neighbours in a Moore neighbourhood of i,j\n    \"\"\"\n    neighbours = []\n    for a, b in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]:\n        if split_string[a][b] == \"+\":\n            neighbours.append((a, b))\n    return neighbours\n\n\ndef add_noise(maze, solved_maze_string, ceilings, debug=False):\n    split_string = solved_maze_string.split(\"\n\")\n    \n    c = Counter()\n    solns = maze.solutions[0]\n    positions = {r[0]: [] for r in solns}\n    for r in solns:\n        row, position = r\n        c[row] += 1\n        positions[row].append(position)\n    \n    ins = lambda s, i: s[:i] + \" \" + s[i + 1:]\n    ins2 = lambda s, i: s[:i] + \"?\" + s[i + 1:]\n    \n    ix = random.choice([k for k in positions.keys() if k > ceilings[0] and k < ceilings[1]])\n    point = random.choice(positions[ix])\n    new_split_string = [s for s in split_string]\n    answer = \"\"\n    \n    neighbours = get_neighbours(split_string, ix, point)\n    i, j = random.choice(neighbours)\n\n    iy, jy = 0, 0\n    if i == ix:\n        if j == point - 1:\n            new_split_string[i] = ins2(split_string[i], j)\n            new_split_string[ix] = ins(new_split_string[ix], point)\n            neighbours_2 = get_neighbours(new_split_string, ix, point)\n            iy, jy = ix, point\n        else: #j == point + 1:\n            new_split_string[i] = ins(split_string[i], j)\n            new_split_string[ix] = ins2(new_split_string[ix], point)\n            neighbours_2 = get_neighbours(new_split_string, i, j)\n            iy, jy = i, j\n        answer = \"right,right\"\n    elif i == ix - 1: #j == point (moore)\n            new_split_string[i] = ins2(split_string[i], j)\n            new_split_string[ix] = ins(new_split_string[ix], point)\n            neighbours_2 = get_neighbours(new_split_string, ix, point)\n            iy, jy = ix, point\n            answer = \"down,down\"\n    elif i == ix + 1: #j == point:\n            new_split_string[i] = ins(split_string[i], j)\n            new_split_string[ix] = ins2(new_split_string[ix], point)\n            neighbours_2 = get_neighbours(new_split_string, i, j)\n            iy, jy = i, j\n            answer = \"down,down\"\n\n    i, j  = random.choice(neighbours_2)\n\n    if i == iy:\n        if j == jy - 1:\n            new_split_string[i] = ins(new_split_string[i], j)\n            new_split_string[iy] = ins(new_split_string[iy], jy)\n        else: #j == point + 1:\n            new_split_string[i] = ins(new_split_string[i], j)\n            new_split_string[iy] = ins(new_split_string[iy], jy)\n        answer += \",right\"\n    elif i == iy - 1: #j == point (moore)\n            new_split_string[i] = ins(new_split_string[i], j)\n            new_split_string[iy] = ins(new_split_string[iy], jy)\n            answer += \",down\"\n    elif i == iy + 1: #j == point:\n            new_split_string[i] = ins(new_split_string[i], j)\n            new_split_string[iy] = ins(new_split_string[iy], jy)\n            answer += \",down\"\n    \n    final_string = \"\n\".join(new_split_string)\n    return final_string, answer\n\n"
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nMissing moves:\ndown,right,down"

parity_desc = "You are helping me complete a maze. You will be given a maze almost solved, and sequence of moves to finish solving it, along with code to determine what are the positions of the neighbours.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to solving the maze solved.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "The path you must complete is denoted by uninterrupted \"+\", and your completion starts at \"?\". Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"
parity_desc += "The first move you must verify is the one connecting the path to \"?\".\n"

parity_desc += f"Here's the code:\n{ecm_str_desc}\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe label is 1 because \"?\" is above and to the left to the last \"+\" from the path, so moving down,right,down is the right move to connect the path.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nMissing moves:\ndown,right,down"

parity_desc = "You are helping me complete a maze. You will be given a maze almost solved, and sequence of moves to finish solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to solving the maze solved.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "The path you must complete is denoted by uninterrupted \"+\", and your completion starts at \"?\". Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"
parity_desc += "The first move you must verify is the one connecting the path to \"?\".\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe label is 1 because \"?\" is above and to the left to the last \"+\" from the path, so moving down,right,down is the right move to connect the path.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
suff = f"apo_{problem}_shuffled"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Maze Solve

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/MazeSolve/corpus_maze_solve_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/MazeSolve/corpus_maze_solve_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/MazeSolve/corpus_maze_solve_ood_4k_test.json", "r", encoding="utf-8").readlines()]

problem = "maze_solve"
out_dir_root = f"{root_out}/MazeSolve/"


def get_neighbours(split_string, i, j, c = " "):
    """
    Neighbours in a Moore neighbourhood of i,j matching character c
    """
    neighbours, neighbour_positions = [], []
    def skip(m, n):
        if m  < 0 or n < 0:
            return True
        if m > len(split_string) - 1 or n > len(split_string[0]) - 1:
            return True
        return False

    for n, (a, b) in [("up", (i - 1, j)), ("down", (i + 1, j)), ("left", (i, j - 1)), ("right", (i, j + 1))]:
        if skip(a, b):
            continue
        if split_string[a][b] == c:
            neighbours.append((a, b))
            neighbour_positions.append(n)
    return {"Neighbours": neighbours, "Position": neighbour_positions}


def get_prompt_for_maze_solve(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    def build_cot_exemplar(_ex, lab):
        ex = _ex.replace("Solved maze:", "").replace("Solution:", "").strip()
        maze_split_string = ex.split("\n")
        answer = maze_split_string[-1].split(",")
        maze_split_string = maze_split_string[:-1]

        exemplar = f"Let's think and solve this step-by-step.\nWe begin under \"S\", at line 1, position (1, 1).\n"
        if type(use_desc) != str:
            exemplar = f"{' '.join(random.sample(use_desc, k=1))}\n{' '.join(random.sample(use_desc, k=1))} \"S\", {' '.join(random.sample(use_desc, k=1))} 1.\n"
            exemplar += f"{random.choice(use_desc)} (1, 1).\n"

        running_tally = ""
        i, j = (1, 1)
        state = "down"
        sequence = []
        for k, _state in enumerate(answer):
            is_exit = get_neighbours(maze_split_string, i, j, c="E")
            neighbours = get_neighbours(maze_split_string, i, j, c="+")

            if is_exit["Position"] != []:
                sequence.append(is_exit["Position"][0])
                if type(use_desc) == str:
                    exemplar += f"The exit is in the neighbours, at \"{is_exit['Position'][0]}\". We add it to our buffer and terminate."
                    exemplar += f" Our buffer is now: {sequence}.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))} \"{is_exit['Position'][0]}\". {' '.join(random.sample(use_desc, k=1))}."
                    exemplar += f" {' '.join(random.sample(use_desc, k=1))}: {sequence}.\n"
                break

            if type(use_desc) == str:
                exemplar += f"The {'next' if k != 0 else 'first'} move in the answer is \"{_state}\". The list of available neighbours for this move is {neighbours['Position']}.\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} \"{_state}\"." 
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} {neighbours['Position']}.\n"
            if _state in neighbours['Position']:
                sequence.append(_state)
                i, j = neighbours['Neighbours'][neighbours["Position"].index(_state)]
                if type(use_desc) == str:
                    exemplar += f"This move is available, so we add it to our buffer. Our buffer is now: {sequence}.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=2))}: {sequence}\n"
            else:
                if type(use_desc) == str:
                    exemplar += "This move is not available, so it cannot be the correct answer. We terminate.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))}.\n"
                break

        if type(use_desc) == str:
            exemplar += f"We are done!\nOur final set of moves is {','.join(sequence)} and the solution says {','.join(answer)}.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=1))}!\n"
            exemplar += f"{' '.join(random.sample(use_desc, k=1))} {','.join(sequence)} {' '.join(random.sample(use_desc, k=1))} {','.join(answer)}.\n"
        exemplar += f"So the answer is {lab}"
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt

## Modus ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nSolution:\ndown,down,down,down,right,down,right"

parity_desc = "You are helping verify a solution to a maze. You will be given a maze, and sequence of moves that claim solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to correctly solving the maze.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe solution is correct because following the moves will lead to E. There are only two right moves required, and the rest must be downs.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#92

## CoT

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nSolution:\ndown,down,down,down,right,down,right"

parity_desc = "You are helping verify a solution to a maze. You will be given a maze, and sequence of moves that claim solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to correctly solving the maze.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe solution is correct because following the moves will lead to E. There are only two right moves required, and the rest must be downs.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"CoT_{problem}"
llm.update_params({tkey: cot_tokens})

In [None]:
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Automata encoded

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nSolution:\ndown,down,down,down,right,down,right"
ecm_str = ""

parity_desc = "You are helping verify a solution to a maze. You will be given a maze, and sequence of moves that claim solving it, along with some DFS code that could help in solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to correctly solving the maze.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"

parity_desc += f"Here's the code:\n{ecm_str}\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe solution is correct because following the moves will lead to E. There are only two right moves required, and the rest must be downs.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
sample_maze = "Solved maze:\n#S##\n#+##\n#? #\n#  #\n# +#\n# +E#\n####\nSolution:\ndown,down,down,down,right,down,right"

parity_desc = "You are helping verify a solution to a maze. You will be given a maze, and sequence of moves that claim solving it.\n"
parity_desc += "Your job is to determine whether the moves are correct and will lead to correctly solving the maze.\n"
parity_desc += "You must always output 0 (incorrect) or 1 (correct).\n"
parity_desc += "Walls are denoted by \"#\", and the start and end are \"S\" and \"E\", respectively.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += f"For example:{sample_maze}\n1\nThe solution is correct because following the moves will lead to E. There are only two right moves required, and the rest must be downs.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given maze and moves; and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
suff = f"apo_{problem}_shuffled"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


# Reversal

## Shared functions

In [None]:
corpus_id = [json.loads(l) for l in open("datasets/Reversal/corpus_reversal_id_4k.json", "r", encoding="utf-8").readlines()]
corpus_id_test = [json.loads(l) for l in open("datasets/Reversal/corpus_reversal_id_4k_test.json", "r", encoding="utf-8").readlines()]
corpus_od_test_w_negs = [json.loads(l) for l in open("datasets/Reversal/corpus_reversal_ood_4k_test.json", "r", encoding="utf-8").readlines()]

problem = "reversal"
out_dir_root = f"{root_out}/Reversal/"

def split_by_token(s):
    if s[0] == "c":
        return "chtte"
    if s[0] == "g":
        return "gfx"
    if s[0] == "%":
        return "%"
    if s[0] == "l":
        return "ltintprk"
    if s[0] == "¯":
        return "¯\\_(ツ)_/¯"

def split_by_token_mirror(s):
    if s[0] == "e":
        return "chtte"
    if s[0] == "x":
        return "gfx"
    if s[0] == "%":
        return "%"
    if s[0] == "k":
        return "ltintprk"
    if s[0] == "¯":
        return "¯\\_(ツ)_/¯"


def get_prompt_for_reversal(point, use_desc="", max_exemplars=10, shuffle_exemplars=False, is_cot=False):
    """ Get the system and user prompts for this problem. 
    It takes as input a (raw) datapoint (so no labels), and some extra params to make it shiny.
    ---
    Params:
    point (str): your datapoint
    use_desc (str, ""): the system prompt / description of the task (optional)
    max_exemplars (int, 10): number of shots
    shuffle_exemplars (False): this shuffles exemplars positionally only. There's another notebook for the other bits.
    is_cot (False): do you need a full CoT exemplar?
    """
    def build_cot_exemplar(ex, lab):
        left, right = ex.split("#")
        is_even = str(lab) == "1" # Need this for noise
        exemplar = "Let's think and solve this step-by-step. "
        if type(use_desc) == str:
            exemplar += f"We have the strings LEFT: {left} and RIGHT: {right}\n"
            exemplar += f"We begin checking LEFT from the left and RIGHT from the right.\n"
        else:
            exemplar += f"{' '.join(random.sample(use_desc, k=1))} LEFT: {left} and RIGHT: {right}.\n"
            exemplar += f"{' '.join(random.sample(use_desc, k=1))}\n"
        ii = 0
        while ii < len(left):
            left_token = split_by_token(left[ii])
            right_token = split_by_token_mirror(right[len(right) - 1 - ii])
            if type(use_desc) == str:
                exemplar += f"We observe to the left: {left_token} and to the right: {right_token}.\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))} left: {left_token} {' '.join(random.sample(use_desc, k=1))} right: {right_token}.\n"
            if left_token == right_token:
                if type(use_desc) == str:
                    exemplar += f"The tokens are the same, so we move to the next.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))}.\n"
                ii += len(left_token)
            else:                
                if type(use_desc) == str:
                    exemplar += f"The tokens are the not same, so we have our answer.\n"
                else:
                    exemplar += f"{' '.join(random.sample(use_desc, k=1))}.\n"
                break

        if ii == len(left):
            if type(use_desc) == str:
                exemplar += f"We have reached the end of the string, and both strings are the same.\n"
            else:
                exemplar += f"{' '.join(random.sample(use_desc, k=1))}.\n"

        exemplar += "So the answer is {}".format(lab)
        return exemplar

    prompt = []
    base_prompt = use_desc if type(use_desc) == str else " ".join(random.sample(use_desc, k=random.randint(15, 40)))
    if base_prompt != "":
         prompt.append({"role": "system", "content": base_prompt})
    exemplars = corpus_id[:max_exemplars]
    if shuffle_exemplars:
        random.shuffle(exemplars)
    for i in range(len(exemplars)):
        prompt.append({"role": "user", "content": corpus_id[i]["Entry"] + ": "})
        ex = str(corpus_id[i]["Label"])
        if is_cot:
            ex = build_cot_exemplar(corpus_id[i]["Entry"], str(corpus_id[i]["Label"]))
        prompt.append({"role": "assistant", "content": ex})
    user_str = f"{point}: "
    prompt.append({"role": "user", "content": user_str})
    return prompt


## Modus ponens

In [None]:
suff = f"modus_ponens_{problem}"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=False)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=False)

In [None]:
suff = f"modus_ponens_{problem}_shuffled"

call_and_solve_for(problem=problem, suff=suff, dataset=corpus_id_test, out_dir_root=out_dir_root, shuffle_exemplars=True)
call_and_solve_for(problem=problem, suff=suff, dataset=corpus_od_test_w_negs, do_ood=True, out_dir_root=out_dir_root, shuffle_exemplars=True)

## With description

In [None]:
parity_desc = "This is a string detection task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each input is of the form LEFT#RIGHT. Each string is labelled 0 or 1 depending on whether the RIGHT pattern is (1) or is not (0) a reversal of LEFT.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\abc#cba:1\nThe right string is a reversal of the left string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suff = f"w_description_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

suff = f"w_description_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)

## Word Salad

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc

suff = f"word_salad_{problem}"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


suff = f"word_salad_{problem}_shuffled"

call_and_solve_for(problem, suff=suff, dataset=corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)
call_and_solve_for(problem, suff=suff, dataset=corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#10 0.85

## CoT

In [None]:
parity_desc = "This is a string detection task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each input is of the form LEFT#RIGHT. Each string is labelled 0 or 1 depending on whether the RIGHT pattern is (1) or is not (0) a reversal of LEFT.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\abc#cba:1\nThe right string is a reversal of the left string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"CoT_{problem}"
llm.update_params({tkey: cot_tokens})

In [None]:
call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
suffix = f"CoT_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Salad-of-Thought

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

In [None]:
parity_desc = list(set([l.strip().replace(".", " ") for l in open("datasets/words.txt", "r", encoding="utf-8").readlines()[0].split(" ")]))
parity_desc_zero_shot = parity_desc


suffix = f"SoT_{problem}_shuffled"
llm.update_params({tkey: cot_tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 5, 10, 20, 50, 100], #0, 
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=True)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=True)

## Automata encoded

In [None]:
ecm_str = "ALPHABET = [\"gfx\", \"chtte\", \"%\", \"ltintprk\", \"¯\\_(ツ)_/¯\"]\nMIN_LEN = 5\n\ndef reversal_tape(P):\n    # n + 1 states: ALPHABET + \"final\"\n    tape = []\n    end_state = \"stop\"\n    current_state = \"start\"\n    states = [a for a in ALPHABET] + [end_state]\n    while True:\n        next_state = random.choices(states, weights=P[current_state])[0]\n        if next_state == end_state:\n            break\n        else:\n            tape.append(next_state)\n        current_state = next_state\n\n    return tape\n"

parity_desc = "This is a string detection task. The strings in this task are generated from a probabilistic automaton, described in the code below.\n"
parity_desc += "Each input is of the form LEFT#RIGHT. Each string is labelled 0 or 1 depending on whether the RIGHT pattern is (1) or is not (0) a reversal of LEFT.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"

parity_desc += f"Here's the code:\n{ecm_str}\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\abc#cba:1\nThe right string is a reversal of the left string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

suffix = f"w_automaton_{problem}"

llm.update_params({tkey: tokens})

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=False,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

suffix = f"w_automaton_{problem}_shuffled"

call_and_solve_for(problem, suffix, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suffix, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = parity_desc_zero_shot, few_shot_prompt = parity_desc, shuffle_exemplars=True,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
#988

## APO

In [None]:
parity_desc = "This is a string detection task. The strings in this task are generated from a probabilistic automaton.\n"
parity_desc += "Each input is of the form LEFT#RIGHT. Each string is labelled 0 or 1 depending on whether the RIGHT pattern is (1) or is not (0) a reversal of LEFT.\n"
parity_desc += "Your job is to learn what is the likelihood of a string to be labelled 0 or 1, and output the correct label.\n"

parity_desc_zero_shot = parity_desc + "Give your answer as a single integer, and your reasoning in a new line.\n"
parity_desc_zero_shot += "For example:\abc#cba:1\nThe right string is a reversal of the left string.\n\n"

parity_desc += "Given the data below, determine what is the most likely label for the given string and output ONLY the label.\n"
parity_desc += "Data:\n\n"

p_hat_candidates = apo(parity_desc, corpus_id, problem=problem) # Use training for optimisation
p_hat = p_hat_candidates[0][0]
score = p_hat_candidates[0][-1]

with open(f"{out_dir_root}/apo_prompt_{problem}.json", "w", encoding="utf-8") as f:
    f.write(json.dumps({"Prompt": p_hat, "Score": score, "InitialPrompt": parity_desc, "OtherCandidates": p_hat_candidates, "Problem": problem}, ensure_ascii=False))

In [None]:
suff = f"apo_{problem}"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100],
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)


In [None]:
suff = f"apo_{problem}_shuffled"

p_hat = [json.loads(l) for l in open(f"{out_dir_root}/apo_prompt_{problem}.json", "r", encoding="utf-8").readlines()][0]["Prompt"]
# APO is a zero-shot problem, but we'll humour ourselves and try it out with shots anyway.
call_and_solve_for(problem, suff, corpus_id_test, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=False, out_dir_root=out_dir_root, is_cot=False)

call_and_solve_for(problem, suff, corpus_od_test_w_negs, SHOTS=[0, 2, 5, 10, 20, 50, 100], shuffle_exemplars=True,
                   zero_shot_prompt = p_hat, few_shot_prompt = p_hat,
                   do_ood=True, out_dir_root=out_dir_root, is_cot=False)
