-
Notifications
You must be signed in to change notification settings - Fork 43.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the ability to capture and load snapshots of memory and message history #192
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import browse | ||
import json | ||
import memory as mem | ||
import memory | ||
import datetime | ||
import agent_manager as agents | ||
import speak | ||
|
@@ -15,7 +15,6 @@ | |
|
||
cfg = Config() | ||
|
||
|
||
def get_command(response): | ||
try: | ||
response_json = fix_and_parse_json(response) | ||
|
@@ -178,14 +177,14 @@ def get_hyperlinks(url): | |
|
||
def commit_memory(string): | ||
_text = f"""Committing memory with string "{string}" """ | ||
mem.permanent_memory.append(string) | ||
memory.append(string) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, the stuff in this area makes sense to me. |
||
return _text | ||
|
||
|
||
def delete_memory(key): | ||
if key >= 0 and key < len(mem.permanent_memory): | ||
_text = "Deleting memory with key " + str(key) | ||
del mem.permanent_memory[key] | ||
memory.delete_memory(key) | ||
print(_text) | ||
return _text | ||
else: | ||
|
@@ -197,7 +196,7 @@ def overwrite_memory(key, string): | |
if int(key) >= 0 and key < len(mem.permanent_memory): | ||
_text = "Overwriting memory with key " + \ | ||
str(key) + " and string " + string | ||
mem.permanent_memory[key] = string | ||
memory.overwrite_memory(key, string) | ||
print(_text) | ||
return _text | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import shelve | ||
import memory | ||
import message_history | ||
|
||
class DataStore(): | ||
|
||
def persist_message_history(self, id): | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about return False on everything for the abstract base class? Or if errors are involved, throwing the NotImplementedError? |
||
|
||
def load_message_history(self, id): | ||
return True | ||
|
||
def persist_memory(self ,id): | ||
return True | ||
|
||
def load_memory(self, id): | ||
return True | ||
|
||
instance = DataStore() | ||
|
||
class ShelfDataStore(DataStore): | ||
|
||
def __init__(self, path): | ||
self.path = path | ||
# Ensure path ends with a slash | ||
if self.path[-1] != "/": | ||
self.path += "/" | ||
self.message_history = [] | ||
self.memory = [] | ||
|
||
def persist_message_history(self, id): | ||
os.makedirs(f"{self.path}{id}", exist_ok=True) | ||
with shelve.open(f"{self.path}{id}/store") as f: | ||
f["message_history"] = message_history.message_history | ||
return True | ||
|
||
def load_message_history(self, id): | ||
try: | ||
with shelve.open(f"{self.path}{id}/store") as f: | ||
global message_history | ||
message_history.set_history(f["message_history"]) | ||
return True | ||
except Exception as e: | ||
return e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function returns either |
||
|
||
def persist_memory(self, id): | ||
os.makedirs(f"{self.path}{id}", exist_ok=True) | ||
with shelve.open(f"{self.path}/{id}/store") as f: | ||
f["memory"] = memory.permanent_memory | ||
return True | ||
|
||
def load_memory(self, id): | ||
try: | ||
with shelve.open(f"{self.path}{id}/store") as f: | ||
memory.permanent_memory = f["memory"] | ||
return True | ||
except Exception as e: | ||
return e |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,9 @@ | |
import traceback | ||
import yaml | ||
import argparse | ||
|
||
import message_history | ||
import snapshots | ||
import data_store as ds | ||
|
||
def print_to_console( | ||
title, | ||
|
@@ -172,6 +174,8 @@ def construct_prompt(): | |
Goals: {config.ai_goals} | ||
Continue (y/n): """) | ||
if should_continue.lower() == "n": | ||
mem.clear_memory() | ||
message_history.clear_history() | ||
config = AIConfig() | ||
|
||
if not config.ai_name: | ||
|
@@ -248,6 +252,9 @@ def parse_arguments(): | |
parser.add_argument('--speak', action='store_true', help='Enable Speak Mode') | ||
parser.add_argument('--debug', action='store_true', help='Enable Debug Mode') | ||
parser.add_argument('--gpt3only', action='store_true', help='Enable GPT3.5 Only Mode') | ||
parser.add_argument('--enable-snapshots', action='store_true', help='Enable Snapshots') | ||
parser.add_argument('--snapshot-path', type=str, default=None, required=False, help='Path for the snapshot directory') | ||
parser.add_argument('--snapshot-id', type=str, default=None, required=False, help='ID of the snapshot to load') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This all looks good right here |
||
args = parser.parse_args() | ||
|
||
if args.continuous: | ||
|
@@ -266,6 +273,25 @@ def parse_arguments(): | |
print_to_console("GPT3.5 Only Mode: ", Fore.GREEN, "ENABLED") | ||
cfg.set_smart_llm_model(cfg.fast_llm_model) | ||
|
||
if args.enable_snapshots: | ||
print_to_console("Snapshots: ", Fore.GREEN, "ENABLED") | ||
cfg.set_snapshots_enabled(True) | ||
ds.instance = ds.ShelfDataStore("outputs/snapshots/") | ||
|
||
if args.snapshot_path: | ||
print_to_console("Snapshot Path: ", Fore.GREEN, args.snapshot_path) | ||
ds.instance = ds.ShelfDataStore(args.snapshot_path) | ||
|
||
if args.snapshot_id: | ||
result = snapshots.load_snapshot(args.snapshot_id) | ||
message_history = result["message_history"] | ||
memory = result["memory"] | ||
if message_history is not True: | ||
print_to_console("Load Snapshot: ", Fore.RED, f"FAILED - {message_history}") | ||
if memory is not True: | ||
print_to_console("Load Snapshot: ", Fore.RED, f"FAILED - {memory}") | ||
if message_history is True and memory is True: | ||
print_to_console("Load Snapshot: ", Fore.GREEN, f"SUCCESSFULLY loaded {args.snapshot_id}") | ||
|
||
# TODO: fill in llm values here | ||
|
||
|
@@ -287,7 +313,7 @@ def parse_arguments(): | |
assistant_reply = chat.chat_with_ai( | ||
prompt, | ||
user_input, | ||
full_message_history, | ||
message_history.message_history, | ||
mem.permanent_memory, | ||
cfg.fast_token_limit) # TODO: This hardcodes the model to use GPT3.5. Make this an argument | ||
|
||
|
@@ -352,10 +378,10 @@ def parse_arguments(): | |
# Check if there's a result from the command append it to the message | ||
# history | ||
if result is not None: | ||
full_message_history.append(chat.create_chat_message("system", result)) | ||
message_history.append(chat.create_chat_message("system", result)) | ||
print_to_console("SYSTEM: ", Fore.YELLOW, result) | ||
else: | ||
full_message_history.append( | ||
message_history.append( | ||
chat.create_chat_message( | ||
"system", "Unable to execute command")) | ||
print_to_console("SYSTEM: ", Fore.YELLOW, "Unable to execute command") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,56 @@ | ||
import snapshots | ||
|
||
permanent_memory = [] | ||
|
||
def append(string): | ||
permanent_memory.append(string) | ||
result = snapshots.create_snapshot() | ||
if result["memory"] is not True: | ||
print("Failed to persist memory") | ||
print(result["memory"]) | ||
return True | ||
|
||
def delete(key): | ||
if key >= 0 and key < len(permanent_memory): | ||
del permanent_memory[key] | ||
result = snapshots.create_snapshot() | ||
if result["memory"] is not True: | ||
print("Failed to persist memory") | ||
print(result["memory"]) | ||
return key | ||
else: | ||
return None | ||
|
||
def commit_memory(string): | ||
permanent_memory.append(string) | ||
result = snapshots.create_snapshot() | ||
if result["memory"] is not True: | ||
print("Failed to persist memory") | ||
print(result["memory"]) | ||
return True | ||
|
||
def delete_memory(key): | ||
if key in permanent_memory: | ||
del permanent_memory[key] | ||
result = snapshots.create_snapshot() | ||
if result["memory"] is not True: | ||
print("Failed to persist memory") | ||
print(result["memory"]) | ||
return key | ||
else: | ||
return None | ||
|
||
def overwrite_memory(key, string): | ||
if key in permanent_memory: | ||
permanent_memory[key] = string | ||
result = snapshots.create_snapshot() | ||
if result["memory"] is not True: | ||
print("Failed to persist memory") | ||
print(result["memory"]) | ||
return key | ||
else: | ||
return None | ||
|
||
def clear_memory(): | ||
permanent_memory.clear() | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import snapshots | ||
|
||
message_history = [] | ||
|
||
def append(message): | ||
message_history.append(message) | ||
result = snapshots.create_snapshot() | ||
if result["message_history"] is not True: | ||
print("Failed to persist message history") | ||
print(result["message_history"]) | ||
return message | ||
|
||
def overwrite(key, message): | ||
message_history[key] = message | ||
result = snapshots.create_snapshot() | ||
if result["message_history"] is not True: | ||
print("Failed to persist message history") | ||
print(result["message_history"]) | ||
return message | ||
|
||
def set_history(history): | ||
global message_history | ||
message_history = history | ||
|
||
def clear_history(): | ||
global message_history | ||
message_history = [] | ||
return True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import time | ||
from config import Config | ||
import data_store as ds | ||
|
||
cfg = Config() | ||
|
||
def create_snapshot(): | ||
if not cfg.snapshots_enabled: | ||
return { "message_history": True, "memory": True } | ||
current_unix_time = int(time.time()) | ||
message_history = ds.instance.persist_message_history(current_unix_time) | ||
memory = ds.instance.persist_memory(current_unix_time) | ||
return { "message_history": message_history, "memory": memory } | ||
|
||
def load_snapshot(id): | ||
message_history = ds.instance.load_message_history(id) | ||
memory = ds.instance.load_memory(id) | ||
return { "message_history": message_history, "memory": memory } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a significant or required change?