# Lab 3.4.6: Reasoning Pipeline

**Module:** 3.4 - Test-Time Compute & Reasoning  
**Time:** 2 hours  
**Difficulty:** ⭐⭐⭐⭐ (Advanced)

---

## Learning Objectives

By the end of this lab, you will:
- [ ] Build an intelligent query router that detects complexity
- [ ] Route simple queries to fast models, complex ones to reasoning models
- [ ] Implement caching for repeated reasoning patterns
- [ ] Measure overall latency and quality improvements
- [ ] Deploy a production-ready adaptive reasoning system

---

## Prerequisites

- Completed Labs 3.4.1-3.4.5
- Multiple models available in Ollama (fast + reasoning)
- Understanding of CoT, self-consistency, and reward models

---

## Real-World Context

In production, you can't use a reasoning model for every query - it's too slow and expensive. The solution: **intelligent routing**.

**The Problem:**
- Simple queries ("What's the capital of France?") don't need R1
- Complex queries ("Solve this optimization problem") do
- Using R1 for everything: slow, expensive
- Using fast model for everything: inaccurate on hard problems

**The Solution:** Build a pipeline that:
1. Classifies query complexity
2. Routes to appropriate model
3. Caches results for efficiency

**Industry Examples:**
- **ChatGPT:** Routes between GPT-3.5 and GPT-4 based on complexity
- **Claude:** Uses different reasoning depths internally
- **Cursor:** Routes coding queries to specialized models

---

## ELI5: Adaptive Reasoning Pipeline

> **Imagine a hospital emergency room...**
>
> Not every patient needs to see the top specialist!
>
> **Triage Nurse (Router):** Quickly assesses each patient
> - Minor scrape? → Nurse can handle it (fast model)
> - Complex symptoms? → See the specialist (reasoning model)
>
> **Our AI Pipeline:**
> - Simple question? → Fast 8B model (milliseconds)
> - Complex reasoning? → R1 70B model (seconds, but accurate)
>
> **Bonus:** If someone asks the same question twice,
> we just look up the previous answer (cache)!

```
Query ─> [Complexity Classifier] ─┬─> Simple ─> Fast Model (8B) ─┐
                                  │                              ├─> Response
                                  └─> Complex ─> R1 Model (70B) ─┘
                                        ↑
                                    [Cache: Skip if seen before]
```

---

## Part 1: Setup

In [None]:
import json
import time
import hashlib
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Callable
from dataclasses import dataclass, field
from collections import OrderedDict
from enum import Enum

import ollama

# List available models
models = ollama.list()
model_names = [m.get('name', 'unknown') for m in models.get('models', [])]

print("Available models:")
for name in model_names:
    print(f"  - {name}")

In [None]:
# Configure models
# Fast model: smaller, quicker responses
FAST_MODEL = None
for name in model_names:
    if any(x in name.lower() for x in ['7b', '8b', '3b']) and 'r1' not in name.lower():
        FAST_MODEL = name
        break

# Reasoning model: larger, more capable
REASONING_MODEL = None
for name in model_names:
    if 'r1' in name.lower():
        REASONING_MODEL = name
        break

# Fallback: use largest available model as reasoning
if not REASONING_MODEL:
    for name in model_names:
        if any(x in name.lower() for x in ['70b', '32b', '14b']):
            REASONING_MODEL = name
            break

# If still not found, use same model for both
if not FAST_MODEL:
    FAST_MODEL = model_names[0] if model_names else "qwen3:8b"
if not REASONING_MODEL:
    REASONING_MODEL = FAST_MODEL

print(f"Fast model: {FAST_MODEL}")
print(f"Reasoning model: {REASONING_MODEL}")

# Note if using same model
if FAST_MODEL == REASONING_MODEL:
    print("\nNote: Using same model for both. The pipeline will still demonstrate routing.")

---

## Part 2: Query Complexity Classifier

The first component: detect whether a query needs reasoning or not.

In [None]:
class QueryComplexity(Enum):
    """Query complexity levels."""
    SIMPLE = "simple"      # Factual, direct answer
    MODERATE = "moderate"  # Some reasoning needed
    COMPLEX = "complex"    # Multi-step reasoning required


