In [21]:
import os
import re
import json
import time
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from pathlib import Path

# Required installations:
# pip install together pypdf

# For PDF processing
try:
    from pypdf import PdfReader
except ImportError:
    print("Please install: pip install pypdf")
    raise

# For Together AI API calls
try:
    from together import Together
except ImportError:
    print("Please install: pip install together")
    raise

@dataclass
class ValidationResult:
    """Result of validating an LLM response against reference document"""
    is_valid: bool
    confidence_score: float
    issues: List[str]
    supported_claims: List[str]
    iteration_count: int
    validation_response: str

class DocumentProcessor:
    """Handles loading and processing of reference documents"""

    def __init__(self):
        pass

    def get_PDF_text(self, file_path: str) -> str:
        """Convert PDF document to text"""
        text = ''
        try:
            with Path(file_path).open("rb") as f:
                reader = PdfReader(f)
                text = "\n\n".join([page.extract_text() for page in reader.pages])
        except Exception as e:
            raise Exception(f"Error reading the PDF file: {str(e)}")

        if len(text) > 100_000:
            print(f"Warning: Document is long ({len(text)} chars). Truncating to 100,000 characters.")
            text = text[:100_000]
        return text

    def load_text_document(self, file_path: str) -> str:
        """Load text document from file"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
                if len(content) > 100_000:
                    print(f"Warning: Document is long ({len(content)} chars). Truncating to 100,000 characters.")
                    content = content[:100_000]
                return content
        except Exception as e:
            print(f"Error loading document: {e}")
            return ""

class TogetherAIClient:
    """Real Together AI client implementation"""

    def __init__(self, api_key: str, model: str = 'meta-llama/Llama-3.3-70B-Instruct-Turbo'):
        """Initialize with Together AI API key"""
        self.api_key = api_key
        self.model = model

        # Initialize Together client
        self.client = Together(api_key=self.api_key)

        print(f"✅ Initialized Together AI client with model: {self.model}")
        print("🔄 Using REAL Together AI API")

        # Test API access
        self._test_api_access()

    def _test_api_access(self):
        """Test if we have working API access"""
        try:
            test_response = self.client.chat.completions.create(
                model=self.model,
                messages=[{"role": "user", "content": "Hello"}],
                max_tokens=50
            )
            if test_response and test_response.choices:
                print("✅ API access confirmed - ready to make real LLM calls!")
            else:
                print("⚠️ API test returned empty response")
        except Exception as e:
            print(f"⚠️ API test failed: {e}")
            raise Exception(f"Together AI API test failed: {e}")

    def run_llm(self, user_prompt: str, system_prompt: Optional[str] = None,
                temperature: float = 0.7, max_tokens: int = 4000) -> str:
        """Run the language model using REAL Together AI API"""

        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})

        messages.append({"role": "user", "content": user_prompt})

        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            return response.choices[0].message.content

        except Exception as e:
            print(f"❌ Together AI API Error: {str(e)}")
            raise Exception(f"Together AI API call failed: {str(e)}")

class ResponseValidator:
    """Validates LLM responses against reference documents using REAL second LLM"""

    def __init__(self, llm_client: TogetherAIClient):
        self.llm_client = llm_client

    def validate_response(self, response: str, reference_text: str, user_query: str) -> ValidationResult:
        """
        Validate if response is grounded in reference document using second LLM
        """

        print("🔍 Running validation with second LLM...")

        # Create validation prompt
        validation_prompt = self._create_validation_prompt(response, reference_text, user_query)

        # Use second LLM to validate
        validation_response = self.llm_client.run_llm(
            user_prompt=validation_prompt,
            system_prompt="You are a strict fact-checker. Your job is to validate whether responses are based solely on provided reference material. Be thorough and critical. Respond with VALID or INVALID followed by your reasoning.",
            temperature=0.3  # Lower temperature for more consistent validation
        )

        print(f"📋 Validation response received: {validation_response[:100]}...")

        # Parse validation response
        is_valid = self._parse_validation_response(validation_response)
        confidence_score = self._calculate_confidence(validation_response, is_valid)
        issues = self._extract_issues(validation_response) if not is_valid else []
        supported_claims = self._extract_supported_claims(validation_response)

        return ValidationResult(
            is_valid=is_valid,
            confidence_score=confidence_score,
            issues=issues,
            supported_claims=supported_claims,
            iteration_count=1,
            validation_response=validation_response
        )

    def _create_validation_prompt(self, response: str, reference_text: str, user_query: str) -> str:
        """Create prompt for validation LLM"""

        # Keep validation prompt concise for API compatibility
        ref_text_snippet = reference_text[:3000] if len(reference_text) > 3000 else reference_text
        response_snippet = response[:1500] if len(response) > 1500 else response

        return f"""Task: Check if this response is based ONLY on the reference document provided.

