In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import numpy as np
from typing import List, Dict, Optional, Tuple
import os
from dataclasses import dataclass
import logging
import faiss

@dataclass
class CAGConfig:
    """Configuration for CAG system"""
    max_cache_size: int = 1000
    similarity_threshold: float = 0.85
    max_new_tokens: int = 100
    temperature: float = 0.7
    top_k_cache: int = 5
    chunk_size: int = 512

class CacheManager:
    """Advanced cache management system for CAG"""
    def __init__(self, config: CAGConfig):
        self.config = config
        self.cache_store: Dict[str, DynamicCache] = {}
        self.embeddings_index = faiss.IndexFlatL2(768)  # Using FAISS for fast similarity search
        self.cache_keys: List[str] = []

    def add_to_cache(self, key: str, cache: DynamicCache, embedding: np.ndarray):
        """Add a new cache entry with smart cache management"""
        if len(self.cache_keys) >= self.config.max_cache_size:
            # Remove least recently used cache
            old_key = self.cache_keys.pop(0)
            del self.cache_store[old_key]
            self.embeddings_index.remove_ids(np.array([self.cache_keys.index(old_key)]))

        self.cache_store[key] = cache
        self.cache_keys.append(key)
        self.embeddings_index.add(embedding.reshape(1, -1))

    def find_similar_cache(self, query_embedding: np.ndarray) -> Optional[Tuple[str, DynamicCache]]:
        """Find most similar cache entry using FAISS"""
        if len(self.cache_keys) == 0:
            return None

        D, I = self.embeddings_index.search(query_embedding.reshape(1, -1), 1)
        if D[0][0] < self.config.similarity_threshold:
            return self.cache_keys[I[0][0]], self.cache_store[self.cache_keys[I[0][0]]]
        return None

class EnhancedCAG:
    """Enhanced Cache-Augmented Generation system"""
    def __init__(
        self,
        model_name: str,
        config: CAGConfig = CAGConfig(),
        device: str = None
    ):
        self.config = config
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            device_map="auto"
        )

        self.cache_manager = CacheManager(config)
        self.setup_logging()

    def setup_logging(self):
        """Setup logging for monitoring cache performance"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

    def generate_with_cache(
        self,
        input_ids: torch.Tensor,
        past_key_values: Optional[DynamicCache] = None
    ) -> torch.Tensor:
        """Enhanced generation with cache support"""
        device = self.model.device
        origin_len = input_ids.shape[-1]
        input_ids = input_ids.to(device)
        output_ids = input_ids.clone()
        next_token = input_ids

        with torch.no_grad():
            for _ in range(self.config.max_new_tokens):
                outputs = self.model(
                    input_ids=next_token,
                    past_key_values=past_key_values,
                    use_cache=True,
                    temperature=self.config.temperature
                )

                logits = outputs.logits[:, -1, :]

                # Apply temperature and top-k sampling
                if self.config.temperature > 0:
                    logits = logits / self.config.temperature
                    filtered_logits = torch.topk(logits, k=10, dim=-1)[0]
                    probs = torch.softmax(filtered_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    next_token = torch.argmax(logits, dim=-1, keepdim=True)

                output_ids = torch.cat([output_ids, next_token], dim=-1)
                past_key_values = outputs.past_key_values

                if self.tokenizer.eos_token_id is not None and next_token.item() == self.tokenizer.eos_token_id:
                    break

        return output_ids[:, origin_len:]

    def process_document(self, document_path: str) -> DynamicCache:
        """Process document and create cache"""
        with open(document_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Chunk document for better processing
        chunks = self._chunk_text(text)
        cache = DynamicCache()

        for chunk in chunks:
            input_ids = self.tokenizer(chunk, return_tensors="pt").input_ids.to(self.device)
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids,
                    past_key_values=cache,
                    use_cache=True,
                    output_hidden_states=True
                )
                cache = outputs.past_key_values

        return cache

    def _chunk_text(self, text: str) -> List[str]:
        """Split text into manageable chunks"""
        words = text.split()
        chunks = []
        current_chunk = []
        current_length = 0

        for word in words:
            word_length = len(self.tokenizer.encode(word))
            if current_length + word_length > self.config.chunk_size:
                chunks.append(" ".join(current_chunk))
                current_chunk = [word]
                current_length = word_length
            else:
                current_chunk.append(word)
                current_length += word_length

        if current_chunk:
            chunks.append(" ".join(current_chunk))
        return chunks

    def answer_question(self, question: str, document_cache: DynamicCache) -> str:
        """Generate answer using cached knowledge"""
        # Clean up cache to original length
        origin_len = document_cache.key_cache[0].shape[-2]
        self._clean_cache(document_cache, origin_len)

        # Prepare question
        input_ids = self.tokenizer(question + "\n", return_tensors="pt").input_ids.to(self.device)

        # Generate answer
        generated_ids = self.generate_with_cache(input_ids, document_cache)
        answer = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        return answer

    def _clean_cache(self, cache: DynamicCache, origin_len: int):
        """Clean up cache to maintain original knowledge"""
        for i in range(len(cache.key_cache)):
            cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
            cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]

# Example usage
def main():
    config = CAGConfig(
        max_cache_size=1000,
        similarity_threshold=0.85,
        max_new_tokens=100,
        temperature=0.7,
        top_k_cache=5,
        chunk_size=512
    )

    cag = EnhancedCAG(
        model_name="mistralai/Mistral-7B-Instruct-v0.1",
        config=config
    )

    # Process document
    document_cache = cag.process_document("document.txt")

    # Ask questions
    questions = [
        "What is the main topic of this document?",
        "What are the key findings?"
    ]

    for question in questions:
        answer = cag.answer_question(question, document_cache)
        print(f"Q: {question}")
        print(f"A: {answer}\n")

if __name__ == "__main__":
    main()