In [1]:
import requests
import re
import time
import json
from bs4 import BeautifulSoup
from typing import List, Dict, Tuple

In [2]:
class LongContextTest:
    def __init__(self, wikipedia_url: str, ollama_url: str = "http://localhost:11434"):
        self.wikipedia_url = wikipedia_url
        self.ollama_url = ollama_url
        self.base_text = self._extract_wikipedia_text()
        self.sentences = self._split_into_sentences(self.base_text)
        
    def _extract_wikipedia_text(self) -> str:
        try:
            print(f"Fetching Wikipedia page: {self.wikipedia_url}")
            response = requests.get(self.wikipedia_url)
            response.raise_for_status()
            
            soup = BeautifulSoup(response.content, 'html.parser')
            
            # Remove unwanted elements
            for element in soup(['script', 'style', 'sup', 'table', 'div.navbox', 'div.infobox']):
                element.decompose()
            
            # Extract main content (usually in div with id="mw-content-text")
            content = soup.find('div', {'id': 'mw-content-text'})
            if not content:
                content = soup.find('div', {'class': 'mw-parser-output'})
            if not content:
                content = soup
            
            # Get all paragraphs
            paragraphs = content.find_all('p')
            text = ' '.join([p.get_text().strip() for p in paragraphs if p.get_text().strip()])
            
            # Clean up text
            text = re.sub(r'\[.*?\]', '', text)  # Remove citation brackets
            text = re.sub(r'\s+', ' ', text)     # Normalize whitespace
            
            print(f"Extracted {len(text)} characters from Wikipedia")
            return text
            
        except Exception as e:
            print(f"Error fetching Wikipedia page: {e}")
            return ""
        
    def _split_into_sentences(self, text: str) -> List[str]:
        sentences = re.split(r'[.!?]\s+', text)
        return [s.strip() + '.' for s in sentences if len(s.strip()) > 10]
    
    def _approximate_tokens(self, text: str) -> int:
        return len(text) // 4
    
    def _get_text_chunk(self, target_tokens: int) -> str:
        target_chars = target_tokens * 4
        current_chars = 0
        result = []
        
        for sentence in self.sentences:
            if current_chars + len(sentence) > target_chars and result:
                break
            result.append(sentence)
            current_chars += len(sentence)
        
        return ' '.join(result)
    
    def create_test_questions(self) -> List[Dict]:
        questions = [
            {
                "question": "How many daily passengers use the rail system in Greater Tokyo?",
                "needle": "40 million passengers (counted twice if transferring between operators) use the rail system daily (14.6 billion annually)",
                "expected_answer": "40 million"
            },
            {
                "question": "What is the busiest train station in the world by passenger throughput?",
                "needle": "Shinjuku Station is the busiest train station in the world by passenger throughput.",
                "expected_answer": "Shinjuku Station"
            },
            {
                "question": "How many rail lines does JR East operate within Greater Tokyo?",
                "needle": "In total, JR alone operates 23 lines within the Greater Tokyo area.",
                "expected_answer": "23 lines"
            }
        ]
        return questions
    
    def call_ollama(self, prompt: str, model: str = "llama3") -> str:
        try:
            response = requests.post(
                f"{self.ollama_url}/api/generate",
                json={
                    "model": model,
                    "prompt": prompt,
                    "stream": False,
                    "options": {"temperature": 0.1}  # Low temperature for consistency
                },
                timeout=60
            )
            return response.json().get("response", "")
        except Exception as e:
            return f"Error: {str(e)}"
    
    def create_context(self, needle: str, context_tokens: int) -> str:
        if context_tokens <= 400:  # Just needle + 200 before/after
            before_tokens = after_tokens = 200
        else:
            # Distribute extra tokens
            extra_tokens = context_tokens - 400
            before_tokens = 200 + extra_tokens // 2
            after_tokens = 200 + extra_tokens - extra_tokens // 2
        
        before_text = self._get_text_chunk(before_tokens)
        after_text = self._get_text_chunk(after_tokens)
        
        context = f"{before_text} {needle} {after_text}"
        return context
    
    def test_single_question(self, question: Dict, context_tokens: int, trials: int = 10) -> Dict:
        context = self.create_context(question["needle"], context_tokens)
        
        prompt = f"""Based on the following text about Tokyo's transport system, answer the question briefly and accurately.

                    Text: {context}

                    Question: {question["question"]}

                    Answer:"""
        
        correct_answers = 0
        responses = []
        
        for i in range(trials):
            response = self.call_ollama(prompt)
            responses.append(response)
            
            # Simple check if expected answer is in response
            if question["expected_answer"].lower() in response.lower():
                correct_answers += 1
            
            time.sleep(0.5)  # Small delay between requests
        
        accuracy = correct_answers / trials
        actual_tokens = self._approximate_tokens(context)
        
        return {
            "context_tokens": actual_tokens,
            "target_tokens": context_tokens,
            "accuracy": accuracy,
            "correct_answers": correct_answers,
            "total_trials": trials,
            "sample_responses": responses[:3]  # Keep first 3 responses as samples
        }
    
    def run_full_test(self, max_context_tokens: int = 8000, step_size: int = 400) -> Dict:
        questions = self.create_test_questions()
        results = {}
        
        print(f"Running long-context efficacy test...")
        print(f"Questions: {len(questions)}")
        print(f"Max context: {max_context_tokens} tokens")
        print(f"Step size: {step_size} tokens")
        print("-" * 50)
        
        for i, question in enumerate(questions):
            print(f"\nTesting Question {i+1}: {question['question'][:60]}...")
            question_results = []
            
            context_sizes = range(400, max_context_tokens + 1, step_size)
            
            for context_tokens in context_sizes:
                print(f"  Context size: {context_tokens} tokens", end=" ")
                
                result = self.test_single_question(question, context_tokens)
                question_results.append(result)
                
                print(f"- Accuracy: {result['accuracy']:.1%}")
                
                # Early stopping if we can't fit more context
                if result["context_tokens"] >= max_context_tokens * 0.9:
                    break
            
            results[f"question_{i+1}"] = {
                "question": question["question"],
                "expected_answer": question["expected_answer"],
                "results": question_results
            }
        
        return results
    
    def print_summary(self, results: Dict):
        print("\n" + "="*60)
        print("LONG CONTEXT EFFICACY TEST RESULTS")
        print("="*60)
        
        for q_key, q_data in results.items():
            print(f"\n{q_key.upper()}: {q_data['question']}")
            print(f"Expected: {q_data['expected_answer']}")
            print("-" * 40)
            
            for result in q_data["results"]:
                tokens = result["context_tokens"]
                accuracy = result["accuracy"]
                print(f"  {tokens:5d} tokens: {accuracy:5.1%} accuracy")
        
        # Overall accuracy trend
        print("\n" + "="*40)
        print("ACCURACY TRENDS")
        print("="*40)
        
        # Calculate average accuracy at each context length
        context_sizes = set()
        for q_data in results.values():
            for result in q_data["results"]:
                context_sizes.add(result["context_tokens"])
        
        for size in sorted(context_sizes):
            accuracies = []
            for q_data in results.values():
                for result in q_data["results"]:
                    if result["context_tokens"] == size:
                        accuracies.append(result["accuracy"])
            
            if accuracies:
                avg_accuracy = sum(accuracies) / len(accuracies)
                print(f"  {size:5d} tokens: {avg_accuracy:5.1%} average accuracy")


