In [None]:
import os
import json
import numpy as np
import random
import time
import re
import tiktoken
import google.generativeai as genai
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer
import math
from collections import defaultdict
from openai import OpenAI
import math


BASE_FEATURE_DIM = 384  # Dimension for question embeddings
ANSWER_EMBED_DIM = 384  # Dimension for answer embeddings
CASCADE_LENGTH = 4      # Number of attempts in the cascade (K)
UPDATE_FREQUENCY = 1    # Update JSON records after every question
USE_EMBEDDINGS = True   # Use embeddings instead of TF-IDF

# LinUCB parameters
ALPHA = 0.675           # Exploration-exploitation trade-off parameter
LAMBDA_REG = 0.45       # Regularization parameter for LinUCB matrix initialization

# Value Density & Cost Estimation parameters
BETA_COST = 1.0         # Controls width of LCB interval for LLM output cost estimation
EPSILON_VD = 1e-10      # Small positive value for value density calculation
DEFAULT_FALLBACK_COST = 0.0001  # Default cost when insufficient data

In [None]:
OPENROUTER_API_KEY = # Input Openrouter API key hr
OPENROUTER_BASE_URL = # Input Openrouter URL

MODELS_CONFIG = {
   "mistralai/mistral-small-3.1-24b-instruct": {"input_cost": 0.05 / 1e6, "output_cost": 0.15 / 1e6},
    "microsoft/phi-4": {"input_cost": 0.07 / 1e6, "output_cost": 0.14 / 1e6},
    "meta-llama/llama-4-maverick": {"input_cost": 0.17 / 1e6, "output_cost": 0.16 / 1e6},
    "google/gemini-2.0-flash-001": {"input_cost": 0.1 / 1e6, "output_cost": 0.4 / 1e6},
    "openai/gpt-4.1-nano": {"input_cost": 0.1 / 1e6, "output_cost": 0.4 / 1e6},
    "deepseek/deepseek-chat": {"input_cost": 0.38 / 1e6, "output_cost": 0.89 / 1e6},
}

In [None]:
GRADER_MODEL_NAME = "google/gemini-2.0-flash-lite-001"

AVAILABLE_LLMS = list(MODELS_CONFIG.keys())
LLM_ID_DIM = len(AVAILABLE_LLMS)

openrouter_client = OpenAI(
    base_url=OPENROUTER_BASE_URL,
    api_key=OPENROUTER_API_KEY,
    # You might need to add default headers if required by your OpenRouter setup/account
    # default_headers={"HTTP-Referer": "YOUR_SITE_URL", "X-Title": "YOUR_APP_NAME"}
)
INPUT_JSON = "Math500.json"
RECORDS_PATH = "M92V2.json"  # The name of output file is called M92
LINUCB_MODEL_PATH = "M92V2.npz"
SUMMARY_STATS_PATH = "M92V2.txt"
BACKUP_SUFFIX = ".bak"  # For backing up existing files

CONTEXT_FEATURE_DIM = ANSWER_EMBED_DIM + LLM_ID_DIM + ANSWER_EMBED_DIM
TOTAL_FEATURE_DIM = BASE_FEATURE_DIM + 1 + CONTEXT_FEATURE_DIM

