# Lab 3.4.5: Best-of-N with Reward Model - SOLUTIONS

This notebook contains complete solutions to all exercises from Lab 3.4.5.

In [None]:
import ollama
import json
import time
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

LLM_MODEL = "qwen3:8b"

# Check for transformers
try:
    import torch
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    HAS_TRANSFORMERS = True
    print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
except ImportError:
    HAS_TRANSFORMERS = False
    print("Transformers not available - using heuristic reward model")

## Solution: ScoredResponse Dataclass

In [None]:
@dataclass
class ScoredResponse:
    """A response with its reward score and metadata."""
    response: str
    score: float
    generation_time: float
    token_count: int = 0
    
    def __repr__(self):
        return f"ScoredResponse(score={self.score:.3f}, len={len(self.response)}, time={self.generation_time:.2f}s)"

## Solution: Enhanced Heuristic Reward Model

In [None]:
class EnhancedHeuristicRewardModel:
    """
    Solution: A more sophisticated heuristic reward model.
    
    Scores based on:
    - Response length (appropriate range)
    - Reasoning markers
    - Answer clarity
    - Repetition penalty
    - Confidence language
    - Structure (lists, steps)
    """
    
    def __init__(self):
        self.name = "enhanced_heuristic"
    
    def score(self, prompt: str, response: str) -> float:
        """Score a response. Returns value in [0, 1]."""
        score = 0.5  # Neutral start
        response_lower = response.lower()
        
        # 1. Length scoring
        length = len(response)
        if 100 < length < 800:
            score += 0.15
        elif 50 < length < 1200:
            score += 0.08
        elif length < 30:
            score -= 0.2  # Too short
        elif length > 2000:
            score -= 0.1  # Too verbose
        
        # 2. Reasoning markers
        reasoning_markers = [
            'because', 'therefore', 'thus', 'hence',
            'first', 'second', 'then', 'finally',
            'step', 'reason', 'explanation',
        ]
        marker_count = sum(1 for m in reasoning_markers if m in response_lower)
        score += min(marker_count * 0.03, 0.15)
        
        # 3. Clear answer indicators
        answer_markers = ['the answer is', 'answer:', 'result:', 'solution:', '=']
        if any(m in response_lower for m in answer_markers):
            score += 0.1
        
        # 4. Structure bonus (lists, numbered steps)
        if any(f'{i}.' in response or f'{i})' in response for i in range(1, 6)):
            score += 0.08
        if '- ' in response or '* ' in response:
            score += 0.05
        
        # 5. Repetition penalty
        words = response_lower.split()
        if words:
            unique_ratio = len(set(words)) / len(words)
            if unique_ratio < 0.4:
                score -= 0.25
            elif unique_ratio < 0.6:
                score -= 0.1
        
        # 6. Uncertainty penalty
        uncertain_markers = ["i'm not sure", "i don't know", "might be", "could be", "perhaps"]
        if any(m in response_lower for m in uncertain_markers):
            score -= 0.1
        
        # 7. Confidence bonus (but not overconfidence)
        confident_markers = ['clearly', 'definitely', 'certainly']
        if any(m in response_lower for m in confident_markers):
            score += 0.05
        
        # 8. Relevance to prompt (basic keyword overlap)
        prompt_words = set(prompt.lower().split())
        response_words = set(response_lower.split())
        overlap = len(prompt_words & response_words) / max(len(prompt_words), 1)
        score += min(overlap * 0.1, 0.1)
        
        return max(0.0, min(1.0, score))


# Test the enhanced model
reward_model = EnhancedHeuristicRewardModel()

test_prompt = "Explain what machine learning is."
test_responses = [
    "ML is AI.",  # Too short
    "Machine learning is a subset of artificial intelligence that enables systems to learn from data. First, data is collected. Then, the model learns patterns. Finally, it makes predictions.",
    "I'm not sure, but I think machine learning might be something to do with computers learning stuff maybe?",
    "Machine learning is a type of AI. Machine learning is AI. AI uses machine learning. Learning machines use AI.",  # Repetitive
]

print("Testing Enhanced Reward Model:")
print(f"Prompt: {test_prompt}\n")
for i, resp in enumerate(test_responses):
    score = reward_model.score(test_prompt, resp)
    print(f"Response {i+1} (score: {score:.3f}): {resp[:60]}...")

## Solution: Best-of-N Sampler Class

