In [None]:
import os
import faiss
import pickle
import numpy as np
from typing import List, Dict, Optional, Union
from sentence_transformers import SentenceTransformer
from dataclasses import dataclass
import google.generativeai as genai
import dotenv
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class SearchResult:
    """Data class for search results"""
    rank: int
    score: float
    chunk: str
    company: str
    year: str
    filename: str
    chunk_id: int
    file_path: str

class GermanReportsRAG:
    """RAG system for German company reports with FAISS backend and Google Gemini"""
    
    def __init__(
        self,
        index_path: str = "german_reports_index.faiss",
        metadata_path: str = "german_reports_metadata.pkl",
        chunks_path: str = "german_reports_chunks.pkl",
        model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        gemini_api_key: Optional[str] = None,
        gemini_model: str = "gemini-1.5-flash"
    ):
        """
        Initialize the RAG system.
        
        Args:
            index_path: Path to FAISS index file
            metadata_path: Path to metadata pickle file
            chunks_path: Path to chunks pickle file
            model_name: Sentence transformer model for embeddings
            gemini_api_key: Google Gemini API key
            gemini_model: Gemini model to use
        """
        self.model_name = model_name
        self.gemini_model = gemini_model
        
        # Load components
        self._load_index(index_path, metadata_path, chunks_path)
        self._setup_gemini(gemini_api_key)
        
        logger.info(f"RAG system initialized with {len(self.chunks)} chunks")
    
    def _load_index(self, index_path: str, metadata_path: str, chunks_path: str):
        """Load FAISS index, metadata, and chunks"""
        try:
            # Load FAISS index
            self.index = faiss.read_index(index_path)
            
            # Load metadata
            with open(metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)
            
            # Load chunks
            with open(chunks_path, 'rb') as f:
                self.chunks = pickle.load(f)
            
            # Load sentence transformer model
            self.encoder = SentenceTransformer(self.model_name)
            
            logger.info(f"✓ Loaded FAISS index with {self.index.ntotal} vectors")
            logger.info(f"✓ Loaded {len(self.metadata)} metadata entries")
            logger.info(f"✓ Loaded {len(self.chunks)} text chunks")
            
        except Exception as e:
            logger.error(f"Error loading index components: {e}")
            raise
    
    def _setup_gemini(self, api_key: Optional[str]):
        """Setup Google Gemini client"""
        try:
            # Get API key from parameter or environment
            if api_key:
                genai.configure(api_key=api_key)
            else:
                api_key = os.getenv('GEMINI_API_KEY')
                if not api_key:
                    raise ValueError("Gemini API key not provided. Set GEMINI_API_KEY environment variable or pass api_key parameter")
                genai.configure(api_key=api_key)
            
            # Initialize the model
            self.gemini = genai.GenerativeModel(self.gemini_model)
            
            logger.info(f"✓ Gemini model {self.gemini_model} initialized")
            
        except Exception as e:
            logger.error(f"Error setting up Gemini: {e}")
            raise
    
    def get_available_companies(self) -> List[str]:
        """Get list of all available companies in the index"""
        companies = list(set(meta['company'] for meta in self.metadata))
        return sorted(companies)
    
    def get_available_years(self, company: Optional[str] = None) -> List[str]:
        """Get list of all available years, optionally filtered by company"""
        if company:
            years = [meta['year'] for meta in self.metadata if meta['company'].lower() == company.lower()]
        else:
            years = [meta['year'] for meta in self.metadata]
        return sorted(list(set(years)))
    
    def _filter_indices_by_metadata(
        self, 
        company: Optional[str] = None, 
        year: Optional[Union[str, int]] = None
    ) -> List[int]:
        """
        Filter chunk indices based on company and/or year.
        
        Args:
            company: Company name to filter by
            year: Year to filter by
            
        Returns:
            List of valid chunk indices
        """
        valid_indices = []
        
        for i, meta in enumerate(self.metadata):
            # Check company filter
            if company and meta['company'].lower() != company.lower():
                continue
            
            # Check year filter
            if year and str(meta['year']) != str(year):
                continue
            
            valid_indices.append(i)
        
        return valid_indices
    
    def search_chunks(
        self,
        query: str,
        k: int = 10,
        company: Optional[str] = None,
        year: Optional[Union[str, int]] = None,
        min_score: float = 0.0
    ) -> List[SearchResult]:
        """
        Search for relevant chunks with optional company/year filtering.
        
        Args:
            query: Search query
            k: Number of results to return
            company: Filter by company name
            year: Filter by year
            min_score: Minimum similarity score threshold
            
        Returns:
            List of SearchResult objects
        """
        try:
            # Get filtered indices
            valid_indices = self._filter_indices_by_metadata(company, year)
            
            if not valid_indices:
                logger.warning(f"No chunks found for company={company}, year={year}")
                return []
            
            # Create query embedding
            query_embedding = self.encoder.encode([query], convert_to_numpy=True)
            faiss.normalize_L2(query_embedding)
            
            # Search in FAISS
            scores, indices = self.index.search(query_embedding, min(k * 3, self.index.ntotal))
            
            # Filter results and create SearchResult objects
            results = []
            for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
                if idx == -1 or score < min_score:
                    continue
                
                # Check if this index is in our filtered set
                if idx not in valid_indices:
                    continue
                
                meta = self.metadata[idx]
                result = SearchResult(
                    rank=len(results) + 1,
                    score=float(score),
                    chunk=self.chunks[idx],
                    company=meta['company'],
                    year=meta['year'],
                    filename=meta['filename'],
                    chunk_id=meta['chunk_id'],
                    file_path=meta['file_path']
                )
                results.append(result)
                
                # Stop when we have enough results
                if len(results) >= k:
                    break
            
            logger.info(f"Found {len(results)} relevant chunks for query: '{query}'")
            return results
            
        except Exception as e:
            logger.error(f"Error during search: {e}")
            return []
    
    def generate_answer(
        self,
        query: str,
        search_results: List[SearchResult],
        language: str = "German",
        include_sources: bool = True
    ) -> str:
        """
        Generate answer using Gemini based on search results.
        
        Args:
            query: Original user query
            search_results: List of relevant chunks
            language: Language for the response
            include_sources: Whether to include source information
            
        Returns:
            Generated answer
        """
        if not search_results:
            return f"Keine relevanten Informationen gefunden für die Anfrage: '{query}'"
        
        # Prepare context from search results
        context_parts = []
        for result in search_results:
            source_info = f"[{result.company} {result.year}]"
            context_parts.append(f"{source_info}: {result.chunk}")
        
        context = "\n\n".join(context_parts)
        
        # Create prompt for Gemini
        prompt = self._create_rag_prompt(query, context, language, include_sources, search_results)
        
        try:
            # Generate response with Gemini
            response = self.gemini.generate_content(prompt)
            return response.text
            
        except Exception as e:
            logger.error(f"Error generating answer with Gemini: {e}")
            return f"Fehler beim Generieren der Antwort: {str(e)}"
    
    def _create_rag_prompt(
        self,
        query: str,
        context: str,
        language: str,
        include_sources: bool,
        search_results: List[SearchResult]
    ) -> str:
        """Create the RAG prompt for Gemini"""
        
        sources_info = ""
        if include_sources and search_results:
            unique_sources = {}
            for result in search_results:
                key = f"{result.company}_{result.year}"
                if key not in unique_sources:
                    unique_sources[key] = f"- {result.company} Jahresbericht {result.year}"
            sources_info = f"\n\nVerfügbare Quellen:\n" + "\n".join(unique_sources.values())
        
        if language.lower() == "german":
            prompt = f"""Sie sind ein Experte für die Analyse deutscher Unternehmensberichte. Basierend auf den bereitgestellten Informationen aus Geschäftsberichten, beantworten Sie die folgende Frage präzise und umfassend.

FRAGE: {query}

KONTEXT AUS GESCHÄFTSBERICHTEN:
{context}

ANWEISUNGEN:
1. Antworten Sie ausschließlich auf Deutsch
2. Basieren Sie Ihre Antwort nur auf den bereitgestellten Informationen
3. Wenn Sie spezifische Zahlen oder Fakten erwähnen, geben Sie das Unternehmen und Jahr an
4. Falls die Informationen nicht ausreichen, sagen Sie das deutlich
5. Strukturieren Sie Ihre Antwort klar und logisch
6. {"Fügen Sie am Ende eine Liste der verwendeten Quellen hinzu" if include_sources else ""}

ANTWORT:"""
        else:
            prompt = f"""You are an expert in analyzing German company reports. Based on the provided information from annual reports, answer the following question precisely and comprehensively.

QUESTION: {query}

CONTEXT FROM BUSINESS REPORTS:
{context}

INSTRUCTIONS:
1. Answer in {language}
2. Base your answer only on the provided information
3. When mentioning specific numbers or facts, indicate the company and year
4. If the information is insufficient, state this clearly
5. Structure your answer clearly and logically
6. {"Include a list of sources used at the end" if include_sources else ""}

ANSWER:"""
        
        return prompt + sources_info
    
    def ask(
        self,
        query: str,
        company: Optional[str] = None,
        year: Optional[Union[str, int]] = None,
        k: int = 5,
        language: str = "German",
        include_sources: bool = True,
        min_score: float = 0.3
    ) -> Dict[str, any]:
        """
        Main RAG function: search and generate answer.
        
        Args:
            query: User question
            company: Filter by company name
            year: Filter by year
            k: Number of chunks to retrieve
            language: Response language
            include_sources: Include source information
            min_score: Minimum similarity threshold
            
        Returns:
            Dictionary with answer, sources, and metadata
        """
        # Search for relevant chunks
        search_results = self.search_chunks(
            query=query,
            k=k,
            company=company,
            year=year,
            min_score=min_score
        )
        
        # Generate answer
        answer = self.generate_answer(
            query=query,
            search_results=search_results,
            language=language,
            include_sources=include_sources
        )
        
        # Prepare response
        response = {
            'query': query,
            'answer': answer,
            'num_sources': len(search_results),
            'sources': [
                {
                    'company': r.company,
                    'year': r.year,
                    'filename': r.filename,
                    'score': r.score,
                    'chunk_preview': r.chunk[:200] + "..." if len(r.chunk) > 200 else r.chunk
                }
                for r in search_results
            ],
            'filters_applied': {
                'company': company,
                'year': year
            }
        }
        
        return response
    
    def interactive_chat(self):
        """Interactive chat interface"""
        print("🤖 German Reports RAG System")
        print("Verfügbare Unternehmen:", ", ".join(self.get_available_companies()))
        print("Verfügbare Jahre:", ", ".join(self.get_available_years()))
        print("\nCommands:")
        print("- 'exit' oder 'quit' zum Beenden")
        print("- 'companies' für verfügbare Unternehmen")
        print("- 'years [company]' für verfügbare Jahre")
        print("- 'filter company=BMW year=2023' um Filter zu setzen")
        print("- 'clear' um Filter zu löschen")
        
        current_company = None
        current_year = None
        
        while True:
            try:
                # Show current filters
                filter_info = ""
                if current_company or current_year:
                    filter_parts = []
                    if current_company:
                        filter_parts.append(f"Company: {current_company}")
                    if current_year:
                        filter_parts.append(f"Year: {current_year}")
                    filter_info = f" [{', '.join(filter_parts)}]"
                
                user_input = input(f"\n💬{filter_info} Ihre Frage: ").strip()
                
                if user_input.lower() in ['exit', 'quit']:
                    print("Auf Wiedersehen!")
                    break
                
                elif user_input.lower() == 'companies':
                    print("Verfügbare Unternehmen:", ", ".join(self.get_available_companies()))
                    continue
                
                elif user_input.lower().startswith('years'):
                    parts = user_input.split()
                    company = parts[1] if len(parts) > 1 else None
                    years = self.get_available_years(company)
                    print(f"Verfügbare Jahre{' für ' + company if company else ''}: {', '.join(years)}")
                    continue
                
                elif user_input.lower().startswith('filter'):
                    # Parse filter command: filter company=BMW year=2023
                    parts = user_input.split()[1:]
                    for part in parts:
                        if '=' in part:
                            key, value = part.split('=', 1)
                            if key.lower() == 'company':
                                current_company = value
                            elif key.lower() == 'year':
                                current_year = value
                    print(f"Filter gesetzt - Company: {current_company}, Year: {current_year}")
                    continue
                
                elif user_input.lower() == 'clear':
                    current_company = None
                    current_year = None
                    print("Filter gelöscht")
                    continue
                
                elif not user_input:
                    continue
                
                # Process the question
                print("🔍 Suche nach relevanten Informationen...")
                response = self.ask(
                    query=user_input,
                    company=current_company,
                    year=current_year,
                    k=5,
                    language="German",
                    include_sources=True
                )
                
                print(f"\n📋 Antwort ({response['num_sources']} Quellen):")
                print("=" * 50)
                print(response['answer'])
                
                if response['sources']:
                    print(f"\n📚 Verwendete Quellen:")
                    for i, source in enumerate(response['sources'], 1):
                        print(f"{i}. {source['company']} {source['year']} (Score: {source['score']:.3f})")
                
            except KeyboardInterrupt:
                print("\n\nAuf Wiedersehen!")
                break
            except Exception as e:
                print(f"Fehler: {e}")


# Example usage and testing
if __name__ == "__main__":
    # Initialize RAG system
    rag = GermanReportsRAG(
        index_path="german_reports_index.faiss",
        metadata_path="german_reports_metadata.pkl",
        chunks_path="german_reports_chunks.pkl",
        gemini_api_key=None  # Will use GEMINI_API_KEY environment variable
    )
    
    # Example 1: General question
    print("=== Example 1: General Question ===")
    response = rag.ask("Wie hoch war der Umsatz?", k=10)
    print(f"Answer: {response['answer']}")
    print(f"Sources: {len(response['sources'])}")
    

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
INFO:__main__:✓ Loaded FAISS index with 1874 vectors
INFO:__main__:✓ Loaded 1874 metadata entries
INFO:__main__:✓ Loaded 1874 text chunks
ERROR:__main__:Error setting up Gemini: Gemini API key not provided. Set GEMINI_API_KEY environment variable or pass api_key parameter


ValueError: Gemini API key not provided. Set GEMINI_API_KEY environment variable or pass api_key parameter