In [1]:
import torch

In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [12]:
def generate_synthetic_data(n_points, w_true, b_true):
    points = torch.arange(1, n_points + 1)
    y = w_true * points + b_true + torch.randn(n_points)
    return [(points[i].item(), y[i].item()) for i in range(n_points)]

def sample_starting_region(n_points, starting_region):
    w = torch.randint(low=starting_region[0], high=starting_region[1], size=(n_points,))
    b = torch.randint(low=starting_region[0], high=starting_region[1], size=(n_points,))
    pairs = [(w[i].item(), b[i].item()) for i in range(n_points)]
    return pairs

init_data = sample_starting_region(n_points=5, starting_region=[10, 20])  

In [13]:
init_data

[(15, 17), (18, 10), (17, 15), (10, 14), (11, 19)]

In [14]:
from openai import OpenAI

class OpenAIWrapper:
    def __init__(self, model="gpt-3.5-turbo-1106"):
        self.client = OpenAI()
        self.model = model
        
    def __call__(self, system_prompt, user_prompt, temperature):
        completion = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "system", "content": system_prompt},
                      {"role": "user", "content": user_prompt}],
            temperature=temperature,
        )
        return completion.choices[0].message.content

llm = OpenAIWrapper()

In [34]:
import re

class ObjectiveFunction:
    def __init__(self, synthetic_data):
        self.synthetic_data = synthetic_data
        
    def __call__(self, w, b):
        return int(sum([(y - (w * x + b)) ** 2 for x, y in self.synthetic_data]))

def generate_new_solutions(pairs, n=8):
    system_prompt = "You are a helpful assistant"
    user_prompt = f"Now you will help me minimize a function with two input variables w, b. I have some (w, b) pairs and the function values at those points. The pairs are arranged in descending order based on their function values, where lower values are better.\n\n"
    for pair in pairs[::-1]:
        w, b, value = pair
        user_prompt += f"Input:\nw={w}, b={b}\nValue:\n{value}\n\n"
    user_prompt += f"Give me a new (w, b) pair that is different from all pairs above, and has a function value lower than all of the above. Do not write code. Do not ask me questions. Please ensure the output contains the new pair in this format: [w, b], where w and b are integer values."
    solution = [llm(system_prompt, user_prompt, temperature=1) for _ in range(n)]
    parsed_solution = [parse_pair(text) for text in solution if parse_pair(text) is not None]
    return parsed_solution

def parse_pair(text):
    # Use regular expression to find numbers in the format [w, b]
    match1 = re.search(r'w\s*=\s*(-?\d+\.?\d*)\s*,\s*b\s*=\s*(-?\d+\.?\d*)\s*', text)
    match2 = re.search(r'\[\s*(-?\d+\.?\d*)\s*,\s*(-?\d+\.?\d*)\s*\]', text)
    if match1 or match2:
        # Extract w and b, convert them to float or int
        w, b = match1.groups() if match1 else match2.groups()
        w = float(w) if '.' in w else int(w)
        b = float(b) if '.' in b else int(b)
        return w, b
    else:
        # print(f"Could not parse pair from text: {text}")
        return None

In [35]:
from collections import namedtuple
from operator import attrgetter

Solution = namedtuple("Solution", "w b value")

true_data = [(15, 14), (17, 17), (16, 10)]
results = {}
replicas = 5
max_steps = 30 
for w_true, b_true in true_data:
    steps = []
    for _ in range(replicas):
        synthetic_data = generate_synthetic_data(n_points=50, w_true=w_true, b_true=b_true)
        objective_function = ObjectiveFunction(synthetic_data)
        history = [Solution(pair[0], pair[1], objective_function(*pair)) for pair in init_data]
        for step in range(1, max_steps + 1):
            best_pairs = sorted(history, key=attrgetter("value"))[:20]
            new_solutions = generate_new_solutions(best_pairs, n=8)
            best_solution = min(new_solutions, key=lambda x: objective_function(*x))
            history.append(Solution(best_solution[0], best_solution[1], objective_function(*best_solution)))
            if best_solution == (w_true, b_true):
                steps.append(step)
                break
    results[(w_true, b_true)] = steps 
        

In [38]:
# calculate mean and std for each true pair
import numpy as np
for key, value in results.items():
    print(f"{key}: {np.mean(value)} +- {np.std(value)}")

(15, 14): 8.8 +- 4.445222154178574
(17, 17): 13.0 +- 4.09878030638384
(16, 10): 8.0 +- 4.06201920231798
