# Module 19: Agent Memory & Planning

**Goal:** Learn how to give agents persistent memory, manage context windows, and plan multi-step tasks.

**Prerequisites:** Module 18 (Tool Calling)

**Expected Runtime:** ~25 minutes

**Outputs:**
- Implemented different memory types
- Built context management strategies
- Created a planning agent

---

## Setup

In [None]:
import json
import re
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

## Part 1: Conversation Memory

In [None]:
@dataclass
class Message:
    role: str  # 'user', 'assistant', 'system'
    content: str
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    tokens: int = 0
    
    def __post_init__(self):
        self.tokens = len(self.content) // 4  # Rough estimate

class ConversationMemory:
    """Basic conversation memory with sliding window."""
    
    def __init__(self, max_tokens: int = 4000):
        self.messages: List[Message] = []
        self.max_tokens = max_tokens
    
    def add(self, role: str, content: str):
        msg = Message(role=role, content=content)
        self.messages.append(msg)
        self._enforce_limit()
    
    def _enforce_limit(self):
        """Remove oldest messages if over token limit."""
        total = sum(m.tokens for m in self.messages)
        while total > self.max_tokens and len(self.messages) > 1:
            removed = self.messages.pop(0)
            total -= removed.tokens
            print(f"[Memory] Removed old message ({removed.tokens} tokens)")
    
    def get_context(self) -> List[Dict]:
        return [{"role": m.role, "content": m.content} for m in self.messages]
    
    def total_tokens(self) -> int:
        return sum(m.tokens for m in self.messages)

# Test basic memory
memory = ConversationMemory(max_tokens=100)

print("=== Conversation Memory ===")
memory.add("user", "Hi, I need help with order ORD-12345")
memory.add("assistant", "I'd be happy to help! Let me look up order ORD-12345.")
memory.add("user", "It hasn't arrived yet")
memory.add("assistant", "I see your order was shipped 3 days ago. It should arrive tomorrow.")
memory.add("user", "Can you provide tracking?")

print(f"\nTotal tokens: {memory.total_tokens()}")
print(f"Messages in memory: {len(memory.messages)}")

## Part 2: Entity Memory

In [None]:
class EntityMemory:
    """Extract and track entities from conversation."""
    
    def __init__(self):
        self.entities = {
            "orders": {},
            "products": {},
            "dates": [],
            "amounts": []
        }
    
    def extract(self, text: str) -> Dict:
        """Extract entities from text."""
        extracted = {}
        
        # Orders
        orders = re.findall(r'ORD-\d+', text, re.IGNORECASE)
        if orders:
            extracted['orders'] = [o.upper() for o in orders]
            for order_id in extracted['orders']:
                if order_id not in self.entities['orders']:
                    self.entities['orders'][order_id] = {
                        'first_seen': datetime.now().isoformat(),
                        'mentions': 0
                    }
                self.entities['orders'][order_id]['mentions'] += 1
        
        # Amounts
        amounts = re.findall(r'\$[\d,]+\.?\d*', text)
        if amounts:
            extracted['amounts'] = amounts
            self.entities['amounts'].extend(amounts)
        
        # Dates
        dates = re.findall(r'\b(today|tomorrow|yesterday|\d{1,2}/\d{1,2}/\d{2,4})\b', text, re.IGNORECASE)
        if dates:
            extracted['dates'] = dates
            self.entities['dates'].extend(dates)
        
        return extracted
    
    def get_recent_order(self) -> Optional[str]:
        """Get most recently mentioned order."""
        if not self.entities['orders']:
            return None
        # Return order with most mentions (likely current topic)
        return max(self.entities['orders'].items(), key=lambda x: x[1]['mentions'])[0]

# Test entity memory
entity_mem = EntityMemory()

print("=== Entity Memory ===")
texts = [
    "I placed order ORD-12345 yesterday for $99.99",
    "When will ORD-12345 arrive?",
    "I also have a question about ORD-67890"
]