class ModelCostEstimator:
    def __init__(self, model_names):
        self.stats = {}
        for model_name in model_names:
            self.stats[model_name] = {
                "overall": {"count": 0, "total_output_cost": 0.0, "sum_sq_output_cost": 0.0},
                "by_position": {
                    str(i): {"count": 0, "total_output_cost": 0.0, "sum_sq_output_cost": 0.0}
                    for i in range(1, CASCADE_LENGTH + 1)
                }
            }

    def update_stats(self, model_name, step, actual_output_cost):
        step_str = str(step)
        if model_name in self.stats and step_str in self.stats[model_name]["by_position"]:
            pos_stats = self.stats[model_name]["by_position"][step_str]
            pos_stats["count"] += 1
            pos_stats["total_output_cost"] += actual_output_cost
            pos_stats["sum_sq_output_cost"] += actual_output_cost ** 2

        if model_name in self.stats:
            overall_stats = self.stats[model_name]["overall"]
            overall_stats["count"] += 1
            overall_stats["total_output_cost"] += actual_output_cost
            overall_stats["sum_sq_output_cost"] += actual_output_cost ** 2

    def estimate_lcb_output_cost(self, model_name, step, beta=BETA_COST, default_cost=DEFAULT_FALLBACK_COST):
        """Estimate the Lower Confidence Bound (LCB) of output cost for a model at a specific step."""
        step_str = str(step)

        # Try position-specific stats first
        if model_name in self.stats and step_str in self.stats[model_name]["by_position"]:
            pos_stats = self.stats[model_name]["by_position"][step_str]
            if pos_stats["count"] >= 2:
                mean_cost = pos_stats["total_output_cost"] / pos_stats["count"]
                variance = (pos_stats["sum_sq_output_cost"] / pos_stats["count"]) - (mean_cost ** 2)
                variance = max(0, variance)  # Ensure non-negative variance
                std_dev = math.sqrt(variance)
                std_error = std_dev / math.sqrt(pos_stats["count"])
                lcb_cost = mean_cost - beta * std_error
                return max(EPSILON_VD / 10, lcb_cost)  # Ensure positive cost for value density

        # Fallback to overall stats if position-specific stats are insufficient
        if model_name in self.stats:
            overall_stats = self.stats[model_name]["overall"]
            if overall_stats["count"] >= 2:
                mean_cost = overall_stats["total_output_cost"] / overall_stats["count"]
                variance = (overall_stats["sum_sq_output_cost"] / overall_stats["count"]) - (mean_cost ** 2)
                variance = max(0, variance)
                std_dev = math.sqrt(variance)
                std_error = std_dev / math.sqrt(overall_stats["count"])
                lcb_cost = mean_cost - beta * std_error
                return max(EPSILON_VD / 10, lcb_cost)

        # Use default cost if there is insufficient data
        return default_cost

    def save_state(self, file_path):
        try:
            save_dict = {}
            for model_name, model_stats in self.stats.items():
                overall = model_stats["overall"]
                save_dict[f"overall_count_{model_name}"] = overall["count"]
                save_dict[f"overall_total_output_cost_{model_name}"] = overall["total_output_cost"]
                save_dict[f"overall_sum_sq_output_cost_{model_name}"] = overall["sum_sq_output_cost"]
                for pos, pos_stats in model_stats["by_position"].items():
                    save_dict[f"pos{pos}_count_{model_name}"] = pos_stats["count"]
                    save_dict[f"pos{pos}_total_output_cost_{model_name}"] = pos_stats["total_output_cost"]
                    save_dict[f"pos{pos}_sum_sq_output_cost_{model_name}"] = pos_stats["sum_sq_output_cost"]
            np.savez_compressed(file_path, **save_dict)
            return True
        except Exception as e:
            print(f"Error saving model cost estimator state: {e}")
            return False

    def load_state(self, file_path):
        try:
            loaded = np.load(file_path)
            for model_name in self.stats:
                if f"overall_count_{model_name}" in loaded:
                    overall = self.stats[model_name]["overall"]
                    overall["count"] = int(loaded[f"overall_count_{model_name}"])
                    overall["total_output_cost"] = float(loaded[f"overall_total_output_cost_{model_name}"])
                    overall["sum_sq_output_cost"] = float(loaded[f"overall_sum_sq_output_cost_{model_name}"])
                for pos in self.stats[model_name]["by_position"]:
                    if f"pos{pos}_count_{model_name}" in loaded:
                        pos_stats = self.stats[model_name]["by_position"][pos]
                        pos_stats["count"] = int(loaded[f"pos{pos}_count_{model_name}"])
                        pos_stats["total_output_cost"] = float(loaded[f"pos{pos}_total_output_cost_{model_name}"])
                        pos_stats["sum_sq_output_cost"] = float(loaded[f"pos{pos}_sum_sq_output_cost_{model_name}"])
            print(f"Loaded model cost estimator state from {file_path}")
            return True
        except Exception as e:
            print(f"Error loading model cost estimator state: {e}")
            return False


class FeatureExtractor:
    def __init__(self, feature_dim=BASE_FEATURE_DIM, use_embeddings=USE_EMBEDDINGS):
        self.feature_dim = feature_dim
        self.use_embeddings = use_embeddings
        if use_embeddings:
            try:
                self.embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')
                print("Initialized sentence transformer embedding model")
            except Exception as e:
                print(f"Error initializing sentence transformer: {e}\nFalling back to TF-IDF.")
                self.use_embeddings = False
        if not self.use_embeddings:
            self.vectorizer = TfidfVectorizer()
            self.svd = None
        self.initialized = False

    def initialize(self, questions):
        if not self.use_embeddings:
            from sklearn.decomposition import TruncatedSVD
            all_text = [q["question"] for q in questions]
            self.vectorizer.fit(all_text)
            dtm = self.vectorizer.transform(all_text)
            n_components = min(self.feature_dim, dtm.shape[1])
            self.svd = TruncatedSVD(n_components=n_components)
            self.svd.fit(dtm)
            print(f"Using TF-IDF with SVD dimensionality reduction to {n_components} features")
            print(f"Explained variance ratio: {sum(self.svd.explained_variance_ratio_):.4f}")
        self.initialized = True

    def extract_features(self, question):
        """Extract features from a question (no options for Math500)."""
        if not self.initialized:
            raise ValueError("Feature extractor not initialized. Call initialize() first.")
        text = question["question"]
        if self.use_embeddings:
            features = self.embedding_model.encode([text])[0]
        else:
            tfidf_vector = self.vectorizer.transform([text])
            features = self.svd.transform(tfidf_vector)[0]
            # Ensure we have exactly feature_dim dimensions
            if len(features) < self.feature_dim:
                padding = np.zeros(self.feature_dim - len(features))
                features = np.concatenate([features, padding])
        return features

    def extract_answer_features(self, answer_text):
        if not answer_text:
            return np.zeros(ANSWER_EMBED_DIM)
        if self.use_embeddings:
            try:
                features = self.embedding_model.encode([answer_text])[0]
                if features is None:
                    print(f"Warning: Embedding for answer text resulted in None. Returning zero vector.")
                    return np.zeros(ANSWER_EMBED_DIM)
                # Handle if the embedding doesn't match the expected dimension
                if len(features) != ANSWER_EMBED_DIM:
                    features = features[:ANSWER_EMBED_DIM] if len(features) > ANSWER_EMBED_DIM else np.concatenate([features, np.zeros(ANSWER_EMBED_DIM - len(features))])
                return features
            except Exception as e:
                print(f"Error embedding answer: {e}")
                return np.zeros(ANSWER_EMBED_DIM)
        else:
            # For simplicity, use zero vector if not using embeddings
            return np.zeros(ANSWER_EMBED_DIM)

    def construct_feature_vector(self, base_features, step_i, failed_answers, failed_llm_ids, model_name_to_index):
        """
        Construct the augmented feature vector for LinUCB.
        - base_features: Features extracted from the question
        - step_i: Current step in the cascade (1-indexed)
        - failed_answers: List of previous failed answer strings
        - failed_llm_ids: List of previous failed LLM IDs
        - model_name_to_index: Mapping from model names to indices
        """
        normalized_step = np.array([step_i / CASCADE_LENGTH])

        # Initialize context features - last answer embedding
        last_answer_features = np.zeros(ANSWER_EMBED_DIM)
        if step_i > 1 and failed_answers:
            last_answer_features = self.extract_answer_features(failed_answers[-1])

        # Calculate Last Failed LLM ID One-Hot encoding
        last_llm_onehot = np.zeros(LLM_ID_DIM)
        if step_i > 1 and failed_llm_ids:
            last_llm_name = failed_llm_ids[-1]
            if last_llm_name in model_name_to_index:
                last_llm_onehot[model_name_to_index[last_llm_name]] = 1.0

        # Calculate Average Failed Answer Embedding
        avg_answer_features = np.zeros(ANSWER_EMBED_DIM)
        if step_i > 1 and failed_answers:
            all_answer_features = [self.extract_answer_features(ans) for ans in failed_answers]
            if all_answer_features:
                avg_answer_features = np.mean(all_answer_features, axis=0)
        if avg_answer_features.shape == (): # Handle case where avg_answer_features might be a scalar
            avg_answer_features = np.zeros(ANSWER_EMBED_DIM)

        context_features = np.concatenate([last_answer_features, last_llm_onehot, avg_answer_features])
        augmented_features = np.concatenate([base_features, normalized_step, context_features])

        # Ensure the final dimension matches expectations
        if augmented_features.shape[0] != TOTAL_FEATURE_DIM:
            raise ValueError(f"Constructed feature vector dimension {augmented_features.shape[0]} != expected {TOTAL_FEATURE_DIM}")
        return augmented_features


