Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Install packages

In [None]:
!pip install openai==0.28
!pip install tiktoken
!pip install tqdm
!pip install matplotlib

# Import packages

In [None]:
import numpy as np
import openai
import tiktoken
from tqdm.auto import trange, tqdm
import time
import os
import json
from tqdm import tqdm
import re
from types import NoneType
import multiprocessing.dummy
from io import StringIO
from contextlib import redirect_stdout
import signal
from contextlib import contextmanager
import matplotlib.pyplot as plt
import sys
import ast
import copy

# Set up API key

In [None]:
openai.api_key = "your_openai_api_key"

# Global variables

In [None]:
ENGINE = 'gpt-3.5-turbo'
ANSWER_TOKEN = 'Answer: '
CODE_START_TOKEN = "# CODE START"
CODE_END_TOKEN = "# CODE END"
MAX_TOKENS = 4096
ENCODER = tiktoken.encoding_for_model(ENGINE)

TOTAL_CORRECT_ANSWERS = 0
FAILED_RUNS = 0

# Helper functions

In [None]:
def query_llm(prompt, max_tokens, stop=None, temperature=0):
    """
    Give the prompt to the LLM and get response
    """
    assert type(prompt)

    if 'instruct' in ENGINE:
        response = openai.Completion.create(prompt=prompt, model=ENGINE, max_tokens=max_tokens, temperature=temperature, stop=stop)
        response_text = response.choices[0]["text"].strip()
        return response_text
    else:
        messages = [{"role": "user", "content": prompt}]
        response = openai.ChatCompletion.create(messages=messages, model=ENGINE, max_tokens=max_tokens, temperature=temperature, stop=stop)
        return response['choices'][0]['message']['content'].strip()


In [None]:
def print_result(method, response, answer, correct_answer):
  global TOTAL_CORRECT_ANSWERS

  with open('output.txt', 'a') as f:
    print("#### Full Response ####", file=f)
    print(response, file=f)
    print("#### Model Answer ####", file=f)
    print(answer, file=f)
    print("#### Correct Answer ####", file=f)
    print(correct_answer, file=f)

  if str(answer).strip().lower() == str(correct_answer).strip().lower():
    TOTAL_CORRECT_ANSWERS += 1

  # print("#### Method ####")
  # print(method)
  # print("#### Full Response ####")
  # print(response)
  # print("#### Model Answer ####")
  # print(answer)
  # print("#### Correct Answer ####")
  # print(correct_answer)
  # if str(answer).strip().lower() == str(correct_answer).strip().lower():
  #   TOTAL_CORRECT_ANSWERS += 1



In [None]:
code_generation_prompt = """Q: Which sentence has the correct adjective order:
Options:
(A) rubber terrible ship
(B) terrible rubber ship
A:
# CODE START
import numpy as np
options = {"(A)": "rubber terrible ship", "(B)": "terrible rubber ship"}
priority = {"opinion": 1, "size": 2, "age": 3, "shape": 4, "color": 5, "origin": 6, "material": 7, "purpose": 8}
valid_types = list(priority.keys())
scores = []
for option, sentence in options.items():
  adjectives = sentence.split(" ")[:-1]
  order = [priority[get_adjective_type(adjective, valid_types, ret_type=str)] for adjective in adjectives]
  scores.append([order[i+1] > order[i] for i in range(len(order) - 1)].count(True))
answer = list(options.keys())[np.argmax(scores)]
# CODE END

Q: Today is Christmas Eve of 1937. What is the date 10 days ago in MM/DD/YYYY?
A:
# CODE START
import datetime
options = {"12/14/2026": "(A)", "12/14/1950": "(B)", "12/14/2007": "(C)", "12/14/1937": "(D)", "07/14/1938": "(E)", "12/14/1988": "(F)"}
today = datetime.date(year=1937, month=12, day=24)
date = today - datetime.timedelta(days=10)
answer = date.strftime("%m/%d/%Y")
# CODE END

Q: Recommend a movie similar to Star Wars Episode IV - A New Hope, Indiana Jones and the Last Crusade, Star Wars Episode V - The Empire Strikes Back, The Big Lebowski:
A:
# CODE START
ref_movies = ["Star Wars Episode IV - A New Hope", "Indiana Jones and the Last Crusade", "Star Wars Episode V - The Empire Strikes Back", "The Big Lebowski"]
ref_movie_infos = get_movie_genre_and_year(ref_movies, ret_type="list[tuple[str, int]]")
answer = get_most_similar_movie(ref_movies, ref_movie_infos, ret_type="str")
# CODE END

Q: bhgvynv is a hamster, sppqg is an ant, so bhgvynv is more likely to spin on a wheel than sppqg. True or False?
A:
# CODE START
entities = {"bhgvynv": "hamster", "sppqg": "ant"}
answer = is_more_likely_to_spin_on_a_wheel(entities["bhgvynv"], entities["sppqg"])
# CODE END

Q: kabazpvx is a bee, ixex is a dolphin, so kabazpvx is not less likely to be in a meadow than ixex. True or False?
A:
# CODE START
entities = {"kabazpvx": "bee", "ixex": "dolphin"}
answer = is_not_less_likely_to_be_in_a_meadow(entities["kabazpvx"], entities["inex"])
# CODE END

Q: pcmzhdkoqcs walks around town while qjopibkxim does not, so qjopibkxim is not a less explorer than pcmzhdkoqcs. True or False?
A:
# CODE START
entities = {"pcmzhdkoqcs": "walks around town", "qjopibkxim": "does not walk about town"}
answer = is_not_a_less_explorer(entities["qjopibkxim"], entities["pcmzhdkoqcs"])
# CODE END

Q: trej waits for jrxwtuuau's return while jrxwtuuau does not, so trej is not less patient than jrxwtuuau. True or False?
A:
# CODE START
entities = {"trej": "waits for return", "jrxwtuuau": "does not wait for return"}
answer = is_not_less_patient(entities["trej"], entities["jrxwtuuau"])
# CODE END

"""