for text in texts:
    extracted = entity_mem.extract(text)
    print(f"\nText: '{text[:50]}...'")
    print(f"Extracted: {extracted}")

print(f"\nMost discussed order: {entity_mem.get_recent_order()}")
print(f"All entities: {json.dumps(entity_mem.entities, indent=2, default=str)[:300]}...")

## Part 3: Summary Memory

In [None]:
class SummaryMemory:
    """Memory that summarizes old messages."""
    
    def __init__(self, recent_limit: int = 5):
        self.summary = ""
        self.recent_messages: List[Message] = []
        self.recent_limit = recent_limit
    
    def add(self, role: str, content: str):
        msg = Message(role=role, content=content)
        self.recent_messages.append(msg)
        
        # Summarize when we exceed limit
        if len(self.recent_messages) > self.recent_limit:
            self._summarize_oldest()
    
    def _summarize_oldest(self):
        """Summarize oldest messages and remove them."""
        # In production, this would call an LLM
        to_summarize = self.recent_messages[:2]
        self.recent_messages = self.recent_messages[2:]
        
        # Simulate summarization
        new_summary = self._mock_summarize(to_summarize)
        if self.summary:
            self.summary = f"{self.summary} {new_summary}"
        else:
            self.summary = new_summary
        
        print(f"[Summary] Compressed {len(to_summarize)} messages")
    
    def _mock_summarize(self, messages: List[Message]) -> str:
        """Mock summarization (in production, use LLM)."""
        topics = []
        for msg in messages:
            if 'order' in msg.content.lower():
                topics.append('order status')
            if 'return' in msg.content.lower():
                topics.append('return request')
            if 'help' in msg.content.lower():
                topics.append('general help')
        
        return f"User discussed: {', '.join(set(topics)) or 'general inquiry'}."
    
    def get_context(self) -> str:
        context = ""
        if self.summary:
            context += f"Previous conversation summary: {self.summary}\n\n"
        context += "Recent messages:\n"
        for msg in self.recent_messages:
            context += f"{msg.role}: {msg.content}\n"
        return context

# Test summary memory
print("=== Summary Memory ===")
sum_mem = SummaryMemory(recent_limit=4)

conversation = [
    ("user", "Hi, I need help with my order"),
    ("assistant", "Of course! What's your order number?"),
    ("user", "It's ORD-12345"),
    ("assistant", "I found your order. It's currently in transit."),
    ("user", "When will it arrive?"),
    ("assistant", "Expected delivery is tomorrow."),
    ("user", "Can I get a refund instead?"),
]

for role, content in conversation:
    sum_mem.add(role, content)

print("\n=== Current Context ===")
print(sum_mem.get_context())

## Part 4: Planning Agent

In [None]:
@dataclass
class PlanStep:
    step_num: int
    action: str
    status: str = "pending"  # pending, in_progress, complete, failed
    result: Optional[str] = None