class LinUCBModel:
    def __init__(self, model_names, feature_dim=TOTAL_FEATURE_DIM, alpha=ALPHA, lambda_reg=LAMBDA_REG):
        self.model_names = model_names
        self.feature_dim = feature_dim
        self.alpha = alpha
        self.lambda_reg = lambda_reg
        self.model_name_to_index = {name: i for i, name in enumerate(model_names)}
        self.cost_estimator = ModelCostEstimator(model_names)
        self.models = {
            model_name: {
                'A': np.identity(feature_dim) * lambda_reg,
                'b': np.zeros(feature_dim),
                'last_call_time': 0
            } for model_name in model_names
        }

    def update_reward_only(self, model_name, feature_vector, reward):
        """Update the LinUCB model parameters based on observed reward."""
        model = self.models[model_name]
        model['A'] += np.outer(feature_vector, feature_vector)
        model['b'] += feature_vector * reward

    def calculate_ucb_scores(self, feature_vector):
        """Calculate UCB scores for model selection."""
        scores = {}
        for model_name in self.model_names:
            model = self.models[model_name]
            try:
                A_inv = np.linalg.inv(model['A'])
                theta = A_inv.dot(model['b'])
                ucb_term = self.alpha * np.sqrt(feature_vector.dot(A_inv).dot(feature_vector))
                expected_reward = feature_vector.dot(theta)
                scores[model_name] = {
                    "p_ia": float(expected_reward),
                    "e_ia": float(ucb_term),
                    "ucb_score": float(expected_reward + ucb_term)
                }
            except np.linalg.LinAlgError:
                scores[model_name] = {"p_ia": 0.0, "e_ia": 0.0, "ucb_score": 0.0}
        return scores

    def select_model_value_density(self, feature_vector, step, prompt, remaining_budget, is_test_phase, reward_importance, cost_importance, used_llms_in_current_cascade=None):
        """Select a model based on strategy: pure UCB in train, weighted score in test."""
        scores = self.calculate_ucb_scores(feature_vector)
        if not scores:
            return None, {}
        if used_llms_in_current_cascade is None:
            used_llms_in_current_cascade = []

        best_model = None
        max_selection_score = -float('inf')

        for model_name, score_info in scores.items():
            # This except block is expected for OpenRouter models listed in MODELS_CONFIG
            try:
                deterministic_input_cost = len(prompt) // 4 * MODELS_CONFIG[model_name]["input_cost"]
            except Exception:
                print(f"Warning: Input cost configuration missing for {model_name}. Using default.")
                deterministic_input_cost = len(prompt) // 4 * DEFAULT_FALLBACK_COST

            lcb_output_cost_estimate = self.cost_estimator.estimate_lcb_output_cost(model_name, step)
            score_info['total_estimated_cost'] = float(deterministic_input_cost + lcb_output_cost_estimate)
            # 'remaining_budget' and 'budget_sufficient' will be updated per model inside phase logic

        # --- Phase-specific selection logic ---
        if not is_test_phase:
            # TRAIN PHASE: Pure UCB selection
            for model_name, score_info in scores.items():
                score_info['budget_sufficient'] = True # Budget is effectively infinite for selection
                current_model_selection_score = -float('inf')
                if model_name in used_llms_in_current_cascade:
                    score_info['selection_strategy'] = 'Prohibited (Train - Re-selection)'
                else:
                    current_model_selection_score = score_info["ucb_score"]
                    score_info['selection_strategy'] = 'Pure UCB (Train)'
                score_info['selection_score'] = float(current_model_selection_score)
                if current_model_selection_score > max_selection_score:
                    max_selection_score = current_model_selection_score
                    best_model = model_name
        else:
            # TEST PHASE: Budget-aware, value-density selection
            # Stage A: Identify budget-eligible models
            eligible_models_data_list = []
            for model_name, score_info in scores.items():
                score_info['remaining_budget'] = float(remaining_budget)
                if model_name in used_llms_in_current_cascade:
                    score_info['selection_strategy'] = 'Prohibited (Test - Re-selection)'
                    score_info['selection_score'] = -float('inf')
                    continue
                if score_info['total_estimated_cost'] <= remaining_budget:
                    score_info['budget_sufficient'] = True
                    eligible_models_data_list.append((model_name, score_info['total_estimated_cost'], score_info))
                else:
                    score_info['budget_sufficient'] = False
                    score_info['selection_strategy'] = 'Budget Pruned (Test)'
                    score_info['selection_score'] = -float('inf')

            if not eligible_models_data_list:
                return None, scores # All models were pruned by budget or re-selection

            # Stage B: Find max_estimated_cost among ELIGIBLE models for normalization
            all_eligible_costs = [cost for _, cost, _ in eligible_models_data_list]
            max_eligible_cost = max(all_eligible_costs) if all_eligible_costs else EPSILON_VD
            max_eligible_cost = max(max_eligible_cost, EPSILON_VD) # Safeguard against zero cost

            # Stage C: Calculate final selection scores for ELIGIBLE models
            for model_name, model_cost, score_info_ref in eligible_models_data_list:
                ucb_score = score_info_ref["ucb_score"]
                cost_norm_factor = 1 + ((model_cost + EPSILON_VD) / (max_eligible_cost + EPSILON_VD))
                stable_denominator = 1.0 + max(cost_norm_factor, EPSILON_VD)
                current_model_selection_score = ucb_score / stable_denominator

                score_info_ref['selection_strategy'] = 'Double/Norm (Test)'
                score_info_ref['selection_score_terms'] = {
                    'ucb_score_numerator': float(ucb_score),
                    'stable_denominator_used': float(stable_denominator),
                }
                score_info_ref['selection_score'] = float(current_model_selection_score)

                if current_model_selection_score > max_selection_score:
                    max_selection_score = current_model_selection_score
                    best_model = model_name

        return best_model, scores

    def register_model_call(self, model_name):
        self.models[model_name]['last_call_time'] = time.time()

    def respect_rate_limit(self, model_name):
        """Wait if necessary to respect the model's rate limit."""
        # For the specified OpenRouter models, 'rpm' is not defined, so this block will be skipped.
        model_cfg = MODELS_CONFIG.get(model_name)
        if model_cfg and "rpm" in model_cfg:
            model_state = self.models[model_name]
            min_seconds_between_calls = 60.0 / model_cfg["rpm"]
            time_since_last_call = time.time() - model_state['last_call_time']
            if time_since_last_call < min_seconds_between_calls:
                time.sleep(min_seconds_between_calls - time_since_last_call)

    def save_model_state(self, file_path):
        """Save the model state to a file using numpy's compressed format."""
        save_dict = {f'A_{model_name}': model['A'] for model_name, model in self.models.items()}
        for model_name, model in self.models.items():
            save_dict[f'b_{model_name}'] = model['b']
        np.savez_compressed(file_path, **save_dict)
        # Save cost estimator state
        self.cost_estimator.save_state(file_path.replace('.npz', '_cost_estimator.npz'))

    def load_model_state(self, file_path):
        try:
            loaded = np.load(file_path)
            for model_name in self.models.keys():
                if f'A_{model_name}' in loaded and f'b_{model_name}' in loaded:
                    self.models[model_name]['A'] = loaded[f'A_{model_name}']
                    self.models[model_name]['b'] = loaded[f'b_{model_name}']
            print(f"Loaded LinUCB model state from {file_path}")
            # Load cost estimator state
            cost_estimator_path = file_path.replace('.npz', '_cost_estimator.npz')
            if os.path.exists(cost_estimator_path):
                self.cost_estimator.load_state(cost_estimator_path)
            return True
        except Exception as e:
            print(f"Error loading model state: {e}")
            return False


