Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to capture and load snapshots of memory and message history #192

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dotenv import load_dotenv
from config import Config
import token_counter
import message_history

cfg = Config()

Expand Down Expand Up @@ -40,7 +41,6 @@ def chat_with_ai(
Args:
prompt (str): The prompt explaining the rules to the AI.
user_input (str): The input from the user.
full_message_history (list): The list of all messages sent between the user and the AI.
permanent_memory (list): The list of items in the AI's permanent memory.
token_limit (int): The maximum number of tokens allowed in the API call.

Expand Down Expand Up @@ -114,10 +114,10 @@ def chat_with_ai(
)

# Update full message history
full_message_history.append(
message_history.append(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a significant or required change?

create_chat_message(
"user", user_input))
full_message_history.append(
message_history.append(
create_chat_message(
"assistant", assistant_reply))

Expand Down
9 changes: 4 additions & 5 deletions scripts/commands.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import browse
import json
import memory as mem
import memory
import datetime
import agent_manager as agents
import speak
Expand All @@ -15,7 +15,6 @@

cfg = Config()


def get_command(response):
try:
response_json = fix_and_parse_json(response)
Expand Down Expand Up @@ -178,14 +177,14 @@ def get_hyperlinks(url):

def commit_memory(string):
_text = f"""Committing memory with string "{string}" """
mem.permanent_memory.append(string)
memory.append(string)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, the stuff in this area makes sense to me.

return _text


def delete_memory(key):
if key >= 0 and key < len(mem.permanent_memory):
_text = "Deleting memory with key " + str(key)
del mem.permanent_memory[key]
memory.delete_memory(key)
print(_text)
return _text
else:
Expand All @@ -197,7 +196,7 @@ def overwrite_memory(key, string):
if int(key) >= 0 and key < len(mem.permanent_memory):
_text = "Overwriting memory with key " + \
str(key) + " and string " + string
mem.permanent_memory[key] = string
memory.overwrite_memory(key, string)
print(_text)
return _text
else:
Expand Down
6 changes: 5 additions & 1 deletion scripts/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Config(metaclass=Singleton):
def __init__(self):
self.continuous_mode = False
self.speak_mode = False
self.snapshots_enabled = False
# TODO - make these models be self-contained, using langchain, so we can configure them once and call it good
self.fast_llm_model = os.getenv("FAST_LLM_MODEL", "gpt-3.5-turbo")
self.smart_llm_model = os.getenv("SMART_LLM_MODEL", "gpt-4")
Expand Down Expand Up @@ -70,4 +71,7 @@ def set_google_api_key(self, value: str):
self.google_api_key = value

def set_custom_search_engine_id(self, value: str):
self.custom_search_engine_id = value
self.custom_search_engine_id = value

def set_snapshots_enabled(self, value: bool):
self.snapshots_enabled = value
59 changes: 59 additions & 0 deletions scripts/data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import shelve
import memory
import message_history

class DataStore():

def persist_message_history(self, id):
return True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about return False on everything for the abstract base class? Or if errors are involved, throwing the NotImplementedError?


def load_message_history(self, id):
return True

def persist_memory(self ,id):
return True

def load_memory(self, id):
return True

instance = DataStore()

class ShelfDataStore(DataStore):

def __init__(self, path):
self.path = path
# Ensure path ends with a slash
if self.path[-1] != "/":
self.path += "/"
self.message_history = []
self.memory = []

def persist_message_history(self, id):
os.makedirs(f"{self.path}{id}", exist_ok=True)
with shelve.open(f"{self.path}{id}/store") as f:
f["message_history"] = message_history.message_history
return True

def load_message_history(self, id):
try:
with shelve.open(f"{self.path}{id}/store") as f:
global message_history
message_history.set_history(f["message_history"])
return True
except Exception as e:
return e
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns either True or Error. Both a consistent boolean or an actual exception is better if you ask me.


def persist_memory(self, id):
os.makedirs(f"{self.path}{id}", exist_ok=True)
with shelve.open(f"{self.path}/{id}/store") as f:
f["memory"] = memory.permanent_memory
return True

def load_memory(self, id):
try:
with shelve.open(f"{self.path}{id}/store") as f:
memory.permanent_memory = f["memory"]
return True
except Exception as e:
return e
34 changes: 30 additions & 4 deletions scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import traceback
import yaml
import argparse

import message_history
import snapshots
import data_store as ds

def print_to_console(
title,
Expand Down Expand Up @@ -172,6 +174,8 @@ def construct_prompt():
Goals: {config.ai_goals}
Continue (y/n): """)
if should_continue.lower() == "n":
mem.clear_memory()
message_history.clear_history()
config = AIConfig()

if not config.ai_name:
Expand Down Expand Up @@ -248,6 +252,9 @@ def parse_arguments():
parser.add_argument('--speak', action='store_true', help='Enable Speak Mode')
parser.add_argument('--debug', action='store_true', help='Enable Debug Mode')
parser.add_argument('--gpt3only', action='store_true', help='Enable GPT3.5 Only Mode')
parser.add_argument('--enable-snapshots', action='store_true', help='Enable Snapshots')
parser.add_argument('--snapshot-path', type=str, default=None, required=False, help='Path for the snapshot directory')
parser.add_argument('--snapshot-id', type=str, default=None, required=False, help='ID of the snapshot to load')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all looks good right here

args = parser.parse_args()

if args.continuous:
Expand All @@ -266,6 +273,25 @@ def parse_arguments():
print_to_console("GPT3.5 Only Mode: ", Fore.GREEN, "ENABLED")
cfg.set_smart_llm_model(cfg.fast_llm_model)

if args.enable_snapshots:
print_to_console("Snapshots: ", Fore.GREEN, "ENABLED")
cfg.set_snapshots_enabled(True)
ds.instance = ds.ShelfDataStore("outputs/snapshots/")

if args.snapshot_path:
print_to_console("Snapshot Path: ", Fore.GREEN, args.snapshot_path)
ds.instance = ds.ShelfDataStore(args.snapshot_path)

if args.snapshot_id:
result = snapshots.load_snapshot(args.snapshot_id)
message_history = result["message_history"]
memory = result["memory"]
if message_history is not True:
print_to_console("Load Snapshot: ", Fore.RED, f"FAILED - {message_history}")
if memory is not True:
print_to_console("Load Snapshot: ", Fore.RED, f"FAILED - {memory}")
if message_history is True and memory is True:
print_to_console("Load Snapshot: ", Fore.GREEN, f"SUCCESSFULLY loaded {args.snapshot_id}")

# TODO: fill in llm values here

Expand All @@ -287,7 +313,7 @@ def parse_arguments():
assistant_reply = chat.chat_with_ai(
prompt,
user_input,
full_message_history,
message_history.message_history,
mem.permanent_memory,
cfg.fast_token_limit) # TODO: This hardcodes the model to use GPT3.5. Make this an argument

Expand Down Expand Up @@ -352,10 +378,10 @@ def parse_arguments():
# Check if there's a result from the command append it to the message
# history
if result is not None:
full_message_history.append(chat.create_chat_message("system", result))
message_history.append(chat.create_chat_message("system", result))
print_to_console("SYSTEM: ", Fore.YELLOW, result)
else:
full_message_history.append(
message_history.append(
chat.create_chat_message(
"system", "Unable to execute command"))
print_to_console("SYSTEM: ", Fore.YELLOW, "Unable to execute command")
Expand Down
55 changes: 55 additions & 0 deletions scripts/memory.py
Original file line number Diff line number Diff line change
@@ -1 +1,56 @@
import snapshots

permanent_memory = []

def append(string):
permanent_memory.append(string)
result = snapshots.create_snapshot()
if result["memory"] is not True:
print("Failed to persist memory")
print(result["memory"])
return True

def delete(key):
if key >= 0 and key < len(permanent_memory):
del permanent_memory[key]
result = snapshots.create_snapshot()
if result["memory"] is not True:
print("Failed to persist memory")
print(result["memory"])
return key
else:
return None

def commit_memory(string):
permanent_memory.append(string)
result = snapshots.create_snapshot()
if result["memory"] is not True:
print("Failed to persist memory")
print(result["memory"])
return True

def delete_memory(key):
if key in permanent_memory:
del permanent_memory[key]
result = snapshots.create_snapshot()
if result["memory"] is not True:
print("Failed to persist memory")
print(result["memory"])
return key
else:
return None

def overwrite_memory(key, string):
if key in permanent_memory:
permanent_memory[key] = string
result = snapshots.create_snapshot()
if result["memory"] is not True:
print("Failed to persist memory")
print(result["memory"])
return key
else:
return None

def clear_memory():
permanent_memory.clear()
return True
28 changes: 28 additions & 0 deletions scripts/message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import snapshots

message_history = []

def append(message):
message_history.append(message)
result = snapshots.create_snapshot()
if result["message_history"] is not True:
print("Failed to persist message history")
print(result["message_history"])
return message

def overwrite(key, message):
message_history[key] = message
result = snapshots.create_snapshot()
if result["message_history"] is not True:
print("Failed to persist message history")
print(result["message_history"])
return message

def set_history(history):
global message_history
message_history = history

def clear_history():
global message_history
message_history = []
return True
18 changes: 18 additions & 0 deletions scripts/snapshots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import time
from config import Config
import data_store as ds

cfg = Config()

def create_snapshot():
if not cfg.snapshots_enabled:
return { "message_history": True, "memory": True }
current_unix_time = int(time.time())
message_history = ds.instance.persist_message_history(current_unix_time)
memory = ds.instance.persist_memory(current_unix_time)
return { "message_history": message_history, "memory": memory }

def load_snapshot(id):
message_history = ds.instance.load_message_history(id)
memory = ds.instance.load_memory(id)
return { "message_history": message_history, "memory": memory }