<a href="https://colab.research.google.com/github/Jatish-Khanna/AISamples/blob/main/Day3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import tiktoken
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import json
import hashlib
from enum import Enum

class ContextStrategy(Enum):
    SLIDING_WINDOW = "sliding_window"
    SUMMARIZATION = "summarization"
    SEMANTIC_COMPRESSION = "semantic_compression"
    HIERARCHICAL = "hierarchical"

@dataclass
class MessageMetadata:
    timestamp: datetime
    token_count: int
    importance_score: float = 0.0
    message_type: str = "normal"  # system, user, assistant, summary
    thread_id: Optional[str] = None

@dataclass
class ConversationMemory:
    messages: List[Dict[str, str]] = field(default_factory=list)
    metadata: List[MessageMetadata] = field(default_factory=list)
    summaries: List[Dict[str, str]] = field(default_factory=list)
    persona: str = "assistant"
    total_tokens: int = 0
    created_at: datetime = field(default_factory=datetime.now)
    last_accessed: datetime = field(default_factory=datetime.now)
    context_strategy: ContextStrategy = ContextStrategy.SLIDING_WINDOW

class AdvancedContextManager:
    def __init__(self, max_tokens: int = 4000, strategy: ContextStrategy = ContextStrategy.SLIDING_WINDOW):
        self.max_tokens = max_tokens
        self.strategy = strategy
        self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
        self.summary_threshold = max_tokens * 0.7  # Summarize when 70% full

    def count_tokens(self, text: str) -> int:
        """Count tokens in text"""
        return len(self.encoding.encode(text))

    def calculate_importance_score(self, message: Dict[str, str], context: ConversationMemory) -> float:
        """Calculate importance score for message prioritization"""
        content = message.get("content", "")
        role = message.get("role", "")

        score = 0.0

        # Role-based scoring
        if role == "system":
            score += 10.0
        elif role == "user":
            score += 5.0
        elif role == "assistant":
            score += 3.0

        # Content-based scoring
        importance_keywords = [
            "important", "remember", "key", "critical", "essential",
            "problem", "solution", "error", "bug", "issue"
        ]

        for keyword in importance_keywords:
            if keyword.lower() in content.lower():
                score += 2.0

        # Length-based scoring (longer messages often more important)
        score += min(len(content) / 100, 5.0)

        # Recency boost (recent messages slightly more important)
        if len(context.messages) > 0:
            position = len(context.messages)
            recency_boost = min(position / 10, 2.0)
            score += recency_boost

        return score

    def add_message(self, memory: ConversationMemory, message: Dict[str, str]) -> ConversationMemory:
        """Add message with metadata tracking"""
        memory.last_accessed = datetime.now()

        # Calculate metadata
        token_count = self.count_tokens(message.get("content", ""))
        importance_score = self.calculate_importance_score(message, memory)

        metadata = MessageMetadata(
            timestamp=datetime.now(),
            token_count=token_count,
            importance_score=importance_score,
            message_type=message.get("role", "normal")
        )

        # Add to memory
        memory.messages.append(message)
        memory.metadata.append(metadata)
        memory.total_tokens += token_count

        # Apply context management strategy
        return self._apply_context_strategy(memory)

    def _apply_context_strategy(self, memory: ConversationMemory) -> ConversationMemory:
        """Apply the selected context management strategy"""
        if memory.total_tokens <= self.max_tokens:
            return memory

        if memory.context_strategy == ContextStrategy.SLIDING_WINDOW:
            return self._sliding_window_strategy(memory)
        elif memory.context_strategy == ContextStrategy.SUMMARIZATION:
            return self._summarization_strategy(memory)
        elif memory.context_strategy == ContextStrategy.SEMANTIC_COMPRESSION:
            return self._semantic_compression_strategy(memory)
        elif memory.context_strategy == ContextStrategy.HIERARCHICAL:
            return self._hierarchical_strategy(memory)

        return memory

    def _sliding_window_strategy(self, memory: ConversationMemory) -> ConversationMemory:
        """Remove oldest messages while preserving system messages"""
        while memory.total_tokens > self.max_tokens and len(memory.messages) > 1:
            # Find first non-system message to remove
            for i, (msg, meta) in enumerate(zip(memory.messages, memory.metadata)):
                if msg.get("role") != "system":
                    # Remove message and metadata
                    removed_msg = memory.messages.pop(i)
                    removed_meta = memory.metadata.pop(i)
                    memory.total_tokens -= removed_meta.token_count
                    break
            else:
                # If only system messages left, remove oldest
                if memory.messages:
                    removed_msg = memory.messages.pop(0)
                    removed_meta = memory.metadata.pop(0)
                    memory.total_tokens -= removed_meta.token_count

        return memory

    def _summarization_strategy(self, memory: ConversationMemory) -> ConversationMemory:
        """Summarize older messages when approaching token limit"""
        if memory.total_tokens < self.summary_threshold:
            return memory

        # Find messages to summarize (older half)
        mid_point = len(memory.messages) // 2
        messages_to_summarize = memory.messages[:mid_point]

        if len(messages_to_summarize) < 4:  # Need enough content to summarize
            return self._sliding_window_strategy(memory)

        # Create summary
        summary_content = self._create_summary(messages_to_summarize)
        summary_message = {
            "role": "system",
            "content": f"[SUMMARY] Previous conversation: {summary_content}"
        }

        # Calculate tokens saved
        tokens_to_remove = sum(meta.token_count for meta in memory.metadata[:mid_point])
        summary_tokens = self.count_tokens(summary_message["content"])

        # Replace messages with summary
        memory.messages = [summary_message] + memory.messages[mid_point:]
        memory.metadata = [MessageMetadata(
            timestamp=datetime.now(),
            token_count=summary_tokens,
            importance_score=8.0,  # Summaries are important
            message_type="summary"
        )] + memory.metadata[mid_point:]

        memory.total_tokens = memory.total_tokens - tokens_to_remove + summary_tokens
        memory.summaries.append(summary_message)

        return memory

    def _create_summary(self, messages: List[Dict[str, str]]) -> str:
        """Create a summary of messages (simplified version)"""
        # In production, you'd use an LLM for better summaries
        content_parts = []
        current_topic = None

        for msg in messages:
            role = msg.get("role", "")
            content = msg.get("content", "")[:200]  # Truncate for summary

            if role == "user":
                content_parts.append(f"User asked: {content}")
            elif role == "assistant":
                content_parts.append(f"Assistant responded: {content}")

        return " | ".join(content_parts[-5:])  # Last 5 exchanges

    def _semantic_compression_strategy(self, memory: ConversationMemory) -> ConversationMemory:
        """Remove messages with lowest importance scores"""
        if len(memory.messages) <= 2:  # Keep minimum viable conversation
            return memory

        # Sort by importance score (ascending)
        indexed_items = list(enumerate(zip(memory.messages, memory.metadata)))
        indexed_items.sort(key=lambda x: x[1][1].importance_score)

        # Remove lowest importance messages until under token limit
        to_remove = []
        tokens_to_remove = 0

        for idx, (msg, meta) in indexed_items:
            if memory.total_tokens - tokens_to_remove <= self.max_tokens:
                break
            if msg.get("role") != "system":  # Preserve system messages
                to_remove.append(idx)
                tokens_to_remove += meta.token_count

        # Remove messages (in reverse order to maintain indices)
        for idx in sorted(to_remove, reverse=True):
            memory.messages.pop(idx)
            removed_meta = memory.metadata.pop(idx)
            memory.total_tokens -= removed_meta.token_count

        return memory

    def _hierarchical_strategy(self, memory: ConversationMemory) -> ConversationMemory:
        """Organize messages in hierarchical structure with different retention policies"""
        # Implement tiered retention: recent (full), medium (compressed), old (summary)
        now = datetime.now()

        recent_threshold = timedelta(minutes=10)
        medium_threshold = timedelta(hours=1)

        recent_msgs = []
        medium_msgs = []
        old_msgs = []

        for msg, meta in zip(memory.messages, memory.metadata):
            age = now - meta.timestamp
            if age <= recent_threshold:
                recent_msgs.append((msg, meta))
            elif age <= medium_threshold:
                medium_msgs.append((msg, meta))
            else:
                old_msgs.append((msg, meta))

        # Keep all recent, compress medium, summarize old
        final_messages = []
        final_metadata = []

        # Add recent messages as-is
        for msg, meta in recent_msgs:
            final_messages.append(msg)
            final_metadata.append(meta)

        # Compress medium messages (every other message)
        for i, (msg, meta) in enumerate(medium_msgs):
            if i % 2 == 0 or msg.get("role") == "system":
                final_messages.append(msg)
                final_metadata.append(meta)

        # Summarize old messages
        if old_msgs:
            old_messages_only = [msg for msg, meta in old_msgs]
            summary = self._create_summary(old_messages_only)
            summary_msg = {
                "role": "system",
                "content": f"[ARCHIVED] Earlier conversation: {summary}"
            }
            final_messages.insert(0, summary_msg)
            final_metadata.insert(0, MessageMetadata(
                timestamp=now,
                token_count=self.count_tokens(summary_msg["content"]),
                importance_score=7.0,
                message_type="archive"
            ))

        # Update memory
        memory.messages = final_messages
        memory.metadata = final_metadata
        memory.total_tokens = sum(meta.token_count for meta in final_metadata)

        return memory

    def get_context_stats(self, memory: ConversationMemory) -> Dict:
        """Get detailed context statistics"""
        if not memory.messages:
            return {"total_messages": 0, "total_tokens": 0}

        message_types = {}
        importance_distribution = []

        for msg, meta in zip(memory.messages, memory.metadata):
            msg_type = msg.get("role", "unknown")
            message_types[msg_type] = message_types.get(msg_type, 0) + 1
            importance_distribution.append(meta.importance_score)

        return {
            "total_messages": len(memory.messages),
            "total_tokens": memory.total_tokens,
            "token_utilization": memory.total_tokens / self.max_tokens,
            "message_types": message_types,
            "avg_importance": sum(importance_distribution) / len(importance_distribution),
            "strategy": memory.context_strategy.value,
            "summaries_created": len(memory.summaries),
            "session_duration": (memory.last_accessed - memory.created_at).total_seconds() / 60
        }

    def export_conversation(self, memory: ConversationMemory) -> str:
        """Export conversation for analysis or backup"""
        export_data = {
            "conversation": [
                {
                    "message": msg,
                    "metadata": {
                        "timestamp": meta.timestamp.isoformat(),
                        "token_count": meta.token_count,
                        "importance_score": meta.importance_score,
                        "message_type": meta.message_type
                    }
                }
                for msg, meta in zip(memory.messages, memory.metadata)
            ],
            "summaries": memory.summaries,
            "stats": self.get_context_stats(memory),
            "persona": memory.persona
        }

        return json.dumps(export_data, indent=2)

