In [32]:
import requests
import json
import numpy as np
import re
import os

import matplotlib.pyplot as plt

In [157]:
import openai

openai.api_key = "YOUR_OPENAI_KEY"

model = "gpt-3.5-turbo-instruct"
token_limit = 4096

def LLM(prompt, prompt_len, temperature=0):
  responses = openai.Completion.create(engine=model, prompt=prompt, max_tokens=token_limit-prompt_len, temperature=temperature, stop=['\n---'])
  text = [response['text'] for response in responses['choices']]
  return text[0]

In [147]:
"""
import openai

openai.api_key = "YOUR_OPENAI_KEY"


def LLM(prompt, prompt_len, temperature=0):
    model = "gpt-3.5-turbo"
  
    response = openai.ChatCompletion.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}],
        temperature=temperature)
    text = response['choices'][0]['message']['content']
    return text
"""

In [158]:
alphabet = [str(i) for i in range(10)] # classic
#alphabet = ['G', 'D', 'E', 'B', 'A', 'C', 'F', 'H', 'I', 'J']
value_to_token = lambda x: {i:a for i, a in enumerate(alphabet)}[x]
token_to_value = lambda x: {a:i for i, a in enumerate(alphabet)}[x]

def replace_numbers_with_tokens(s):
    # Loop through each character in the string
    result = []
    for char in s:
        if char.isdigit():  # Check if character is a digit
            token = value_to_token(int(char))
            result.append(token)
        else:
            result.append(char)
    return ''.join(result)

def replace_tokens_with_numbers(s):
    result = []
    index = 0
    while index < len(s):
        if s[index] in alphabet:
            number = token_to_value(s[index])
            result.append(str(number))
        else:
            result.append(s[index])
        index += 1
    return ''.join(result)

In [159]:
#from the code of the "LLMs as general pattern machines" paper
colors = [(0, 0, 0),
          (0, 116, 217),
          (255, 65, 54),
          (46, 204, 6),
          (255, 220, 0),
          (170, 170, 170),
          (240, 18, 190),
          (255, 133, 27),
          (127, 219, 255),
          (135, 12, 37)]

def grid_to_img(grid):
  grid = np.int32(grid)
  scale = 10
  img = np.zeros((grid.shape[0] * scale + 1, grid.shape[1] * scale + 1, 3), dtype=np.uint8)
  for r in range(grid.shape[0]):
    for c in range(grid.shape[1]):
      img[r*scale+1:(r+1)*scale, c*scale+1:(c+1)*scale, :] = colors[grid[r, c]]
  new_img = img.copy()
  new_img[0::10, :, :] = np.uint8(np.round((0.7 * np.float32(img[0::10, :, :]) + 0.3 * 255)))
  new_img[:, 0::10, :] = np.uint8(np.round((0.7 * np.float32(img[:, 0::10, :]) + 0.3 * 255)))
  return new_img

def show_task(task):
  print("TRAIN:")
  for i, ex in enumerate(task["train"]):
    in_img = grid_to_img(ex["input"])
    out_img = grid_to_img(ex["output"])
    plt.subplot(1, 2, 1); plt.imshow(grid_to_img(ex["input"]))
    plt.subplot(1, 2, 2); plt.imshow(grid_to_img(ex["output"]))
    plt.show()
  print("TEST:")
  for i, ex in enumerate([task["test"][0]]): #only display the first test task, as its only this task thats sent to the model
    in_img = grid_to_img(ex["input"])
    out_img = grid_to_img(ex["output"])
    plt.subplot(1, 2, 1); plt.imshow(grid_to_img(ex["input"]))
    plt.subplot(1, 2, 2); plt.imshow(grid_to_img(ex["output"]))
    plt.show()

In [162]:
def load_task(filename):
    with open(filename, 'r') as f:
        task = json.load(f)
    return task

def process_arc_task(task):
    task_np = {
        "train": [{"input": np.array(item["input"]), "output": np.array(item["output"])} for item in task["train"]],
        "test": [{"input": np.array(item["input"]), "output": np.array(item["output"])} for item in task["test"]]
    }
    return task_np

def numpy_array_to_string(arr):
    return '\n'.join([','.join(map(str, row)) for row in arr])

def string_to_numpy_array(s):
    return np.array([list(map(int, row.split(','))) for row in s.split('\n')])

def task_to_prompt(task):
    prompt = ""

    for demo in task['train']:
        prompt += "input:\n"
        prompt += numpy_array_to_string(demo['input'])
        prompt += "\n"
        prompt += "output:\n"
        prompt += numpy_array_to_string(demo['output'])
        prompt += "\n---\n"

    prompt += "input:\n"
    prompt += numpy_array_to_string(task['test'][0]['input'])
    prompt += "\n"
    prompt += "output:\n"
    return prompt

tasks_training = os.listdir("ARC/data/training/")
tasks_evaluation = os.listdir("ARC/data/evaluation/")

tasks_json = []

for file in [os.path.join("ARC/data/training/", file) for file in tasks_training]:
    tasks_json.append(load_task(file))

for file in [os.path.join("ARC/data/evaluation/", file) for file in tasks_evaluation]:
    tasks_json.append(load_task(file))

#tasks_ids = [6, 14, 33, 34, 37, 59, 72, 79, 85, 89, 96]
#tasks_json = [tasks_json[i] for i in tasks_ids]

tasks_ids = list(range(200, 400))
tasks_json = [tasks_json[i] for i in tasks_ids]

In [163]:
responses = {}

for id, task in enumerate(tasks_json):
    task_np = process_arc_task(task)
    prompt = task_to_prompt(task_np)
    prompt = replace_numbers_with_tokens(prompt)

    if len(prompt) > 4000:
        print(str(id+200) + " task too long (prompt len={})".format(len(prompt)))
        continue

    response = LLM(prompt, len(prompt))
    response = replace_tokens_with_numbers(response)
    responses[id] = response
    
    try:
        response_np = string_to_numpy_array(response)
    except Exception as e:
        print(str(id+200) + " decoding error (task len : {})".format(len(prompt)))
        continue

    correct_np = np.array(task_np['test'][0]['output'])
    print(str(id+200) + " " + str(np.array_equal(response_np, correct_np)))

200 False
201 False
202 False
203 False
204 True
205 False
206 False
207 False
208 False
209 task too long (prompt len=6246)
210 False
211 task too long (prompt len=5672)
212 False
213 False
214 False
215 False
216 False
217 True
218 False
219 False
220 False
221 True
222 False
223 False
224 False
225 False
226 True
227 False
228 False
229 False
230 False
231 False
232 True
233 False
234 False
235 False
236 False
237 False
238 False
239 False
240 task too long (prompt len=5672)
241 False
242 False
243 False
244 False
245 False
246 True
247 True
248 False
249 False
250 False
251 True
252 task too long (prompt len=7988)
253 False
254 decoding error (task len : 3283)
255 True
256 task too long (prompt len=16291)
257 False
258 False
259 True
260 False
261 False
262 False
263 False
264 False
265 False
266 False
267 True
268 False
269 False
270 False
271 False
272 False
273 False
274 False
275 False
276 False
277 False
278 False
279 False
280 False
281 False
282 False
283 True
284 False
285 