# 06 - Context Retrieval with Google Gemini API

This notebook sets up real-time context retrieval using Google's Gemini API. We will:

1. Configure the Gemini API client
2. Build retrieval functions for keyphrases
3. Implement Wikipedia fallback
4. Create a production-ready retrieval service

---

## API Setup

Get your free API key from [Google AI Studio](https://aistudio.google.com/)

## Setup and Imports

In [None]:
import os
import json
from pathlib import Path
from dotenv import load_dotenv

import google.generativeai as genai
import wikipediaapi

from pydantic import BaseModel
from typing import List, Optional
import time

load_dotenv()

## Configuration

Create a `.env` file in your project root with:
```
GEMINI_API_KEY=your_api_key_here
```

In [None]:
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

if not GEMINI_API_KEY:
    print("Warning: GEMINI_API_KEY not found in environment.")
    print("Set it using: os.environ['GEMINI_API_KEY'] = 'your_key'")
    GEMINI_API_KEY = input("Enter your Gemini API key: ")

genai.configure(api_key=GEMINI_API_KEY)
print("Gemini API configured.")

## Initialize Gemini Model

In [None]:
model = genai.GenerativeModel('gemini-1.5-flash')

generation_config = genai.GenerationConfig(
    temperature=0.3,
    max_output_tokens=500,
    top_p=0.8,
    top_k=40
)

print("Gemini 1.5 Flash model initialized.")

## Basic Context Retrieval

In [None]:
def retrieve_context_gemini(keyphrases, max_retries=3):
    """
    Retrieve educational context for keyphrases using Gemini API.
    
    Args:
        keyphrases: List of keyphrase strings
        max_retries: Number of retry attempts on failure
    
    Returns:
        dict with context for each keyphrase
    """
    if not keyphrases:
        return {}
    
    prompt = f"""Provide brief, factual educational context for each of these topics. 
For each topic, give 2-3 sentences of objective information that would help someone 
understand and think critically about the topic.

Topics: {', '.join(keyphrases)}

Format your response as:
TOPIC: [topic name]
CONTEXT: [2-3 sentence explanation]

Do not include opinions. Focus on established facts and key considerations."""

    for attempt in range(max_retries):
        try:
            response = model.generate_content(
                prompt,
                generation_config=generation_config
            )
            return parse_context_response(response.text, keyphrases)
        except Exception as e:
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
                continue
            print(f"Gemini API error: {e}")
            return {}
    
    return {}

In [None]:
def parse_context_response(response_text, keyphrases):
    """Parse the structured response from Gemini."""
    result = {}
    current_topic = None
    current_context = []
    
    for line in response_text.strip().split('\n'):
        line = line.strip()
        if line.startswith('TOPIC:'):
            if current_topic and current_context:
                result[current_topic] = ' '.join(current_context)
            current_topic = line.replace('TOPIC:', '').strip()
            current_context = []
        elif line.startswith('CONTEXT:'):
            current_context.append(line.replace('CONTEXT:', '').strip())
        elif current_topic and line:
            current_context.append(line)
    
    if current_topic and current_context:
        result[current_topic] = ' '.join(current_context)
    
    if not result:
        for kp in keyphrases:
            if kp.lower() in response_text.lower():
                result[kp] = response_text[:500]
                break
    
    return result

In [None]:
test_keyphrases = ["climate change", "renewable energy", "carbon emissions"]

print("Testing Gemini Context Retrieval...")
context = retrieve_context_gemini(test_keyphrases)

print("\nRetrieved Context:")
for topic, ctx in context.items():
    print(f"\n{topic}:")
    print(f"  {ctx[:200]}..." if len(ctx) > 200 else f"  {ctx}")

## Wikipedia Fallback

In [None]:
wiki = wikipediaapi.Wikipedia(
    user_agent='SocraticPath/1.0 (https://github.com/socraticpath)',
    language='en'
)

def retrieve_context_wikipedia(keyphrase, max_chars=500):
    """
    Retrieve context from Wikipedia for a single keyphrase.
    
    Returns:
        dict with 'summary' and 'url' or None if not found
    """
    try:
        page = wiki.page(keyphrase)
        if page.exists():
            summary = page.summary[:max_chars]
            if len(page.summary) > max_chars:
                last_period = summary.rfind('.')
                if last_period > max_chars // 2:
                    summary = summary[:last_period + 1]
            
            return {
                'summary': summary,
                'url': page.fullurl,
                'title': page.title
            }
    except Exception as e:
        print(f"Wikipedia error for '{keyphrase}': {e}")
    
    return None

In [None]:
print("Testing Wikipedia Fallback...")

wiki_result = retrieve_context_wikipedia("machine learning")
if wiki_result:
    print(f"\nTitle: {wiki_result['title']}")
    print(f"URL: {wiki_result['url']}")
    print(f"Summary: {wiki_result['summary'][:200]}...")

## Combined Retrieval Service

In [None]:
class RetrievalResult(BaseModel):
    keyphrase: str
    context: str
    source: str
    url: Optional[str] = None


class ContextRetriever:
    """
    Production-ready context retrieval service.
    Uses Gemini API with Wikipedia fallback.
    """
    
    def __init__(self, api_key=None):
        if api_key:
            genai.configure(api_key=api_key)
        
        self.model = genai.GenerativeModel('gemini-1.5-flash')
        self.wiki = wikipediaapi.Wikipedia(
            user_agent='SocraticPath/1.0',
            language='en'
        )
        self.generation_config = genai.GenerationConfig(
            temperature=0.3,
            max_output_tokens=500
        )
    
    def retrieve(self, keyphrases, use_fallback=True):
        """
        Retrieve context for a list of keyphrases.
        
        Args:
            keyphrases: List of keyphrase strings
            use_fallback: Whether to use Wikipedia if Gemini fails
        
        Returns:
            List of RetrievalResult objects
        """
        results = []
        
        gemini_context = self._retrieve_gemini(keyphrases)
        
        for kp in keyphrases:
            matched_key = None
            for key in gemini_context:
                if kp.lower() in key.lower() or key.lower() in kp.lower():
                    matched_key = key
                    break
            
            if matched_key and gemini_context.get(matched_key):
                results.append(RetrievalResult(
                    keyphrase=kp,
                    context=gemini_context[matched_key],
                    source='gemini'
                ))
            elif use_fallback:
                wiki_result = self._retrieve_wikipedia(kp)
                if wiki_result:
                    results.append(RetrievalResult(
                        keyphrase=kp,
                        context=wiki_result['summary'],
                        source='wikipedia',
                        url=wiki_result['url']
                    ))
        
        return results
    
    def _retrieve_gemini(self, keyphrases):
        """Internal Gemini retrieval."""
        prompt = f"""Provide brief, factual educational context for these topics.
For each, give 2-3 objective sentences.

Topics: {', '.join(keyphrases)}

Format:
TOPIC: [name]
CONTEXT: [explanation]"""
        
        try:
            response = self.model.generate_content(
                prompt,
                generation_config=self.generation_config
            )
            return parse_context_response(response.text, keyphrases)
        except Exception as e:
            print(f"Gemini retrieval error: {e}")
            return {}
    
    def _retrieve_wikipedia(self, keyphrase):
        """Internal Wikipedia retrieval."""
        try:
            page = self.wiki.page(keyphrase)
            if page.exists():
                summary = page.summary[:500]
                last_period = summary.rfind('.')
                if last_period > 250:
                    summary = summary[:last_period + 1]
                return {
                    'summary': summary,
                    'url': page.fullurl
                }
        except Exception:
            pass
        return None
    
    def retrieve_single(self, keyphrase):
        """Convenience method for single keyphrase retrieval."""
        results = self.retrieve([keyphrase])
        return results[0] if results else None

In [None]:
retriever = ContextRetriever()

test_phrases = ["artificial intelligence", "neural networks", "deep learning"]

print("Testing Combined Retriever...")
results = retriever.retrieve(test_phrases)

print("\nRetrieval Results:")
for r in results:
    print(f"\n[{r.source.upper()}] {r.keyphrase}")
    print(f"  {r.context[:150]}...")
    if r.url:
        print(f"  Source: {r.url}")

## Enhanced Context for Socratic Questions

In [None]:
def enhance_socratic_context(user_input, keyphrases, retriever):
    """
    Retrieve and format context specifically for Socratic question enhancement.
    
    Returns:
        dict with 'combined_context' and 'sources'
    """
    results = retriever.retrieve(keyphrases)
    
    context_parts = []
    sources = []
    
    for r in results:
        context_parts.append(f"{r.keyphrase}: {r.context}")
        sources.append({
            'keyphrase': r.keyphrase,
            'source': r.source,
            'url': r.url
        })
    
    combined = "\n\n".join(context_parts)
    
    return {
        'user_input': user_input,
        'keyphrases': keyphrases,
        'combined_context': combined,
        'sources': sources,
        'retrieval_results': results
    }

In [None]:
user_input = "I think vaccines are dangerous and the government is hiding the truth about their side effects."
keyphrases = ["vaccines", "vaccine safety", "clinical trials"]

enhanced = enhance_socratic_context(user_input, keyphrases, retriever)

print("Enhanced Context for Socratic Question Generation:")
print("=" * 60)
print(f"\nUser Input: {enhanced['user_input']}")
print(f"\nKeyphrases: {enhanced['keyphrases']}")
print(f"\nCombined Context ({len(enhanced['combined_context'])} chars):")
print(enhanced['combined_context'][:500])
print(f"\nSources: {len(enhanced['sources'])} items")
for s in enhanced['sources']:
    print(f"  - {s['keyphrase']} ({s['source']})")

## Rate Limiting and Caching Considerations

In [None]:
from functools import lru_cache
import hashlib

class CachedContextRetriever(ContextRetriever):
    """Context retriever with simple in-memory caching."""
    
    def __init__(self, api_key=None, cache_size=100):
        super().__init__(api_key)
        self._cache = {}
        self._cache_size = cache_size
        self._last_request_time = 0
        self._min_request_interval = 0.5
    
    def _get_cache_key(self, keyphrases):
        """Generate cache key from keyphrases."""
        sorted_kps = tuple(sorted([kp.lower() for kp in keyphrases]))
        return hashlib.md5(str(sorted_kps).encode()).hexdigest()
    
    def retrieve(self, keyphrases, use_fallback=True):
        """Retrieve with caching and rate limiting."""
        cache_key = self._get_cache_key(keyphrases)
        
        if cache_key in self._cache:
            return self._cache[cache_key]
        
        elapsed = time.time() - self._last_request_time
        if elapsed < self._min_request_interval:
            time.sleep(self._min_request_interval - elapsed)
        
        results = super().retrieve(keyphrases, use_fallback)
        
        self._last_request_time = time.time()
        
        if len(self._cache) >= self._cache_size:
            oldest_key = next(iter(self._cache))
            del self._cache[oldest_key]
        
        self._cache[cache_key] = results
        
        return results

In [None]:
cached_retriever = CachedContextRetriever()

print("First request (not cached):")
start = time.time()
results1 = cached_retriever.retrieve(["quantum computing"])
print(f"  Time: {time.time() - start:.2f}s")

print("\nSecond request (cached):")
start = time.time()
results2 = cached_retriever.retrieve(["quantum computing"])
print(f"  Time: {time.time() - start:.4f}s")

## Save Configuration

In [None]:
OUTPUT_PATH = Path("../models/retrieval_config")
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

config = {
    "gemini_model": "gemini-1.5-flash",
    "generation_config": {
        "temperature": 0.3,
        "max_output_tokens": 500,
        "top_p": 0.8,
        "top_k": 40
    },
    "wikipedia_config": {
        "language": "en",
        "max_summary_chars": 500
    },
    "cache_config": {
        "enabled": True,
        "max_size": 100,
        "min_request_interval_seconds": 0.5
    }
}

with open(OUTPUT_PATH / "config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"Configuration saved to {OUTPUT_PATH / 'config.json'}")

## Summary

**Retrieval Architecture:**

1. **Primary**: Gemini 1.5 Flash (fast, cost-effective)
2. **Fallback**: Wikipedia API (reliable, free, includes source URLs)

**Key Classes:**
- `ContextRetriever`: Basic retrieval with fallback
- `CachedContextRetriever`: With in-memory caching and rate limiting

**API Usage:**
```python
retriever = CachedContextRetriever()
results = retriever.retrieve(['topic1', 'topic2'])
```

---

**Next Step**: Proceed to `07_inference_pipeline.ipynb` to build the complete inference pipeline.