In [3]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
import torch
from typing import List, Dict, Tuple
import logging


class FinancialQASystem:
    """Advanced Q&A system optimized for financial documents."""
    
    def __init__(self, model_name: str = "deepset/roberta-base-squad2"):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        self.pipeline = pipeline("question-answering", 
                               model=self.model, 
                               tokenizer=self.tokenizer,
                               device=0 if torch.cuda.is_available() else -1)
        self.logger = logging.getLogger(__name__)
        
    def answer_question(self, question: str, context) -> dict:
        """Answer a question using the provided context."""
        try:
            # Handle both string context and dict context
            if isinstance(context, str):
                context_text = context
            elif isinstance(context, dict) and 'text' in context:
                context_text = context['text']
            elif isinstance(context, list):
                # Handle list of context chunks
                context_text = "\n\n".join([
                    chunk['text'] if isinstance(chunk, dict) else str(chunk) 
                    for chunk in context
                ])
            else:
                context_text = str(context)
        
            print(f"QA processing - Question length: {len(question)}, Context length: {len(context_text)}")
        
            # Truncate context if too long (BERT models have token limits)
            max_context_length = 2000  # Adjust based on your model
            if len(context_text.split()) > max_context_length:
                context_text = ' '.join(context_text.split()[:max_context_length])
                print(f"Context truncated to {len(context_text.split())} words")
        
            # Use the QA pipeline
            result = self.pipeline(question=question, context=context_text)
        
            print(f"QA pipeline result: {result}")
        
            return {
                'answer': result.get('answer', ''),
                'confidence': result.get('score', 0.0),
                'qa_confidence': result.get('score', 0.0),
                'similarity_confidence': 0.8,  # Placeholder
                'combined_confidence': result.get('score', 0.0),
                'source_chunks': [],
                'context_preview': context_text[:200] + "..." if len(context_text) > 200 else context_text
            }
        
        except Exception as e:
            print(f"QA Error: {e}")
            return {
                'answer': '',
                'confidence': 0.0,
                'qa_confidence': 0.0,
                'similarity_confidence': 0.0,
                'combined_confidence': 0.0,
                'source_chunks': [],
                'context_preview': '',
                'error': str(e)
            }

    def _select_best_answer(self, candidates: List[Dict], question: str) -> Dict:
        """Select best answer using multiple criteria."""
        # Score each candidate
        for candidate in candidates:
            # Combined score: QA confidence + similarity + answer quality
            qa_score = candidate['confidence']
            similarity_score = candidate['similarity_score']
            
            # Answer quality heuristics
            answer_length_score = min(len(candidate['answer'].split()) / 10, 1.0)  # Prefer moderate length
            numeric_bonus = 0.1 if any(char.isdigit() for char in candidate['answer']) else 0  # Financial answers often have numbers
            
            candidate['combined_score'] = (
                qa_score * 0.5 + 
                similarity_score * 0.3 + 
                answer_length_score * 0.1 + 
                numeric_bonus * 0.1
            )
        
        # Return candidate with highest combined score
        return max(candidates, key=lambda x: x['combined_score'])
    
    def _calculate_combined_confidence(self, answer_data: Dict) -> float:
        """Calculate overall confidence score."""
        qa_conf = answer_data['confidence']
        sim_conf = answer_data['similarity_score']
        
        # Weighted combination
        return (qa_conf * 0.7 + sim_conf * 0.3)
    
    def batch_qa(self, questions: List[str], retrieval_system, top_k: int = 3) -> List[Dict]:
        """Process multiple questions efficiently."""
        results = []
        
        for question in questions:
            contexts = retrieval_system.search(question, top_k=top_k)
            answer = self.answer_question(question, contexts)
            answer['question'] = question
            results.append(answer)
        
        return results


class StreamlitQAInterface:
    def __init__(self):
        # Use FinancialQASystem instead of QASystem
        self.qa_system = FinancialQASystem()  # Fixed class name
        
    def process_streamlit_question(self, question, context):  # Changed parameter from 'company' to 'context'
        """Process question for Streamlit interface"""
        try:
            # Use the answer_question method with proper context
            result = self.qa_system.answer_question(question, context)
            
            return {
                "answer": result.get("answer", "No answer found"),
                "confidence": result.get("confidence", 0.0),
                "source": result.get("context_preview", "Unknown"),
                "status": "success",
                "qa_confidence": result.get("qa_confidence", 0.0),
                "combined_confidence": result.get("combined_confidence", 0.0)
            }
        except Exception as e:
            return {
                "answer": f"Error: {str(e)}",
                "confidence": 0.0,
                "source": "Error",
                "status": "error",
                "qa_confidence": 0.0,
                "combined_confidence": 0.0
            }


# Initialize the QA system
print("🤖 Loading FinancialQA model...")
streamlit_qa = StreamlitQAInterface()
print("✅ FinancialQA System initialized for Streamlit")


🤖 Loading FinancialQA model...


Device set to use cpu


✅ FinancialQA System initialized for Streamlit