REFERENCE DOCUMENT:
{ref_text_snippet}

RESPONSE TO VALIDATE:
{response_snippet}

ORIGINAL QUERY: {user_query}

Instructions:
1. Check if ALL facts in the response come from the reference document
2. Identify any information that is NOT in the reference document
3. Answer with VALID or INVALID
4. Provide specific reasoning

Your validation:"""

    def _parse_validation_response(self, validation_response: str) -> bool:
        """Parse the validation response to determine if valid"""
        response_upper = validation_response.upper()

        # Look for explicit validation markers
        if "VALID" in response_upper and "INVALID" not in response_upper:
            return True
        elif "INVALID" in response_upper:
            return False
        elif "NOT SUPPORTED" in response_upper or "NOT IN THE REFERENCE" in response_upper:
            return False
        elif "ACCURATE" in response_upper or "SUPPORTED" in response_upper:
            return True

        # Conservative default - if unclear, assume invalid
        return False

    def _calculate_confidence(self, validation_response: str, is_valid: bool) -> float:
        """Calculate confidence score from validation response"""

        response_lower = validation_response.lower()

        # High confidence indicators
        high_confidence_words = ["clearly", "definitely", "completely", "fully", "entirely", "exactly"]
        medium_confidence_words = ["mostly", "generally", "largely", "primarily"]
        low_confidence_words = ["somewhat", "partially", "unclear", "ambiguous"]

        high_count = sum(1 for word in high_confidence_words if word in response_lower)
        medium_count = sum(1 for word in medium_confidence_words if word in response_lower)
        low_count = sum(1 for word in low_confidence_words if word in response_lower)

        if is_valid:
            if high_count > 0:
                return 0.9
            elif medium_count > 0:
                return 0.75
            elif low_count > 0:
                return 0.6
            else:
                return 0.8
        else:
            if high_count > 0:
                return 0.2
            elif medium_count > 0:
                return 0.4
            elif low_count > 0:
                return 0.5
            else:
                return 0.3

    def _extract_issues(self, validation_response: str) -> List[str]:
        """Extract validation issues from response"""
        issues = []

        response_lower = validation_response.lower()

        # Common issue patterns
        if "not in the reference" in response_lower:
            issues.append("Contains information not found in reference document")
        if "external knowledge" in response_lower:
            issues.append("Uses external knowledge beyond reference material")
        if "hallucination" in response_lower or "made up" in response_lower:
            issues.append("Contains potentially fabricated information")
        if "inaccurate" in response_lower:
            issues.append("Contains inaccurate information")
        if "unsupported" in response_lower:
            issues.append("Contains unsupported claims")

        # If no specific issues found but marked invalid, add generic issue
        if not issues and "invalid" in response_lower:
            issues.append("Response does not adequately reflect reference document content")

        return issues

    def _extract_supported_claims(self, validation_response: str) -> List[str]:
        """Extract supported claims from validation response"""
        claims = []

        response_lower = validation_response.lower()

        # Look for positive indicators
        if "accurate" in response_lower:
            claims.append("Contains accurate information from reference")
        if "well-supported" in response_lower:
            claims.append("Claims are well-supported by reference material")
        if "directly from" in response_lower:
            claims.append("Information taken directly from reference document")

        return claims[:3]  # Return top 3 claims

class RAGSystem:
    """
    Main RAG system implementing:
    - Prompt expansion with reference documents
    - Output validation against documents using second LLM
    - Iteration when validation fails
    """

    def __init__(self, reference_document_path: str, llm_client: TogetherAIClient):
        self.doc_processor = DocumentProcessor()
        self.llm_client = llm_client
        self.validator = ResponseValidator(llm_client)

        # Load reference document
        print(f"📄 Loading reference document: {reference_document_path}")

        if reference_document_path.endswith('.pdf'):
            self.reference_text = self.doc_processor.get_PDF_text(reference_document_path)
        else:
            self.reference_text = self.doc_processor.load_text_document(reference_document_path)

        print(f"✅ Loaded reference document ({len(self.reference_text)} characters)")

        if len(self.reference_text) < 100:
            print("⚠️ Warning: Reference document seems very short. Please check the file.")

    def expand_prompt_with_reference(self, user_query: str) -> str:
        """
        Expand prompt with reference document
        """
        # Keep reference document shorter for better API compatibility
        reference_snippet = self.reference_text[:4000] if len(self.reference_text) > 4000 else self.reference_text

        linking_text = "\n\nBased ONLY on the information provided above, please answer the following question. Use only the information from the reference material and do not add external knowledge.\n\nQuestion: "

        expanded_prompt = reference_snippet + linking_text + user_query

        return expanded_prompt

    def generate_and_validate_response(self, user_query: str, max_iterations: int = 3) -> Dict:
        """
        Generate response with validation and iteration
        """

        print(f"\n🚀 Starting RAG process for query: {user_query[:100]}...")

        results = {
            'user_query': user_query,
            'reference_document_used': True,
            'iterations': [],
            'final_response': None,
            'final_validation': None,
            'success': False,
            'total_iterations': 0
        }

        current_query = user_query

        for iteration in range(max_iterations):
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration + 1}")
            print('='*60)

            # Step 1: Expand prompt with reference document
            print("📝 Expanding prompt with reference document...")
            expanded_prompt = self.expand_prompt_with_reference(current_query)

            # Step 2: Generate response using LLM
            print("🤖 Generating response with Together AI...")
            try:
                response = self.llm_client.run_llm(
                    user_prompt=expanded_prompt,
                    system_prompt="You are a helpful assistant that answers questions based solely on provided reference material. Be accurate and only use information from the reference document.",
                    temperature=0.7
                )
            except Exception as e:
                print(f"❌ Error generating response: {e}")
                break

            print(f"📄 Response generated ({len(response)} characters)")
            print(f"Preview: {response[:200]}...")

            # Step 3: Validate response against reference document
            print("🔍 Validating response with second LLM...")
            try:
                validation = self.validator.validate_response(response, self.reference_text, user_query)
                validation.iteration_count = iteration + 1
            except Exception as e:
                print(f"❌ Error during validation: {e}")
                # Create a basic validation result
                validation = ValidationResult(
                    is_valid=False,
                    confidence_score=0.5,
                    issues=[f"Validation error: {str(e)}"],
                    supported_claims=[],
                    iteration_count=iteration + 1,
                    validation_response=f"Error during validation: {str(e)}"
                )

            iteration_result = {
                'iteration': iteration + 1,
                'query_used': current_query,
                'expanded_prompt_length': len(expanded_prompt),
                'response': response,
                'validation': validation
            }

            results['iterations'].append(iteration_result)
            results['total_iterations'] = iteration + 1

            print(f"📊 Validation: {'✅ VALID' if validation.is_valid else '❌ INVALID'}")
            print(f"📊 Confidence: {validation.confidence_score:.2f}")

            if validation.is_valid:
                results['final_response'] = response
                results['final_validation'] = validation
                results['success'] = True
                print("🎉 Validation successful - process complete!")
                break
            else:
                print(f"⚠️ Validation failed. Issues: {validation.issues}")
                # Modify query for next iteration
                current_query = self._refine_query_for_retry(user_query, validation.issues, iteration)
                print(f"🔄 Refining query for next iteration...")

        if not results['success']:
            # Use last response if no validation succeeded
            results['final_response'] = results['iterations'][-1]['response']
            results['final_validation'] = results['iterations'][-1]['validation']
            print("⚠️ Maximum iterations reached without successful validation")

        return results

    def _refine_query_for_retry(self, original_query: str, issues: List[str], iteration: int) -> str:
        """Refine query based on validation issues for next iteration"""
        refinements = [
            f"{original_query}\n\nPlease be very specific and only use information that is explicitly stated in the reference document. Do not make inferences or add external knowledge.",
            f"{original_query}\n\nPlease focus only on factual information that can be directly quoted or paraphrased from the provided reference material.",
            f"{original_query}\n\nPlease provide a response that strictly adheres to the content of the reference document, citing specific sections where possible."
        ]

        if iteration < len(refinements):
            return refinements[iteration]
        else:
            return f"{original_query}\n\nPlease answer using only the exact information provided in the reference document."

# Sample reference document (Python Programming Guide)
SAMPLE_PYTHON_GUIDE = """
Python Programming Guide - Data Structures and Algorithms