class ComplexityClassifier:
    """
    Classify query complexity using heuristics and patterns.
    
    This uses rule-based classification. For production, you could
    use a small trained classifier or an LLM call.
    """
    
    # Keywords indicating complexity
    COMPLEX_KEYWORDS = [
        'solve', 'calculate', 'prove', 'derive', 'analyze',
        'step by step', 'explain how', 'why does', 'compare and contrast',
        'what if', 'optimize', 'debug', 'implement', 'algorithm',
        'mathematical', 'equation', 'probability', 'logic puzzle',
    ]
    
    SIMPLE_KEYWORDS = [
        'what is', 'who is', 'when did', 'where is', 'define',
        'capital of', 'how many', 'true or false', 'yes or no',
    ]
    
    # Patterns indicating math problems
    MATH_PATTERNS = [
        r'\d+\s*[+\-*/]\s*\d+',  # Basic arithmetic
        r'\d+%\s*of',            # Percentage
        r'equation|formula',      # Math terms
        r'\$\d+',                # Money calculations
    ]
    
    def __init__(self):
        self.call_count = 0
        self.classifications = []
    
    def classify(self, query: str) -> QueryComplexity:
        """
        Classify a query's complexity.
        
        Returns:
            QueryComplexity enum value
        """
        self.call_count += 1
        query_lower = query.lower()
        
        # Score based on indicators
        complexity_score = 0
        
        # Check complex keywords
        for keyword in self.COMPLEX_KEYWORDS:
            if keyword in query_lower:
                complexity_score += 2
        
        # Check simple keywords (reduce score)
        for keyword in self.SIMPLE_KEYWORDS:
            if keyword in query_lower:
                complexity_score -= 1
        
        # Check math patterns
        for pattern in self.MATH_PATTERNS:
            if re.search(pattern, query_lower):
                complexity_score += 2
        
        # Length heuristic (longer = more complex)
        if len(query) > 200:
            complexity_score += 1
        if len(query) > 500:
            complexity_score += 1
        
        # Multiple sentences often mean complex
        sentence_count = len(re.split(r'[.!?]', query))
        if sentence_count > 3:
            complexity_score += 1
        
        # Classify based on score
        if complexity_score <= 0:
            result = QueryComplexity.SIMPLE
        elif complexity_score <= 3:
            result = QueryComplexity.MODERATE
        else:
            result = QueryComplexity.COMPLEX
        
        self.classifications.append((query[:50], result))
        return result
    
    def get_stats(self) -> Dict:
        """Get classification statistics."""
        counts = {c: 0 for c in QueryComplexity}
        for _, complexity in self.classifications:
            counts[complexity] += 1
        return {
            'total': self.call_count,
            'by_complexity': {c.value: count for c, count in counts.items()},
        }

In [None]:
# Test the classifier
classifier = ComplexityClassifier()

test_queries = [
    "What is the capital of France?",
    "Explain how photosynthesis works.",
    "Solve this equation step by step: 3x + 7 = 22",
    "Who wrote Romeo and Juliet?",
    "A train leaves at 9am traveling at 60mph. Another train leaves at 10am at 80mph. The stations are 280 miles apart. When do they meet?",
    "Is Python a programming language?",
    "Implement a function to find the longest palindromic substring in O(n) time complexity.",
]

print("Query Complexity Classification:")
print("="*70)
for query in test_queries:
    complexity = classifier.classify(query)
    print(f"[{complexity.value.upper():8}] {query[:60]}...")

---

## Part 3: Response Cache

Cache reasoning results to avoid recomputation.

