In [1]:
import chromadb
from chromadb.config import Settings
from datetime import datetime, timedelta
import time
import json

In [None]:
class ScoringSystem:
    def __init__(self):
        self.db = chromadb.ChromaDB(Settings.DB_PATH)

In [29]:
from chromadb import Collection


class AgentMemory:
    def __init__(self, agent_name, client_path = "./database/mem_test"):
        self.agent_name = agent_name
        self.client = chromadb.PersistentClient(path = client_path+f"/{agent_name}")
        self.static_memory = self.client.get_or_create_collection("static_memory")
        self.short_term_memory = self.client.get_or_create_collection("short_term_memory")
        self.long_term_memory = self.client.get_or_create_collection("long_term_memory")
        return
    
    def add_static_memory(self, overwrite_existing=False, personality_path='./agent_personalities'):
        """
        Adds static memory for an agent from a JSON file to the static memory database.
        
        Args:
            overwrite_existing (bool): Whether to overwrite existing entries with the same IDs.
        """
        
        with open(personality_path+f'/agent_{self.agent_name}.json') as f:
            data = json.load(f)
            
            # Retrieve existing IDs in the database
            existing_ids = set(self.static_memory.get()["ids"])
            
            new_documents = []
            new_metadatas = []
            new_ids = []
            
            for doc, meta, id_ in zip(data['documents'], data['metadatas'], data['ids']):
                if id_ in existing_ids:
                    if overwrite_existing:
                        # Overwrite: Remove existing entry first
                        self.static_memory.delete(ids=[id_])
                        new_documents.append(doc)
                        new_metadatas.append(meta)
                        new_ids.append(id_)
                    else:
                        # Skip: Do not add duplicate entries
                        print(f"Skipped adding memory with ID: {id_} (already exists)")
                else:
                    # Add new entries
                    new_documents.append(doc)
                    new_metadatas.append(meta)
                    new_ids.append(id_)
            
            # Add new/updated data to the database
            if new_ids:
                self.static_memory.add(
                    documents=new_documents,
                    metadatas=new_metadatas,
                    ids=new_ids,
                )
                print(f"Added {len(new_ids)} static memories for agent {self.agent_name} successfully")
            else:
                print(f"No new static memories added for agent {self.agent_name}")
        
        return
    
    def add_short_term_memory(self, event, timestamp=None, overwrite_existing=False):
        """
        Adds a log of an event to the short-term memory.
        
        Args:
            event (str): The event to be logged.
            timestamp (str): The timestamp of the event. Defaults to the current time if not provided.
            overwrite_existing (bool): Whether to overwrite existing entries with the same ID.
        """
        if timestamp is None:
            timestamp = datetime.now().isoformat()
        
        importance_score = self._calculate_importance_score(event)
        print(importance_score)
        event_id = f"stm_{timestamp}"
        
        # Retrieve existing IDs in the database
        existing_ids = set(self.short_term_memory.get()["ids"])
        
        if event_id in existing_ids:
            if overwrite_existing:
                # Overwrite: Remove existing entry first
                self.short_term_memory.delete(ids=[event_id])
                self.short_term_memory.add(
                    documents=[event], 
                    metadatas=[{"timestamp": timestamp, "importance_score": importance_score}], 
                    ids=[event_id]
                )
                print(f"Overwritten event in short-term memory with ID: {event_id}")
            else:
                # Skip: Do not add duplicate entry
                print(f"Skipped adding event to short-term memory (ID already exists): {event_id}")
        else:
            # Add new entry
            self.short_term_memory.add(
                    documents=[event], 
                    metadatas=[{"timestamp": timestamp, "importance_score": importance_score}], 
                    ids=[event_id]
                )
            print(f"Added event to short-term memory: {event}")
        return
    
    def migrate_to_long_term_memory(self):
        """
        Migrates important short-term memories to long-term memory.
        """
        stm = self.short_term_memory.get(include=['documents', 'metadatas'])
        for doc, meta in zip(stm['documents'], stm['metadatas']):
            # Simple importance score: Longer-lasting events get higher priority
            importance_score = len(doc.split())
            self.long_term_memory.add(
                documents=[doc],
                metadatas=[{**meta, "importance": importance_score}],
                ids=[f"ltm_{meta['timestamp']}"],
            )
            # Remove the migrated memory from short-term memory
            self.short_term_memory.delete(ids=[f"stm_{meta['timestamp']}"])
        return
    
    def decay_long_term_memory(self):
        """
        Decays the importance of long-term memories.
        """
        ltm = self.long_term_memory.get(include=['documents', 'metadatas'])
        for doc, meta, id_ in zip(ltm['documents'], ltm['metadatas'], ltm['ids']):
            if 'importance' in meta:
                meta['importance'] *= 0.9
                if meta['importance'] < 1:
                    self.long_term_memory.delete(ids=[id_])
                else:
                    self.long_term_memory.update(ids=[id_], metadatas=[meta])
        return
    
    def summarize_and_forget(self):
        """
        Summarizes and forgets less important long-term memories.
        """
        ltm = self.long_term_memory.get()
        low_importance_docs = []
        ids_to_delete = []

        for doc, meta, id_ in zip(ltm['documents'], ltm['metadatas'], ltm['ids']):
            if 'importance' in meta and meta['importance'] < 5:
                low_importance_docs.append(doc)
                ids_to_delete.append(id_)
                
        if low_importance_docs:
            summary = " ".join(low_importance_docs)
            timestamp = datetime.now().isoformat()
            self.long_term_memory.add(
                documents=[f"Summary: {summary}"],
                metadatas=[{"timestamp": timestamp, "importance": 5}],
                ids=[f"summary_{timestamp}"]
            )
            self.long_term_memory.delete(ids=ids_to_delete)
        return
    
    def query_memory(self, query: str):
        """
        Queries memory based on hierarchy.
        
        Args:
            query (str): The query text.
        
        Returns:
            dict: The query results from static, short-term, and long-term memory.
        """
        # Query static memory first
        static_results = self.static_memory.query(query_texts=[query], n_results=3)
        if static_results["documents"]:
            return static_results

        # Query short-term memory next
        short_term_results = self.short_term_memory.query(query_texts=[query], n_results=3)
        if short_term_results["documents"]:
            return short_term_results

        # Finally, query long-term memory
        long_term_results = self.long_term_memory.query(query_texts=[query], n_results=3)
        return long_term_results
    
    def get_stats(self):
        """
        Get statistics of the memory.
        
        Args:
            memory_type (str): The type of memory to get statistics for. Can be 'static', 'short-term', 'long-term', or 'all'.
        
        Returns:
            dict: The statistics of the memory.
        """
        static_stats = self._get_stats(self.static_memory)
        short_term_stats = self._get_stats(self.short_term_memory)
        long_term_stats = self._get_stats(self.long_term_memory)
        return {
            "static_memory": static_stats,
            "short_term_memory": short_term_stats,
            "long_term_memory": long_term_stats,
        }
    
    # TODO @Danit: Implement the importance score calculation
    def _calculate_importance_score(self, doc: str):
        """
        Calculate the importance score of a document. The parameters considered are Relevency, Recency, Frequency, Personal Impact
        
        """
        relevency  = 1
        recency = 1
        frequency = 1
        personal_impact = 1

        weights = {
            "relevency": 0.25,
            "recency": 0.25,
            "frequency": 0.25,
            "personal_impact": 0.25 
        }
        
        importance_score = relevency * weights["relevency"] + recency * weights["recency"] + frequency * weights["frequency"] + personal_impact * weights["personal_impact"]
        
        return importance_score
    
    def _get_stats(self, memory:Collection):
        """
        Get statistics of the memory/vector database.

        Args:
            memory (Collection): The memory to get statistics for.
        """
        memory_data = memory.get()
        n_documents = len(memory_data["documents"])
        return {
            "n_documents": n_documents,
        }
    
    def clear_memory(self, memory_type="all"):
        """
        Clear all the memory of specified type.
        
        Args:
            memory_type (str): The type of memory to clear. Can be 'static', 'short-term', 'long-term', or 'all'.
        """
        if memory_type == "all":
            try: self.static_memory.delete(ids=self.static_memory.get()["ids"])
            except: pass    # Memory is already empty
            try: self.short_term_memory.delete(ids=self.short_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            try: self.long_term_memory.delete(ids=self.long_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("All memory cleared successfully")
        elif memory_type == "static":
            try: self.static_memory.delete(ids=self.static_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Static memory cleared successfully")
        elif memory_type == "short-term":
            try: self.short_term_memory.delete(ids=self.short_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Short-term memory cleared successfully")
        elif memory_type == "long-term":
            try: self.long_term_memory.delete(ids=self.long_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Long-term memory cleared successfully")
        else:
            print("Invalid memory type. Please specify 'static', 'short-term', 'long-term', or 'all'.")
            
        return
    
    

In [30]:
agent_name = "alex"

agent_memory = AgentMemory(agent_name)
agent_memory.add_static_memory()
agent_memory.add_short_term_memory("Alex interacted with a wild animal", timestamp="2021-09-01T12:00:00")
agent_memory.add_short_term_memory("Alex found a hidden treasure", timestamp="2021-09-02T12:00:00")
# agent_memory.migrate_to_long_term_memory()


1.0
Skipped adding event to short-term memory (ID already exists): stm_2021-09-01T12:00:00
1.0
Skipped adding event to short-term memory (ID already exists): stm_2021-09-02T12:00:00


In [4]:
# Example Queries
print("Querying Static Memory:", agent_memory.query_memory("Fire burns"))
print("Querying Long-Term Memory:", agent_memory.query_memory("Alex found"))

Querying Static Memory: {'ids': [['rule_1', 'env_1', 'char_1']], 'embeddings': None, 'documents': [['Rule: Fire burns wood', 'Environment: The forest is dense with trees and has wild animals', 'Name: Alex, Age: 25, Role: Warrior']], 'uris': None, 'data': None, 'metadatas': [[{'type': 'rule'}, {'type': 'environment'}, {'type': 'character_info'}]], 'distances': [[0.5438721637512911, 1.7832853887789941, 1.8180477565724145]], 'included': [<IncludeEnum.distances: 'distances'>, <IncludeEnum.documents: 'documents'>, <IncludeEnum.metadatas: 'metadatas'>]}
Querying Long-Term Memory: {'ids': [['char_1', 'rule_1', 'env_1']], 'embeddings': None, 'documents': [['Name: Alex, Age: 25, Role: Warrior', 'Rule: Fire burns wood', 'Environment: The forest is dense with trees and has wild animals']], 'uris': None, 'data': None, 'metadatas': [[{'type': 'character_info'}, {'type': 'rule'}, {'type': 'environment'}]], 'distances': [[1.024590927687949, 1.6749894819160671, 1.7917170688820596]], 'included': [<Incl

In [5]:
# Decay demonstration
print("Decaying Long-Term Memory...")
for _ in range(3):
    agent_memory.decay_long_term_memory()
    time.sleep(0.1)

# Summarization demonstration
print("Summarizing Low Importance Memories...")
agent_memory.summarize_and_forget()

Decaying Long-Term Memory...
Summarizing Low Importance Memories...