In [None]:
class BestOfNSampler:
    """
    Solution: Complete Best-of-N sampling implementation.
    
    Features:
    - Configurable N and temperature
    - Detailed statistics tracking
    - Score distribution analysis
    - Performance metrics
    """
    
    def __init__(
        self,
        llm_model: str = LLM_MODEL,
        reward_model = None,
        default_n: int = 5,
        default_temperature: float = 0.7,
    ):
        self.llm_model = llm_model
        self.reward_model = reward_model or EnhancedHeuristicRewardModel()
        self.default_n = default_n
        self.default_temperature = default_temperature
        
        # Statistics
        self.total_samples = 0
        self.total_queries = 0
        self.score_history = []
    
    def generate_candidates(
        self,
        prompt: str,
        n: int = None,
        temperature: float = None,
        max_tokens: int = 512,
    ) -> List[ScoredResponse]:
        """Generate N candidate responses."""
        n = n or self.default_n
        temperature = temperature or self.default_temperature
        
        candidates = []
        
        for i in range(n):
            start_time = time.time()
            
            response = ollama.chat(
                model=self.llm_model,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": temperature, "num_predict": max_tokens}
            )
            
            elapsed = time.time() - start_time
            response_text = response['message']['content']
            
            candidates.append(ScoredResponse(
                response=response_text,
                score=0.0,
                generation_time=elapsed,
                token_count=len(response_text) // 4,
            ))
            
            self.total_samples += 1
        
        return candidates
    
    def score_candidates(
        self,
        prompt: str,
        candidates: List[ScoredResponse]
    ) -> List[ScoredResponse]:
        """Score all candidates with the reward model."""
        for candidate in candidates:
            candidate.score = self.reward_model.score(prompt, candidate.response)
            self.score_history.append(candidate.score)
        
        return candidates
    
    def sample(
        self,
        prompt: str,
        n: int = None,
        temperature: float = None,
        max_tokens: int = 512,
        verbose: bool = False,
    ) -> Tuple[ScoredResponse, List[ScoredResponse]]:
        """
        Run Best-of-N sampling.
        
        Returns:
            Tuple of (best_response, all_candidates)
        """
        n = n or self.default_n
        self.total_queries += 1
        
        if verbose:
            print(f"Best-of-{n} Sampling")
            print(f"Prompt: {prompt[:50]}...")
        
        # Generate
        if verbose:
            print(f"\nGenerating {n} candidates...")
        candidates = self.generate_candidates(prompt, n, temperature, max_tokens)
        
        # Score
        if verbose:
            print("Scoring candidates...")
        candidates = self.score_candidates(prompt, candidates)
        
        # Find best
        best = max(candidates, key=lambda x: x.score)
        
        if verbose:
            print(f"\nScore distribution:")
            for i, c in enumerate(sorted(candidates, key=lambda x: x.score, reverse=True)):
                marker = " <-- BEST" if c.score == best.score else ""
                print(f"  {i+1}. Score: {c.score:.3f}{marker}")
        
        return best, candidates
    
    def get_statistics(self) -> Dict:
        """Get sampling statistics."""
        if not self.score_history:
            return {'message': 'No samples yet'}
        
        return {
            'total_queries': self.total_queries,
            'total_samples': self.total_samples,
            'avg_n': self.total_samples / max(self.total_queries, 1),
            'avg_score': sum(self.score_history) / len(self.score_history),
            'min_score': min(self.score_history),
            'max_score': max(self.score_history),
            'score_std': self._std(self.score_history),
        }
    
    def _std(self, values: List[float]) -> float:
        """Calculate standard deviation."""
        if len(values) < 2:
            return 0.0
        mean = sum(values) / len(values)
        variance = sum((x - mean) ** 2 for x in values) / len(values)
        return variance ** 0.5


# Test the sampler
sampler = BestOfNSampler(default_n=5)

test_prompt = "What are three benefits of regular exercise?"
best, all_candidates = sampler.sample(test_prompt, verbose=True)

print("\n" + "="*50)
print("BEST RESPONSE:")
print("="*50)
print(best.response[:500])

## Solution: Greedy vs Best-of-N Comparison