class PlanningAgent:
    """Agent that plans and executes multi-step tasks."""
    
    def __init__(self):
        self.current_plan: List[PlanStep] = []
        self.max_iterations = 10
    
    def create_plan(self, goal: str) -> List[PlanStep]:
        """Create a plan for a goal (simulated LLM call)."""
        goal_lower = goal.lower()
        
        if 'return' in goal_lower or 'refund' in goal_lower:
            return [
                PlanStep(1, "Look up order details"),
                PlanStep(2, "Verify return eligibility (30-day window)"),
                PlanStep(3, "Get reason for return"),
                PlanStep(4, "Generate return shipping label"),
                PlanStep(5, "Process refund to original payment method"),
            ]
        elif 'order' in goal_lower and 'status' in goal_lower:
            return [
                PlanStep(1, "Look up order by ID"),
                PlanStep(2, "Get shipping status"),
                PlanStep(3, "Format response for user"),
            ]
        else:
            return [
                PlanStep(1, "Understand user request"),
                PlanStep(2, "Gather necessary information"),
                PlanStep(3, "Execute task"),
                PlanStep(4, "Confirm with user"),
            ]
    
    def execute_step(self, step: PlanStep) -> str:
        """Execute a single step (simulated)."""
        step.status = "in_progress"
        
        # Simulate execution
        results = {
            "Look up order": "Found order ORD-12345, status: shipped",
            "Verify return": "Order within 30-day return window",
            "Get reason": "Customer selected 'Changed mind'",
            "Generate return": "Label generated: RET-98765",
            "Process refund": "Refund of $99.99 initiated",
            "shipping status": "Package in transit, ETA: tomorrow",
            "Format response": "Response prepared for user",
        }
        
        for key, result in results.items():
            if key.lower() in step.action.lower():
                step.result = result
                step.status = "complete"
                return result
        
        step.result = "Step completed"
        step.status = "complete"
        return step.result
    
    def run(self, goal: str) -> str:
        """Run the full planning loop."""
        print(f"\n=== Planning for: {goal} ===")
        
        # Create plan
        self.current_plan = self.create_plan(goal)
        print(f"\nPlan created with {len(self.current_plan)} steps:")
        for step in self.current_plan:
            print(f"  {step.step_num}. {step.action}")
        
        # Execute steps
        print("\n--- Execution ---")
        for i, step in enumerate(self.current_plan):
            if i >= self.max_iterations:
                print("Max iterations reached!")
                break
            
            print(f"\nStep {step.step_num}: {step.action}")
            result = self.execute_step(step)
            print(f"  → {result}")
        
        # Compile results
        completed = [s for s in self.current_plan if s.status == "complete"]
        return f"Completed {len(completed)}/{len(self.current_plan)} steps."

# Test planning agent
agent = PlanningAgent()
result = agent.run("Help me return order ORD-12345")

## Part 5: ReAct Pattern

In [None]:
class ReActAgent:
    """Agent using Reasoning + Acting pattern."""
    
    def __init__(self):
        self.trace = []
        self.max_steps = 5
    
    def think(self, observation: str = None) -> str:
        """Generate a thought (simulated LLM)."""
        if not self.trace:
            return "I need to understand what the user is asking for."
        
        last_action = self.trace[-1] if self.trace else None
        
        if observation and 'shipped' in observation.lower():
            return "The order has been shipped. I should get tracking info."
        if observation and 'tracking' in observation.lower():
            return "I have the tracking info. I can now respond to the user."
        
        return "I should take the next logical action."
    
    def act(self, thought: str) -> tuple:
        """Decide on action based on thought."""
        if 'understand' in thought.lower():
            return 'get_order_status', {'order_id': 'ORD-12345'}
        if 'tracking' in thought.lower():
            return 'get_tracking', {'order_id': 'ORD-12345'}
        if 'respond' in thought.lower():
            return 'respond_to_user', {'message': 'Your order is on the way!'}
        return 'think_more', {}
    
    def observe(self, action: str, args: dict) -> str:
        """Get observation from action (simulated)."""
        observations = {
            'get_order_status': 'Order ORD-12345 status: shipped on Jan 15',
            'get_tracking': 'Tracking: 1Z999AA10123456784, ETA: Jan 18',
            'respond_to_user': 'Response sent to user',
        }
        return observations.get(action, 'No observation')
    
    def run(self, query: str) -> str:
        """Run the ReAct loop."""
        print(f"\n=== ReAct Agent ===")
        print(f"Query: {query}\n")
        
        observation = None
        
        for step in range(self.max_steps):
            # Thought
            thought = self.think(observation)
            print(f"Thought: {thought}")
            self.trace.append(('thought', thought))
            
            # Action
            action, args = self.act(thought)
            print(f"Action: {action}({args})")
            self.trace.append(('action', action, args))
            
            # Check for completion
            if action == 'respond_to_user':
                print(f"\n✓ Task complete: {args.get('message')}")
                return args.get('message')
            
            # Observation
            observation = self.observe(action, args)
            print(f"Observation: {observation}\n")
            self.trace.append(('observation', observation))
        
        return "Max steps reached"