In [None]:
class LRUCache:
    """
    Least Recently Used (LRU) cache for responses.
    
    Caches query-response pairs. When full, evicts least recently used entries.
    """
    
    def __init__(self, max_size: int = 100):
        """
        Initialize cache.
        
        Args:
            max_size: Maximum number of entries to cache
        """
        self.max_size = max_size
        self.cache = OrderedDict()  # Maintains insertion order
        self.hits = 0
        self.misses = 0
    
    def _make_key(self, query: str, model: str) -> str:
        """Create a cache key from query and model."""
        combined = f"{model}:{query}"
        return hashlib.md5(combined.encode()).hexdigest()
    
    def get(self, query: str, model: str) -> Optional[str]:
        """
        Get cached response if available.
        
        Returns:
            Cached response or None
        """
        key = self._make_key(query, model)
        
        if key in self.cache:
            # Move to end (most recently used)
            self.cache.move_to_end(key)
            self.hits += 1
            return self.cache[key]
        
        self.misses += 1
        return None
    
    def set(self, query: str, model: str, response: str):
        """
        Cache a response.
        """
        key = self._make_key(query, model)
        
        # If already exists, update and move to end
        if key in self.cache:
            self.cache.move_to_end(key)
            self.cache[key] = response
            return
        
        # If full, evict oldest
        if len(self.cache) >= self.max_size:
            self.cache.popitem(last=False)  # Remove oldest
        
        self.cache[key] = response
    
    def clear(self):
        """Clear the cache."""
        self.cache.clear()
        self.hits = 0
        self.misses = 0
    
    def get_stats(self) -> Dict:
        """Get cache statistics."""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate,
        }


# Test cache
cache = LRUCache(max_size=10)

# Add some entries
cache.set("What is 2+2?", "fast_model", "The answer is 4.")
cache.set("Explain gravity.", "reasoning_model", "Gravity is...")

# Test retrieval
print("Cache test:")
print(f"  Hit: {cache.get('What is 2+2?', 'fast_model')[:30]}...")
print(f"  Miss: {cache.get('Unknown query', 'fast_model')}")
print(f"  Stats: {cache.get_stats()}")

---

## Part 4: Building the Reasoning Pipeline

In [None]:
@dataclass
class PipelineResponse:
    """Response from the reasoning pipeline."""
    response: str
    model_used: str
    complexity: QueryComplexity
    from_cache: bool
    latency: float
    thinking_tokens: int = 0
    
    def __repr__(self):
        cache_str = "[CACHED]" if self.from_cache else ""
        return f"PipelineResponse({self.complexity.value}, {self.model_used}, {self.latency:.2f}s) {cache_str}"


