# Lab 3.4.6: Reasoning Pipeline - SOLUTIONS

This notebook contains complete solutions to all exercises from Lab 3.4.6.

In [None]:
import ollama
import json
import time
import hashlib
import re
from typing import Dict, List, Optional, Tuple, Callable
from dataclasses import dataclass, field
from collections import OrderedDict
from enum import Enum

# Model configuration
models = ollama.list()
model_names = [m['name'] for m in models.get('models', [])]

# Find models
FAST_MODEL = None
REASONING_MODEL = None

for name in model_names:
    if any(x in name.lower() for x in ['7b', '8b', '3b']) and 'r1' not in name.lower() and not FAST_MODEL:
        FAST_MODEL = name
    if 'r1' in name.lower() and not REASONING_MODEL:
        REASONING_MODEL = name

FAST_MODEL = FAST_MODEL or (model_names[0] if model_names else "llama3.1:8b")
REASONING_MODEL = REASONING_MODEL or FAST_MODEL

print(f"Fast Model: {FAST_MODEL}")
print(f"Reasoning Model: {REASONING_MODEL}")

## Solution: Query Complexity Classifier

In [None]:
class QueryComplexity(Enum):
    """Query complexity levels."""
    SIMPLE = "simple"
    MODERATE = "moderate"
    COMPLEX = "complex"


class EnhancedComplexityClassifier:
    """
    Solution: Enhanced complexity classifier with configurable weights.
    
    Features:
    - Configurable keyword weights
    - Pattern-based detection
    - Length and structure analysis
    - Detailed logging
    """
    
    # Weighted keywords
    COMPLEX_KEYWORDS = {
        'solve': 3, 'calculate': 2, 'prove': 4, 'derive': 4,
        'analyze': 2, 'step by step': 3, 'explain how': 2,
        'why does': 2, 'compare and contrast': 3, 'what if': 2,
        'optimize': 3, 'debug': 2, 'implement': 3, 'algorithm': 3,
        'mathematical': 3, 'equation': 2, 'probability': 3,
        'logic puzzle': 4, 'proof': 4, 'theorem': 4,
    }
    
    SIMPLE_KEYWORDS = {
        'what is': -2, 'who is': -2, 'when did': -2,
        'where is': -2, 'define': -1, 'capital of': -3,
        'how many': -1, 'true or false': -2, 'yes or no': -2,
        'list': -1, 'name': -1,
    }
    
    MATH_PATTERNS = [
        (r'\d+\s*[+\-*/]\s*\d+', 2),  # Arithmetic
        (r'\d+%\s*of', 2),  # Percentage
        (r'equation|formula', 3),
        (r'\$\d+', 1),  # Money
        (r'\d+\s*(mph|km|miles|meters)', 2),  # Word problems
    ]
    
    def __init__(self, simple_threshold: int = 0, complex_threshold: int = 4):
        self.simple_threshold = simple_threshold
        self.complex_threshold = complex_threshold
        self.classifications = []
        self.detailed_logs = []
    
    def classify(self, query: str, log_details: bool = False) -> QueryComplexity:
        """Classify query complexity with optional detailed logging."""
        query_lower = query.lower()
        score = 0
        details = {'query': query[:50], 'factors': []}
        
        # Check complex keywords
        for keyword, weight in self.COMPLEX_KEYWORDS.items():
            if keyword in query_lower:
                score += weight
                details['factors'].append(f"+{weight} ('{keyword}')")
        
        # Check simple keywords
        for keyword, weight in self.SIMPLE_KEYWORDS.items():
            if keyword in query_lower:
                score += weight
                details['factors'].append(f"{weight} ('{keyword}')")
        
        # Check math patterns
        for pattern, weight in self.MATH_PATTERNS:
            if re.search(pattern, query_lower):
                score += weight
                details['factors'].append(f"+{weight} (pattern: {pattern[:20]})")
        
        # Length factor
        if len(query) > 200:
            score += 1
            details['factors'].append("+1 (long query)")
        if len(query) > 400:
            score += 1
            details['factors'].append("+1 (very long query)")
        
        # Multiple sentences
        sentence_count = len(re.split(r'[.!?]', query))
        if sentence_count > 3:
            score += 1
            details['factors'].append(f"+1 ({sentence_count} sentences)")
        
        # Question marks (multiple questions = more complex)
        q_count = query.count('?')
        if q_count > 1:
            score += q_count - 1
            details['factors'].append(f"+{q_count-1} (multiple questions)")
        
        # Classify
        details['score'] = score
        
        if score <= self.simple_threshold:
            result = QueryComplexity.SIMPLE
        elif score >= self.complex_threshold:
            result = QueryComplexity.COMPLEX
        else:
            result = QueryComplexity.MODERATE
        
        details['result'] = result.value
        self.classifications.append((query[:50], result))
        
        if log_details:
            self.detailed_logs.append(details)
        
        return result
    
    def get_stats(self) -> Dict:
        """Get classification statistics."""
        counts = {c: 0 for c in QueryComplexity}
        for _, complexity in self.classifications:
            counts[complexity] += 1
        
        return {
            'total': len(self.classifications),
            'by_complexity': {c.value: count for c, count in counts.items()},
        }
    
    def print_detailed_log(self, n: int = 5):
        """Print detailed classification logs."""
        for log in self.detailed_logs[-n:]:
            print(f"\nQuery: {log['query']}...")
            print(f"  Score: {log['score']} -> {log['result']}")
            print(f"  Factors: {', '.join(log['factors'][:5])}")