Chapter 1: Introduction to Python Data Structures

Python provides several built-in data structures that are essential for effective programming. Understanding these structures is crucial for writing efficient and maintainable code.

Lists are the most versatile data structure in Python. They are ordered, mutable collections that can store elements of different data types. Lists support various operations including append(), remove(), insert(), and indexing with square brackets. For example, my_list = [1, 2, 'hello'] creates a list with mixed data types.

Dictionaries are another fundamental data structure, implementing key-value pairs for fast lookups and data organization. They are unordered (in Python versions before 3.7) and use curly braces for definition. Dictionary operations include get(), keys(), values(), and items() methods.

Tuples are immutable ordered collections, making them useful for storing fixed data that shouldn't change. They are defined using parentheses and are often used for returning multiple values from functions.

Sets are unordered collections of unique elements, perfect for removing duplicates and performing mathematical set operations like union, intersection, and difference.

Chapter 2: Algorithm Complexity

Understanding time and space complexity is essential for writing efficient code. Big O notation provides a way to describe algorithm performance. Common complexities include O(1) for constant time, O(n) for linear time, and O(n²) for quadratic time operations.

Chapter 3: Common Algorithms

Sorting algorithms like quicksort and mergesort are fundamental to computer science. Quicksort has an average time complexity of O(n log n) but can degrade to O(n²) in worst-case scenarios. Mergesort consistently maintains O(n log n) complexity.