# Testing and demonstration
if __name__ == "__main__":
    # Test different strategies
    strategies = [
        ContextStrategy.SLIDING_WINDOW,
        ContextStrategy.SUMMARIZATION,
        ContextStrategy.SEMANTIC_COMPRESSION,
        ContextStrategy.HIERARCHICAL
    ]

    for strategy in strategies:
        print(f"\n=== Testing {strategy.value} ===")

        manager = AdvancedContextManager(max_tokens=500, strategy=strategy)
        memory = ConversationMemory(context_strategy=strategy)

        # Simulate conversation
        test_messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "What is machine learning?"},
            {"role": "assistant", "content": "Machine learning is a subset of artificial intelligence..."},
            {"role": "user", "content": "Can you explain neural networks?"},
            {"role": "assistant", "content": "Neural networks are computing systems inspired by biological neural networks..."},
            {"role": "user", "content": "This is very important: I need to remember this for my exam tomorrow."},
            {"role": "assistant", "content": "I'll help you remember the key concepts for your exam..."},
        ]

        for msg in test_messages:
            memory = manager.add_message(memory, msg)

        stats = manager.get_context_stats(memory)
        print(f"Final stats: {stats}")
        print(f"Messages retained: {len(memory.messages)}")


=== Testing sliding_window ===
Final stats: {'total_messages': 7, 'total_tokens': 65, 'token_utilization': 0.13, 'message_types': {'system': 1, 'user': 3, 'assistant': 3}, 'avg_importance': 6.795714285714287, 'strategy': 'sliding_window', 'summaries_created': 0, 'session_duration': 3.3e-05}
Messages retained: 7

=== Testing summarization ===
Final stats: {'total_messages': 7, 'total_tokens': 65, 'token_utilization': 0.13, 'message_types': {'system': 1, 'user': 3, 'assistant': 3}, 'avg_importance': 6.795714285714287, 'strategy': 'summarization', 'summaries_created': 0, 'session_duration': 1.2833333333333333e-05}
Messages retained: 7

=== Testing semantic_compression ===
Final stats: {'total_messages': 7, 'total_tokens': 65, 'token_utilization': 0.13, 'message_types': {'system': 1, 'user': 3, 'assistant': 3}, 'avg_importance': 6.795714285714287, 'strategy': 'semantic_compression', 'summaries_created': 0, 'session_duration': 6.316666666666667e-06}
Messages retained: 7

=== Testing hierar