class ReasoningPipeline:
    """
    Adaptive reasoning pipeline that routes queries intelligently.
    
    Features:
    - Complexity-based routing
    - Response caching
    - Performance tracking
    
    Optimized for DGX Spark: can run fast model (8B) and reasoning
    model (70B) both in memory thanks to 128GB unified memory.
    """
    
    def __init__(
        self,
        fast_model: str = FAST_MODEL,
        reasoning_model: str = REASONING_MODEL,
        cache_size: int = 100,
    ):
        """
        Initialize the pipeline.
        
        Args:
            fast_model: Model for simple queries
            reasoning_model: Model for complex queries
            cache_size: Maximum cache entries
        """
        self.fast_model = fast_model
        self.reasoning_model = reasoning_model
        
        self.classifier = ComplexityClassifier()
        self.cache = LRUCache(max_size=cache_size)
        
        # Performance tracking
        self.total_queries = 0
        self.fast_model_calls = 0
        self.reasoning_model_calls = 0
        self.total_latency = 0.0
        self.saved_latency = 0.0  # From cache hits
    
    def query(
        self,
        query: str,
        use_cache: bool = True,
        force_reasoning: bool = False,
        verbose: bool = False,
    ) -> PipelineResponse:
        """
        Process a query through the pipeline.
        
        Args:
            query: The user query
            use_cache: Whether to use caching
            force_reasoning: Force use of reasoning model
            verbose: Print detailed info
        
        Returns:
            PipelineResponse with result and metadata
        """
        self.total_queries += 1
        start_time = time.time()
        
        # Step 1: Classify complexity
        complexity = self.classifier.classify(query)
        if verbose:
            print(f"Complexity: {complexity.value}")
        
        # Step 2: Choose model
        if force_reasoning or complexity == QueryComplexity.COMPLEX:
            model = self.reasoning_model
            use_cot = True
        elif complexity == QueryComplexity.MODERATE:
            # Use fast model with CoT for moderate
            model = self.fast_model
            use_cot = True
        else:
            model = self.fast_model
            use_cot = False
        
        if verbose:
            print(f"Model selected: {model}")
        
        # Step 3: Check cache
        if use_cache:
            cached = self.cache.get(query, model)
            if cached:
                latency = time.time() - start_time
                self.saved_latency += 2.0  # Estimate of saved time
                
                if verbose:
                    print("Cache hit!")
                
                return PipelineResponse(
                    response=cached,
                    model_used=model,
                    complexity=complexity,
                    from_cache=True,
                    latency=latency,
                )
        
        # Step 4: Generate response
        if model == self.reasoning_model:
            self.reasoning_model_calls += 1
        else:
            self.fast_model_calls += 1
        
        # Build prompt
        if use_cot:
            prompt = f"{query}\n\nLet's think step by step:"
        else:
            prompt = query
        
        response = ollama.chat(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            options={"temperature": 0.0, "num_predict": 1024}
        )
        
        response_text = response['message']['content']
        latency = time.time() - start_time
        self.total_latency += latency
        
        # Count thinking tokens (for R1)
        thinking_tokens = len(re.findall(r'<think>.*?</think>', response_text, re.DOTALL))
        
        # Step 5: Cache the response
        if use_cache:
            self.cache.set(query, model, response_text)
        
        if verbose:
            print(f"Response generated in {latency:.2f}s")
        
        return PipelineResponse(
            response=response_text,
            model_used=model,
            complexity=complexity,
            from_cache=False,
            latency=latency,
            thinking_tokens=thinking_tokens,
        )
    
    def get_stats(self) -> Dict:
        """Get pipeline statistics."""
        return {
            'total_queries': self.total_queries,
            'fast_model_calls': self.fast_model_calls,
            'reasoning_model_calls': self.reasoning_model_calls,
            'total_latency': self.total_latency,
            'avg_latency': self.total_latency / max(self.total_queries, 1),
            'saved_latency': self.saved_latency,
            'cache_stats': self.cache.get_stats(),
            'classifier_stats': self.classifier.get_stats(),
        }

In [None]:
# Create the pipeline
pipeline = ReasoningPipeline(
    fast_model=FAST_MODEL,
    reasoning_model=REASONING_MODEL,
    cache_size=50,
)

print(f"Pipeline created:")
print(f"  Fast model: {pipeline.fast_model}")
print(f"  Reasoning model: {pipeline.reasoning_model}")

---

## Part 5: Testing the Pipeline

In [None]:
# Test queries of varying complexity
test_queries = [
    # Simple (should use fast model)
    "What is the capital of Japan?",
    "Is Python a programming language?",
    
    # Moderate (fast model with CoT)
    "Explain what machine learning is in simple terms.",
    "What are the main differences between Python and JavaScript?",
    
    # Complex (should use reasoning model)
    "Solve step by step: If a train leaves at 9am traveling at 60mph, and another train leaves at 10am traveling at 80mph, when do they meet if they started 280 miles apart?",
    "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Think carefully.",
]

print("Testing Pipeline\n" + "="*70)

for query in test_queries:
    print(f"\nQuery: {query[:60]}...")
    
    result = pipeline.query(query, verbose=True)
    
    print(f"  Model: {result.model_used}")
    print(f"  Complexity: {result.complexity.value}")
    print(f"  Latency: {result.latency:.2f}s")
    print(f"  Response: {result.response[:100]}...")

In [None]:
# Test caching by repeating a query
print("\nTesting Cache:")
print("="*50)

query = "What is the capital of Japan?"

# First call (cache miss)
result1 = pipeline.query(query)
print(f"First call: {result1.latency:.3f}s (from_cache: {result1.from_cache})")

# Second call (cache hit)
result2 = pipeline.query(query)
print(f"Second call: {result2.latency:.3f}s (from_cache: {result2.from_cache})")

speedup = result1.latency / result2.latency if result2.latency > 0 else 0
print(f"Speedup: {speedup:.0f}x faster with cache!")

---

## Part 6: Performance Analysis

