Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
ecb6419
Remove `-accurate` mode until fixed
joshbickett Jan 4, 2024
3b4b9c4
Add `DECISION_PROMPT` and `LABELED_IMAGE_PROMPT`
joshbickett Jan 4, 2024
73d59a7
Add `labeling.py`
joshbickett Jan 4, 2024
511c918
Move `prompt` functions to `prompt.py`
joshbickett Jan 4, 2024
a2e7194
Rename `prompt.py` to `prompts.py`
joshbickett Jan 4, 2024
e98df7a
update `ansi_colors.py` to `styles.py`, and other small changes
joshbickett Jan 4, 2024
2a72434
Add prompt function `format_decision_prompt`, etc.
joshbickett Jan 4, 2024
af47ea0
Update file name to `actions.py` until we think of something better
joshbickett Jan 5, 2024
4f7ca64
Update function names `call_gpt_4_v`, etc.
joshbickett Jan 5, 2024
76467fc
Add `parse_click_content` to `labeling.py`
joshbickett Jan 5, 2024
111b5f8
Add `call_gpt_4_v_labeled` for set of mark prompting
joshbickett Jan 5, 2024
a167008
add `aiohttp`
joshbickett Jan 5, 2024
67c97e5
Update `requirements.txt`
joshbickett Jan 5, 2024
05a4733
Update `.gitignore`
joshbickett Jan 5, 2024
61584a6
Add `gpt-4-with-som` model option
joshbickett Jan 5, 2024
3fcb7e1
Fix `dict` issue for `call_gpt_4_v_labeled`
joshbickett Jan 6, 2024
eebabeb
Change naming from `predictions` to `detections`
joshbickett Jan 6, 2024
0d92682
Adjust file names
joshbickett Jan 7, 2024
b013002
Delete `prompt_util.py`
joshbickett Jan 7, 2024
0b1981b
Remove unnecessary folder structure
joshbickett Jan 7, 2024
ecc39a0
Add `best.pt` yolo weights
joshbickett Jan 7, 2024
075e365
More file name updates
joshbickett Jan 7, 2024
e30726f
Update to `os.py` and
joshbickett Jan 7, 2024
061746f
change name to `keyboard_type`
joshbickett Jan 7, 2024
bc831f8
Add `YOLO("./model/weights/best.pt")`
joshbickett Jan 7, 2024
44b0992
Update `yolo` directory
joshbickett Jan 7, 2024
b2ae0ba
Add new `best.pt`
joshbickett Jan 7, 2024
fa05718
Add fallback `Something went wrong`
joshbickett Jan 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,5 @@ cython_debug/
.DS_Store

# Avoid sending testing screenshots up
screenshot.png
screenshot_with_grid.png
screenshot_with_labeled_grid.png
screenshot_mini.png
screenshot_mini_with_grid.png
grid_screenshot.png
grid_reflection_screenshot.png
reflection_screenshot.png
summary_screenshot.png
operate/screenshots/
*.png
operate/screenshots/
221 changes: 179 additions & 42 deletions operate/actions/api_interactions.py → operate/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,65 @@
import json
import base64
import re
import io
import asyncio
import aiohttp

from PIL import Image
from ultralytics import YOLO
import google.generativeai as genai
from operate.config.settings import Config
from operate.exceptions.exceptions import ModelNotRecognizedException
from operate.utils.screenshot_util import capture_screen_with_cursor, add_grid_to_image, capture_mini_screenshot_with_cursor
from operate.utils.action_util import get_last_assistant_message
from operate.utils.prompt_util import format_vision_prompt, format_accurate_mode_vision_prompt,format_summary_prompt
from operate.settings import Config
from operate.exceptions import ModelNotRecognizedException
from operate.utils.screenshot import (
capture_screen_with_cursor,
add_grid_to_image,
capture_mini_screenshot_with_cursor,
)
from operate.utils.os import get_last_assistant_message
from operate.prompts import (
format_vision_prompt,
format_accurate_mode_vision_prompt,
format_summary_prompt,
format_decision_prompt,
format_label_prompt,
)


from operate.utils.label import (
add_labels,
parse_click_content,
get_click_position_in_percent,
get_label_coordinates,
)
from operate.utils.style import (
ANSI_GREEN,
ANSI_RED,
ANSI_RESET,
)


# Load configuration
config = Config()

client = config.initialize_openai_client()

yolo_model = YOLO("./operate/model/weights/best.pt") # Load your trained model

def get_next_action(model, messages, objective, accurate_mode):
if model == "gpt-4-vision-preview":
content = get_next_action_from_openai(
messages, objective, accurate_mode)
return content

async def get_next_action(model, messages, objective):
if model == "gpt-4":
return call_gpt_4_v(messages, objective)
if model == "gpt-4-with-som":
return await call_gpt_4_v_labeled(messages, objective)
elif model == "agent-1":
return "coming soon"
elif model == "gemini-pro-vision":
content = get_next_action_from_gemini_pro_vision(
messages, objective
)
return content
return call_gemini_pro_vision(messages, objective)

raise ModelNotRecognizedException(model)


def get_next_action_from_openai(messages, objective, accurate_mode):
def call_gpt_4_v(messages, objective):
"""
Get the next action for Self-Operating Computer
"""
Expand Down Expand Up @@ -95,32 +124,14 @@ def get_next_action_from_openai(messages, objective, accurate_mode):