In [None]:
lmulator_prompt = """# TRACE START
state: {}
line: adjective = "red"
delta state: {'adjective': 'red'}
line: valid_types = {"opinion", "size", "age", "shape", "color", "origin", "material", "purpose"}
delta state: {'valid_types': {'opinion', 'size', 'age', 'shape', 'color', 'origin', 'material', 'purpose'}}
line: adj_type = get_adjective_type(adjective, valid_types, ret_type=str)
delta state: {'adj_type': 'color'}
# TRACE END

# TRACE START
state: {}
line: obj1 = "soda can"
delta state: {'obj1': 'soda can'}
line: is_obj1_recyclable = is_recyclable(obj1, ret_type=bool)
delta state: {'is_obj1_recyclable': True}
line: obj2 = "fruit"
delta state: {'obj2': 'fruit'}
line: is_obj2_recyclable = is_recyclable(obj2, ret_type=bool)
delta state: {'is_obj2_recyclable': False}
# TRACE END

# TRACE START
state: {}
line: num1 = 23
delta state: {'num1': 23}
line: num2 = 52
delta state: {'num2': 52}
line: sum_of_two = num1 + num2
delta state: {'sum_of_two': 75}
line: greated_than_one_hundred = is_greater(sum_of_two, 100)
delta state: {'greated_than_one_hundred': False}
# TRACE END

# TRACE START
state: {}
line: entities = {"bhgvynv": "hamster", "sppqg": "ant"}
delta state: {'entities': {'bhgvynv': 'hamster', 'sppqg': 'ant'}}
line: answer = is_not_less_likely_to_be_in_a_meadow(entities["kabazpvx"], entities["inex"])
delta state: {'answer': True}
# TRACE END

# TRACE START
state: {}"""

In [None]:
errors = {}
error_lineno = None
lines = None
trace_lines = []
last_state = None

def get_delta_state(state, last_state):
  """
  This method indicates the difference between the current state and the
  last state. So it focses on the differences between two consecutive states.

  Returns a dictionary that contains info about the changes between two states.
  """
  delta_state = {}
  for key, val in state.items():
    if key not in last_state or val != last_state[key]:
      delta_state[key] = val
  return delta_state

def get_state(frame):
  """
  This method 'captures' and returns the local variables that currently in
  the frame that we pass as an argument
  """
  state = {}
  for key, item in frame.f_locals.items():
    if isinstance(item, (bool, str, int, float, tuple, list, set, dict, NoneType)):
      state[key] = item
  return state