In [None]:
def run_benchmark(
    pipeline: ReasoningPipeline,
    queries: List[str],
    repetitions: int = 2,
) -> Dict:
    """
    Run a benchmark on the pipeline.
    
    Runs each query multiple times to test caching.
    """
    results = {
        'queries': len(queries),
        'repetitions': repetitions,
        'total_calls': len(queries) * repetitions,
        'latencies': [],
        'by_complexity': {},
        'cache_hits': 0,
        'cache_misses': 0,
    }
    
    for rep in range(repetitions):
        for query in queries:
            result = pipeline.query(query)
            
            results['latencies'].append(result.latency)
            
            if result.from_cache:
                results['cache_hits'] += 1
            else:
                results['cache_misses'] += 1
            
            # Track by complexity
            comp = result.complexity.value
            if comp not in results['by_complexity']:
                results['by_complexity'][comp] = {'count': 0, 'total_latency': 0}
            results['by_complexity'][comp]['count'] += 1
            results['by_complexity'][comp]['total_latency'] += result.latency
    
    # Calculate averages
    results['avg_latency'] = sum(results['latencies']) / len(results['latencies'])
    results['cache_hit_rate'] = results['cache_hits'] / results['total_calls']
    
    for comp in results['by_complexity']:
        data = results['by_complexity'][comp]
        data['avg_latency'] = data['total_latency'] / data['count']
    
    return results

In [None]:
# Create a fresh pipeline for benchmarking
benchmark_pipeline = ReasoningPipeline(
    fast_model=FAST_MODEL,
    reasoning_model=REASONING_MODEL,
)

# Mixed complexity queries
benchmark_queries = [
    # Simple
    "What year did World War II end?",
    "Who painted the Mona Lisa?",
    
    # Moderate
    "Explain what an API is.",
    "What's the difference between HTTP and HTTPS?",
    
    # Complex
    "Calculate step by step: A car travels 60 miles in 1.5 hours. What is its average speed in km/h? (1 mile = 1.6 km)",
]

print("Running benchmark...")
benchmark_results = run_benchmark(benchmark_pipeline, benchmark_queries, repetitions=2)

# Print results
print("\n" + "="*60)
print("BENCHMARK RESULTS")
print("="*60)

print(f"\nTotal queries: {benchmark_results['total_calls']}")
print(f"Average latency: {benchmark_results['avg_latency']:.2f}s")
print(f"Cache hit rate: {benchmark_results['cache_hit_rate']:.0%}")

print("\nLatency by Complexity:")
for comp, data in benchmark_results['by_complexity'].items():
    print(f"  {comp}: {data['avg_latency']:.2f}s avg ({data['count']} calls)")

In [None]:
# Compare: Pipeline vs Always-Reasoning
print("\n" + "="*60)
print("PIPELINE vs ALWAYS-REASONING COMPARISON")
print("="*60)

# Get pipeline stats
stats = benchmark_pipeline.get_stats()

# Estimate what always-reasoning would cost
# Assume reasoning model is ~3x slower on average
fast_calls = stats['fast_model_calls']
reasoning_calls = stats['reasoning_model_calls']
avg_fast_latency = 1.0  # Estimate: 1s for fast model
avg_reasoning_latency = 3.0  # Estimate: 3s for reasoning model

pipeline_latency = stats['total_latency']
always_reasoning_latency = (fast_calls + reasoning_calls) * avg_reasoning_latency

print(f"\n{'Metric':<30} {'Pipeline':<15} {'Always-Reasoning':<15}")
print("-"*60)
print(f"{'Fast model calls':<30} {fast_calls:<15} {0:<15}")
print(f"{'Reasoning model calls':<30} {reasoning_calls:<15} {fast_calls + reasoning_calls:<15}")
print(f"{'Est. total latency':<30} {pipeline_latency:<15.1f}s {always_reasoning_latency:<15.1f}s")

savings = always_reasoning_latency - pipeline_latency
savings_pct = savings / always_reasoning_latency * 100 if always_reasoning_latency > 0 else 0

print(f"\nEstimated savings: {savings:.1f}s ({savings_pct:.0f}%)")
print("\nNote: Actual savings depend on your specific query mix.")