class BatchBudgetCascade:
    def __init__(self, feature_extractor, linucb_model, cascade_length=CASCADE_LENGTH):
        self.feature_extractor = feature_extractor
        self.linucb_model = linucb_model
        self.cascade_length = cascade_length

    def format_prompt(self, question, failed_answers=None, failed_llm_ids=None):
        prompt = f"Solve the following math problem: {question['question']}\n\n"
        prompt += "Also provide an explnation in one/serveral very short yet consice complete sentence within 75 words in total.\n"
        prompt += "At the end, clearly state your final answer in LaTeX format, enclosed within \\boxed{}.\n"
        prompt += "For example: 'The final answer is \\boxed{x=5}'."
        if failed_answers and failed_llm_ids:
            prompt += "\n\nNote: The following previous attempts were incorrect. Please provide a different solution:\n"
            for i, answer_info in enumerate(failed_answers):
                prompt += f"- Attempt {i+1} (by {failed_llm_ids[i]}) led to: {answer_info}\n"
        return prompt

    def parse_llm_answer(self, answer_text):
        """
        Modified: Returns the stripped raw text as the answer for grading,
        and the full raw text as the explanation (matching Code 1's behavior).
        """
        if not answer_text:
            return "", ""
        return answer_text.strip(), answer_text

    def grade_with_gemma12b(self, llm_answer_latex, ground_truth_latex):
        """Grade an answer against the ground truth using the grader model."""
        if llm_answer_latex is None or llm_answer_latex == ground_truth_latex:
            return llm_answer_latex is not None

        prompt = f"Expression 1: {llm_answer_latex}\nExpression 2: {ground_truth_latex}\n\n"
        prompt += "Expression 2 is the answer and expression is attempt by student,  look at their final answer only which might be boxed, does student get the final expected answer？ Respond with only the word 'True' or 'False'."
        try:
            time.sleep(0.5)  # Simple rate limiting for the grader
            api_response = openrouter_client.chat.completions.create(
                model=GRADER_MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0, # Keep it deterministic for grading
                max_tokens=10
            )
            grader_response_text = api_response.choices[0].message.content.strip().lower()
            if "true" in grader_response_text:
                return True
            elif "false" in grader_response_text:
                return False
            else:
                print(f"Warning: Grader returned ambiguous response: {grader_response_text}")
                return False
        except Exception as e:
            print(f"Error calling grader model: {e}")
            return False

    def calculate_token_cost(self, model_name, prompt, response_text, usage_info=None):
        """Calculate the actual cost. For OpenRouter, relies on usage_info from API response."""
        input_tokens, output_tokens, total_cost = 0, 0, 0.0
        error_message = None
        if model_name in MODELS_CONFIG:
            model_cfg = MODELS_CONFIG[model_name]
            if usage_info:
                input_tokens = usage_info.get("prompt_tokens", 0)
                output_tokens = usage_info.get("completion_tokens", 0)
                total_cost = (input_tokens * model_cfg["input_cost"]) + (output_tokens * model_cfg["output_cost"])
            else:
                error_message = f"Usage info not available for {model_name}. Cost is a rough estimate."
                input_tokens = len(prompt) // 4
                output_tokens = len(response_text) // 4 if response_text else 0
                total_cost = (input_tokens * model_cfg["input_cost"]) + (output_tokens * model_cfg["output_cost"])
        else:
            error_message = f"Model {model_name} not found in config for cost calculation."
        result = {"total_cost": total_cost}
        if error_message:
            result["error"] = error_message
            print(f"Cost calculation warning for {model_name}: {error_message}")
        return result

    def query_llm(self, model_name, prompt):
        raw_response_full, parsed_answer_tuple, cost_data = "", ("", ""), {}
        usage_info = None
        try:
            self.linucb_model.respect_rate_limit(model_name)
            if model_name in MODELS_CONFIG:
                api_response = openrouter_client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                )
                raw_response_full = api_response.choices[0].message.content
                if api_response.usage:
                    usage_info = {"prompt_tokens": api_response.usage.prompt_tokens, "completion_tokens": api_response.usage.completion_tokens}
                cost_data = self.calculate_token_cost(model_name, prompt, raw_response_full, usage_info=usage_info)
                parsed_answer_tuple = self.parse_llm_answer(raw_response_full)
                return raw_response_full, parsed_answer_tuple, cost_data
            else:
                raise ValueError(f"Model {model_name} is not configured in MODELS_CONFIG.")
        except Exception as e:
            print(f"Error querying LLM {model_name}: {e}")
            self.linucb_model.register_model_call(model_name)
            cost_data = self.calculate_token_cost(model_name, prompt, "", usage_info=None)
            return "", ("", ""), cost_data

    def run_cascade_single_question(self, question, current_question_budget=float('inf'), is_test_phase=False):
        """
        Run the Value Density-based LinUCB cascade for a single question.
        - question: The question to answer
        - current_question_budget: Budget limit for this question (used in test phase)
        - is_test_phase: Whether we're in the test phase (applies budget constraints)
        """
        base_features = self.feature_extractor.extract_features(question)
        failed_answers, failed_llm_ids, used_llms_this_question = [], [], []
        question_total_cost, current_attempts_log = 0.0, []
        final_status = "Failure"  # Default to failure, will update on success
        remaining_budget = current_question_budget if is_test_phase else float('inf')

        # Value Density Cascade Loop
        for i in range(1, self.cascade_length + 1):
            print(f"Step {i}")
            # 1. Construct feature vector using history
            x_i = self.feature_extractor.construct_feature_vector(base_features, i, failed_answers, failed_llm_ids, self.linucb_model.model_name_to_index)
            # 2. Format prompt with context
            prompt = self.format_prompt(question, failed_answers, failed_llm_ids)
            # 3. Select model based on learned strategy, respecting budget
            chosen_model, scores = self.linucb_model.select_model_value_density(x_i, i, prompt, remaining_budget, is_test_phase, REWARD_IMPORTANCE, COST_IMPORTANCE, used_llms_this_question)

            if not chosen_model:
                final_status = "Failure - Budget Exceeded or No Suitable Model" if is_test_phase else "Failure - All LLMs Used (Train)"
                print(f"Stopping cascade: {final_status}")
                break

            print(f"Selected: {chosen_model}, Strategy: {scores[chosen_model].get('selection_strategy', 'N/A')}, Score: {scores[chosen_model].get('selection_score', -inf):.4f}")
            # 4. Query the chosen model
            raw_response, (answer_for_grading, explanation_text), cost_data = self.query_llm(chosen_model, prompt)
            # 5. Extract costs
            actual_total_cost = cost_data.get("total_cost", 0.0)
            actual_output_cost = cost_data.get("output_cost", 0.0)
            # 6. Update budget
            question_total_cost += actual_total_cost
            remaining_budget -= actual_total_cost
            # 7. Update cost estimator with actual observed output cost
            self.linucb_model.cost_estimator.update_stats(chosen_model, i, actual_output_cost)
            # 8. Determine outcome using the grader
            is_correct = self.grade_with_gemma12b(answer_for_grading, question['ground_truth_answer'])
            reward = 1 if is_correct else 0
            print(f"Answer: {answer_for_grading}, Correct: {is_correct}, Cost: ${actual_total_cost:.8f}")
            # 9. Update LinUCB for the chosen model
            self.linucb_model.update_reward_only(chosen_model, x_i, reward)
            # 10. Log the attempt
            current_attempts_log.append({
                "step": i, "chosen_model": chosen_model, "chosen_model_cost": cost_data,
                "is_correct": is_correct, "reward_ri": reward, "llm_answer": answer_for_grading,
                "llm_explanation": explanation_text, "scores_per_arm": scores
            })
            # 11. Handle Outcome
            if is_correct:
                final_status = "Success"
                print(f"Success in step {i}!")
                break
            else:
                # Failure: Update history for the next step
                failed_answers.append(answer_for_grading if answer_for_grading else "Unknown")
                failed_llm_ids.append(chosen_model)
                if chosen_model:
                    used_llms_this_question.append(chosen_model)

        return {
            "question": question["question"], "ground_truth_answer": question["ground_truth_answer"],
            "final_status": final_status, "total_cost": question_total_cost,
            "steps_taken": len(current_attempts_log), "attempts": current_attempts_log,
            "is_test_phase": is_test_phase, "question_budget": current_question_budget if is_test_phase else None
        }