In [None]:
def compare_greedy_vs_bon(
    prompts: List[str],
    sampler: BestOfNSampler,
    n: int = 5,
) -> Dict:
    """
    Solution: Compare greedy decoding vs Best-of-N.
    """
    results = {
        'greedy': {'scores': [], 'times': []},
        'bon': {'scores': [], 'times': []},
        'improvements': [],
    }
    
    for i, prompt in enumerate(prompts):
        print(f"\nPrompt {i+1}/{len(prompts)}: {prompt[:40]}...")
        
        # Greedy (temperature=0, n=1)
        greedy_start = time.time()
        response = ollama.chat(
            model=sampler.llm_model,
            messages=[{"role": "user", "content": prompt}],
            options={"temperature": 0.0, "num_predict": 512}
        )
        greedy_time = time.time() - greedy_start
        greedy_score = sampler.reward_model.score(prompt, response['message']['content'])
        
        results['greedy']['scores'].append(greedy_score)
        results['greedy']['times'].append(greedy_time)
        
        # Best-of-N
        bon_start = time.time()
        best, _ = sampler.sample(prompt, n=n)
        bon_time = time.time() - bon_start
        
        results['bon']['scores'].append(best.score)
        results['bon']['times'].append(bon_time)
        results['improvements'].append(best.score - greedy_score)
        
        print(f"  Greedy: {greedy_score:.3f} ({greedy_time:.1f}s)")
        print(f"  BoN-{n}: {best.score:.3f} ({bon_time:.1f}s)")
        print(f"  Improvement: {best.score - greedy_score:+.3f}")
    
    # Summary
    print("\n" + "="*50)
    print("SUMMARY")
    print("="*50)
    
    avg_greedy = sum(results['greedy']['scores']) / len(prompts)
    avg_bon = sum(results['bon']['scores']) / len(prompts)
    avg_improvement = sum(results['improvements']) / len(prompts)
    
    print(f"Average Greedy Score: {avg_greedy:.3f}")
    print(f"Average BoN Score:    {avg_bon:.3f}")
    print(f"Average Improvement:  {avg_improvement:+.3f}")
    
    bon_wins = sum(1 for imp in results['improvements'] if imp > 0)
    print(f"\nBoN outperformed: {bon_wins}/{len(prompts)} ({bon_wins/len(prompts):.0%})")
    
    return results


# Test prompts
test_prompts = [
    "What is the difference between a list and a tuple in Python?",
    "Explain how photosynthesis works.",
    "Describe the process of making coffee.",
]

# Run comparison (uncomment to execute)
# comparison = compare_greedy_vs_bon(test_prompts, sampler, n=5)
print("Uncomment the comparison call to run.")

## Solution: Experiment with Different N Values

In [None]:
def experiment_with_n(
    prompt: str,
    sampler: BestOfNSampler,
    n_values: List[int] = [1, 3, 5, 10, 15],
) -> Dict:
    """
    Solution: Experiment with different N values.
    """
    results = {}
    
    print(f"Prompt: {prompt[:50]}...\n")
    
    for n in n_values:
        start = time.time()
        
        if n == 1:
            # Greedy
            response = ollama.chat(
                model=sampler.llm_model,
                messages=[{"role": "user", "content": prompt}],
                options={"temperature": 0.0, "num_predict": 512}
            )
            score = sampler.reward_model.score(prompt, response['message']['content'])
            all_scores = [score]
        else:
            best, candidates = sampler.sample(prompt, n=n)
            score = best.score
            all_scores = [c.score for c in candidates]
        
        elapsed = time.time() - start
        
        results[n] = {
            'best_score': score,
            'all_scores': all_scores,
            'mean_score': sum(all_scores) / len(all_scores),
            'time': elapsed,
            'time_per_sample': elapsed / n,
        }
        
        print(f"N={n:2}: best={score:.3f}, mean={results[n]['mean_score']:.3f}, time={elapsed:.1f}s")
    
    # Analysis
    print("\n" + "="*50)
    print("DIMINISHING RETURNS ANALYSIS")
    print("="*50)
    
    baseline = results[1]['best_score']
    for n in n_values[1:]:
        improvement = results[n]['best_score'] - baseline
        efficiency = improvement / results[n]['time'] if results[n]['time'] > 0 else 0
        print(f"N={n:2}: +{improvement:.3f} improvement, efficiency={efficiency:.4f}/s")
    
    return results


# Test (uncomment to run)
# experiment_results = experiment_with_n(
#     "Explain the concept of recursion in programming.",
#     sampler,
#     n_values=[1, 3, 5, 10]
# )
print("Uncomment the experiment call to run.")

## Key Takeaways

1. **Temperature > 0** is essential for diverse candidates
2. **N=3-5** usually provides good balance
3. **Diminishing returns** after N~10
4. **Reward model quality** directly impacts selection quality
5. **Time/cost tradeoff**: 5x samples â‰ˆ 5x cost, but often <5x improvement