---

## Part 7: Production Considerations

In [None]:
production_tips = """
╔══════════════════════════════════════════════════════════════════════╗
║             PRODUCTION DEPLOYMENT TIPS                               ║
╠══════════════════════════════════════════════════════════════════════╣
║                                                                      ║
║  1. CLASSIFIER IMPROVEMENTS:                                         ║
║     - Train a small ML classifier on labeled examples                ║
║     - Use an LLM to classify (costs 1 extra call, but more accurate) ║
║     - A/B test different routing thresholds                          ║
║                                                                      ║
║  2. CACHING STRATEGIES:                                              ║
║     - Use Redis/Memcached for distributed caching                    ║
║     - Implement semantic similarity for cache lookup                 ║
║     - Set TTL (time-to-live) for stale data                         ║
║     - Cache by embedding similarity, not exact match                 ║
║                                                                      ║
║  3. MONITORING:                                                      ║
║     - Track latency percentiles (p50, p95, p99)                      ║
║     - Monitor model accuracy by complexity level                     ║
║     - Alert on cache hit rate drops                                  ║
║     - Log routing decisions for analysis                             ║
║                                                                      ║
║  4. FALLBACK STRATEGIES:                                             ║
║     - If reasoning model is slow, timeout and use fast model         ║
║     - Implement circuit breakers for model failures                  ║
║     - Have a "confident answer" threshold before responding          ║
║                                                                      ║
║  5. DGX SPARK OPTIMIZATION:                                          ║
║     - Keep both models in memory (128GB is enough!)                  ║
║     - Use NVFP4 for reasoning model if available                     ║
║     - Batch similar queries together                                 ║
║     - Use tensor parallelism for 70B+ models                         ║
║                                                                      ║
╚══════════════════════════════════════════════════════════════════════╝
"""

print(production_tips)

---

## Common Mistakes

### Mistake 1: Not Updating the Cache for Updated Models

```python
# Wrong: Cache persists old responses after model update
pipeline.update_model("new-model-v2")
# Old cached responses from v1 are still returned!

# Right: Clear cache when changing models
pipeline.update_model("new-model-v2")
pipeline.cache.clear()
```

### Mistake 2: Over-Routing to Reasoning Model

```python
# Wrong: Classifier is too aggressive
if any_math_word_present(query):  # "one", "two" trigger reasoning
    use_reasoning_model()

# Right: More nuanced classification
if requires_calculation(query) and multi_step(query):
    use_reasoning_model()
```

### Mistake 3: Not Handling Edge Cases

```python
# Wrong: No fallback
response = reasoning_model.generate(query, timeout=30)
# If timeout, user gets nothing!

# Right: Graceful fallback
try:
    response = reasoning_model.generate(query, timeout=30)
except TimeoutError:
    response = fast_model.generate(query)  # Fallback
```

---

## Checkpoint

You've learned:
- ✅ How to classify query complexity
- ✅ How to route to appropriate models
- ✅ How to implement response caching
- ✅ How to measure pipeline performance
- ✅ Production deployment considerations

---

## Cleanup and Summary

In [None]:
# Final summary
print("="*70)
print("MODULE 3.4 COMPLETE: TEST-TIME COMPUTE & REASONING")
print("="*70)

print("""
You've mastered:

1. Chain-of-Thought Prompting
   - Zero-shot and few-shot CoT
   - When and why it improves accuracy

2. Self-Consistency
   - Multiple reasoning paths + majority voting
   - Temperature and N tuning

3. Reasoning Models (DeepSeek-R1)
   - <think> tokens and GRPO training
   - Running 70B models on DGX Spark

4. Model Comparison
   - Quantifying reasoning advantages
   - Token economy analysis

5. Reward Models & Best-of-N
   - Scoring responses for quality
   - Generate N, pick best strategy

6. Adaptive Reasoning Pipeline
   - Complexity classification
   - Intelligent routing
   - Response caching

""")

print("Next: Module 3.5 - RAG Systems & Vector Databases")
print("="*70)

import gc
gc.collect()
print("\nMemory cleaned up. Great work!")