def load_math500_dataset(json_path):
    """Load the Math500 dataset from a JSON file with robust error handling."""
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            raw_data = json.load(f)
        # Assuming data is under the key "test"
        return [{
            "question": item["problem"],
            "options": [], # Math500 has no MCQs
            "ground_truth_answer": item["answer"],
            "unique_id": item.get("unique_id", "Unknown")
        } for item in raw_data["test"]]
    except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
        print(f"Error loading dataset from {json_path}: {e}")
        return []

def save_records_with_backup(records, json_path):
    if os.path.exists(json_path):
        try:
            os.replace(json_path, json_path + BACKUP_SUFFIX)
        except Exception as e:
            print(f"Failed to create backup: {e}")
    try:
        with open(json_path, 'w') as f:
            json.dump(records, f, indent=4)
        return True
    except Exception as e:
        print(f"Error saving records: {e}")
        # Try to restore from backup if save failed
        if os.path.exists(json_path + BACKUP_SUFFIX):
            try:
                os.replace(json_path + BACKUP_SUFFIX, json_path)
            except: pass
        return False

def initialize_json_files():
    if not os.path.exists(RECORDS_PATH):
        with open(RECORDS_PATH, 'w') as f: json.dump([], f)
    else:
        try:
            with open(RECORDS_PATH, 'r') as f:
                if not isinstance(json.load(f), list): raise ValueError
        except (json.JSONDecodeError, ValueError):
            os.rename(RECORDS_PATH, RECORDS_PATH + BACKUP_SUFFIX)
            with open(RECORDS_PATH, 'w') as f: json.dump([], f)


