Skip to content

Commit

Permalink
Fixing bug with chat messages and prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
TheR1D committed Mar 28, 2023
1 parent f8fa845 commit a0eb1ad
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# pylint: disable=consider-using-with
setup(
name="shell_gpt",
version="0.8.0",
version="0.8.1",
packages=find_packages(),
install_requires=[
"typer~=0.7.0",
Expand Down
24 changes: 16 additions & 8 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,22 @@ def main(
if editor:
prompt = get_edited_prompt()

prompt = make_prompt.initial(prompt, shell, code)

if chat:
initiated = bool(OpenAIClient.chat_cache.get_messages(chat))
if initiated:
if shell or code:
raise BadParameter("Can't use --shell or --code for existing chat.")
prompt = make_prompt.chat_mode(prompt, shell, code)
if chat and OpenAIClient.chat_cache.exists(chat):
chat_history = OpenAIClient.chat_cache.get_messages(chat)
is_shell_chat = chat_history[0].endswith("###\nCommand:")
is_code_chat = chat_history[0].endswith("###\nCode:")
if is_shell_chat and code:
raise BadParameter(
f"Chat id:{chat} was initiated as shell assistant, can be used with --shell only"
)
if is_code_chat and shell:
raise BadParameter(
f"Chat id:{chat} was initiated as code assistant, can be used with --code only"
)

prompt = make_prompt.chat_mode(prompt, is_shell_chat, is_code_chat)
else:
prompt = make_prompt.initial(prompt, shell, code)

completion = get_completion(
messages=[{"role": "user", "content": prompt}],
Expand Down
7 changes: 5 additions & 2 deletions sgpt/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from hashlib import md5
from pathlib import Path
from typing import List, Dict, Callable
from typing import List, Dict, Callable, Optional


class Cache:
Expand Down Expand Up @@ -122,9 +122,12 @@ def invalidate(self, chat_id: str):
file_path.unlink()

def get_messages(self, chat_id):
messages = self._read(self.storage_path / chat_id)
messages = self._read(chat_id)
return [f"{message['role']}: {message['content']}" for message in messages]

def exists(self, chat_id: Optional[str]) -> bool:
return chat_id and bool(self._read(chat_id))

def list(self):
# Get all files in the folder.
files = self.storage_path.glob("*")
Expand Down
2 changes: 1 addition & 1 deletion sgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_edited_prompt() -> str:

def echo_chat_messages(chat_id: str) -> None:
# Prints all messages from a specified chat ID to the console.
for index, message in enumerate(OpenAIClient.chat_cache.show(chat_id)):
for index, message in enumerate(OpenAIClient.chat_cache.get_messages(chat_id)):
color = "cyan" if index % 2 == 0 else "green"
typer.secho(message, fg=color)

Expand Down

0 comments on commit a0eb1ad

Please sign in to comment.