# Test enhanced classifier
classifier = EnhancedComplexityClassifier()

test_queries = [
    "What is the capital of France?",
    "Explain how neural networks learn.",
    "Solve step by step: A train leaves at 9am at 60mph...",
    "Is Python interpreted?",
]

print("Enhanced Classifier Test:")
for q in test_queries:
    result = classifier.classify(q, log_details=True)
    print(f"  [{result.value:8}] {q[:50]}...")

print("\nDetailed Logs:")
classifier.print_detailed_log()

## Solution: LRU Cache with TTL

In [None]:
class LRUCacheWithTTL:
    """
    Solution: LRU Cache with Time-To-Live support.
    
    Features:
    - LRU eviction when full
    - TTL-based expiration
    - Detailed statistics
    - Optional semantic key hashing
    """
    
    def __init__(self, max_size: int = 100, ttl_seconds: float = 3600):
        self.max_size = max_size
        self.ttl_seconds = ttl_seconds
        self.cache = OrderedDict()
        self.timestamps = {}
        self.hits = 0
        self.misses = 0
        self.expirations = 0
    
    def _make_key(self, query: str, model: str) -> str:
        """Create cache key."""
        combined = f"{model}:{query}"
        return hashlib.md5(combined.encode()).hexdigest()
    
    def _is_expired(self, key: str) -> bool:
        """Check if entry has expired."""
        if key not in self.timestamps:
            return True
        age = time.time() - self.timestamps[key]
        return age > self.ttl_seconds
    
    def get(self, query: str, model: str) -> Optional[str]:
        """Get cached response if available and not expired."""
        key = self._make_key(query, model)
        
        if key not in self.cache:
            self.misses += 1
            return None
        
        if self._is_expired(key):
            # Remove expired entry
            del self.cache[key]
            del self.timestamps[key]
            self.expirations += 1
            self.misses += 1
            return None
        
        # Move to end (most recently used)
        self.cache.move_to_end(key)
        self.hits += 1
        return self.cache[key]
    
    def set(self, query: str, model: str, response: str):
        """Cache a response."""
        key = self._make_key(query, model)
        
        if key in self.cache:
            self.cache.move_to_end(key)
        else:
            if len(self.cache) >= self.max_size:
                # Evict oldest
                oldest_key = next(iter(self.cache))
                del self.cache[oldest_key]
                if oldest_key in self.timestamps:
                    del self.timestamps[oldest_key]
        
        self.cache[key] = response
        self.timestamps[key] = time.time()
    
    def clear(self):
        """Clear the cache."""
        self.cache.clear()
        self.timestamps.clear()
        self.hits = 0
        self.misses = 0
        self.expirations = 0
    
    def get_stats(self) -> Dict:
        """Get cache statistics."""
        total = self.hits + self.misses
        return {
            'size': len(self.cache),
            'max_size': self.max_size,
            'ttl_seconds': self.ttl_seconds,
            'hits': self.hits,
            'misses': self.misses,
            'expirations': self.expirations,
            'hit_rate': self.hits / total if total > 0 else 0,
        }