def analyze_results(records):
    """Analyze and print results from the records"""
    if not records:
        print("No records to analyze")
        return

    # Separate records into training and test sets
    train_records = [r for r in records if not r.get("is_test_phase", False)]
    test_records = [r for r in records if r.get("is_test_phase", True)]

    # Extract data for analysis
    total_train_questions = len(train_records)
    total_test_questions = len(test_records)

    train_success = sum(1 for r in train_records if r["final_status"] == "Success")
    test_success = sum(1 for r in test_records if r["final_status"] == "Success")

    train_success_rate = train_success / total_train_questions if total_train_questions > 0 else 0
    test_success_rate = test_success / total_test_questions if total_test_questions > 0 else 0

    # Calculate total steps
    total_train_steps = sum(r["steps_taken"] for r in train_records)
    total_test_steps = sum(r["steps_taken"] for r in test_records)

    # Calculate average steps
    avg_train_steps = total_train_steps / total_train_questions if total_train_questions > 0 else 0
    avg_test_steps = total_test_steps / total_test_questions if total_test_questions > 0 else 0

    # Calculate success by position
    train_successes_by_position = [0] * CASCADE_LENGTH
    test_successes_by_position = [0] * CASCADE_LENGTH

    for record in train_records:
        if record["final_status"] == "Success":
            position = len(record["attempts"]) - 1  # 0-indexed
            if position < CASCADE_LENGTH:
                train_successes_by_position[position] += 1

    for record in test_records:
        if record["final_status"] == "Success":
            position = len(record["attempts"]) - 1  # 0-indexed
            if position < CASCADE_LENGTH:
                test_successes_by_position[position] += 1

    train_per_position_success = [count/total_train_questions for count in train_successes_by_position] if total_train_questions > 0 else [0] * CASCADE_LENGTH
    test_per_position_success = [count/total_test_questions for count in test_successes_by_position] if total_test_questions > 0 else [0] * CASCADE_LENGTH

    # Calculate costs
    train_total_cost = sum(r["total_cost"] for r in train_records)
    test_total_cost = sum(r["total_cost"] for r in test_records)

    train_avg_cost = train_total_cost / total_train_questions if total_train_questions > 0 else 0
    test_avg_cost = test_total_cost / total_test_questions if total_test_questions > 0 else 0

    # Calculate cost per successful question
    train_success_cost = sum(r["total_cost"] for r in train_records if r["final_status"] == "Success")
    test_success_cost = sum(r["total_cost"] for r in test_records if r["final_status"] == "Success")

    train_avg_cost_success = train_success_cost / train_success if train_success > 0 else 0
    test_avg_cost_success = test_success_cost / test_success if test_success > 0 else 0

    # Count budget exceeded in test phase
    budget_exceeded_count = sum(1 for r in test_records if r["final_status"] == "Failure - Budget Exceeded")
    budget_exceeded_rate = budget_exceeded_count / total_test_questions if total_test_questions > 0 else 0

    # Analyze model performance with per-position breakdown
    model_metrics = {
        "train": {model: {"calls": 0, "successes": 0, "total_cost": 0.0,
                          "by_position": defaultdict(lambda: {"calls": 0, "successes": 0, "total_cost": 0.0})}
                 for model in AVAILABLE_LLMS},
        "test": {model: {"calls": 0, "successes": 0, "total_cost": 0.0,
                         "by_position": defaultdict(lambda: {"calls": 0, "successes": 0, "total_cost": 0.0})}
                for model in AVAILABLE_LLMS}
    }

    # Track LCB cost estimation accuracy (output cost only)
    lcb_cost_accuracy = {model: {"total_lcb_estimate": 0.0, "total_actual_cost": 0.0, "count": 0}
                         for model in AVAILABLE_LLMS}

    # Process training records
    for record in train_records:
        for attempt in record["attempts"]:
            model = attempt["chosen_model"]
            is_correct = attempt["is_correct"]
            cost = attempt["chosen_model_cost"]["total_cost"]
            position = attempt["step"]

            # Update train stats
            model_metrics["train"][model]["calls"] += 1
            model_metrics["train"][model]["total_cost"] += cost
            model_metrics["train"][model]["by_position"][position]["calls"] += 1
            model_metrics["train"][model]["by_position"][position]["total_cost"] += cost

            if is_correct:
                model_metrics["train"][model]["successes"] += 1
                model_metrics["train"][model]["by_position"][position]["successes"] += 1

    # Process test records
    for record in test_records:
        for attempt in record["attempts"]:
            model = attempt["chosen_model"]
            is_correct = attempt["is_correct"]
            cost = attempt["chosen_model_cost"]["total_cost"]
            position = attempt["step"]

            # Update test stats
            model_metrics["test"][model]["calls"] += 1
            model_metrics["test"][model]["total_cost"] += cost
            model_metrics["test"][model]["by_position"][position]["calls"] += 1
            model_metrics["test"][model]["by_position"][position]["total_cost"] += cost

            if is_correct:
                model_metrics["test"][model]["successes"] += 1
                model_metrics["test"][model]["by_position"][position]["successes"] += 1

            # Track LCB output cost estimation accuracy (test phase only)
            if "chosen_model_cost" in attempt and "scores_per_arm" in attempt and model in attempt["scores_per_arm"]:
                actual_output_cost = attempt["chosen_model_cost"]["output_cost"]
                if "lcb_output_cost_estimate" in attempt["scores_per_arm"][model]:
                    lcb_cost_accuracy[model]["total_lcb_estimate"] += attempt["scores_per_arm"][model]["lcb_output_cost_estimate"]
                    lcb_cost_accuracy[model]["total_actual_cost"] += actual_output_cost
                    lcb_cost_accuracy[model]["count"] += 1

    # Calculate success rates and average costs for models
    for phase in ["train", "test"]:
        for model, data in model_metrics[phase].items():
            if data["calls"] > 0:
                data["success_rate"] = data["successes"] / data["calls"]
                data["avg_cost"] = data["total_cost"] / data["calls"]

                # Calculate position-specific metrics
                for pos, pos_data in data["by_position"].items():
                    if pos_data["calls"] > 0:
                        pos_data["success_rate"] = pos_data["successes"] / pos_data["calls"]
                        pos_data["avg_cost"] = pos_data["total_cost"] / pos_data["calls"]
                    else:
                        pos_data["success_rate"] = 0
                        pos_data["avg_cost"] = 0
            else:
                data["success_rate"] = 0
                data["avg_cost"] = 0

    # Calculate LCB estimation accuracy
    for model, data in lcb_cost_accuracy.items():
        if data["count"] > 0:
            data["avg_lcb_estimate"] = data["total_lcb_estimate"] / data["count"]
            data["avg_actual_cost"] = data["total_actual_cost"] / data["count"]
            data["lcb_accuracy"] = data["avg_lcb_estimate"] / data["avg_actual_cost"] if data["avg_actual_cost"] > 0 else 0
        else:
            data["avg_lcb_estimate"] = 0
            data["avg_actual_cost"] = 0
            data["lcb_accuracy"] = 0

    # Generate summary text
    summary = "=== BUDGET-AWARE LinUCB CASCADE (MATH LATEX RESPONSE) ===\n\n"
    summary += f"LLM Output Cost Beta: {BETA_COST}\n\n"

    summary += "=== TRAIN PHASE RESULTS (30%) ===\n"
    summary += f"Total Questions: {total_train_questions}\n"
    summary += f"Success Rate: {train_success_rate:.4f}\n"
    summary += f"Average Steps (Train): {avg_train_steps:.4f}\n"
    summary += f"Average Cost per Question: ${train_avg_cost:.8f}\n"
    summary += f"Average Cost per Successful Question: ${train_avg_cost_success:.8f}\n"
    summary += "Success Rate by Position:\n"
    for i, rate in enumerate(train_per_position_success):
        summary += f"  Position {i+1}: {rate:.4f}\n"

    summary += "\n=== TEST PHASE RESULTS (70%) ===\n"
    summary += f"Total Questions: {total_test_questions}\n"
    summary += f"Success Rate: {test_success_rate:.4f}\n"
    summary += f"Average Steps (Test): {avg_test_steps:.4f}\n"
    summary += f"Average Cost per Question: ${test_avg_cost:.8f}\n"
    summary += f"Average Cost per Successful Question: ${test_avg_cost_success:.8f}\n"
    summary += f"Budget Exceeded Count: {budget_exceeded_count} ({budget_exceeded_rate:.4f})\n"
    summary += "Success Rate by Position:\n"
    for i, rate in enumerate(test_per_position_success):
        summary += f"  Position {i+1}: {rate:.4f}\n"

    summary += "\n=== MODEL PERFORMANCE (TRAIN PHASE) ===\n"
    for model, metrics in model_metrics["train"].items():
        if metrics["calls"] > 0:
            summary += f"{model}:\n"
            summary += f"  Overall: {metrics['successes']}/{metrics['calls']} = {metrics['success_rate']:.4f}\n"
            summary += f"  Average Cost: ${metrics['avg_cost']:.8f}\n"

            # No need for position breakdown in train phase summary

    summary += "\n=== MODEL PERFORMANCE (TEST PHASE) ===\n"
    for model, metrics in model_metrics["test"].items():
        if metrics["calls"] > 0:
            summary += f"{model}:\n"
            summary += f"  Overall: {metrics['successes']}/{metrics['calls']} = {metrics['success_rate']:.4f}\n"
            summary += f"  Average Cost: ${metrics['avg_cost']:.8f}\n"
            summary += f"  By Position:\n"

            # Add per-position breakdown
            for pos in sorted(metrics["by_position"].keys()):
                pos_data = metrics["by_position"][pos]
                if pos_data["calls"] > 0:
                    summary += f"    Pos {pos}: {pos_data['successes']}/{pos_data['calls']} = {pos_data['success_rate']:.4f}, "
                    summary += f"Avg Cost: ${pos_data['avg_cost']:.8f}\n"

    summary += "\n=== LLM LCB OUTPUT COST ESTIMATION ACCURACY (TEST PHASE) ===\n"
    for model, data in lcb_cost_accuracy.items():
        if data["count"] > 0:
            summary += f"{model}: Avg LCB Est: ${data['avg_lcb_estimate']:.8f}, "
            summary += f"Avg Actual: ${data['avg_actual_cost']:.8f}, Ratio: {data['lcb_accuracy']:.4f}\n"

    summary += f"\nTotal Overall Cost (Train): ${train_total_cost:.8f}\n"
    summary += f"Total Overall Cost (Test): ${test_total_cost:.8f}\n"
    summary += f"Total Overall Cost: ${(train_total_cost + test_total_cost):.8f}\n"

    # Print and save summary
    print(summary)

    with open(SUMMARY_STATS_PATH, 'w') as f:
        f.write(summary)