content = response.choices[0].message.content

if accurate_mode:
if content.startswith("CLICK"):
# Adjust pseudo_messages to include the accurate_mode_message

click_data = re.search(r"CLICK \{ (.+) \}", content).group(1)
click_data_json = json.loads(f"{{{click_data}}}")
prev_x = click_data_json["x"]
prev_y = click_data_json["y"]

if config.debug:
print(
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
)
content = accurate_mode_double_check(
"gpt-4-vision-preview", pseudo_messages, prev_x, prev_y
)
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"

return content

except Exception as e:
print(f"Error parsing JSON: {e}")
return "Failed take action after looking at the screenshot"


def get_next_action_from_gemini_pro_vision(messages, objective):
def call_gemini_pro_vision(messages, objective):
"""
Get the next action for Self-Operating Computer using Gemini Pro Vision
"""
Expand Down Expand Up @@ -172,14 +183,13 @@ def get_next_action_from_gemini_pro_vision(messages, objective):
return "Failed take action after looking at the screenshot"


# This function is not used. `-accurate` mode was removed for now until a new PR fixes it.
def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
"""
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
"""
print("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check")
try:
screenshot_filename = os.path.join(
"screenshots", "screenshot_mini.png")
screenshot_filename = os.path.join("screenshots", "screenshot_mini.png")
capture_mini_screenshot_with_cursor(
file_path=screenshot_filename, x=prev_x, y=prev_y
)
Expand All @@ -191,8 +201,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
with open(new_screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")

accurate_vision_prompt = format_accurate_mode_vision_prompt(
prev_x, prev_y)
accurate_vision_prompt = format_accurate_mode_vision_prompt(prev_x, prev_y)

accurate_mode_message = {
"role": "user",
Expand Down Expand Up @@ -234,7 +243,7 @@ def summarize(model, messages, objective):
capture_screen_with_cursor(screenshot_filename)

summary_prompt = format_summary_prompt(objective)

if model == "gpt-4-vision-preview":
with open(screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
Expand Down Expand Up @@ -266,7 +275,135 @@ def summarize(model, messages, objective):
)
content = summary_message.text
return content

except Exception as e:
print(f"Error in summarize: {e}")
return "Failed to summarize the workflow"
return "Failed to summarize the workflow"


async def call_gpt_4_v_labeled(messages, objective):
time.sleep(1)
try:
screenshots_dir = "screenshots"
if not os.path.exists(screenshots_dir):
os.makedirs(screenshots_dir)

screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
# Call the function to capture the screen with the cursor
capture_screen_with_cursor(screenshot_filename)

with open(screenshot_filename, "rb") as img_file:
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")

previous_action = get_last_assistant_message(messages)

img_base64_labeled, img_base64_original, label_coordinates = add_labels(
img_base64, yolo_model
)

decision_prompt = format_decision_prompt(objective, previous_action)
labeled_click_prompt = format_label_prompt(objective)

click_message = {
"role": "user",
"content": [
{"type": "text", "text": labeled_click_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64_labeled}"
},
},
],
}
decision_message = {
"role": "user",
"content": [
{"type": "text", "text": decision_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64_original}"
},
},
],
}

click_messages = messages.copy()
click_messages.append(click_message)
decision_messages = messages.copy()
decision_messages.append(decision_message)

click_future = fetch_openai_response_async(click_messages)
decision_future = fetch_openai_response_async(decision_messages)

click_response, decision_response = await asyncio.gather(
click_future, decision_future
)

# Extracting the message content from the ChatCompletionMessage object
click_content = click_response.get("choices")[0].get("message").get("content")

decision_content = (
decision_response.get("choices")[0].get("message").get("content")
)

if not decision_content.startswith("CLICK"):
return decision_content

label_data = parse_click_content(click_content)

if label_data and "label" in label_data:
coordinates = get_label_coordinates(label_data["label"], label_coordinates)
image = Image.open(
io.BytesIO(base64.b64decode(img_base64))
) # Load the image to get its size
image_size = image.size # Get the size of the image (width, height)
click_position_percent = get_click_position_in_percent(
coordinates, image_size
)
if not click_position_percent:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Failed to get click position in percent. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)

x_percent = f"{click_position_percent[0]:.2f}%"
y_percent = f"{click_position_percent[1]:.2f}%"
click_action = f'CLICK {{ "x": "{x_percent}", "y": "{y_percent}", "description": "{label_data["decision"]}", "reason": "{label_data["reason"]}" }}'

else:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] No label found. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)

return click_action

except Exception as e:
print(
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Something went wrong. Trying another method {ANSI_RESET}"
)
return call_gpt_4_v(messages, objective)


async def fetch_openai_response_async(messages):
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {config.openai_api_key}",
}
data = {
"model": "gpt-4-vision-preview",
"messages": messages,
"frequency_penalty": 1,
"presence_penalty": 1,
"temperature": 0.7,
"max_tokens": 300,
}

async with aiohttp.ClientSession() as session:
async with session.post(
url, headers=headers, data=json.dumps(data)
) as response:
return await response.json()
Empty file removed operate/actions/__init__.py
Empty file.
Empty file removed operate/config/__init__.py
Empty file.
Loading