def show_trace(frame, event, arg):
  # Declare these global variable first
  global errors
  global error_lineno
  global lines
  global trace_lines
  global last_state
  global lines_run_history

  # The LLM-generated code will be wrapped around in the get_answer function call.
  # If we don't filter by "get_answer", we got a bunch of random exception from colab
  if frame.f_code.co_name != "get_answer":
    return

  lineno = frame.f_lineno - 1
  # Running a certain line
  if event == "line":
    current_line = lines[lineno]
    if current_line.strip() in ["try:", "except:", "pass"]:
      pass
    elif current_line.strip() == "return answer":
      assert lineno == len(lines) - 2, "return answer is at the wrong line" # Second to last line
      state = get_state(frame)
      assert last_state is not None
      delta_state = get_delta_state(state, last_state)
      trace_lines.append(f"delta state: {delta_state}")
      # Append the final state
      trace_lines.append(f"final state: {state}")
    elif lineno not in errors:
      # We previous indent 2 spaces
      assert current_line[:2] == "  ", f"Python: actual line to run doesn't have two leading spaces: {current_line} {lines}"
      # Now we revert back
      current_line = current_line[2:]

      state = get_state(frame)

      delta_state = None
      if last_state is None:
        delta_state = None
      else:
        delta_state = get_delta_state(state, last_state)
      last_state = copy.deepcopy(state)

      if delta_state is None:
        trace_lines.append("state: {}")
      else:
        trace_lines.append(f"delta state: {delta_state}")
      trace_lines.append(f"line: {current_line}")

    else:

      # We previous indent 4 spaces
      assert current_line[:4] == "    ", f"LLM: actual line to run doesn't have four leading spaces: {current_line} {lines}"
      # Now we revert back
      current_line = current_line[4:]
      # When LLM excutes, remove any trailing space at the beginning

      state = get_state(frame)
      delta_state = None
      if last_state is None:
        delta_state = None
      else:
        delta_state = get_delta_state(state, last_state)
      last_state = copy.deepcopy(state)

      if delta_state is None:
        trace_lines.append("state: {}")
      else:
        trace_lines.append(f"delta state: {delta_state}")
      trace_lines.append(f"line: {current_line}")

      # Due to the context length constraint, only feed in the last three lines of the trace.
      prompt = lmulator_prompt + "\n" + "\n".join(trace_lines[-3:]) + "\n" + "delta state:"

      token_length = len(ENCODER.encode(prompt))

      llm_result = query_llm(prompt, max_tokens=32, stop=["\nline:"])

      program_state_str = llm_result.strip()
      try:
        new_program_state = ast.literal_eval(program_state_str)
        assert isinstance(new_program_state, dict), "new program state is not a valid dict"
        # Actually update the local variables with the new program state
        frame.f_locals.update(new_program_state)
      except Exception as e:
        raise e

  elif event == "exception":
    # Only capture the lowest level exception AND if this exception hasn't been "fixed" before, i.e. this line hasn't be sandwiched by try/except yet.
    if error_lineno is None and lineno not in errors:
      error_lineno = lineno

  return show_trace

sys.settrace(show_trace)

In [None]:
def evaluate_coc(prompt, query, correct_answer):
  """
  Evaluates the response from the CoC approach.

  Queries LLM
  Parses Response to extract Generated Code (between 'CODE_START_TOKEN' and 'CODE_END_TOKEN')
  Creates the 'code_to_run' obj that contains the code
  'max_trials': times that the loop can run
  """

  global errors
  global error_lineno
  global lines
  global trace_lines
  global last_state
  global FAILED_RUNS

  coc_response = query_llm(prompt + "\n\n" + query, max_tokens=1024)

  if CODE_START_TOKEN not in coc_response or CODE_END_TOKEN not in coc_response:
    FAILED_RUNS += 1
    return
  code_to_run = coc_response.split(CODE_START_TOKEN)[1].split(CODE_END_TOKEN)[0].strip()

  answer = None
  max_trials = 10
  # Wrap the code inside the get_answer function call
  code_to_run_temp = code_to_run.split("\n")
  code_to_run = "\n".join(["  " + l for l in code_to_run_temp])
  code_to_run = f"""def get_answer():
{code_to_run}
  return answer
answer = get_answer()"""


  lines = code_to_run.split("\n")
  local_vars = locals()

  for num_trial in range(max_trials):
    with open('output.txt', 'a') as f:
      print(f"Trial is {num_trial+1}", file=f)

    if sys.gettrace() is None: sys.settrace(show_trace)
    assert sys.gettrace() is not None, "get trace is None"
    try:
      # answer will be populated by exec function.
      exec(code_to_run, globals(), local_vars)
      coc_answer = local_vars["answer"]
      assert coc_answer is not None
      break
    except Exception as e:
      assert error_lineno is not None
      # Update errors
      line = lines[error_lineno]
      errors[error_lineno + 1] = line

      # Update lines and code_to_run
      num_indent = len(line) - len(line.lstrip())
      lines[error_lineno] = " " * 2 + lines[error_lineno]
      lines.insert(error_lineno, " " * num_indent + "try:")
      lines.insert(error_lineno + 2, " " * num_indent + "except:")
      lines.insert(error_lineno + 3, " " * (num_indent + 2) + "pass")
      code_to_run = "\n".join(lines)

      # Reset error_lineno and trace_lines
      error_lineno = None
      trace_lines = []
      last_state = None

  print_result('CoC', coc_response, coc_answer, correct_answer)

In [None]:
with open("RICA_for_CoC.json", 'r') as f:
  data = json.load(f)


  TOTAL_CORRECT_ANSWERS = 0
  FAILED_RUNS = 0

  for index, item in enumerate(data):
    sys.settrace(show_trace)

    with open('output.txt', 'a') as f:
      print(f"RICA Statement no.{index}", file=f)

    print(f"RICA Statement no.{index}")

    query = f"""
    Q: {item['question']}
    """.strip()

    try:
      evaluate_coc(code_generation_prompt, query, item['answer'])
    except Exception as e:
      print(e)
      FAILED_RUNS += 1

    with open('output.txt', 'a') as f:
      print("\n", file=f)

with open('output.txt', 'a') as f:
  print(f"Total correct answers: {TOTAL_CORRECT_ANSWERS}", file=f)
  print(f"Total failed runs: {FAILED_RUNS}", file=f)


print(f"Total correct answers: {TOTAL_CORRECT_ANSWERS}")
print(f"Total failed runs: {FAILED_RUNS}")