In [1]:
import dataclasses

import numpy as np

import enact
import chat
import llm_task

import asteroids_2goal as asteroids
import demo_utils

task_prompt = '''
Write a concise python function with the following signature.
Answer with a single code block.'''
code_gen_prompt = '''

```
def control_two_goal(position: np.ndarray,
                     goal1: np.ndarray,
                     goal2: np.ndarray,
                     orientation: float) -> Tuple[float, float]:
  """Control an asteroids-style ship to move to both goals.
  
  Args:
    position: A 2D vector representing the position of the ship.
    goal1: A 2D vector representing the position of the first goal.
    goal2: A 2D vector representing the position of the second goal.
    orientation: A float representing the orientation of the ship in radians.
  Returns:
    A pair of floats representing:
      torque: The amount of torque to apply to the ship, between -1 and 1.
      thrust: The amount of thrust to apply to the ship, between 0.0 and 1
  """
  <INSERT IMPLEMENTATION HERE>
```
'''

code_gen = llm_task.Task(
  task_prompt=task_prompt,
  chat_agent=chat.GPTAgent(model='gpt-4'))
code_gen.add_example(
  '''```def add(x: int, y: int):\n  <INSERT IMPLEMENTATION HERE>```''',
  '''```python\ndef add(x: int, y: int):\n  return x + y\n```''')

@enact.typed_invokable(enact.NoneResource, enact.Str)
@dataclasses.dataclass
class CreatePolicy(enact.Invokable):
  code_gen: llm_task.Task
  
  def call(self):
    return self.code_gen(enact.Str(code_gen_prompt))


sample_input = (np.zeros((2,)), np.zeros((2,)), np.zeros((2,)), 0.0)

PolicyChecker = demo_utils.get_policy_checker(asteroids.Action,
                                              asteroids.create_trajectory,
                                              asteroids.plot_trajectory,
                                              sample_input,
                                              func_name='control_two_goal')

store = enact.Store(backend=enact.FileBackend('/home/max/Documents/enact/examples/store_backend/'))
store = enact.Store()
with store:
  post_processor_ref = enact.commit(PolicyChecker())
  code_gen.post_processor = post_processor_ref
  code_gen.max_retries = 2
  create_policy = CreatePolicy(code_gen)

In [2]:
import enact.gradio as gradio

with store:
  gui = gradio.gradio.GUI(
    enact.commit(create_policy),
    input_required_outputs=[enact.Image],
    input_required_inputs=[enact.Str])
  gui.launch(share=True, debug=True)

  from .autonotebook import tqdm as notebook_tqdm


Running on local URL:  http://127.0.0.1:7861
Running on public URL: https://a4bd7ee2eb9400d8b2.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7861 <> https://a4bd7ee2eb9400d8b2.gradio.live
