Skip to content

Commit

Permalink
Merge pull request #30 from WhatTheFuzz/feature/openai-updates
Browse files Browse the repository at this point in the history
Feature/openai updates
  • Loading branch information
WhatTheFuzz committed Apr 17, 2024
2 parents 4d520c0 + e7be587 commit 136a0e1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion plugin.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
"openai"
]
},
"version": "2.0.1",
"version": "3.0.1",
"minimumbinaryninjaversion": 3200
}
13 changes: 7 additions & 6 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from pathlib import Path

import openai
from openai.api_resources.model import Model
from openai.error import APIError
from openai import APIError

from binaryninja.function import Function
from binaryninja.lowlevelil import LowLevelILFunction
Expand All @@ -18,6 +17,7 @@

from . query import Query
from . c import Pseudo_C
from . exceptions import NoAPIKeyException


class Agent:
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(self,
path_to_api_key: Optional[Path]=None) -> None:

# Read the API key from the environment variable.
openai.api_key = self.read_api_key(path_to_api_key)
self.client = openai.OpenAI(api_key=self.read_api_key(filename=path_to_api_key))

assert bv is not None, 'BinaryView is None. Check how you called this function.'
# Set instance attributes.
Expand Down Expand Up @@ -87,12 +87,12 @@ def read_api_key(self, filename: Optional[Path]=None) -> str:
except FileNotFoundError:
log.log_error(f'Could not find API key file at {filename}.')

raise APIError('No API key found. Refer to the documentation to add the '
raise NoAPIKeyException('No API key found. Refer to the documentation to add the '
'API key.')

def is_valid_model(self, model: str) -> bool:
'''Checks if the model is valid by querying the OpenAI API.'''
models: list[Model] = openai.Model.list().data
models: list = self.client.models.list().data
return model in [m.id for m in models]

def get_model(self) -> str:
Expand Down Expand Up @@ -206,7 +206,8 @@ def rename_variable(self, response: str) -> None:

def send_query(self, query: str, callback: Optional[Callable]=None) -> None:
'''Sends a query to the engine and prints the response.'''
query = Query(query_string=query,
query = Query(client=self.client,
query_string=query,
model=self.model,
max_token_count=self.get_token_count(),
callback_function=callback)
Expand Down
3 changes: 3 additions & 0 deletions src/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
class NoAPIKeyException(Exception):
pass

class RegisterSettingsGroupException(Exception):
pass

Expand Down
9 changes: 5 additions & 4 deletions src/query.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Optional
import openai
from openai import Client
from binaryninja.plugin import BackgroundTaskThread
from binaryninja.log import log_debug, log_info

class Query(BackgroundTaskThread):

def __init__(self, query_string: str, model: str,
def __init__(self, client: Client, query_string: str, model: str,
max_token_count: int, callback_function: Optional[Callable]=None) -> None:
BackgroundTaskThread.__init__(self,
initial_progress_text="",
can_cancel=False)
self.client: Client = client
self.query_string: str = query_string
self.model: str = model
self.max_token_count: int = max_token_count
Expand All @@ -23,15 +24,15 @@ def run(self) -> None:
log_debug(f'Sending query: {self.query_string}')

if self.model in ["gpt-3.5-turbo","gpt-4","gpt-4-32k"]:
response = openai.ChatCompletion.create(
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role":"user","content":self.query_string}],
max_tokens=self.max_token_count,
)
# Get the response text.
result: str = response.choices[0].message.content
else:
response = openai.Completion.create(
response = self.client.chat.completions.create(
model=self.model,
prompt=self.query_string,
max_tokens=self.max_token_count,
Expand Down

0 comments on commit 136a0e1

Please sign in to comment.