/
base.py
158 lines (128 loc) · 4.7 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import math
from datetime import datetime, timedelta, tzinfo
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
import numpy as np
import pytz
from pydantic import BaseModel
from ..utils.embeddings import cosine_similarity, get_embedding
from ..utils.formatting import parse_array
from ..utils.parameters import (
IMPORTANCE_WEIGHT,
RECENCY_WEIGHT,
SIMILARITY_WEIGHT,
TIME_SPEED_MULTIPLIER,
)
class MemoryType(Enum):
OBSERVATION = "observation"
REFLECTION = "reflection"
class SingleMemory(BaseModel):
id: UUID
agent_id: UUID
type: MemoryType
description: str
embedding: np.ndarray
importance: int
created_at: datetime
last_accessed: datetime
related_memory_ids: list[UUID]
@property
def recency(self) -> float:
if self.last_accessed.tzinfo is None:
self.last_accessed = pytz.utc.localize(self.last_accessed)
last_retrieved_hours_ago = (
datetime.now(pytz.utc) - self.last_accessed
) / timedelta(hours=1 / TIME_SPEED_MULTIPLIER)
decay_factor = 0.99
return math.pow(decay_factor, last_retrieved_hours_ago)
@property
def verbose_description(self) -> str:
return f"{self.description} @ {self.created_at.strftime('%Y-%m-%d %H:%M:%S')}"
class Config:
arbitrary_types_allowed = True
def __init__(
self,
agent_id: UUID,
type: MemoryType,
description: str,
importance: int,
embedding: np.ndarray,
related_memory_ids: Optional[list[UUID]] = [],
id: Optional[UUID] = None,
created_at: Optional[datetime] = datetime.now(tz=pytz.utc),
last_accessed: Optional[datetime] = None,
):
if id is None:
id = uuid4()
if isinstance(embedding, str):
embedding = parse_array(embedding)
else:
embedding = np.array(embedding)
if not isinstance(embedding, np.ndarray):
raise ValueError("Embedding must be a numpy array")
super().__init__(
id=id,
agent_id=agent_id,
type=type,
description=description,
embedding=embedding,
importance=importance,
created_at=created_at,
last_accessed=last_accessed or created_at,
related_memory_ids=related_memory_ids,
)
def db_dict(self):
return {
"id": str(self.id),
"agent_id": str(self.agent_id),
"type": self.type.value,
"description": self.description,
"embedding": str(self.embedding.tolist()),
"importance": self.importance,
"created_at": self.created_at.isoformat(),
"last_accessed": self.last_accessed.isoformat()
if self.last_accessed
else None,
"related_memory_ids": [
str(related_memory_id) for related_memory_id in self.related_memory_ids
],
}
# Customize the printing behavior
def __str__(self):
return f"[{self.type.name}] - {self.description} ({round(self.importance, 1)})"
def update_last_accessed(self):
self.last_accessed = datetime.now(tz=pytz.utc)
async def similarity(self, query: str) -> float:
query_embedding = await get_embedding(query)
return cosine_similarity(self.embedding, query_embedding)
async def relevance(self, query: str) -> float:
return (
IMPORTANCE_WEIGHT * self.importance
+ SIMILARITY_WEIGHT * (await self.similarity(query))
+ RECENCY_WEIGHT * self.recency
)
class RelatedMemory(BaseModel):
memory: SingleMemory
relevance: float
def __str__(self) -> str:
return f"SingleMemory: {self.memory.description}, Relevance: {self.relevance}"
async def get_relevant_memories(
query: str, memories: list[SingleMemory], k: int = 5
) -> list[SingleMemory]:
"""Returns a list of the top k most relevant NON MESSAGE memories, based on the query string"""
memories_with_relevance = [
RelatedMemory(memory=memory, relevance=await memory.relevance(query))
for memory in memories
]
# Sort the list of dictionaries based on the 'relevance' key in descending order
sorted_by_relevance = sorted(
memories_with_relevance, key=lambda x: x.relevance, reverse=True
)
# get the top k memories, as a list of SingleMemory object
top_memories = [memory.memory for memory in sorted_by_relevance[:k]]
# now sort the list based on the created_at field, with the oldest memories first
sorted_by_created_at = sorted(
top_memories, key=lambda x: x.created_at, reverse=False
)
return sorted_by_created_at