# Test cache
cache = LRUCacheWithTTL(max_size=10, ttl_seconds=60)
cache.set("test query", "model", "test response")
print(f"Cache hit: {cache.get('test query', 'model')[:20]}...")
print(f"Cache miss: {cache.get('unknown', 'model')}")
print(f"Stats: {cache.get_stats()}")

## Solution: Complete Reasoning Pipeline

In [None]:
@dataclass
class PipelineResponse:
    """Response from the reasoning pipeline."""
    response: str
    model_used: str
    complexity: QueryComplexity
    from_cache: bool
    latency: float
    thinking_tokens: int = 0


class ProductionReasoningPipeline:
    """
    Solution: Production-ready adaptive reasoning pipeline.
    
    Features:
    - Intelligent complexity-based routing
    - TTL-enabled response caching
    - Fallback handling
    - Comprehensive metrics
    - Configurable thresholds
    """
    
    def __init__(
        self,
        fast_model: str = FAST_MODEL,
        reasoning_model: str = REASONING_MODEL,
        cache_size: int = 100,
        cache_ttl: float = 3600,
        timeout: float = 30.0,
    ):
        self.fast_model = fast_model
        self.reasoning_model = reasoning_model
        self.timeout = timeout
        
        self.classifier = EnhancedComplexityClassifier()
        self.cache = LRUCacheWithTTL(max_size=cache_size, ttl_seconds=cache_ttl)
        
        # Metrics
        self.total_queries = 0
        self.fast_calls = 0
        self.reasoning_calls = 0
        self.cache_hits = 0
        self.fallbacks = 0
        self.total_latency = 0.0
        self.errors = []
    
    def query(
        self,
        query: str,
        use_cache: bool = True,
        force_model: str = None,
        verbose: bool = False,
    ) -> PipelineResponse:
        """Process a query through the pipeline."""
        self.total_queries += 1
        start_time = time.time()
        
        # Classify
        complexity = self.classifier.classify(query)
        if verbose:
            print(f"Complexity: {complexity.value}")
        
        # Select model
        if force_model:
            model = force_model
        elif complexity == QueryComplexity.COMPLEX:
            model = self.reasoning_model
        else:
            model = self.fast_model
        
        use_cot = complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]
        
        if verbose:
            print(f"Model: {model}, CoT: {use_cot}")
        
        # Check cache
        if use_cache:
            cached = self.cache.get(query, model)
            if cached:
                self.cache_hits += 1
                latency = time.time() - start_time
                if verbose:
                    print("Cache HIT")
                return PipelineResponse(
                    response=cached,
                    model_used=model,
                    complexity=complexity,
                    from_cache=True,
                    latency=latency,
                )
        
        # Build prompt
        prompt = f"{query}\n\nLet's think step by step:" if use_cot else query
        
        # Generate with fallback
        try:
            response = ollama.chat(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": 0.0, "num_predict": 1024}
            )
            response_text = response['message']['content']
            
            if model == self.reasoning_model:
                self.reasoning_calls += 1
            else:
                self.fast_calls += 1
                
        except Exception as e:
            # Fallback to fast model
            self.fallbacks += 1
            self.errors.append(str(e))
            
            if verbose:
                print(f"Error with {model}, falling back...")
            
            response = ollama.chat(
                model=self.fast_model,
                messages=[{"role": "user", "content": query}],
                options={"temperature": 0.0, "num_predict": 512}
            )
            response_text = response['message']['content']
            model = self.fast_model
            self.fast_calls += 1
        
        latency = time.time() - start_time
        self.total_latency += latency
        
        # Count thinking tokens
        thinking_matches = re.findall(r'<think>.*?</think>', response_text, re.DOTALL)
        thinking_tokens = sum(len(m) // 4 for m in thinking_matches)
        
        # Cache response
        if use_cache:
            self.cache.set(query, model, response_text)
        
        if verbose:
            print(f"Latency: {latency:.2f}s")
        
        return PipelineResponse(
            response=response_text,
            model_used=model,
            complexity=complexity,
            from_cache=False,
            latency=latency,
            thinking_tokens=thinking_tokens,
        )
    
    def get_stats(self) -> Dict:
        """Get comprehensive pipeline statistics."""
        return {
            'total_queries': self.total_queries,
            'fast_calls': self.fast_calls,
            'reasoning_calls': self.reasoning_calls,
            'cache_hits': self.cache_hits,
            'cache_hit_rate': self.cache_hits / max(self.total_queries, 1),
            'fallbacks': self.fallbacks,
            'total_latency': self.total_latency,
            'avg_latency': self.total_latency / max(self.total_queries - self.cache_hits, 1),
            'routing_efficiency': self.fast_calls / max(self.fast_calls + self.reasoning_calls, 1),
            'cache_stats': self.cache.get_stats(),
            'classifier_stats': self.classifier.get_stats(),
        }
    
    def print_dashboard(self):
        """Print a dashboard of pipeline statistics."""
        stats = self.get_stats()
        
        print("\n" + "=" * 60)
        print("REASONING PIPELINE DASHBOARD")
        print("=" * 60)
        
        print(f"\nQuery Statistics:")
        print(f"  Total queries:      {stats['total_queries']}")
        print(f"  Fast model calls:   {stats['fast_calls']}")
        print(f"  Reasoning calls:    {stats['reasoning_calls']}")
        print(f"  Cache hits:         {stats['cache_hits']} ({stats['cache_hit_rate']:.1%})")
        print(f"  Fallbacks:          {stats['fallbacks']}")
        
        print(f"\nPerformance:")
        print(f"  Total latency:      {stats['total_latency']:.1f}s")
        print(f"  Avg latency:        {stats['avg_latency']:.2f}s")
        print(f"  Routing efficiency: {stats['routing_efficiency']:.1%}")
        
        class_stats = stats['classifier_stats']['by_complexity']
        print(f"\nComplexity Distribution:")
        for comp, count in class_stats.items():
            print(f"  {comp}: {count}")
        
        print("=" * 60)

## Solution: Test the Pipeline

In [None]:
# Create pipeline
pipeline = ProductionReasoningPipeline(
    fast_model=FAST_MODEL,
    reasoning_model=REASONING_MODEL,
    cache_size=50,
    cache_ttl=300,
)

# Test queries
test_queries = [
    "What is the capital of Japan?",  # Simple
    "Explain machine learning.",  # Moderate
    "Solve step by step: 3x + 7 = 22",  # Complex
    "What is the capital of Japan?",  # Cache hit
]

print("Testing Pipeline:\n")
for q in test_queries:
    print(f"Q: {q}")
    result = pipeline.query(q, verbose=True)
    print(f"  Response: {result.response[:80]}...")
    print()

pipeline.print_dashboard()

## Key Takeaways

1. **Complexity classification** routes queries to appropriate models
2. **Caching with TTL** prevents stale data while saving compute
3. **Fallback handling** ensures reliability
4. **Metrics tracking** enables optimization
5. **DGX Spark** can run both models simultaneously (128GB memory)