In [None]:
import time
from typing import List
from dataclasses import dataclass
import numpy as np

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaForCausalLM
from transformers import pipeline

In [None]:
# ============== set optimization experiment configurations ================
num_points = 10  # number of points in TSP
num_steps = 500  # the number of optimization steps
max_num_pairs = 10  # the maximum number of input-output pairs in meta-prompt
num_decimals = 0  # num of decimals for distances in meta-prompt
num_starting_points = 5  # the number of initial points for optimization
num_decode_per_step = 8 # the number of decoded solutions per step

In [None]:
x = np.random.uniform(low=-100, high=100, size=num_points)
y = np.random.uniform(low=-100, high=100, size=num_points)
x = [np.round(xi, num_decimals) if num_decimals > 0 else int(xi) for xi in x]
y = [np.round(yi, num_decimals) if num_decimals > 0 else int(yi) for yi in y]

In [None]:
def solve_tsp(x, y, num_points, num_decimals):
    f = {(0, 1): (0, [0])}
    q = [(0, 1)]
    min_dis = -1
    gt_sol = list(range(num_points))
    
    while len(q) > 0:
        p, status = q[0]
        q = q[1:]
        for i in range(num_points):
            if 2 << i >> 1 & status == 0:
                new_status = status + (2 << i >> 1)
                new_dis = f[(p, status)][0] + np.sqrt((x[i] - x[p]) ** 2 + (y[i] - y[p]) ** 2)
                if (i, new_status) not in f or new_dis < f[(i, new_status)][0]:
                    f[(i, new_status)] = (new_dis, f[(p, status)][1] + [i])
                    if new_status == (2 << num_points >> 1) - 1:
                        new_dis += np.sqrt((x[i] - x[0]) ** 2 + (y[i] - y[0]) ** 2)
                        if min_dis == -1 or new_dis < min_dis:
                            min_dis = new_dis
                            gt_sol = f[(i, new_status)][1][:]
                    elif (i, new_status) not in q:
                        q.append((i, new_status))
    min_dis = np.round(min_dis, num_decimals) if num_decimals > 0 else int(min_dis)
    
    return gt_sol, min_dis

gt_sol, min_dis = solve_tsp(x, y, num_points, num_decimals)

print("ground truth solution" + str(gt_sol))
print("min distance: ", min_dis)
gt_sol_str = ",".join([str(i) for i in gt_sol])
point_list = range(num_points)
init_sols = []
print("here")
while len(init_sols) < num_starting_points:
    sol = np.random.permutation(point_list)
    print(f"\r{len(init_sols)}-{num_starting_points} {sol}", end='')
    if sol[0] != 0:
        continue
    sol_str = ",".join([str(i) for i in sol])
    if sol_str == gt_sol_str:
        continue
    init_sols.append(list(sol))
else:
    print()
init_sols

In [None]:

def evaluate_distance(x, y, trace, num_decimals):  # pylint: disable=invalid-name
    dis = 0
    try:
        for i in range(len(trace) - 1):
            id0 = trace[i]
            id1 = trace[i + 1]
            dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2)
    except:
        return -1
    id0 = trace[-1]
    id1 = trace[0]
    dis += np.sqrt((x[id0] - x[id1]) ** 2 + (y[id0] - y[id1]) ** 2)
    dis = np.round(dis, num_decimals) if num_decimals > 0 else int(dis)
    return dis

old_value_pairs_set = set()
old_value_pairs_with_i_step = []  # format: [(trace, dis = f(trace), i_step)]
meta_prompts_dict = dict()  # format: {i_step: meta_prompt}
raw_outputs_dict = dict()  # format: {i_step: raw_outputs}

for sol in init_sols:
    dis = evaluate_distance(x, y, sol, num_decimals)
    sol_str = ",".join([str(i) for i in sol])
    old_value_pairs_set.add((sol_str, dis))
    old_value_pairs_with_i_step.append((sol_str, dis, -1))

print("\n================ run optimization ==============")
print(f"initial points: {[tuple(item[:-1]) for item in old_value_pairs_set]}")
print(f"initial values: {[item[-1] for item in old_value_pairs_set]}")

