In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from collections import defaultdict

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

class SamplingEvaluator:
    def __init__(self, model_name="gpt2", device="cuda" if torch.cuda.is_available() else "cpu"):
        """
        Initialize the sampling evaluator with a pre-trained model
        
        Args:
            model_name: HuggingFace model identifier
            device: Device to run the model on (cuda or cpu)
        """
        self.device = device
        print(f"Loading model {model_name} on {device}...")
        model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        self.model.eval()
        print("Model loaded successfully")
        
    def get_next_token_logits(self, input_text):
        """
        Get logits for the next token following the input text
        
        Args:
            input_text: Input text to condition on
            
        Returns:
            logits: Unnormalized log probabilities for next token
            input_ids: Tokenized input ids
        """
        input_ids = self.tokenizer.encode(input_text, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(input_ids)
            next_token_logits = outputs.logits[:, -1, :]
        return next_token_logits, input_ids
    
    def greedy_sampling(self, logits):
        """
        Select the token with the highest probability
        
        Args:
            logits: Unnormalized log probabilities
            
        Returns:
            token_id: Id of the selected token
        """
        return torch.argmax(logits, dim=-1).item()
    
    
    def test_speculative_sampling(self, prompt, draft_methods=None, target_methods=None, num_trials=5, max_length=30, max_draft_tokens=5):
        """
        Test combinations of sampling methods in a speculative sampling framework
        
        Args:
            prompt: Text prompt to start generation
            draft_methods: Dictionary of sampling methods to use for the draft model
            target_methods: Dictionary of sampling methods to use for the target model
            num_trials: Number of sequences to generate for each method combination
            max_length: Maximum sequence length
            max_draft_tokens: Maximum number of speculative tokens to generate in each step
            
        Returns:
            results: Dictionary with evaluation results for each combination
        """
        if draft_methods is None:
            draft_methods = {
                "temperature": {"temperature": 1.0},
                "top_p": {"p": 0.9},
                "typical": {"mass": 0.9},
                "sde": {"num_outer_steps": 10, "num_inner_steps": 3}
            }
        
        if target_methods is None:
            target_methods = {
                "greedy": {},
                "temperature": {"temperature": 0.7},
                "top_k": {"k": 40},
                "typical": {"mass": 0.95}
            }
        
        results = {}
        
        print(f"Testing {len(draft_methods) * len(target_methods)} combinations of speculative sampling...")
        
        for draft_name, draft_params in tqdm(draft_methods.items()):
            for target_name, target_params in target_methods.items():
                combination_name = f"{draft_name}->{target_name}"
                print(f"\nTesting combination: {combination_name}")
                
                combination_results = {
                    "texts": [],
                    "metrics": {
                        "time_taken": 0,
                        "tokens_per_second": 0,
                        "mean_prob": 0,
                        "mean_entropy": 0,
                        "perplexity": 0,
                        "acceptance_rate": 0,
                        "speedup_factor": 0
                    },
                    "diversity": 0,
                    "unique_token_ratio": 0
                }
                
                # Run multiple trials for this combination
                total_accepted_tokens = 0
                total_draft_tokens = 0
                baseline_times = []
                speculative_times = []
                
                for trial in range(num_trials):
                    # First, run target model alone to get baseline performance
                    start_time_baseline = time.time()
                    baseline_text, baseline_metrics = self.generate_sequence(
                        prompt, method=target_name, max_length=max_length, **target_params
                    )
                    baseline_time = time.time() - start_time_baseline
                    baseline_times.append(baseline_time)
                    
                    # Now run speculative sampling with the combination
                    start_time_spec = time.time()
                    
                    # Initialize sequence with prompt
                    input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
                    token_probs = []
                    entropy_values = []
                    tokens_generated = 0
                    
                    # Generate sequence using speculative sampling
                    for _ in range(max_length // max_draft_tokens + 1):  # Account for batching
                        # Get current sequence
                        current_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
                        
                        # Draft phase: generate speculative tokens
                        draft_tokens = []
                        draft_logits_list = []
                        
                        draft_input_ids = input_ids.clone()
                        for _ in range(max_draft_tokens):
                            with torch.no_grad():
                                # Get draft model logits
                                outputs = self.model(draft_input_ids)
                                next_token_logits = outputs.logits[:, -1, :]
                                draft_logits_list.append(next_token_logits.clone())
                                
                                # Choose next token based on draft method
                                if draft_name == "greedy":
                                    next_token = self.greedy_sampling(next_token_logits)
                                elif draft_name == "top_k":
                                    next_token = self.top_k_sampling(next_token_logits, **draft_params)
                                elif draft_name == "top_p":
                                    next_token = self.top_p_sampling(next_token_logits, **draft_params)
                                elif draft_name == "temperature":
                                    next_token = self.temperature_sampling(next_token_logits, **draft_params)
                                elif draft_name == "importance":
                                    next_token = self.importance_sampling(next_token_logits, **draft_params)
                                elif draft_name == "rejection":
                                    next_token = self.rejection_sampling(next_token_logits, **draft_params)
                                elif draft_name == "mcmc":
                                    next_token = self.mcmc_sampling(next_token_logits, **draft_params)
                                elif draft_name == "smc":
                                    next_token = self.sequential_monte_carlo(next_token_logits, **draft_params)
                                elif draft_name == "typical":
                                    next_token = self.typical_sampling(next_token_logits, **draft_params)
                                elif draft_name == "sde":
                                    next_token = self.sde_sampling(next_token_logits, **draft_params)
                                elif draft_name == "vesde":
                                    next_token = self.vesde_sampling(next_token_logits, **draft_params)
                                elif draft_name == "eagle":
                                    next_token = self.eagle_sampling(next_token_logits, **draft_params)
                                else:
                                    raise ValueError(f"Unknown draft sampling method: {draft_name}")
                            
                            draft_tokens.append(next_token)
                            
                            # Add token to draft sequence
                            next_token_tensor = torch.tensor([[next_token]]).to(self.device)
                            draft_input_ids = torch.cat((draft_input_ids, next_token_tensor), dim=1)
                            
                            # Stop if we generate an end-of-sequence token
                            if next_token == self.tokenizer.eos_token_id:
                                break
                        
                        total_draft_tokens += len(draft_tokens)
                        if not draft_tokens:
                            break
                        
                        # Verification phase: check draft tokens with target model
                        accepted_draft_tokens = []
                        
                        with torch.no_grad():
                            # Get target model logits for the current sequence
                            outputs = self.model(input_ids)
                            target_logits = outputs.logits[:, -1, :]
                            
                            # Iterate through draft tokens for verification
                            for i, draft_token in enumerate(draft_tokens):
                                # Choose next token based on target method
                                if target_name == "greedy":
                                    target_token = self.greedy_sampling(target_logits)
                                elif target_name == "top_k":
                                    target_token = self.top_k_sampling(target_logits, **target_params)
                                elif target_name == "top_p":
                                    target_token = self.top_p_sampling(target_logits, **target_params)
                                elif target_name == "temperature":
                                    target_token = self.temperature_sampling(target_logits, **target_params)
                                elif target_name == "importance":
                                    target_token = self.importance_sampling(target_logits, **target_params)
                                elif target_name == "rejection":
                                    target_token = self.rejection_sampling(target_logits, **target_params)
                                elif target_name == "mcmc":
                                    target_token = self.mcmc_sampling(target_logits, **target_params)
                                elif target_name == "smc":
                                    target_token = self.sequential_monte_carlo(target_logits, **target_params)
                                elif target_name == "typical":
                                    target_token = self.typical_sampling(target_logits, **target_params)
                                elif target_name == "sde":
                                    target_token = self.sde_sampling(target_logits, **target_params)
                                elif target_name == "vesde":
                                    target_token = self.vesde_sampling(target_logits, **target_params)
                                elif target_name == "eagle":
                                    target_token = self.eagle_sampling(target_logits, **target_params)
                                else:
                                    raise ValueError(f"Unknown target sampling method: {target_name}")
                                
                                # Calculate probabilities for acceptance check
                                target_probs = F.softmax(target_logits, dim=-1)
                                draft_probs = F.softmax(draft_logits_list[i], dim=-1)
                                
                                target_prob = target_probs[0, draft_token].item()
                                draft_prob = draft_probs[0, draft_token].item()
                                
                                # Acceptance test (simplified version of the test used in speculative sampling)
                                acceptance_prob = min(1.0, target_prob / (draft_prob + 1e-10))
                                
                                # Decide whether to accept
                                if np.random.random() < acceptance_prob and draft_token == target_token:
                                    # Accept this draft token
                                    accepted_draft_tokens.append(draft_token)
                                    
                                    # Calculate metrics for this token
                                    probs = F.softmax(target_logits, dim=-1)
                                    top_prob, _ = torch.max(probs, dim=-1)
                                    token_probs.append(top_prob.item())
                                    
                                    # Calculate entropy
                                    log_probs = torch.log(probs + 1e-10)
                                    entropy = -torch.sum(probs * log_probs, dim=-1)
                                    entropy_values.append(entropy.item())
                                    
                                    # If there are more draft tokens, get the next target logits
                                    if i < len(draft_tokens) - 1:
                                        next_token_tensor = torch.tensor([[draft_token]]).to(self.device)
                                        temp_input_ids = torch.cat((input_ids, next_token_tensor), dim=1)
                                        outputs = self.model(temp_input_ids)
                                        target_logits = outputs.logits[:, -1, :]
                                else:
                                    # Reject this and all remaining draft tokens
                                    break
                        
                        # Update accepted token count
                        total_accepted_tokens += len(accepted_draft_tokens)
                        tokens_generated += len(accepted_draft_tokens)
                        
                        # Add all accepted tokens to the input
                        for token in accepted_draft_tokens:
                            next_token_tensor = torch.tensor([[token]]).to(self.device)
                            input_ids = torch.cat((input_ids, next_token_tensor), dim=1)
                            
                            # Stop if we generate an end-of-sequence token
                            if token == self.tokenizer.eos_token_id:
                                break
                        
                        # If no tokens were accepted or we hit EOS, generate one token with the target model
                        if not accepted_draft_tokens:
                            with torch.no_grad():
                                outputs = self.model(input_ids)
                                next_token_logits = outputs.logits[:, -1, :]
                                
                                # Choose next token based on target method
                                if target_name == "greedy":
                                    next_token = self.greedy_sampling(next_token_logits)
                                elif target_name == "top_k":
                                    next_token = self.top_k_sampling(next_token_logits, **target_params)
                                elif target_name == "top_p":
                                    next_token = self.top_p_sampling(next_token_logits, **target_params)
                                elif target_name == "temperature":
                                    next_token = self.temperature_sampling(next_token_logits, **target_params)
                                elif target_name == "importance":
                                    next_token = self.importance_sampling(next_token_logits, **target_params)
                                elif target_name == "rejection":
                                    next_token = self.rejection_sampling(next_token_logits, **target_params)
                                elif target_name == "mcmc":
                                    next_token = self.mcmc_sampling(next_token_logits, **target_params)
                                elif target_name == "smc":
                                    next_token = self.sequential_monte_carlo(next_token_logits, **target_params)
                                elif target_name == "typical":
                                    next_token = self.typical_sampling(next_token_logits, **target_params)
                                elif target_name == "sde":
                                    next_token = self.sde_sampling(next_token_logits, **target_params)
                                elif target_name == "vesde":
                                    next_token = self.vesde_sampling(next_token_logits, **target_params)
                                elif target_name == "eagle":
                                    next_token = self.eagle_sampling(next_token_logits, **target_params)
                                
                                # Calculate metrics for this token
                                probs = F.softmax(next_token_logits, dim=-1)
                                top_prob, _ = torch.max(probs, dim=-1)
                                token_probs.append(top_prob.item())
                                
                                # Calculate entropy
                                log_probs = torch.log(probs + 1e-10)
                                entropy = -torch.sum(probs * log_probs, dim=-1)
                                entropy_values.append(entropy.item())
                                
                                next_token_tensor = torch.tensor([[next_token]]).to(self.device)
                                input_ids = torch.cat((input_ids, next_token_tensor), dim=1)
                                tokens_generated += 1
                                
                                # Stop if we generate an end-of-sequence token
                                if next_token == self.tokenizer.eos_token_id:
                                    break
                    
                    speculative_time = time.time() - start_time_spec
                    speculative_times.append(speculative_time)
                    
                    # Decode the final sequence
                    generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
                    combination_results["texts"].append(generated_text)
                
                # Calculate metrics across all trials
                acceptance_rate = total_accepted_tokens / total_draft_tokens if total_draft_tokens > 0 else 0
                avg_baseline_time = sum(baseline_times) / len(baseline_times) if baseline_times else 0
                avg_speculative_time = sum(speculative_times) / len(speculative_times) if speculative_times else 0
                speedup = avg_baseline_time / avg_speculative_time if avg_speculative_time > 0 else 0
                
                # Calculate token-level metrics
                if token_probs:
                    mean_prob = np.mean(token_probs)
                    mean_entropy = np.mean(entropy_values)
                    perplexity = np.exp(-np.mean(np.log(token_probs)))
                else:
                    mean_prob = 0
                    mean_entropy = 0
                    perplexity = 0
                
                # Calculate diversity metrics
                all_tokens = []
                unique_tokens = set()
                
                for text in combination_results["texts"]:
                    tokens = self.tokenizer.encode(text)
                    all_tokens.extend(tokens)
                    unique_tokens.update(tokens)
                
                unique_token_ratio = len(unique_tokens) / len(all_tokens) if all_tokens else 0
                
                # Calculate pairwise diversity
                if num_trials > 1:
                    diversity_scores = []
                    for i in range(num_trials):
                        for j in range(i+1, num_trials):
                            text_i_tokens = set(self.tokenizer.encode(combination_results["texts"][i]))
                            text_j_tokens = set(self.tokenizer.encode(combination_results["texts"][j]))
                            
                            if not text_i_tokens or not text_j_tokens:
                                continue
                                
                            # Calculate Jaccard similarity
                            intersection = len(text_i_tokens.intersection(text_j_tokens))
                            union = len(text_i_tokens.union(text_j_tokens))
                            similarity = intersection / union if union > 0 else 0
                            diversity = 1 - similarity
                            diversity_scores.append(diversity)
                    
                    average_diversity = np.mean(diversity_scores) if diversity_scores else 0
                else:
                    average_diversity = 0
                
                # Update metrics
                combination_results["metrics"] = {
                    "time_taken": avg_speculative_time,
                    "tokens_per_second": tokens_generated / avg_speculative_time if avg_speculative_time > 0 else 0,
                    "mean_prob": mean_prob,
                    "mean_entropy": mean_entropy,
                    "perplexity": perplexity,
                    "acceptance_rate": acceptance_rate,
                    "speedup_factor": speedup
                }
                
                combination_results["diversity"] = average_diversity
                combination_results["unique_token_ratio"] = unique_token_ratio
                
                results[combination_name] = combination_results
                
                print(f"Results for {combination_name}:")
                print(f"  Acceptance Rate: {acceptance_rate:.2f}")
                print(f"  Speedup Factor: {speedup:.2f}x")
                print(f"  Sample output: {combination_results['texts'][0][:100]}...")
        
        return results

    def visualize_speculative_results(self, results):
        """
        Visualize the results of speculative sampling combinations
        
        Args:
            results: Dictionary with evaluation results
        """
        # Extract metrics for plotting
        combinations = list(results.keys())
        
        metrics_to_plot = [
            ("time_taken", "Generation Time (s)"),
            ("tokens_per_second", "Tokens per Second"),
            ("acceptance_rate", "Acceptance Rate"),
            ("speedup_factor", "Speedup Factor"),
            ("diversity", "Output Diversity"),
            ("perplexity", "Perplexity")
        ]
        
        # Create subplots
        fig, axes = plt.subplots(3, 2, figsize=(16, 16))
        axes = axes.flatten()
        
        for i, (metric_key, metric_name) in enumerate(metrics_to_plot):
            metric_values = []
            
            for combo in combinations:
                if metric_key in results[combo]["metrics"]:
                    metric_values.append(results[combo]["metrics"][metric_key])
                elif metric_key in results[combo]:
                    metric_values.append(results[combo][metric_key])
                else:
                    metric_values.append(0)
            
            # Sort combinations by this metric for better visualization
            sorted_indices = np.argsort(metric_values)
            if metric_name in ["Tokens per Second", "Acceptance Rate", "Speedup Factor", "Output Diversity"]:
                # For these metrics, higher is better, so reverse sort
                sorted_indices = sorted_indices[::-1]
            
            sorted_combos = [combinations[i] for i in sorted_indices]
            sorted_values = [metric_values[i] for i in sorted_indices]
            
            # Plot only top 10 to avoid overcrowding
            if len(sorted_combos) > 10:
                sorted_combos = sorted_combos[:10]
                sorted_values = sorted_values[:10]
            
            axes[i].bar(sorted_combos, sorted_values)
            axes[i].set_title(metric_name)
            axes[i].set_xlabel("Sampling Combination (Draft->Target)")
            axes[i].set_ylabel("Value")
            axes[i].set_xticklabels(sorted_combos, rotation=45, ha="right")
        
        plt.tight_layout()
        plt.savefig("speculative_sampling_comparison.png")
        plt.close()
        
        # Create a summary table
        print("\n===== SPECULATIVE SAMPLING COMBINATIONS SUMMARY =====")
        print(f"{'Combination':<20} | {'Speed↑':<10} | {'Accept Rate':<10} | {'Diversity':<10} | {'Perplexity':<10}")
        print("-" * 70)
        
        # Sort by speedup for the table
        speedups = [results[combo]["metrics"].get("speedup_factor", 0) for combo in combinations]
        sorted_indices = np.argsort(speedups)[::-1]  # Descending order
        
        for idx in sorted_indices:
            combo = combinations[idx]
            speedup = results[combo]["metrics"].get("speedup_factor", 0)
            accept_rate = results[combo]["metrics"].get("acceptance_rate", 0)
            diversity = results[combo].get("diversity", 0)
            perplexity = results[combo]["metrics"].get("perplexity", 0)
            
            print(f"{combo:<20} | {speedup:<10.2f} | {accept_rate:<10.2f} | {diversity:<10.2f} | {perplexity:<10.2f}")



    def top_k_sampling(self, logits, k=50):
        """
        Sample from the k most likely tokens
        
        Args:
            logits: Unnormalized log probabilities
            k: Number of top tokens to consider
            
        Returns:
            token_id: Id of the selected token
        """
        top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
        probs = F.softmax(top_k_logits, dim=-1)
        token_idx = torch.multinomial(probs, 1).item()
        return top_k_indices[0, token_idx].item()
    
    def top_p_sampling(self, logits, p=0.9):
        """
        Nucleus sampling - sample from the smallest set of tokens that exceed probability p
        
        Args:
            logits: Unnormalized log probabilities
            p: Cumulative probability threshold
            
        Returns:
            token_id: Id of the selected token
        """
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        filtered_logits = logits.clone()
        filtered_logits[0, indices_to_remove] = -float('Inf')
        
        probs = F.softmax(filtered_logits, dim=-1)
        token_id = torch.multinomial(probs, 1).item()
        return token_id
    
    def temperature_sampling(self, logits, temperature=0.7):
        """
        Sample with temperature - lower temperature makes distribution more peaked
        
        Args:
            logits: Unnormalized log probabilities
            temperature: Temperature parameter (0-1)
            
        Returns:
            token_id: Id of the selected token
        """
        if temperature == 0:
            return self.greedy_sampling(logits)
        
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)
        token_id = torch.multinomial(probs, 1).item()
        return token_id
    
    def importance_sampling(self, logits, num_samples=100):
        """
        Importance sampling - samples from a simpler proposal distribution
        and weights by the ratio of target to proposal probabilities
        
        Args:
            logits: Unnormalized log probabilities
            num_samples: Number of samples to draw
            
        Returns:
            token_id: Id of the selected token
        """
        # Use temperature sampling as the proposal distribution
        proposal_temp = 1.5  # Higher temperature for broader sampling
        proposal_logits = logits / proposal_temp
        proposal_probs = F.softmax(proposal_logits, dim=-1)
        
        # Target distribution
        target_probs = F.softmax(logits, dim=-1)
        
        # Sample from proposal
        samples = torch.multinomial(proposal_probs, num_samples, replacement=True)
        
        # Calculate importance weights
        weights = torch.zeros(num_samples)
        for i in range(num_samples):
            sample_idx = samples[0, i].item()
            weights[i] = target_probs[0, sample_idx] / proposal_probs[0, sample_idx]
        
        # Normalize weights
        weights = weights / weights.sum()
        
        # Resample according to weights
        resampled_idx = torch.multinomial(weights, 1).item()
        token_id = samples[0, resampled_idx].item()
        
        return token_id
    
    def rejection_sampling(self, logits, max_attempts=100):
        """
        Rejection sampling - accept/reject samples based on the ratio to an upper bound
        
        Args:
            logits: Unnormalized log probabilities
            max_attempts: Maximum number of attempts before falling back
            
        Returns:
            token_id: Id of the selected token
        """
        # Target distribution
        target_probs = F.softmax(logits, dim=-1)
        
        # Find the maximum probability as our upper bound
        M = torch.max(target_probs).item() * 1.1  # Add 10% margin
        
        # Uniform proposal distribution over the vocabulary
        vocab_size = logits.shape[-1]
        
        for _ in range(max_attempts):
            # Sample uniformly from vocabulary
            proposal_idx = np.random.randint(0, vocab_size)
            
            # Acceptance probability
            accept_prob = target_probs[0, proposal_idx].item() / M
            
            # Accept or reject
            if np.random.random() < accept_prob:
                return proposal_idx
        
        # Fallback to greedy sampling if max attempts reached
        return self.greedy_sampling(logits)
    
    def mcmc_sampling(self, logits, num_steps=100, init_token=None):
        """
        Markov Chain Monte Carlo sampling using Metropolis-Hastings algorithm
        
        Args:
            logits: Unnormalized log probabilities
            num_steps: Number of MCMC steps
            init_token: Initial token to start the chain (random if None)
            
        Returns:
            token_id: Id of the selected token
        """
        # Target distribution
        target_probs = F.softmax(logits, dim=-1).cpu().numpy().flatten()
        vocab_size = len(target_probs)
        
        # Initialize with a random token if not provided
        current_token = init_token if init_token is not None else np.random.randint(0, vocab_size)
        
        # Simple proposal: select a nearby token with higher probability assigned to closer tokens
        def proposal(token):
            # Gaussian proposal centered at current token
            proposal = int(np.random.normal(token, vocab_size/10))
            # Ensure it's within bounds
            return max(0, min(vocab_size - 1, proposal))
        
        # Run MCMC
        for _ in range(num_steps):
            # Propose a new token
            proposed_token = proposal(current_token)
            
            # Calculate acceptance probability
            accept_ratio = target_probs[proposed_token] / target_probs[current_token]
            
            # Accept or reject
            if np.random.random() < accept_ratio:
                current_token = proposed_token
                
        return current_token
    
    def sequential_monte_carlo(self, logits, num_particles=100, num_steps=5):
        """
        Sequential Monte Carlo (Particle Filtering) sampling
        
        Args:
            logits: Unnormalized log probabilities
            num_particles: Number of particles
            num_steps: Number of resampling steps
            
        Returns:
            token_id: Id of the selected token
        """
        probs = F.softmax(logits, dim=-1).cpu().numpy().flatten()
        vocab_size = len(probs)
        
        # Initialize particles randomly
        particles = np.random.randint(0, vocab_size, size=num_particles)
        weights = np.ones(num_particles) / num_particles
        
        for _ in range(num_steps):
            # Update weights based on target distribution
            for i in range(num_particles):
                weights[i] = probs[particles[i]]
            
            # Normalize weights
            weights = weights / np.sum(weights)
            
            # Resample particles based on weights
            indices = np.random.choice(num_particles, size=num_particles, p=weights)
            particles = particles[indices]
            
            # Add some noise to particles (mutation step)
            noise = np.random.normal(0, 1, size=num_particles).astype(int)
            particles = np.clip(particles + noise, 0, vocab_size - 1)
            
            # Reset weights
            weights = np.ones(num_particles) / num_particles
        
        # Final weighted average (can also just take the most common particle)
        for i in range(num_particles):
            weights[i] = probs[particles[i]]
        weights = weights / np.sum(weights)
        
        # Return the particle with the highest weight
        return particles[np.argmax(weights)]
    
    def typical_sampling(self, logits, mass=0.9):
        """
        Typical sampling - select tokens based on their typicality
        (how close their information content is to the expected information)
        
        Args:
            logits: Unnormalized log probabilities
            mass: Probability mass to include
            
        Returns:
            token_id: Id of the selected token
        """
        # Calculate token probabilities
        probs = F.softmax(logits, dim=-1)
        
        # Calculate entropy
        log_probs = torch.log(probs + 1e-10)
        expected_entropy = -torch.sum(probs * log_probs, dim=-1)
        
        # Calculate each token's contribution to entropy
        token_entropies = -log_probs
        
        # Calculate how far each token is from the expected entropy
        token_divergence = torch.abs(token_entropies - expected_entropy.unsqueeze(-1))
        
        # Sort by divergence
        sorted_divergence, sorted_indices = torch.sort(token_divergence, dim=-1)
        sorted_probs = probs.gather(-1, sorted_indices)
        
        # Keep tokens until we reach the desired probability mass
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        indices_to_keep = cumulative_probs <= mass
        
        # If nothing is kept, keep at least one token
        if not torch.any(indices_to_keep):
            indices_to_keep[0, 0] = True
        
        # Create a mask for the tokens to keep
        masked_divergence = torch.full_like(token_divergence, float('inf'))
        
        # Fix: Ensure proper dimensionality with unsqueeze
        indices_to_keep_original = sorted_indices.masked_select(indices_to_keep).unsqueeze(0)
        masked_divergence.scatter_(-1, indices_to_keep_original, 0)
        
        # Apply the mask to the logits
        filtered_logits = logits.clone()
        filtered_logits[masked_divergence == float('inf')] = -float('inf')
        
        # Sample from the filtered logits
        probs = F.softmax(filtered_logits, dim=-1)
        token_id = torch.multinomial(probs, 1).item()
        
        return token_id
    
    def generate_sequence(self, prompt, method="greedy", max_length=50, **kwargs):
        """
        Generate a sequence using the specified sampling method
        
        Args:
            prompt: Starting text prompt
            method: Sampling method to use
            max_length: Maximum sequence length to generate
            **kwargs: Additional arguments for the sampling method
            
        Returns:
            generated_text: The generated text sequence
            metrics: Dictionary of performance metrics
        """
        start_time = time.time()
        
        # Tokenize the prompt
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        
        # Track token probabilities and entropies for evaluation
        token_probs = []
        entropy_values = []
        tokens_generated = 0
        
        # Generate sequence token by token
        for _ in range(max_length):
            with torch.no_grad():
                # Get logits for next token
                outputs = self.model(input_ids)
                next_token_logits = outputs.logits[:, -1, :]
                
                # Calculate probabilities for metrics
                probs = F.softmax(next_token_logits, dim=-1)
                top_prob, _ = torch.max(probs, dim=-1)
                token_probs.append(top_prob.item())
                
                # Calculate entropy
                log_probs = torch.log(probs + 1e-10)
                entropy = -torch.sum(probs * log_probs, dim=-1)
                entropy_values.append(entropy.item())
                
                # Choose next token based on selected method
                if method == "greedy":
                    next_token = self.greedy_sampling(next_token_logits)
                elif method == "top_k":
                    k = kwargs.get("k", 50)
                    next_token = self.top_k_sampling(next_token_logits, k=k)
                elif method == "top_p":
                    p = kwargs.get("p", 0.9)
                    next_token = self.top_p_sampling(next_token_logits, p=p)
                elif method == "temperature":
                    temp = kwargs.get("temperature", 0.7)
                    next_token = self.temperature_sampling(next_token_logits, temperature=temp)
                elif method == "importance":
                    num_samples = kwargs.get("num_samples", 100)
                    next_token = self.importance_sampling(next_token_logits, num_samples=num_samples)
                elif method == "rejection":
                    max_attempts = kwargs.get("max_attempts", 100)
                    next_token = self.rejection_sampling(next_token_logits, max_attempts=max_attempts)
                elif method == "mcmc":
                    num_steps = kwargs.get("num_steps", 100)
                    next_token = self.mcmc_sampling(next_token_logits, num_steps=num_steps)
                elif method == "smc":
                    num_particles = kwargs.get("num_particles", 100)
                    num_steps = kwargs.get("num_steps", 5)
                    next_token = self.sequential_monte_carlo(next_token_logits, num_particles=num_particles, num_steps=num_steps)
                elif method == "typical":
                    mass = kwargs.get("mass", 0.9)
                    next_token = self.typical_sampling(next_token_logits, mass=mass)
                
                elif method == "sde":
                    num_outer_steps = kwargs.get("num_outer_steps", 20)
                    num_inner_steps = kwargs.get("num_inner_steps", 5)
                    beta_min = kwargs.get("beta_min", 0.1)
                    beta_max = kwargs.get("beta_max", 20.0)
                    next_token = self.sde_sampling(next_token_logits, 
                                                num_outer_steps=num_outer_steps,
                                                num_inner_steps=num_inner_steps,
                                                beta_min=beta_min,
                                                beta_max=beta_max)
                elif method == "vesde":
                    num_outer_steps = kwargs.get("num_outer_steps", 20)
                    num_inner_steps = kwargs.get("num_inner_steps", 5)
                    beta_min = kwargs.get("beta_min", 0.1)
                    beta_max = kwargs.get("beta_max", 20.0)
                    next_token = self.vesde_sampling(next_token_logits,
                                                    num_outer_steps=num_outer_steps,
                                                    num_inner_steps=num_inner_steps,
                                                    beta_min=beta_min,
                                                    beta_max=beta_max)
                elif method == "eagle":
                    draft_scale = kwargs.get("draft_scale", 0.8)
                    max_attempts = kwargs.get("max_attempts", 5)
                    next_token = self.eagle_sampling(next_token_logits, 
                                                draft_scale=draft_scale,
                                                    max_attempts=max_attempts)
                else:
                    raise ValueError(f"Unknown sampling method: {method}")
                
                tokens_generated += 1
                
                # Append the chosen token to input
                next_token_tensor = torch.tensor([[next_token]]).to(self.device)
                input_ids = torch.cat((input_ids, next_token_tensor), dim=1)
                
                # Stop if we generate an end-of-sequence token
                if next_token == self.tokenizer.eos_token_id:
                    break
        
        end_time = time.time()
        
        # Decode the generated sequence
        generated_text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        
        # Calculate metrics
        metrics = {
            "time_taken": end_time - start_time,
            "tokens_per_second": tokens_generated / (end_time - start_time) if (end_time > start_time) else 0,
            "mean_prob": np.mean(token_probs) if token_probs else 0,
            "mean_entropy": np.mean(entropy_values) if entropy_values else 0,
            "sequence_probability": np.prod(token_probs) if token_probs else 0,
            "perplexity": np.exp(-np.mean(np.log(token_probs)) if token_probs else 0)
        }
        
        return generated_text, metrics
    
    def evaluate_methods(self, prompt, methods=None, num_generations=3, max_length=30):
        """
        Evaluate multiple sampling methods on the same prompt
        
        Args:
            prompt: Text prompt to start generation
            methods: Dictionary of methods and their parameters
            num_generations: Number of sequences to generate for each method
            max_length: Maximum sequence length
            
        Returns:
            results: Dictionary with evaluation results
        """
        if methods is None:
            methods = {
                "greedy": {},
                "top_k": {"k": 50},
                "top_p": {"p": 0.9},
                "temperature": {"temperature": 0.7},
                "importance": {"num_samples": 100},
                "rejection": {"max_attempts": 100},
                "mcmc": {"num_steps": 50},
                "smc": {"num_particles": 50, "num_steps": 3},
                "typical": {"mass": 0.9},
                 "sde": {"num_outer_steps": 20, "num_inner_steps": 5},
                "vesde": {"num_outer_steps": 20, "num_inner_steps": 5},
                "eagle": {"draft_scale": 0.8, "max_attempts": 5}
            }
        
        results = {}
        all_texts = defaultdict(list)
        all_metrics = defaultdict(list)
        
        print(f"Evaluating {len(methods)} sampling methods on prompt: '{prompt}'")
        
        for method_name, params in tqdm(methods.items()):
            print(f"\nGenerating with {method_name}...")
            
            method_results = {
                "texts": [],
                "metrics": {},
                "unique_token_ratio": 0,
                "diversity": 0
            }
            
            for i in range(num_generations):
                text, metrics = self.generate_sequence(prompt, method=method_name, max_length=max_length, **params)
                method_results["texts"].append(text)
                all_texts[method_name].append(text)
                all_metrics[method_name].append(metrics)
            
            # Calculate average metrics
            avg_metrics = {}
            for metric in all_metrics[method_name][0].keys():
                avg_metrics[metric] = np.mean([m[metric] for m in all_metrics[method_name]])
            
            # Calculate diversity metrics
            unique_tokens = set()
            total_tokens = 0
            
            for text in method_results["texts"]:
                tokens = self.tokenizer.encode(text)
                unique_tokens.update(tokens)
                total_tokens += len(tokens)
            
            method_results["unique_token_ratio"] = len(unique_tokens) / total_tokens if total_tokens > 0 else 0
            
            # Calculate diversity using pairwise BLEU score (lower is more diverse)
            if num_generations > 1:
                diversity_scores = []
                for i in range(num_generations):
                    for j in range(i+1, num_generations):
                        text_i_tokens = set(self.tokenizer.encode(method_results["texts"][i]))
                        text_j_tokens = set(self.tokenizer.encode(method_results["texts"][j]))
                        
                        if not text_i_tokens or not text_j_tokens:
                            continue
                            
                        # Calculate Jaccard similarity (intersection over union)
                        intersection = len(text_i_tokens.intersection(text_j_tokens))
                        union = len(text_i_tokens.union(text_j_tokens))
                        similarity = intersection / union if union > 0 else 0
                        diversity = 1 - similarity  # Convert to diversity
                        diversity_scores.append(diversity)
                
                method_results["diversity"] = np.mean(diversity_scores) if diversity_scores else 0
            
            method_results["metrics"] = avg_metrics
            results[method_name] = method_results
            
            print(f"Sample output ({method_name}): {method_results['texts'][0][:100]}...")
        
        return results
    
    def visualize_results(self, results, prompt_num = 1):
        """
        Visualize the evaluation results
        
        Args:
            results: Dictionary with evaluation results
        """
        # Extract metrics for plotting
        methods = list(results.keys())
        
        metrics_to_plot = [
            ("time_taken", "Generation Time (s)"),
            ("tokens_per_second", "Tokens per Second"),
            ("mean_entropy", "Mean Entropy"),
            ("diversity", "Output Diversity"),
            ("unique_token_ratio", "Unique Token Ratio"),
            ("perplexity", "Perplexity")
        ]
        
        # Create subplots
        fig, axes = plt.subplots(3, 2, figsize=(15, 15))
        axes = axes.flatten()
        
        for i, (metric_key, metric_name) in enumerate(metrics_to_plot):
            metric_values = []
            
            for method in methods:
                if metric_key in results[method]["metrics"]:
                    metric_values.append(results[method]["metrics"][metric_key])
                elif metric_key in results[method]:
                    metric_values.append(results[method][metric_key])
                else:
                    metric_values.append(0)
            
            axes[i].bar(methods, metric_values)
            axes[i].set_title(metric_name)
            axes[i].set_xlabel("Sampling Method")
            axes[i].set_ylabel("Value")
            axes[i].set_xticklabels(methods, rotation=45, ha="right")
        
        plt.tight_layout()
        plt.savefig("sampling_comparison.png")
        plt.close()
        
        # Create a summary table
        print("\n===== SAMPLING METHODS EVALUATION SUMMARY =====")
        header = f"{'Method':<12} | {'Time (s)':<10} | {'Tokens/s':<10} | {'Diversity':<10} | {'Perplexity':<10}"
        print(header)
        print("-" * len(header))
        
        for method in methods:
            time_taken = results[method]["metrics"].get("time_taken", 0)
            tokens_per_sec = results[method]["metrics"].get("tokens_per_second", 0)
            diversity = results[method].get("diversity", 0)
            perplexity = results[method]["metrics"].get("perplexity", 0)
            
            print(f"{method:<12} | {time_taken:<10.2f} | {tokens_per_sec:<10.2f} | {diversity:<10.2f} | {perplexity:<10.2f}")



    # S3GM inspired sampling 
    def sde_sampling(self, logits, num_outer_steps=20, num_inner_steps=5, beta_min=0.1, beta_max=20.0, **kwargs):
        """
        Implements SDE-based sampling for token selection
        
        Args:
            logits: The logits from which to sample the next token
            num_outer_steps: Number of denoising steps
            num_inner_steps: Number of corrector steps per denoising step
            beta_min: Minimum noise level
            beta_max: Maximum noise level
            
        Returns:
            next_token: The selected token ID
        """
        # Get vocabulary size from logits
        vocab_size = logits.shape[-1]
        
        # Initialize with noise (similar to diffusion model approach)
        x = torch.randn(1, vocab_size).to(self.device)
        
        # Helper functions for SDE sampling
        def get_time_dependent_params(t):
            """Get time-dependent SDE parameters"""
            # Linear interpolation between beta_min and beta_max
            beta_t = beta_min + t * (beta_max - beta_min)
            # For VESDE, the diffusion coefficient is proportional to sqrt of beta
            sigma_t = np.sqrt(beta_t)
            return beta_t, sigma_t
        
        def get_score_model(combined_logits):
            """Convert logits to a score function (gradient of log probability)"""
            probs = F.softmax(combined_logits, dim=-1)
            scores = torch.log(probs + 1e-10)
            return scores
        
        def langevin_corrector(x, t, score_fn, snr=0.1):
            """Langevin dynamics corrector step for refining token probabilities"""
            _, sigma_t = get_time_dependent_params(t)
            
            # Step size based on SNR
            step_size = (snr * sigma_t) ** 2 * 2
            
            # Get score
            score = score_fn(x)
            
            # Langevin dynamics update
            noise = torch.randn_like(x)
            x_new = x + step_size * score + np.sqrt(2 * step_size) * noise
            
            return x_new
        
        # Progressive denoising (similar to reverse diffusion process)
        for i in range(num_outer_steps):
            t = 1.0 - i / (num_outer_steps - 1)  # Time from 1 to 0
            
            # Get score function based on model logits
            def score_fn(x_in):
                # Combine noisy distribution with model logits
                beta_t, _ = get_time_dependent_params(t)
                alpha_t = 1.0 - beta_t
                combined_logits = alpha_t * logits + np.sqrt(beta_t) * x_in
                return get_score_model(combined_logits)
            
            # Apply corrector steps (Langevin dynamics)
            for _ in range(num_inner_steps):
                x = langevin_corrector(x, t, score_fn)
        
        # Final token selection
        combined_logits = logits + 0.01 * x  # Small noise addition for exploration
        probs = F.softmax(combined_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        
        return next_token


     
    def vesde_sampling(self, logits, num_outer_steps=20, num_inner_steps=5, beta_min=0.1, beta_max=20.0, **kwargs):
        """
        Generate next token using Variance Exploding SDE sampling
        
        Args:
            logits: The logits from which to sample the next token
            num_outer_steps: Number of denoising steps
            num_inner_steps: Number of corrector steps per denoising step
            beta_min: Minimum noise level
            beta_max: Maximum noise level
            
        Returns:
            next_token: The selected token ID
        """
        # Get vocabulary size from logits
        vocab_size = logits.shape[-1]
        
        # Initialize with pure noise
        x = torch.randn(1, vocab_size).to(self.device) * beta_max
        
        # Helper function to get time-dependent parameters
        def get_time_dependent_params(t):
            beta_t = beta_min + t * (beta_max - beta_min)
            sigma_t = np.sqrt(beta_t)
            return beta_t, sigma_t
        
        # Discretized reverse-time SDE
        time_steps = torch.linspace(1.0, 0.0, num_outer_steps).to(self.device)
        for i in range(len(time_steps) - 1):
            t = time_steps[i]
            dt = time_steps[i] - time_steps[i + 1]
            
            # Current noise level
            beta_t, sigma_t = get_time_dependent_params(t.item())
            
            # Convert to tensor for pytorch operations
            beta_t_tensor = torch.tensor(beta_t, device=self.device)
            sigma_t_tensor = torch.tensor(sigma_t, device=self.device)
            
            # Get score estimate (gradient of log probability)
            probs = F.softmax(logits + x / (sigma_t ** 2), dim=-1)
            score = torch.log(probs + 1e-10)
            
            # Drift term
            drift = -0.5 * beta_t_tensor * x
            
            # Diffusion term
            diffusion = torch.sqrt(beta_t_tensor) * torch.randn_like(x)
            
            # Update using Euler-Maruyama discretization
            x = x + (drift + 0.5 * (sigma_t_tensor ** 2) * score) * dt + torch.sqrt(dt) * diffusion
            
            # Apply corrector steps (Langevin dynamics)
            for _ in range(num_inner_steps):
                noise_scale = torch.sqrt(torch.tensor(2.0) * beta_t_tensor * 0.1)  # 0.1 is step size
                noise = torch.randn_like(x) * noise_scale
                
                # Get score
                probs = F.softmax(logits + x / (sigma_t ** 2), dim=-1)
                score = torch.log(probs + 1e-10)
                
                # Update
                x = x + beta_t_tensor * score * 0.1 + noise
        
        # Final selection based on combined logits and noise
        probs = F.softmax(logits + x * 0.01, dim=-1)  # Small noise influence
        next_token = torch.multinomial(probs, num_samples=1).item()
        
        return next_token
    
    # EAGLE sampling from the first suggested sampling research paper
    def eagle_sampling(self, logits, draft_scale=0.8, max_attempts=5, **kwargs):
        """
        Implements EAGLE (Efficient Accelerated Generation with Lightweight Encoders) sampling
        
        Args:
            logits: Logits from the target model
            draft_scale: Scale factor for the draft model (simulated)
            max_attempts: Maximum number of draft tokens to generate
            
        Returns:
            next_token: The selected token ID
        """
        # In a real implementation, we would have a separate draft model
        # Here we'll simulate one by adding noise to the target model's logits
        
        # Get vocabulary size
        vocab_size = logits.shape[-1]
        
        # Create a simulated "draft model" by adding noise to the target logits
        # This simulates having a smaller, less accurate model
        noise = torch.randn_like(logits) * 0.2  # Noise level
        draft_logits = logits * draft_scale + noise
        
        # Get probabilities from both models
        target_probs = F.softmax(logits, dim=-1)
        draft_probs = F.softmax(draft_logits, dim=-1)
        
        # Draft phase: generate candidate tokens from draft model
        # In real EAGLE, we'd generate multiple tokens at once
        # For simplicity, we'll just sample tokens until one is accepted
        for _ in range(max_attempts):
            # Sample from draft model
            draft_token = torch.multinomial(draft_probs, num_samples=1).item()
            
            # Verification phase: compute acceptance probability
            target_prob = target_probs[0, draft_token].item()
            draft_prob = draft_probs[0, draft_token].item()
            
            # Calculate acceptance probability (min to ensure it's <= 1.0)
            # This is a simplified version of the acceptance test
            acceptance_prob = min(1.0, target_prob / (draft_prob + 1e-10))
            
            # Accept or reject
            if torch.rand(1).item() < acceptance_prob:
                return draft_token
        
        # If no draft tokens were accepted, fall back to sampling from target model
        return torch.multinomial(target_probs, num_samples=1).item()   

import openai
import os
import json
from typing import Dict, List, Any

# class SemanticEvaluator:
#     def __init__(self, openai_api_key: str = None):
#         """
#         Initialize the semantic evaluator with OpenAI API
        
#         Args:
#             openai_api_key (str, optional): OpenAI API key. 
#             If not provided, will try to read from environment variable.
#         """
#         if openai_api_key is None:
#             openai_api_key = os.getenv('OPENAI_API_KEY')
        
#         if not openai_api_key:
#             raise ValueError("OpenAI API key must be provided either as argument or in OPENAI_API_KEY environment variable")
        
#         openai.api_key = openai_api_key

#     def evaluate_response(
#         self, 
#         original_prompt: str, 
#         response: str, 
#         model: str = "gpt-3.5-turbo"
#     ) -> Dict[str, Any]:
#         """
#         Evaluate a response's semantic quality using OpenAI's model
        
#         Args:
#             original_prompt (str): The original input prompt
#             response (str): The generated response to evaluate
#             model (str, optional): OpenAI model to use for evaluation
        
#         Returns:
#             Dict containing evaluation metrics
#         """
#         try:
#             # Construct a detailed evaluation prompt
#             evaluation_prompt = f"""
#             Please evaluate the following response to the original prompt:

#             Original Prompt: "{original_prompt}"
#             Response: "{response}"

#             Provide a detailed evaluation focusing on:
#             1. Grammar (Score 0-10)
#             2. Coherence/Making Sense (Score 0-10)
#             3. Completeness of Answer (Score 0-10)
#             4. Brief explanation for each score

#             Respond in a JSON format:
#             {{
#                 "grammar_score": [0-10],
#                 "coherence_score": [0-10],
#                 "completeness_score": [0-10],
#                 "grammar_explanation": "...",
#                 "coherence_explanation": "...",
#                 "completeness_explanation": "..."
#             }}
#             """

#             # Make API call to OpenAI
#             response = openai.ChatCompletion.create(
#                 model=model,
#                 response_format={"type": "json_object"},
#                 messages=[
#                     {"role": "system", "content": "You are a precise evaluator of text responses."},
#                     {"role": "user", "content": evaluation_prompt}
#                 ],
#                 temperature=0.2  # Low temperature for consistent evaluation
#             )

#             # Parse the JSON response
#             evaluation = json.loads(response.choices[0].message.content)
            
#             # Calculate overall semantic score
#             evaluation['semantic_score'] = (
#                 evaluation.get('grammar_score', 0) + 
#                 evaluation.get('coherence_score', 0) + 
#                 evaluation.get('completeness_score', 0)
#             ) / 3

#             return evaluation

#         except Exception as e:
#             print(f"Error in semantic evaluation: {e}")
#             return {
#                 "error": str(e),
#                 "semantic_score": 0
#             }

#     def batch_evaluate_responses(
#         self, 
#         original_prompt: str, 
#         responses: List[str], 
#         model: str = "gpt-3.5-turbo"
#     ) -> List[Dict[str, Any]]:
#         """
#         Batch evaluate multiple responses
        
#         Args:
#             original_prompt (str): The original input prompt
#             responses (List[str]): List of responses to evaluate
#             model (str, optional): OpenAI model to use for evaluation
        
#         Returns:
#             List of evaluation dictionaries
#         """
#         return [
#             self.evaluate_response(original_prompt, response, model) 
#             for response in responses
#         ]

# def main():
#     # Example usage
#     openai_api_key = os.getenv('OPENAI_API_KEY')  # Make sure to set this environment variable
#     evaluator = SemanticEvaluator(openai_api_key)

#     # Sample prompts and responses
#     prompts_and_responses = [
#         {
#             "prompt": "Explain the concept of artificial intelligence",
#             "responses": [
#                 "AI is a technology that makes computers smart.",
#                 "Artificial Intelligence (AI) is a sophisticated field of computer science that focuses on creating intelligent machines capable of simulating human-like cognitive processes, including learning, problem-solving, perception, and decision-making. It encompasses various subfields like machine learning, neural networks, and deep learning, enabling systems to analyze complex data, recognize patterns, and make autonomous decisions across diverse domains such as healthcare, finance, robotics, and more."
#             ]
#         },
#         {
#             "prompt": "What is the capital of France?",
#             "responses": [
#                 "Paris is the capital.",
#                 "Pari is the captial of Fronce"
#             ]
#         }
#     ]

#     # Evaluate responses
#     for item in prompts_and_responses:
#         print(f"\nPrompt: {item['prompt']}")
#         evaluations = evaluator.batch_evaluate_responses(
#             item['prompt'], 
#             item['responses']
#         )
        
#         for i, (response, evaluation) in enumerate(zip(item['responses'], evaluations), 1):
#             print(f"\nResponse {i}: {response}")
#             print("Evaluation:")
#             for key, value in evaluation.items():
#                 print(f"  {key.replace('_', ' ').title()}: {value}")

class SemanticEvaluator:
    def __init__(self, api_key: str = None):
        """
        Initialize the semantic evaluator with OpenAI API client
        
        Args:
            api_key (str, optional): OpenAI API key. 
            If not provided, will try to read from environment variable.
        """
        # Use OpenAI() client initialization
        if api_key is None:
              api_key = os.getenv('OPENAI_API_KEY')
        self.client = openai.OpenAI(api_key=api_key)
        openai.api_key = api_key

    def evaluate_response(
        self, 
        original_prompt: str, 
        response: str, 
        model: str = "gpt-4o"
    ) -> Dict[str, Any]:
        """
        Evaluate a response's semantic quality using OpenAI's model
        
        Args:
            original_prompt (str): The original input prompt
            response (str): The generated response to evaluate
            model (str, optional): OpenAI model to use for evaluation
        
        Returns:
            Dict containing evaluation metrics
        """
        try:
            # Construct a detailed evaluation prompt
            completion = self.client.chat.completions.create(
                model=model,
                response_format={"type": "json_object"},
                messages=[
                    {
                        "role": "system", 
                        "content": "You are a precise evaluator of text responses. Provide a JSON evaluation."
                    },
                    {
                        "role": "user", 
                        "content": f"""
                        Please evaluate the following response to the original prompt:

                        Original Prompt: "{original_prompt}"
                        Response: "{response}"

                        Provide a detailed evaluation focusing on:
                        1. Grammar (Score 0-10)
                        2. Coherence/Making Sense (Score 0-10)
                        3. Completeness of Answer (Score 0-10)
                        4. Brief explanation for each score

                        Respond strictly in this JSON format:
                        {{
                            "grammar_score": [0-10],
                            "coherence_score": [0-10],
                            "completeness_score": [0-10],
                            "grammar_explanation": "...",
                            "coherence_explanation": "...",
                            "completeness_explanation": "..."
                        }}
                        """
                    }
                ],
                temperature=0.2  # Low temperature for consistent evaluation
            )

            # Parse the JSON response
            evaluation = json.loads(completion.choices[0].message.content)
            
            # Calculate overall semantic score
            evaluation['semantic_score'] = (
                evaluation.get('grammar_score', 0) + 
                evaluation.get('coherence_score', 0) + 
                evaluation.get('completeness_score', 0)
            ) / 3

            return evaluation

        except Exception as e:
            print(f"Error in semantic evaluation: {e}")
            return {
                "error": str(e),
                "semantic_score": 0
            }

    def batch_evaluate_responses(
        self, 
        original_prompt: str, 
        responses: List[str], 
        model: str = "gpt-4o"
    ) -> List[Dict[str, Any]]:
        """
        Batch evaluate multiple responses
        
        Args:
            original_prompt (str): The original input prompt
            responses (List[str]): List of responses to evaluate
            model (str, optional): OpenAI model to use for evaluation
        
        Returns:
            List of evaluation dictionaries
        """
        return [
            self.evaluate_response(original_prompt, response, model) 
            for response in responses
        ]

def main():
    # Initialize the evaluator
    evaluator = SamplingEvaluator(model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    prompts = []
    # Define prompt
    prompts.append("Artificial intelligence has transformed the way we")
    prompts.append("What is the capital of France? ")
    prompts.append("What is 2+5?")
    prompts.append("artifitial  intelige has chang the way we")
    prompts.append("hey , how are you?")
    prompts.append("Artificial intelligence has transformed the way we live, work, and interact by revolutionizing countless industries and enhancing daily experiences. " \
    "From enabling personalized recommendations in entertainment and shopping to streamlining operations in healthcare and finance, AI has reshaped efficiency and decision-making. It powers autonomous vehicles, advances scientific research, and augments human creativity with tools for art, writing, and design. Moreover, AI fosters innovation by automating complex tasks, analyzing vast datasets, and uncovering patterns that were once beyond human reach." \
    " This transformative technology is seamlessly integrating into our lives, shaping a smarter, more connected future.")
     # Example usage
    import os
       
    openai_api_key = os.getenv('OPENAI_API_KEY')  # Make sure to set this environment variable
    print("look the api key is", openai_api_key)
    print(os.environ)
    # draft_methods = {
    #     "temperature": {"temperature": 1.0},
    #     "top_p": {"p": 0.9}
    # }

    # target_methods = {
    #     "greedy": {},
    #     "typical": {"mass": 0.95}
    # }

    # Run the test
    # results = evaluator.test_speculative_sampling(
    #     prompt="Once upon a time",
    #     draft_methods=draft_methods,
    #     target_methods=target_methods,
    #     num_trials=3
    # )

    # Visualize results
    # evaluator.visualize_speculative_results(results)



    response_evaluator = SemanticEvaluator(openai_api_key)
    prompts_and_responses = []
    for prompt in prompts:

        #experimment with coming up with something such a joke 
        
        # ask mathematical questions and see if they can answer them
        #check which one works better bad questions (grammar , vagueness , etc)
        #single answer vs long answer
        #formal vs non formal 
        #longer vs shorter prompt 
        
        # Define methods to evaluate with parameters
        methods = {
            "greedy": {},
            "top_k": {"k": 50},
            "top_p": {"p": 0.9},
            "temperature": {"temperature": 0.7},
            "importance": {"num_samples": 100},
            "rejection": {"max_attempts": 100},
            "mcmc": {"num_steps": 50},
            "smc": {"num_particles": 50, "num_steps": 3},
            "typical": {"mass": 0.9},
            "sde": {"num_outer_steps": 20, "num_inner_steps": 5},
            "vesde": {"num_outer_steps": 20, "num_inner_steps": 5},
            "eagle": {"draft_scale": 0.8, "max_attempts": 5}
        }
        
        # Evaluate all methods
        results = evaluator.evaluate_methods(prompt, methods, num_generations=3)
        # Visualize results
        evaluator.visualize_results(results)

        
        # Print a sample of each method's output
        print("\n===== SAMPLE OUTPUTS =====")
        for method, result in results.items():
            print(f"\n{method.upper()}:")
            print(result["texts"][0])
            print(result)
            print(result["texts"])

            obj_openai_eval = {
            "prompt": prompt,
            "responses": [
                result["texts"]
            ]
            }
            prompts_and_responses.append(obj_openai_eval)

            
            # response_evaluator.evaluate_response()


        for item in prompts_and_responses:
            print(f"\nPrompt: {item['prompt']}")
            evaluations = response_evaluator.batch_evaluate_responses(
                item['prompt'], 
                item['responses']
            )
            
            for i, (response, evaluation) in enumerate(zip(item['responses'], evaluations), 1):
                print(f"\nResponse {i}: {response}")
                print("Evaluation:")
                for key, value in evaluation.items():
                    print(f"  {key.replace('_', ' ').title()}: {value}")

    # Define methods to test as combination 
    draft_methods = {
        # "temperature": {"temperature": 1.0},
        "top_p": {"p": 0.9},
        "mcmc": {"num_steps": 50}
    }

    target_methods = {
        "greedy": {},
        # "typical": {"mass": 0.95},
        "mcmc": {"num_steps": 50}
    }

    # Run the test
    results = evaluator.test_speculative_sampling(
        prompt="Once upon a time",
        draft_methods=draft_methods,
        target_methods=target_methods,
        num_trials=3
    )

    # Visualize results
    evaluator.visualize_speculative_results(results)
# https://www.scikit-yb.org/en/latest/api/text/freqdist.html#token-frequency-distribution
if __name__ == "__main__":
    main()

Loading model TinyLlama/TinyLlama-1.1B-Chat-v1.0 on cpu...
Model loaded successfully
look the api key is 
Evaluating 12 sampling methods on prompt: 'Artificial intelligence has transformed the way we'


  0%|          | 0/12 [00:00<?, ?it/s]


Generating with greedy...


  8%|▊         | 1/12 [01:31<16:44, 91.27s/it]

Sample output (greedy): Artificial intelligence has transformed the way we live, work, and communicate. It has enabled us to...

Generating with top_k...


 17%|█▋        | 2/12 [02:55<14:31, 87.11s/it]

Sample output (top_k): Artificial intelligence has transformed the way we live and work for good, it...
The Impact of Artif...

Generating with top_p...


 25%|██▌       | 3/12 [04:24<13:10, 87.83s/it]

Sample output (top_p): Artificial intelligence has transformed the way we communicate, both in social media and on platform...

Generating with temperature...


 33%|███▎      | 4/12 [05:54<11:49, 88.73s/it]

Sample output (temperature): Artificial intelligence has transformed the way we communicate, shop, and learn. It has enabled us t...

Generating with importance...


 42%|████▏     | 5/12 [07:24<10:24, 89.26s/it]

Sample output (importance): Artificial intelligence has transformed the way we conduct various business and entertainment activi...

Generating with rejection...


 50%|█████     | 6/12 [08:51<08:50, 88.47s/it]

Sample output (rejection): Artificial intelligence has transformed the way we live, work, and communicate. It has enabled us to...

Generating with mcmc...
