In [None]:
import dataclasses
import io
import traceback

from matplotlib import pyplot as plt
import numpy as np
import PIL.Image

import enact
import llm_task
import asteroids



task_prompt = '''
Write a python function with the following signature.
Answer with a single code block.
'''
code_gen_prompt = '''

```
def control(position: np.ndarray,
            goal: np.ndarray,
            orientation: float) -> Tuple[float, float]
  """Control an asteroids-style ship to move to the goal.
  
  Args:
    position: A 2D vector representing the position of the ship.
    goal: A 2D vector representing the position of the goal.
    orientation: A float representing the orientation of the ship in radians.
  Returns:
    A pair of floats representing:
      torque: The amount of torque to apply to the ship, between -1 and 1.
      thrust: The amount of thrust to apply to the ship, between 0.0 and 1
  """
  <INSERT IMPLEMENTATION HERE>
```
'''

@enact.typed_invokable(enact.Str, enact.Image)
class PolicyVisualizer(enact.Invokable):
  """Visualizes a policy provided (as a string)."""

  def call(self, code: enact.Str) -> enact.Image:
    """Plots policy trajectories."""
    def_dict = {}
    exec(code, def_dict)
    control = def_dict['control']
    def policy(state: asteroids.State) -> asteroids.Action:
      return asteroids.Action(np.array(list(control(
        state.position[0],
        state.goal_position[0],
        state.rotation[0])))[np.newaxis])

    _, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
   
    for ax in [ax1, ax2, ax3, ax4]:
      t = asteroids.create_trajectory(policy, (1,), steps=300)
      asteroids.plot_trajectory(t, axis=ax)
    b = io.BytesIO()
    plt.savefig(b, format='png')
    return enact.Image(PIL.Image.open(b))


@enact.typed_invokable(enact.Str, llm_task.ProcessedOutput)
class PolicyChecker(enact.Invokable):

  def call(self, input: enact.Str) -> llm_task.ProcessedOutput:
    if not input.startswith('```python') and not input.startswith('```'):
      return llm_task.ProcessedOutput(
        output=None, correction='Input must start with "```".')
    if not input.endswith('```'):
      return llm_task.ProcessedOutput(
        output=None, correction='Input must end with "```".')
    if '```' in input[3:-3]:
      return llm_task.ProcessedOutput(
        output=None, correction='Input must be a single code block```".')
    if input.startswith('```python'):
      code = input[len('```python`'):-len('```')]
    else:
      code = input[len('```'):-len('```')]
    def_dict = {}
    try:
      exec(code, def_dict)
    except Exception as e:
      return llm_task.ProcessedOutput(
        output=None,
        correction=f'Your code raised an exception while parsing: {e}\n{traceback.format_exc()}')
    control = def_dict.get('control')
    if not control:
      return llm_task.ProcessedOutput(
        output=None,
        correction='Your code did not define a `control` function.')
    try:
      result = control(np.zeros((2,)), np.zeros((2,)), 0.0)
    except Exception as e:
      return llm_task.ProcessedOutput(
        output=None,
        correction=f'Your code raised an exception while running: {e}\n{traceback.format_exc()}')
    try:
      thrust, torque = result
      thrust = float(thrust)
      torque = float(torque)
    except TypeError:
      return llm_task.ProcessedOutput(
        output=None,
        correction='Your code could not be unpacked into two float values.')
    critique = enact.RequestInput(enact.Str, 'Please critique the policy or leave empty if ok.')(
      PolicyVisualizer()(enact.Str(code)))
  
    if critique != '':
      return llm_task.ProcessedOutput(
        output=None, correction=f'User critique: {critique}')
    
    return llm_task.ProcessedOutput(
      output=code, correction=None)
      

code_gen = llm_task.Task(
  task_prompt=task_prompt)
code_gen.add_example(
  '''```def add(x: int, y: int):\n  <INSERT IMPLEMENTATION HERE>```''',
  '''```python\ndef add(x: int, y: int):\n  return x + y\n```''')

@enact.typed_invokable(enact.NoneResource, enact.Str)
@dataclasses.dataclass
class CreatePolicy(enact.Invokable):
  code_gen: llm_task.Task
  
  def call(self):
    return self.code_gen(enact.Str(code_gen_prompt))


store = enact.Store()
with store:
  code_gen.post_processor = enact.commit(PolicyChecker())
  code_gen.max_retries = 10
  create_policy = CreatePolicy(code_gen)


In [None]:
with store:
  gui = enact.GUI(
    enact.commit(create_policy),
    input_required_outputs=[enact.Image],
    input_required_inputs=[enact.Str])
  gui.launch(share=True)