In [None]:
model_id = "meta-llama/Llama-3.2-3B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # model_kwargs={
    #     "quantization_config": {"load_in_8bit": True},
	# }
)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:

def gen_meta_prompt(
    old_value_pairs_set,
    x,  # pylint: disable=invalid-name
    y,
    max_num_pairs=100,
):
    """Generate the meta-prompt for optimization.

    Args:
        old_value_pairs_set (set): the set of old traces.
        X (np.array): the 1D array of x values.
        y (np.array): the 1D array of y values.
        num_decimals (int): the number of decimals in the
        meta-prompt.
        max_num_pairs (int): the maximum number of exemplars in the meta-prompt.

    Returns:
        meta_prompt (str): the generated meta-prompt.
    """
    old_value_pairs = list(old_value_pairs_set)
    old_value_pairs = sorted(old_value_pairs, key=lambda x: -x[1])[
        -max_num_pairs:
    ]
    old_value_pairs_substr = ""
    for trace, dis in old_value_pairs:
        old_value_pairs_substr += f"<trace> {trace} </trace> with length: {dis}\n"
    meta_prompt = "You are given a list of points with coordinates below:\n"
    for i, (xi, yi) in enumerate(zip(x, y)):
        if i:
            meta_prompt += ", "
        meta_prompt += f"({i}): ({xi}, {yi})"
    meta_prompt += ".\nBelow are some previous traces and their lengths. The traces are arranged in descending order based on their lengths, where lower values are better.".strip()
    meta_prompt += "\n"
    meta_prompt += old_value_pairs_substr.strip()
    meta_prompt += "\n"
    meta_prompt += """Find a shorter trace that is different from previously mentioned traces and describe the newly found trace in natural language without the need for detail explanation. The trace should traverse all points exactly once. The trace should start with '<trace>' and end with </trace>.
    """.strip()
    return meta_prompt

def extract_string(input_string):
    start_string = "<trace>"
    end_string = "</trace>"
    if start_string not in input_string:
        return ""
    input_string = input_string[input_string.index(start_string) + len(start_string):]
    if end_string not in input_string:
        return ""
    input_string = input_string[:input_string.index(end_string)]
    parsed_list = []
    for p in input_string.split(","):
        p = p.strip()
        try:
            p = int(p)
        except:
            continue
        parsed_list.append(p)
    return parsed_list

In [None]:
def chat(msgs):
	outputs = pipe(
		msgs,
		max_new_tokens=512,
		do_sample=True,
		temperature=0.8,
		top_p=0.95,
	)
	response = outputs[0]['generated_text'][-1]['content']
	msgs.append({
		"role": "assistant",
		"content": response
	})
	return response

messages = []
for i_step in range(num_steps):
	print(f"\nStep {i_step}:")
	meta_prompt = gen_meta_prompt(
		old_value_pairs_set,
		x,
		y,
		max_num_pairs=max_num_pairs,
	)
	messages.append({
		"role": "user",
		"content": meta_prompt
	})
	
	# print("\n=================================================")
	print(meta_prompt)
	meta_prompts_dict[i_step] = meta_prompt
	raw_outputs = []
	parsed_outputs = []
	while len(parsed_outputs) < num_decode_per_step:
		raw_output = chat(messages)
		for string in raw_output.split("\n"):
			try:
				parsed_output = extract_string(string)
				if parsed_output is not None and len(set(parsed_output)) == num_points and len(parsed_output) == num_points and parsed_output[0] == 0:
					dis = evaluate_distance(x, y, parsed_output, num_decimals)
					if dis == -1:
						continue
					sol_str = ",".join([str(i) for i in parsed_output])
					old_value_pairs_set.add((sol_str, dis))
					parsed_outputs.append(parsed_output)
					raw_outputs.append(string)
					# print(parsed_output, dis)
			except:
				pass
		# print("\n=================================================")
		# print(f"proposed points: {parsed_outputs}")
	print(f"new trace with length {dis}")
	raw_outputs_dict[i_step] = raw_outputs	

	# break