In [1]:
import chromadb
from chromadb.config import Settings
from datetime import datetime, timedelta
import time
import json
from chromadb import Collection
import requests
import json
import math
import numpy as np
import dspy

In [2]:
URI_RATING="http://127.0.0.1:5000/rate"

In [None]:
lm = dspy.LM('ollama_chat/llama3.1:8b', api_base='http://localhost:11434', api_key='')
dspy.configure(lm=lm)

In [3]:
class AgentMemory:
    def __init__(self, agent_name, client_path = "./database/mem_test"):
        self.agent_name = agent_name
        self.client = chromadb.PersistentClient(path = client_path+f"/{agent_name}")
        self.static_memory = self.client.get_or_create_collection("static_memory")
        self.short_term_memory = self.client.get_or_create_collection("short_term_memory")
        self.long_term_memory = self.client.get_or_create_collection("long_term_memory")
        return
    
    def add_static_memory(self, overwrite_existing=False, personality_path='./agent_personalities'):
        """
        Adds static memory for an agent from a JSON file to the static memory database.
        
        Args:
            overwrite_existing (bool): Whether to overwrite existing entries with the same IDs.
        """
        
        with open(personality_path+f'/agent_{self.agent_name}.json') as f:
            data = json.load(f)
            
            # Retrieve existing IDs in the database
            existing_ids = set(self.static_memory.get()["ids"])
            
            new_documents = []
            new_metadatas = []
            new_ids = []
            
            for doc, meta, id_ in zip(data['documents'], data['metadatas'], data['ids']):
                if id_ in existing_ids:
                    if overwrite_existing:
                        # Overwrite: Remove existing entry first
                        self.static_memory.delete(ids=[id_])
                        new_documents.append(doc)
                        new_metadatas.append(meta)
                        new_ids.append(id_)
                    else:
                        # Skip: Do not add duplicate entries
                        print(f"Skipped adding memory with ID: {id_} (already exists)")
                else:
                    # Add new entries
                    new_documents.append(doc)
                    new_metadatas.append(meta)
                    new_ids.append(id_)
            
            # Add new/updated data to the database
            if new_ids:
                self.static_memory.add(
                    documents=new_documents,
                    metadatas=new_metadatas,
                    ids=new_ids,
                )
                print(f"Added {len(new_ids)} static memories for agent {self.agent_name} successfully")
            else:
                print(f"No new static memories added for agent {self.agent_name}")
        
        return
    
    def add_short_term_memory(self, event, timestamp=None, overwrite_existing=False):
        """
        Adds a log of an event to the short-term memory.
        
        Args:
            event (str): The event to be logged.
            timestamp (str): The timestamp of the event. Defaults to the current time if not provided.
            overwrite_existing (bool): Whether to overwrite existing entries with the same ID.
        """
        if timestamp is None:
            timestamp = datetime.now().isoformat()
        
        importance_score = self._initial_recency_and_importance_score_calculator(event,timestamp)
        print(importance_score)
        event_id = f"stm_{timestamp}"
        
        """ 
            for development purpose only to prevent multiple ids of same timestamp. though really difficult 
            to have such a event but for preventive mesures.If the agent want to rewrite a previous event then it would be better to have a 
            new event rather than correting a previous one
        """
        
        # Retrieve existing IDs in the database
        existing_ids = set(self.short_term_memory.get()["ids"])
        if event_id in existing_ids:
            if overwrite_existing:
                # Overwrite: Remove existing entry first
                self.short_term_memory.delete(ids=[event_id])
                self.short_term_memory.add(
                    documents=[event], 
                    metadatas=[{"timestamp": timestamp, "importance_score": importance_score}], 
                    ids=[event_id]
                )
                print(f"Overwritten event in short-term memory with ID: {event_id}")
            else:
                # Skip: Do not add duplicate entry
                print(f"Skipped adding event to short-term memory (ID already exists): {event_id}")
        else:
            # Add new entry
            self.short_term_memory.add(
                    documents=[event],
                    metadatas=[{"timestamp": timestamp, "importance_score": importance_score}], 
                    ids=[event_id]
                )
            print(f"Added event to short-term memory: {event}")
        return
    
    def migrate_to_long_term_memory(self):
        """
        Migrates the oldest document from short-term memory to long-term memory.
        The document is deleted from short-term memory fter migration.
        
        Returns:
            bool: True if migration was successful, False if no documents to migrate
        """
        # Get all documents from short-term memory
        stm_data = self.short_term_memory.get()
        
        length_stm = len(stm_data['documents'])
        
        # Check if there are any documents to migrate
        if length_stm <= 3 :
            print("Not enough documents to migrate to long-term memory")
            return
            
        # Zip together documents, metadata, and ids for easier processing
        memory_entries = list(zip(
            stm_data['documents'], 
            stm_data['metadatas'], 
            stm_data['ids']
        ))
        
        # Sort by timestamp in ascending order to get the oldest entry
        oldest_entries = sorted(
            memory_entries,
            key=lambda x: x[1]['timestamp']
        )[:length_stm-3]
        
        for i in range(len(oldest_entries)):
        
            oldest_document, oldest_metadata, oldest_id = oldest_entries[i]
            
            try:
                # Delete from short-term memory
                self.short_term_memory.delete(
                    ids=[oldest_id]
                )
                
            except Exception as e:
                pass
            
            if(oldest_metadata['importance_score'] < 0.5):
                continue
            
            # Create new ID for long-term memory
            ltm_id = f"ltm_{oldest_metadata['timestamp']}"
            
            try:
                # Add to long-term memory
                self.long_term_memory.add(
                    documents=[oldest_document],
                    metadatas=[oldest_metadata],
                    ids=[ltm_id]
                )
                
                print(f"Successfully migrated memory from short-term to long-term: {oldest_document}")
                
            except Exception as e:
                print(f"Error during migration: {str(e)}")
    
    def decay_long_term_memory(self):
        """
        Decays the importance of long-term memories.
        """
        ltm = self.long_term_memory.get(include=['documents', 'metadatas'])
        for doc, meta, id_ in zip(ltm['documents'], ltm['metadatas'], ltm['ids']):
            if 'importance' in meta:
                meta['importance'] *= 0.9
                if meta['importance'] < 1:
                    self.long_term_memory.delete(ids=[id_])
                else:
                    self.long_term_memory.update(ids=[id_], metadatas=[meta])
        return
        
    def summarize_and_forget(self):
        """
        
        summarize the 5 least important memories in the long-term memory and delete them.
        
        """
        ltm_data = self.long_term_memory.get()
        if len(ltm_data['documents']) >= 5:

            ltm_entries = list(zip(
                ltm_data['documents'],
                ltm_data['metadatas'],
                ltm_data['ids']
            ))
            
            
            least_important = sorted(
                ltm_entries,
                key=lambda x: x[1]['importance_score']
            )[:5]
            

            combined_content = " ".join(entry[0] for entry in least_important)
            avg_importance = sum(entry[1]['importance_score'] for entry in least_important) / 5
            
            try:
                question = f"Summarize the following content in exactly two lines:\n{combined_content}"
                gen = dspy.Predict('question -> answer')
                summary = gen(question = question).answer
                
                old_ids = [entry[2] for entry in least_important]
                self.long_term_memory.delete(ids=old_ids)
                
                # Add consolidated summary
                new_metadata = {
                    'timestamp': time.time(),
                    'importance_score': avg_importance,
                    'type': 'consolidated_summary'
                }
                
                self.long_term_memory.add(
                    documents=[summary],
                    metadatas=[new_metadata],
                    ids=[f"consolidated_{int(time.time())}"]
                )
                
                print(f"Successfully consolidated {len(least_important)} memories into summary")
                
            except Exception as e:
                print(f"Error during consolidation: {str(e)}")
                
        else:
            print("Not enough memories in long-term memory to consolidate")
    
    def query_memory(self, prompt: str, top_k: int = 3, score_threshold: float = 0.35) -> list[str]:
        """
        Smart querying mechanism to retrieve context from memories.

        Args:
            prompt (str): The query or prompt for which context is being retrieved.
            top_k (int): Number of top matches to retrieve from each memory.
            score_threshold (float): Threshold to filter out low score context.

        Returns:
            List[str]: The final context selected after smart querying and validation.
        """
        # Helper function to query a memory based on total score, not just similarity
        def query_local_memory(
                memory:Collection, 
                prompt:str, 
                top_k:int, 
                similarity_weight:float=1.0, 
                importance_weight:float=1.0
            ) -> list[str]:
            # Must change this to query based on importance score instead
            # First, query everything
            # query() return list of list of objects, while get() only returns list of objects
            try :
                results = memory.query(
                    query_texts=[prompt],
                    n_results=memory.count(),
                    include=["metadatas", "distances"]
                )
            except:
                return []
            # Now, calculate the similarity and total score
            similarities = [1/(1+dist) for dist in results["distances"][0]]
            # If the similarities are None, that means the query returns no results -> return empty context
            if not similarities:
                return []
            scores = np.array([
                (
                    importance_weight * meta['importance_score'] + 
                    similarity_weight * sim
                ) if 'importance_score' in meta else similarity_weight * sim
                for meta, sim in zip(results["metadatas"][0], similarities)
            ])
            # Sort the scores in descending order, get the top k ids
            top_k_indices = np.argsort(scores)[::-1][:top_k].astype(int)
            top_k_ids = [results["ids"][0][i] for i in top_k_indices]
            print(f"top_k_ids: {top_k_ids}")
            # Get the documents, metadatas, distances
            final_results = memory.get(
                ids=top_k_ids,
                include=["documents", "metadatas"]
            )
            documents = final_results["documents"] # list of string already
            metadatas = final_results["metadatas"] # list of json strings
            final_scores = scores[top_k_indices]
            distances = np.array(results["distances"][0])[top_k_indices]
            print(f"final_scores: {final_scores}")
            
            context_local = []
            for doc, meta, dist, score in zip(documents, metadatas, distances, final_scores):
                # Filter out irrelevant context
                if score >= score_threshold:
                    context_local.append(
                        {
                            "document": doc, 
                            "metadata": meta, 
                            "distances": dist,
                            "score": score
                        }
                    )
            return context_local

        # Priority order: Short-Term & Static -> Long-Term
        context = []
        for memory_name, memory in [("short_term", self.short_term_memory), ("static", self.static_memory)]:
            results = query_local_memory(memory, prompt, top_k)
            # Maybe change list to set
            context.extend(results)
            if len(context) > top_k:
                break  # Stop if we have enough context

        # If context is still insufficient, fall back to long-term memory
        if len(context) < top_k:
            print("Falling back to long-term memory...")
            long_term_results = query_local_memory(self.long_term_memory, prompt, top_k - len(context))
            context.extend(long_term_results)

        # Sort the context by highest score in descending order
        if context:
            context = sorted(context, key=lambda x: x["score"], reverse=True)
        return self._format_context_output(prompt, context)
    
    def plan_from_memory(
            self, 
            timestamp: str = None,
            top_k: int = 20, 
            score_threshold: float = 0.35
    ) -> None:
        """
        (NOT DONE! DON'T USE YET) Plan an action based on the context from memories and add it to long-term memory.

        Args:
            timestamp (str): The timestamp of the current game time. Defaults to the current time if not provided.
            top_k (int): Number of top matches to retrieve from all memory.
            score_threshold (float): Threshold to filter out low score context.
        """

        prompt_for_context = """
            Answer the following question based on the context from memories:
            What is your name?
            What is the summary of your character and traits?
            What is your passion project currently, and how do you plan to achieve it?
            What actions did you do recently, up until 1 day ago?
        """

        context = self.query_memory(
            prompt_for_context, 
            top_k=top_k, 
            score_threshold=score_threshold
        )

        if timestamp is None:
            timestamp = datetime.now().isoformat()

        prompt_for_plan = f"""
            Based on the given context, fill in the following template. 
            Do not hallucinate. If you don't know the answer, say "I don't know".
            <plan> is in format: 1) <action_1> at <time_1> 2) <action_2> at <time_2> 3) <action_3> at <time_3> ...
            Context: {context}

            Only return this part:

            Name: <name>
            Traits: <traits>
            Passion Project: <passion_project>
            Recent Actions: <recent_actions>
            Today is {timestamp}. Here is <name>'s plan today in broad strokes:
            <plan>
        """

        ## get the result from LLM server
        try:
            # response = requests.post(<URI>,json={'memory': prompt_for_plan})
            # response.raise_for_status()  # Raise an exception for bad status codes
            # result = response.json()
            # if isinstance(result, str):
                # result = json.loads(result) 

            # TODO: change this format to be compatible with the model's return
            # plan = result['plan']
            plan = "Plan: " + plan
        except requests.exceptions.RequestException as e:
            print(f"Error calling server: {e}")

        ## add to long-term memory, overwrite the old plan
        # delete the old plan if it exists
        try:    
            self.long_term_memory.delete(ids=[f"plan_1"])
        except: pass
        # add the new plan
        self.long_term_memory.add(
            documents=[plan],
            metadatas=[{"timestamp": timestamp, "importance_score": 8}], # Force model to prioritize planning with high score
            ids=[f"plan_1"] # There can be only one plan per agent per day.
        )

    def get_stats(self):
        """
        Get statistics of the memory.
        
        Args:
            memory_type (str): The type of memory to get statistics for. Can be 'static', 'short-term', 'long-term', or 'all'.
        
        Returns:
            dict: The statistics of the memory.
        """
        static_stats = self._get_stats(self.static_memory)
        short_term_stats = self._get_stats(self.short_term_memory)
        long_term_stats = self._get_stats(self.long_term_memory)
        return {
            "static_memory": static_stats,
            "short_term_memory": short_term_stats,
            "long_term_memory": long_term_stats,
        }

    def calculate_recency_score(timestamp: str, current_time: str = None, decay_factor: float = 0.995) -> float:
        """
        Calculate the recency score using an exponential decay function.
        
        Args:
            timestamp (str): The timestamp string in ISO format of when the memory was created/last accessed
            current_time (str): The current time string in ISO format to compare against (default: current time)
            decay_factor (float): The decay factor for the exponential decay (default: 0.995)
        
        Returns:
            float: A recency score between 0 and 1, where 1 indicates most recent
        """
        if timestamp is None:
            return 1.0  # Return maximum score if no timestamp provided
            
        # Convert timestamp strings to datetime objects
        if current_time is None:
            current_time = datetime.now().isoformat()
            
        timestamp_dt = datetime.fromisoformat(timestamp)
        current_time_dt = datetime.fromisoformat(current_time)
        
        # Calculate the time difference in hours
        time_diff = current_time_dt - timestamp_dt
        hours_diff = time_diff.total_seconds() / (3600*3)
        
        # Calculate recency score using exponential decay
        recency_score = math.pow(decay_factor, hours_diff)
        
        # Ensure the score is between 0 and 1
        recency_score = max(0.0, min(1.0, recency_score))
        
        return recency_score

    def _initial_recency_and_importance_score_calculator(self,doc: str, timestamp: str):
        """
        Calculate the importance score of a document. 
        The parameters considered are Relevency, Recency, Frequency, Personal Impact

        Args:
            doc (str): The document text
            timestamp (str): ISO format timestamp string
        """
        recency = self.calculate_recency_score(timestamp)

        # Call the memory rating server
        try:
            response = requests.post(URI_RATING,json={'memory': doc})
            response.raise_for_status()  # Raise an exception for bad status codes
            result = response.json()
            if isinstance(result, str):
                result = json.loads(result) 

            print(result["rating"])  # Now you can access the rating
            personal_impact = int(result['rating'])    
        except requests.exceptions.RequestException as e:
            print(f"Error calling memory rating server: {e}")
            personal_impact = 5  # Default value if the server call fails

        weights = {
            "recency": 0.25,
            "personal_impact": 0.25
        }
        
        importance_score = (
            recency * weights["recency"] + 
            personal_impact * weights["personal_impact"]
        )
        
        return importance_score

    def _format_context_output(self, prompt: str, context: list[str], separator: str = "\n---\n") -> str:
        """
        Formats the retrieved context into a final text output with separators.

        Args:
            prompt (str): The input prompt for which context is being retrieved.
            context (List[str]): List of retrieved documents to be used as context.
            separator (str): Separator string to distinguish between context documents.

        Returns:
            str: A formatted text output containing the context for the given prompt.
        """      
        # Header for the final output
        formatted_output = f"Context for Prompt: '{prompt}'\n\n"

        # NOTE: handle in case no matched context found here
        if not context:
            formatted_output += "No important context found"
            return formatted_output

        # Add each document separated by the specified separator
        for idx, document in enumerate(context, 1):
            # formatted_output += f"[Document {idx}]:\n{document['document'], document['distances']}{separator}"
            formatted_output += f"[Document {idx}]:\n{document['document']}{separator}"

        # Remove the last separator for cleaner formatting
        formatted_output = formatted_output.rstrip(separator)

        return formatted_output
    
    def _get_stats(self, memory:Collection):
        """
        Get statistics of the memory/vector database.

        Args:
            memory (Collection): The memory to get statistics for.
        """
        memory_data = memory.get()
        n_documents = len(memory_data["documents"])
        return {
            "n_documents": n_documents,
        }
    
    def clear_memory(self, memory_type="all"):
        """
        Clear all the memory of specified type.
        
        Args:
            memory_type (str): The type of memory to clear. Can be 'static', 'short-term', 'long-term', or 'all'.
        """
        if memory_type == "all":
            try: self.static_memory.delete(ids=self.static_memory.get()["ids"])
            except: pass    # Memory is already empty
            try: self.short_term_memory.delete(ids=self.short_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            try: self.long_term_memory.delete(ids=self.long_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("All memory cleared successfully")
        elif memory_type == "static":
            try: self.static_memory.delete(ids=self.static_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Static memory cleared successfully")
        elif memory_type == "short-term":
            try: self.short_term_memory.delete(ids=self.short_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Short-term memory cleared successfully")
        elif memory_type == "long-term":
            try: self.long_term_memory.delete(ids=self.long_term_memory.get()["ids"])
            except: pass    # Memory is already empty
            print("Long-term memory cleared successfully")
        else:
            print("Invalid memory type. Please specify 'static', 'short-term', 'long-term', or 'all'.")
            
        return
    

In [4]:
agent_name = "alex"

agent_memory = AgentMemory(agent_name)

agent_memory.add_static_memory()

Skipped adding memory with ID: char_1 (already exists)
Skipped adding memory with ID: rule_1 (already exists)
Skipped adding memory with ID: env_1 (already exists)
Skipped adding memory with ID: rule_2 (already exists)
No new static memories added for agent alex


In [16]:
# # ADD DUMMY DATA TO SHORT TERM AND LONG TERM MEMORY

# with open ("./agent_personalities/agent_alex_short_mem.json") as f:
#     stm_data = json.load(f)

# with open ("./agent_personalities/agent_alex_long_mem.json") as f:
#     ltm_data = json.load(f)

# agent_memory.short_term_memory.add(documents=stm_data['documents'], metadatas=stm_data['metadatas'], ids=stm_data['ids'])
# agent_memory.long_term_memory.add(documents=ltm_data['documents'], metadatas=ltm_data['metadatas'], ids=ltm_data['ids'])

# print()
# print(agent_memory.get_stats())


In [5]:
# Example Usage
prompt = "What is my name?"
context = agent_memory.query_memory(prompt, top_k=5, score_threshold=0.35)
print("Retrieved Context:")
print(context)


Number of requested results 5 is greater than number of elements in index 3, updating n_results = 3
Number of requested results 5 is greater than number of elements in index 4, updating n_results = 4


top_k_ids: ['exp_2', 'exp_1', 'exp_3']
final_scores: [0.4404222  0.36492197 0.35260854]
top_k_ids: ['env_1', 'rule_1', 'char_1', 'rule_2']
final_scores: [0.51141153 0.43352968 0.37362446 0.32506791]
Retrieved Context:
Context for Prompt: 'What did i do in the forest?'

[Document 1]:
Name: Alex, Age: 25, Role: Warrior
---
[Document 2]:
Timestamp: 2023-10-01T10:00:00Z, Location: Coastal Town, Interaction: Met with local fishermen to discuss the impact of recent storms on fish populations, Actions: Collected water samples for analysis, Responses: Fishermen expressed concerns about declining fish stocks
---
[Document 3]:
Rule: Fire burns wood
---
[Document 4]:
Environment: The forest is dense with trees and has wild animals
---
[Document 5]:
Timestamp: 2023-10-02T14:30:00Z, Location: Forest, Interaction: Observed wildlife and noted changes in animal behavior due to seasonal shifts, Actions: Set up camera traps to monitor animal movements, Responses: Noticed an increase in nocturnal activit

In [6]:
agent_memory.short_term_memory.get()

{'ids': ['exp_1', 'exp_2', 'exp_3'],
 'embeddings': None,
 'documents': ['Timestamp: 2023-10-01T10:00:00Z, Location: Coastal Town, Interaction: Met with local fishermen to discuss the impact of recent storms on fish populations, Actions: Collected water samples for analysis, Responses: Fishermen expressed concerns about declining fish stocks',
  'Timestamp: 2023-10-02T14:30:00Z, Location: Forest, Interaction: Observed wildlife and noted changes in animal behavior due to seasonal shifts, Actions: Set up camera traps to monitor animal movements, Responses: Noticed an increase in nocturnal activity among certain species',
  'Timestamp: 2023-10-03T09:00:00Z, Location: Research Lab, Interaction: Analyzed water samples collected from the coastal town, Actions: Conducted tests to measure pollution levels, Responses: Found elevated levels of contaminants, likely from recent industrial activity'],
 'uris': None,
 'data': None,
 'metadatas': [{'type': 'recent_experience'},
  {'type': 'recent_exp

In [18]:
agent_memory.long_term_memory.get()

{'ids': [],
 'embeddings': None,
 'documents': [],
 'uris': None,
 'data': None,
 'metadatas': [],
 'included': [<IncludeEnum.documents: 'documents'>,
  <IncludeEnum.metadatas: 'metadatas'>]}