-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathgenerate_answer_from_image_node.py
127 lines (104 loc) · 3.83 KB
/
generate_answer_from_image_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
GenerateAnswerFromImageNode Module
"""
import asyncio
import base64
from typing import List, Optional
import aiohttp
from .base_node import BaseNode
class GenerateAnswerFromImageNode(BaseNode):
"""
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
and updates the state with the consolidated answers.
"""
def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "GenerateAnswerFromImageNode",
):
super().__init__(node_name, "node", input, output, 2, node_config)
async def process_image(self, session, api_key, image_data, user_prompt):
"""
async process image
"""
base64_image = base64.b64encode(image_data).decode("utf-8")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
payload = {
"model": self.node_config["config"]["llm"]["model"],
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
"max_tokens": 300,
}
async with session.post(
"https://api.openai.com/v1/chat/completions", headers=headers, json=payload
) as response:
result = await response.json()
return (
result.get("choices", [{}])[0]
.get("message", {})
.get("content", "No response")
)
async def execute_async(self, state: dict) -> dict:
"""
Processes images from the state, generates answers,
consolidates the results, and updates the state asynchronously.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")
images = state.get("screenshots", [])
analyses = []
supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4")
if (
self.node_config["config"]["llm"]["model"].split("/")[-1]
not in supported_models
):
raise ValueError(
f"""The model provided
is not supported. Supported models are:
{', '.join(supported_models)}."""
)
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
async with aiohttp.ClientSession() as session:
tasks = [
self.process_image(
session,
api_key,
image_data,
state.get("user_prompt", "Extract information from the image"),
)
for image_data in images
]
analyses = await asyncio.gather(*tasks)
consolidated_analysis = " ".join(analyses)
state["answer"] = {"consolidated_analysis": consolidated_analysis}
return state
def execute(self, state: dict) -> dict:
"""
Wrapper to run the asynchronous execute_async function in a synchronous context.
"""
try:
eventloop = asyncio.get_event_loop()
except RuntimeError:
eventloop = None
if eventloop and eventloop.is_running():
task = eventloop.create_task(self.execute_async(state))
state = eventloop.run_until_complete(asyncio.gather(task))[0]
else:
state = asyncio.run(self.execute_async(state))
return state