In [3]:
def main():
    # Wikipedia URL for Tokyo transport
    wikipedia_url = "https://en.wikipedia.org/wiki/Transport_in_Greater_Tokyo"
    
    # Initialize and run test
    test = LongContextTest(wikipedia_url)
    
    if not test.base_text:
        print("Failed to extract text from Wikipedia. Exiting.")
        return
    
    print(f"Successfully loaded {len(test.base_text)} characters of text")
    
    results = test.run_full_test(max_context_tokens=4000, step_size=400)
    results = test.run_full_test(max_context_tokens=4000, step_size=400)
    
    # Print results
    test.print_summary(results)
    
    # Save results to file
    with open("long_context_test_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to: long_context_test_results.json")


In [4]:
if __name__ == "__main__":
    main()

Fetching Wikipedia page: https://en.wikipedia.org/wiki/Transport_in_Greater_Tokyo
Extracted 13472 characters from Wikipedia
Successfully loaded 13472 characters of text
Running long-context efficacy test...
Questions: 3
Max context: 4000 tokens
Step size: 400 tokens
--------------------------------------------------

Testing Question 1: How many daily passengers use the rail system in Greater Tok...
  Context size: 400 tokens - Accuracy: 100.0%
  Context size: 800 tokens - Accuracy: 100.0%
  Context size: 1200 tokens - Accuracy: 100.0%
  Context size: 1600 tokens - Accuracy: 100.0%
  Context size: 2000 tokens - Accuracy: 100.0%
  Context size: 2400 tokens - Accuracy: 100.0%
  Context size: 2800 tokens - Accuracy: 100.0%
  Context size: 3200 tokens - Accuracy: 100.0%
  Context size: 3600 tokens - Accuracy: 100.0%
  Context size: 4000 tokens - Accuracy: 100.0%

Testing Question 2: What is the busiest train station in the world by passenger ...
  Context size: 400 tokens - Accuracy: 100.0