<a href="https://colab.research.google.com/github/SynFinAck/MemoryManager/blob/main/MemoryManager_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers



In [None]:
class MemoryManager:
    def __init__(self):
        self.active_memory = {}
        self.short_term_memory = {}
        self.long_term_memory = {}
        self.rag_cache = {}

    def store_active_memory(self, key, value):
        self.active_memory[key] = value

    def store_short_term_memory(self, key, value):
        self.short_term_memory[key] = {
            "value": value,
            "last_accessed": time.time()
        }

    def store_long_term_memory(self, key, value):
        self.long_term_memory[key] = value

    def retrieve_memory(self, key):
        if key in self.active_memory:
            return self.active_memory[key]
        elif key in self.short_term_memory:
            value = self.short_term_memory[key]["value"]
            self.short_term_memory[key]["last_accessed"] = time.time()
            if time.time() - self.short_term_memory[key]["last_accessed"] > 60: # 60 seconds threshold for short-term memory
                del self.short_term_memory[key]
                self.store_long_term_memory(key, value)
            else:
                self.store_active_memory(key, value)
            return value
        elif key in self.long_term_memory:
            value = self.long_term_memory[key]
            self.store_short_term_memory(key, value)
            return value
        elif key in self.rag_cache:
            return self.rag_cache[key]
        else:
            return None

    def cache_memory(self, key, value):
        self.rag_cache[key] = value

In [None]:
import json
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# Load dataset from file
with open('/content/drive/MyDrive/Datasets/context_dataset.json', 'r') as file:
    dataset = json.load(file)

In [None]:
def populate_memory(memory_manager, dataset):
    for conversation in dataset["conversations"]:
        for message in conversation["messages"]:
            memory_manager.store_long_term_memory(message["key"], message["value"])

def calculate_accuracy(dataset, memory_manager=None):
    total_messages = 0
    correct_messages = 0
    failed_messages = []

    # Test memory retention and accuracy
    for conversation in dataset["conversations"]:
        for message in conversation["messages"]:
            total_messages += 1
            if memory_manager:
                retrieved_value = memory_manager.retrieve_memory(message["key"])
                if retrieved_value == message["value"]:
                    correct_messages += 1
                else:
                    failed_messages.append({"key": message["key"], "expected_value": message.get("expected_value", None), "retrieved_value": retrieved_value})
            else:
                # Without Memory Manager
                if message["value"] == message.get("expected_value", None):
                    correct_messages += 1
                else:
                    failed_messages.append({"key": message["key"], "expected_value": message.get("expected_value", None), "retrieved_value": message["value"]})

    accuracy = (correct_messages / total_messages) * 100 if total_messages > 0 else 0
    return accuracy, failed_messages

def calculate_memory_retention(dataset, memory_manager=None):
    total_messages = 0
    correct_messages = 0
    failed_messages = []

    # Test memory retention and accuracy
    for conversation in dataset["conversations"]:
        for message in conversation["messages"]:
            total_messages += 1
            if memory_manager:
                retrieved_value = memory_manager.retrieve_memory(message["key"])
                if retrieved_value == message["value"]:
                    correct_messages += 1
                else:
                    failed_messages.append({"key": message["key"], "expected_value": message.get("expected_value", None), "retrieved_value": retrieved_value})
            else:
                # Without Memory Manager
                if message["value"] == message.get("expected_value", None):
                    correct_messages += 1
                else:
                    failed_messages.append({"key": message["key"], "expected_value": message.get("expected_value", None), "retrieved_value": message["value"]})

    memory_retention = (correct_messages / total_messages) * 100 if total_messages > 0 else 0
    return memory_retention, failed_messages

In [None]:

# Create memory manager
memory_manager = MemoryManager()

# Load pre-trained LLM model and tokenizer
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm_model = AutoModelForCausalLM.from_pretrained(model_name)

# Populate memory
populate_memory(memory_manager, dataset)

# Calculate accuracy and memory retention with memory manager
accuracy_with_memory_manager, failed_messages_accuracy = calculate_accuracy(dataset, memory_manager)
memory_retention_with_memory_manager, failed_messages_retention = calculate_memory_retention(dataset, memory_manager)

# Calculate accuracy and memory retention without memory manager
accuracy_without_memory_manager, failed_messages_accuracy_without = calculate_accuracy(dataset)
memory_retention_without_memory_manager, failed_messages_retention_without = calculate_memory_retention(dataset)

# Print results
print("With Memory Manager:")
print("Accuracy: {:.2f}%".format(accuracy_with_memory_manager))
print("Memory Retention: {:.2f}%".format(memory_retention_with_memory_manager))

print("\nFailed Messages (Accuracy - With Memory Manager):")
for failed_message in failed_messages_accuracy:
    print(f"Key: {failed_message['key']}, Expected Value: {failed_message['expected_value']}, Retrieved Value: {failed_message['retrieved_value']}")

print("\nFailed Messages (Memory Retention - With Memory Manager):")
for failed_message in failed_messages_retention:
    print(f"Key: {failed_message['key']}, Expected Value: {failed_message['expected_value']}, Retrieved Value: {failed_message['retrieved_value']}")

print("\nWithout Memory Manager:")
print("Accuracy: {:.2f}%".format(accuracy_without_memory_manager))
print("Memory Retention: {:.2f}%".format(memory_retention_without_memory_manager))

print("\nFailed Messages (Accuracy - Without Memory Manager):")
for failed_message in failed_messages_accuracy_without:
    print(f"Key: {failed_message['key']}, Expected Value: {failed_message['expected_value']}, Retrieved Value: {failed_message['retrieved_value']}")

print("\nFailed Messages (Memory Retention - Without Memory Manager):")
for failed_message in failed_messages_retention_without:
    print(f"Key: {failed_message['key']}, Expected Value: {failed_message['expected_value']}, Retrieved Value: {failed_message['retrieved_value']}")

With Memory Manager:
Accuracy: 70.00%
Memory Retention: 70.00%

Failed Messages (Accuracy - With Memory Manager):
Key: confirmation, Expected Value: None, Retrieved Value: Your appointment is scheduled for 2 PM next Thursday.
Key: appointment_request, Expected Value: None, Retrieved Value: I need to schedule a dental appointment.
Key: availability, Expected Value: None, Retrieved Value: Is the laptop currently in stock?
Key: confirmation, Expected Value: None, Retrieved Value: Your appointment is scheduled for 2 PM next Thursday.
Key: reminder, Expected Value: None, Retrieved Value: Please arrive 15 minutes early for your appointment.
Key: resolution, Expected Value: None, Retrieved Value: Restarting the modem fixed the issue. Thank you for your help.
Key: product_inquiry, Expected Value: None, Retrieved Value: I'm looking to purchase a new laptop.
Key: features, Expected Value: None, Retrieved Value: What are the key features of the software?
Key: pricing, Expected Value: None, Retrie