# Test ReAct agent
react_agent = ReActAgent()
result = react_agent.run("Where is my order ORD-12345?")

## Part 6: Complete Agent with Memory

In [None]:
class CompleteAgent:
    """Agent with conversation, entity, and summary memory."""
    
    def __init__(self):
        self.conversation = ConversationMemory(max_tokens=500)
        self.entities = EntityMemory()
        self.session_state = {}
    
    def process(self, user_input: str) -> str:
        """Process user input and generate response."""
        # Add to conversation memory
        self.conversation.add("user", user_input)
        
        # Extract entities
        extracted = self.entities.extract(user_input)
        
        # Update session state
        if extracted.get('orders'):
            self.session_state['current_order'] = extracted['orders'][-1]
        
        # Generate response (simulated)
        response = self._generate_response(user_input)
        
        # Add response to memory
        self.conversation.add("assistant", response)
        
        return response
    
    def _generate_response(self, user_input: str) -> str:
        """Generate response using context."""
        lowered = user_input.lower()
        
        # Check for order reference
        order_id = self.session_state.get('current_order') or self.entities.get_recent_order()
        
        if 'status' in lowered or 'where' in lowered:
            if order_id:
                return f"Your order {order_id} is currently in transit and should arrive tomorrow."
            return "Could you provide your order number?"
        
        if 'tracking' in lowered:
            if order_id:
                return f"The tracking number for {order_id} is 1Z999AA10123456784."
            return "Which order would you like tracking for?"
        
        if 'return' in lowered or 'refund' in lowered:
            if order_id:
                return f"I can help you return {order_id}. It's within the 30-day return window. Would you like to proceed?"
            return "Sure, I can help with a return. What's the order number?"
        
        return "I'm here to help! What can I assist you with?"
    
    def get_status(self) -> Dict:
        return {
            'messages': len(self.conversation.messages),
            'tokens': self.conversation.total_tokens(),
            'entities': self.entities.entities,
            'session': self.session_state
        }

# Test complete agent
print("=== Complete Agent Demo ===")
agent = CompleteAgent()

conversation = [
    "Hi, I have a question about order ORD-12345",
    "When will it arrive?",
    "Can you give me the tracking number?",
    "Actually, I want to return it"
]

for msg in conversation:
    print(f"\nUser: {msg}")
    response = agent.process(msg)
    print(f"Agent: {response}")

print("\n=== Agent Status ===")
print(json.dumps(agent.get_status(), indent=2, default=str))

## Part 7: TODO - Build Your Memory System

Design a memory system for a specific use case.

In [None]:
# TODO: Design a memory system for one of these scenarios:
# 1. Customer support bot that remembers past issues
# 2. Personal assistant that learns user preferences
# 3. Technical support that tracks troubleshooting steps

# Consider:
# - What entities need to be tracked?
# - What should be summarized vs kept in full?
# - What session state is important?

print("Design your memory system for a specific use case!")

## Part 8: TODO - Stakeholder Summary

Explain to a product manager:
1. Why agents need memory
2. The difference between short-term and long-term memory
3. How planning helps with complex tasks

### Your Summary:

*Write your explanation here...*

---

## Key Takeaways

1. **Memory types:** Context window, conversation, entity, long-term
2. **Context management:** Sliding window, summarization, retrieval
3. **Entity tracking:** Extract and remember key information
4. **Planning:** Break complex tasks into steps (ReAct pattern)
5. **Combine strategies** for robust memory systems

### Next Steps
- Explore the interactive playground
- Complete the quiz
- Move to Module 20: Guardrails