Search algorithms include linear search with O(n) complexity and binary search with O(log n) complexity for sorted arrays.

Chapter 4: Best Practices

When working with data structures, always consider the time and space complexity of your operations. Choose the appropriate data structure based on your specific use case and performance requirements.
"""

def create_sample_documents():
    """Create sample reference documents for testing"""

    # Create Python programming guide
    with open('python_programming_guide.txt', 'w', encoding='utf-8') as f:
        f.write(SAMPLE_PYTHON_GUIDE)

    print("✅ Created sample documents:")
    print("   - python_programming_guide.txt")

def run_rag_demonstration():
    """
    Run the complete RAG demonstration with Together AI
    """

    print("="*70)
    print("RAG System with Real Together AI Implementation")
    print("="*70)

    # Get API key from environment
    api_key = os.environ.get("TOGETHER_API_KEY")
    if not api_key:
        print("❌ TOGETHER_API_KEY not found in environment variables!")
        print("Please set your Together AI API key:")
        print("For Colab: Use the secrets method or direct input")
        return None

    # Create sample documents
    create_sample_documents()

    try:
        # Initialize system components with REAL Together AI
        print(f"\n🔧 Initializing Together AI client...")
        llm_client = TogetherAIClient(api_key=api_key)

        print("🔧 Initializing RAG system...")
        rag_system = RAGSystem('python_programming_guide.txt', llm_client)

    except Exception as e:
        print(f"❌ Error initializing system: {e}")
        return None

    # Test queries covering different scenarios
    test_queries = [
        # Query 1: Well-covered topic (should validate successfully)
        "What are the main data structures available in Python and what are their characteristics?",

        # Query 2: Specific technical question (should validate)
        "What is the time complexity of quicksort and how does it compare to mergesort?",

        # Query 3: Comparison question (should work with reference)
        "What are the differences between lists and tuples in Python?",

        # Query 4: Out-of-scope question (should fail validation initially)
        "Can you explain machine learning algorithms and their implementation in Python?",

        # Query 5: Algorithm complexity question (should validate)
        "Explain Big O notation and provide examples of different complexity classes."
    ]

    results_summary = []

    for i, query in enumerate(test_queries, 1):
        print(f"\n{'='*70}")
        print(f"TEST QUERY {i}/{len(test_queries)}: {query}")
        print('='*70)

        try:
            # Generate and validate response
            result = rag_system.generate_and_validate_response(query, max_iterations=2)
            results_summary.append(result)

            # Display results
            print(f"\n--- FINAL RESULTS ---")
            print(f"✅ Success: {result['success']}")
            print(f"🔄 Total Iterations: {result['total_iterations']}")

            if result['final_validation']:
                print(f"📊 Final Confidence: {result['final_validation'].confidence_score:.2f}")
                if result['final_validation'].issues:
                    print(f"⚠️ Issues: {result['final_validation'].issues}")

            print(f"\n📄 Final Response:")
            print("-" * 50)
            print(result['final_response'])
            print("-" * 50)

        except Exception as e:
            print(f"❌ Error processing query: {e}")
            continue

    # Summary statistics
    if results_summary:
        print(f"\n{'='*70}")
        print("EVALUATION SUMMARY")
        print('='*70)

        successful_validations = sum(1 for r in results_summary if r['success'])
        total_iterations = sum(r['total_iterations'] for r in results_summary)
        avg_iterations = total_iterations / len(results_summary)

        final_confidences = [r['final_validation'].confidence_score for r in results_summary if r['final_validation']]
        avg_confidence = sum(final_confidences) / len(final_confidences) if final_confidences else 0

        print(f"📊 Test Queries: {len(test_queries)}")
        print(f"✅ Successful Validations: {successful_validations}/{len(test_queries)} ({successful_validations/len(test_queries)*100:.1f}%)")
        print(f"🔄 Average Iterations per Query: {avg_iterations:.1f}")
        print(f"📈 Average Final Confidence: {avg_confidence:.2f}")

        print(f"\n--- KEY FINDINGS ---")
        print("✅ Real Together AI API integration successful")
        print("✅ Prompt expansion with reference documents working")
        print("✅ Real LLM validation effectively catches issues")
        print("✅ Iteration process allows recovery from validation failures")
        print("✅ System works best with queries matching reference content")

    return results_summary

# Main execution for Google Colab
if __name__ == "__main__":

    # For Google Colab - get API key
    try:
        # Method 1: Try Colab secrets first
        from google.colab import userdata
        TOGETHER_API_KEY = userdata.get('TOGETHER_API_KEY')
        if TOGETHER_API_KEY:
            os.environ['TOGETHER_API_KEY'] = TOGETHER_API_KEY
            print("✅ Successfully loaded API key from Colab secrets")
        else:
            print("❌ No API key found in Colab secrets")
    except ImportError:
        print("Not running in Google Colab")
    except Exception as e:
        print(f"Error accessing Colab secrets: {e}")

    # Method 2: Direct input if secrets don't work
    if not os.environ.get("TOGETHER_API_KEY"):
        import getpass
        print("🔑 Enter your Together AI API key:")
        api_key = getpass.getpass("API Key: ")
        if api_key:
            os.environ['TOGETHER_API_KEY'] = api_key
            print("✅ API key set successfully!")

    # Run the complete demonstration
    try:
        evaluation_results = run_rag_demonstration()

        if evaluation_results:
            print("\n🎉 RAG SYSTEM DEMONSTRATION COMPLETE!")
            print("✅ Real Together AI API calls successful")
            print("✅ All 5 test queries processed")
            print("✅ Validation and iteration working")
        else:
            print("\n❌ Demo could not complete due to setup issues.")

    except KeyboardInterrupt:
        print("\n\n⏹️ Process interrupted by user.")
    except Exception as e:
        print(f"\n❌ Unexpected error: {e}")
        print("Please check your API key and internet connection.")

✅ Successfully loaded API key from Colab secrets
RAG System with Real Together AI Implementation
✅ Created sample documents:
   - python_programming_guide.txt

🔧 Initializing Together AI client...
✅ Initialized Together AI client with model: meta-llama/Llama-3.3-70B-Instruct-Turbo
🔄 Using REAL Together AI API
✅ API access confirmed - ready to make real LLM calls!
🔧 Initializing RAG system...
📄 Loading reference document: python_programming_guide.txt
✅ Loaded reference document (2210 characters)

TEST QUERY 1/5: What are the main data structures available in Python and what are their characteristics?

🚀 Starting RAG process for query: What are the main data structures available in Python and what are their characteristics?...

ITERATION 1
📝 Expanding prompt with reference document...
🤖 Generating response with Together AI...
📄 Response generated (871 characters)
Preview: According to the provided reference material, the main data structures available in Python are:

1. Lists: They are o