def main():
    print("Starting LinUCB Cascade with Value Density & Budget (LaTeX Math Response)")
    initialize_json_files()
    dataset_full = load_math500_dataset(INPUT_JSON)
    if not dataset_full:
        print("Dataset is empty or could not be loaded. Exiting.")
        return

    try:
        num_to_process = int(input(f"How many questions to process? (1-{len(dataset_full)}): "))
        dataset = random.sample(dataset_full, min(num_to_process, len(dataset_full)))
    except ValueError:
        dataset = random.sample(dataset_full, 10)
    print(f"Using {len(dataset)} questions from the Math500 dataset")

    feature_extractor = FeatureExtractor()
    feature_extractor.initialize(dataset)
    linucb_model = LinUCBModel(model_names=AVAILABLE_LLMS)
    cascade = BatchBudgetCascade(feature_extractor, linucb_model)

    # Load existing records to determine starting point
    all_records = []
    if os.path.exists(RECORDS_PATH):
        try:
            with open(RECORDS_PATH, 'r') as f: all_records = json.load(f)
            processed_questions = {r["question"] for r in all_records}
            dataset = [q for q in dataset if q["question"] not in processed_questions]
            print(f"Found {len(all_records)} existing records, {len(dataset)} questions remaining")
        except Exception as e:
            print(f"Error loading existing records: {e}")

    if os.path.exists(LINUCB_MODEL_PATH):
        linucb_model.load_model_state(LINUCB_MODEL_PATH)

    train_size = int(0.2 * len(dataset))
    print(f"Train size: {train_size}, Test size: {len(dataset) - train_size}")
    training_total_cost, training_question_count, avg_train_cost = 0.0, 0, 0.0

    try:
        for idx, question in enumerate(dataset):
            is_test_phase = idx >= train_size
            # Calculate budget (only in test phase)
            current_question_budget = float('inf')
            if is_test_phase:
                # Use average cost from training phase as the budget for test questions
                if training_question_count > 0 and avg_train_cost == 0.0:
                    avg_train_cost = training_total_cost / training_question_count
                current_question_budget = 0.00014217 # Fixed budget for reproducibility
                print(f"Test phase budget: ${current_question_budget:.8f}")

            print(f"\nProcessing question {idx+1}/{len(dataset)} ({'Test' if is_test_phase else 'Train'}): {question['question'][:80]}...")
            question_record = cascade.run_cascade_single_question(question, current_question_budget, is_test_phase)
            all_records.append(question_record)

            if not is_test_phase:
                training_question_count += 1
                training_total_cost += question_record["total_cost"]

            # Save records and states periodically
            if (idx + 1) % UPDATE_FREQUENCY == 0 or idx == train_size - 1:
                save_records_with_backup(all_records, RECORDS_PATH)
                linucb_model.save_model_state(LINUCB_MODEL_PATH)

    except KeyboardInterrupt:
        print("\nProcessing interrupted. Saving progress...")
    finally:
        save_records_with_backup(all_records, RECORDS_PATH)
        linucb_model.save_model_state(LINUCB_MODEL_PATH)
        analyze_results(all_records)
        print("LinUCB Cascade completed!")

if __name__ == "__main__":
    main()