In [None]:
import os
import json
import numpy as np
import random
import time
import re
import tiktoken
import google.generativeai as genai
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer
from collections import defaultdict
import math
from openai import OpenAI

# --- Configuration ---
BASE_FEATURE_DIM = 384  # Dimension for question embeddings
ANSWER_EMBED_DIM = 384  # Dimension for answer embeddings
CASCADE_LENGTH = 4      # Number of attempts in the cascade (K)
UPDATE_FREQUENCY = 1    # Update JSON records after every question
USE_EMBEDDINGS = True   # Use embeddings instead of TF-IDF

# LinUCB parameters
ALPHA = 0.675           # Exploration parameter for LinUCB
LAMBDA_REG = 0.45       # Regularization parameter for LinUCB matrix initialization

# Value Density & Cost Estimation parameters
BETA_COST = 1.0         # Controls width of LCB interval for LLM cost estimation
EPSILON_VD = 1e-10      # Small positive value for value density calculation
DEFAULT_FALLBACK_COST = 0.01  # Default cost when insufficient data

# Question Cost Estimation parameters
BETA_COST_QUESTION = 1.0  # Controls width of LCB interval for question cost estimation
DEFAULT_FALLBACK_COST_QUESTION = 0.04  # Default cost for questions when insufficient data
QUESTION_POOL_MULTIPLIER = 3  # How many times BATCH_SIZE to consider for question selection

# Pruning configuration
MAX_WORDS = 250
REPETITION_THRESHOLD = 3

# Batch parameters
BATCH_SIZE = 17

In [None]:
OPENROUTER_API_KEY =  # Replace with your Openrouterkey
OPENROUTER_BASE_URL =  # Replace with your openrouter URL

In [None]:
MODELS_CONFIG = {
    "mistralai/mistral-small-3.1-24b-instruct": {"input_cost": 0.05 / 1e6, "output_cost": 0.15 / 1e6},
    "microsoft/phi-4": {"input_cost": 0.07 / 1e6, "output_cost": 0.14 / 1e6},
    "meta-llama/llama-4-maverick": {"input_cost": 0.17 / 1e6, "output_cost": 0.16 / 1e6},
    "google/gemini-2.0-flash-001": {"input_cost": 0.1 / 1e6, "output_cost": 0.4 / 1e6},
    "openai/gpt-4.1-nano": {"input_cost": 0.1 / 1e6, "output_cost": 0.4 / 1e6},
    "deepseek/deepseek-chat": {"input_cost": 0.38 / 1e6, "output_cost": 0.89 / 1e6},
}

GRADER_MODEL_NAME = "google/gemini-2.0-flash-lite-001"

AVAILABLE_LLMS = list(MODELS_CONFIG.keys())
LLM_ID_DIM = len(AVAILABLE_LLMS)

# Feature dimensions are recalculated based on the number of available LLMs
CONTEXT_FEATURE_DIM = ANSWER_EMBED_DIM + LLM_ID_DIM + ANSWER_EMBED_DIM
TOTAL_FEATURE_DIM = BASE_FEATURE_DIM + 1 + CONTEXT_FEATURE_DIM

# Initialize OpenRouter client
openrouter_client = OpenAI(
    base_url=OPENROUTER_BASE_URL,
    api_key=OPENROUTER_API_KEY,
)

# File paths
INPUT_JSON = "Math500.json"
RECORDS_PATH = "M93.json"
LINUCB_MODEL_PATH = "M93.npz"
SUMMARY_STATS_PATH = "M93.txt"
QUESTION_COST_PATH = "Question_Cost_Estimator.npz"
BACKUP_SUFFIX = ".bak"

def solve_0_1_knapsack(items, capacity):
    """
    Solves the 0/1 knapsack problem to find the optimal subset of items.
    This is used to select a value-maximizing subset of LLMs that fits within a budget.
    - items: List of dictionaries, each with 'name', 'value' (UCB score), and 'weight' (LCB cost).
    - capacity: Maximum total weight (budget) the knapsack can hold.
    """
    if not items or capacity <= 0:
        return []
    n = len(items)
    # Scale weights by 10000 to handle floating point weights as integers for DP table
    scaled_capacity = int(capacity * 10000)
    scaled_weights = [int(item['weight'] * 10000) for item in items]
    dp = [[0 for _ in range(scaled_capacity + 1)] for _ in range(n + 1)]

    for i in range(1, n + 1):
        for w in range(scaled_capacity + 1):
            if scaled_weights[i-1] <= w:
                dp[i][w] = max(dp[i-1][w], items[i-1]['value'] + dp[i-1][w - scaled_weights[i-1]])
            else:
                dp[i][w] = dp[i-1][w]

    # Backtrack to find which items were selected
    selected_items = []
    w = scaled_capacity
    for i in range(n, 0, -1):
        if dp[i][w] != dp[i-1][w]:
            selected_items.append(items[i-1]['name'])
            w -= scaled_weights[i-1]
    return selected_items


class QuestionCostEstimator:
    def __init__(self):
        self.stats = {"count": 0, "total_cost": 0.0, "sum_sq_cost": 0.0}

    def update_stats(self, actual_question_cost):
        self.stats["count"] += 1
        self.stats["total_cost"] += actual_question_cost
        self.stats["sum_sq_cost"] += actual_question_cost ** 2

    def estimate_question_lcb_cost(self, beta=BETA_COST_QUESTION, default_cost=DEFAULT_FALLBACK_COST_QUESTION):
        """Estimate the Lower Confidence Bound (LCB) of cost for a question."""
        if self.stats["count"] >= 2:
            mean_cost = self.stats["total_cost"] / self.stats["count"]
            variance = (self.stats["sum_sq_cost"] / self.stats["count"]) - (mean_cost ** 2)
            std_dev = math.sqrt(max(0, variance))
            std_error = std_dev / math.sqrt(self.stats["count"])
            lcb_cost = mean_cost - beta * std_error
            return max(EPSILON_VD / 10, lcb_cost)
        return default_cost

    def save_state(self, file_path):
        try:
            np.savez_compressed(file_path, **self.stats)
            return True
        except Exception as e:
            print(f"Error saving question cost estimator state: {e}")
            return False

    def load_state(self, file_path):
        try:
            loaded = np.load(file_path)
            self.stats["count"] = int(loaded["count"])
            self.stats["total_cost"] = float(loaded["total_cost"])
            self.stats["sum_sq_cost"] = float(loaded["sum_sq_cost"])
            print(f"Loaded question cost estimator state from {file_path}")
            return True
        except Exception as e:
            print(f"Error loading question cost estimator state: {e}")
            return False


class ModelCostEstimator:
    def __init__(self, model_names):
        self.stats = {}
        for model_name in model_names:
            self.stats[model_name] = {
                "overall": {"count": 0, "total_cost": 0.0, "sum_sq_cost": 0.0},
                "by_position": {
                    str(i): {"count": 0, "total_cost": 0.0, "sum_sq_cost": 0.0}
                    for i in range(1, CASCADE_LENGTH + 1)
                }
            }

    def update_stats(self, model_name, step, actual_cost):
        step_str = str(step)
        if model_name in self.stats:
            # Update position-specific stats
            if step_str in self.stats[model_name]["by_position"]:
                pos_stats = self.stats[model_name]["by_position"][step_str]
                pos_stats["count"] += 1
                pos_stats["total_cost"] += actual_cost
                pos_stats["sum_sq_cost"] += actual_cost ** 2
            # Update overall stats
            overall_stats = self.stats[model_name]["overall"]
            overall_stats["count"] += 1
            overall_stats["total_cost"] += actual_cost
            overall_stats["sum_sq_cost"] += actual_cost ** 2

    def estimate_lcb_cost(self, model_name, step, beta=BETA_COST, default_cost=DEFAULT_FALLBACK_COST):
        """Estimate the Lower Confidence Bound (LCB) of cost for a model at a specific step."""
        step_str = str(step)
        # Try position-specific stats first
        if model_name in self.stats and step_str in self.stats[model_name]["by_position"]:
            pos_stats = self.stats[model_name]["by_position"][step_str]
            if pos_stats["count"] >= 2:
                mean_cost = pos_stats["total_cost"] / pos_stats["count"]
                variance = (pos_stats["sum_sq_cost"] / pos_stats["count"]) - (mean_cost ** 2)
                std_dev = math.sqrt(max(0, variance))
                std_error = std_dev / math.sqrt(pos_stats["count"])
                return max(EPSILON_VD / 10, mean_cost - beta * std_error)
        # Fallback to overall stats
        if model_name in self.stats:
            overall_stats = self.stats[model_name]["overall"]
            if overall_stats["count"] >= 2:
                mean_cost = overall_stats["total_cost"] / overall_stats["count"]
                variance = (overall_stats["sum_sq_cost"] / overall_stats["count"]) - (mean_cost ** 2)
                std_dev = math.sqrt(max(0, variance))
                std_error = std_dev / math.sqrt(overall_stats["count"])
                return max(EPSILON_VD / 10, mean_cost - beta * std_error)
        return default_cost

    def save_state(self, file_path):
        try:
            save_dict = {}
            for model_name, model_stats in self.stats.items():
                overall = model_stats["overall"]
                save_dict[f"overall_count_{model_name}"] = overall["count"]
                save_dict[f"overall_total_cost_{model_name}"] = overall["total_cost"]
                save_dict[f"overall_sum_sq_cost_{model_name}"] = overall["sum_sq_cost"]
                for pos, pos_stats in model_stats["by_position"].items():
                    save_dict[f"pos{pos}_count_{model_name}"] = pos_stats["count"]
                    save_dict[f"pos{pos}_total_cost_{model_name}"] = pos_stats["total_cost"]
                    save_dict[f"pos{pos}_sum_sq_cost_{model_name}"] = pos_stats["sum_sq_cost"]
            np.savez_compressed(file_path, **save_dict)
            return True
        except Exception as e:
            print(f"Error saving model cost estimator state: {e}")
            return False

    def load_state(self, file_path):
        try:
            loaded = np.load(file_path)
            for model_name in self.stats:
                if f"overall_count_{model_name}" in loaded:
                    overall = self.stats[model_name]["overall"]
                    overall["count"] = int(loaded[f"overall_count_{model_name}"])
                    overall["total_cost"] = float(loaded[f"overall_total_cost_{model_name}"])
                    overall["sum_sq_cost"] = float(loaded[f"overall_sum_sq_cost_{model_name}"])
                for pos in self.stats[model_name]["by_position"]:
                    if f"pos{pos}_count_{model_name}" in loaded:
                        pos_stats = self.stats[model_name]["by_position"][pos]
                        pos_stats["count"] = int(loaded[f"pos{pos}_count_{model_name}"])
                        pos_stats["total_cost"] = float(loaded[f"pos{pos}_total_cost_{model_name}"])
                        pos_stats["sum_sq_cost"] = float(loaded[f"pos{pos}_sum_sq_cost_{model_name}"])
            print(f"Loaded model cost estimator state from {file_path}")
            return True
        except Exception as e:
            print(f"Error loading model cost estimator state: {e}")
            return False


class FeatureExtractor:
    def __init__(self, feature_dim=BASE_FEATURE_DIM, use_embeddings=USE_EMBEDDINGS):
        self.feature_dim = feature_dim
        self.use_embeddings = use_embeddings
        if use_embeddings:
            try:
                self.embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5')
                print("Initialized sentence transformer embedding model")
            except Exception as e:
                print(f"Error initializing sentence transformer: {e}")
                print("Falling back to TF-IDF vectorization")
                self.use_embeddings = False
        if not self.use_embeddings:
            # Initialize TF-IDF vectorizer without restricting features yet
            self.vectorizer = TfidfVectorizer()
            # We'll use SVD for dimensionality reduction later
            self.svd = None

        self.initialized = False

    def initialize(self, questions):
        """Initialize the vectorizer with the corpus of questions"""
        if not self.use_embeddings:
            # Combine question text and options for each question
            all_text = [q["problem"] for q in questions] # Use "problem" field, no "options"


            # Fit the vectorizer on all training texts
            self.vectorizer.fit(all_text)

            # Get the document-term matrix for the entire corpus
            dtm = self.vectorizer.transform(all_text)

            # Use SVD for dimensionality reduction to get feature_dim dense features
            from sklearn.decomposition import TruncatedSVD

            # Determine the appropriate dimension to reduce to (min of feature_dim and actual feature count)
            n_components = min(self.feature_dim, dtm.shape[1])

            self.svd = TruncatedSVD(n_components=n_components)
            self.svd.fit(dtm)

            print(f"Using TF-IDF with SVD dimensionality reduction to {n_components} features")
            print(f"Explained variance ratio: {sum(self.svd.explained_variance_ratio_):.4f}")

        self.initialized = True

    def extract_features(self, question):
        """Extract features from a question and its options"""
        if not self.initialized:
            raise ValueError("Feature extractor not initialized. Call initialize() first.")

        text = question["problem"] # Use "problem" field, no "options"

        if self.use_embeddings:
            # Use sentence transformer for embeddings
            features = self.embedding_model.encode([text])[0]
        else:
            # First get TF-IDF vector
            tfidf_vector = self.vectorizer.transform([text])
            # Then apply SVD transformation
            features = self.svd.transform(tfidf_vector)[0]
            # Ensure we have exactly feature_dim dimensions
            if len(features) < self.feature_dim:
                padding = np.zeros(self.feature_dim - len(features))
                features = np.concatenate([features, padding])
        return features

    def extract_answer_features(self, answer_text):
        """Extract features from an answer string"""
        if not answer_text:
            return np.zeros(ANSWER_EMBED_DIM)

        if self.use_embeddings:
            try:
                # Use sentence transformer for embeddings
                features = self.embedding_model.encode([answer_text])[0]

                # Handle if the embedding doesn't match the expected dimension
                if len(features) != ANSWER_EMBED_DIM:
                    if len(features) > ANSWER_EMBED_DIM:
                        features = features[:ANSWER_EMBED_DIM]
                    else:
                        padding = np.zeros(ANSWER_EMBED_DIM - len(features))
                        features = np.concatenate([features, padding])

                return features
            except Exception as e:
                print(f"Error embedding answer: {e}")
                return np.zeros(ANSWER_EMBED_DIM)
        else:
            # For simplicity, use zero vector if not using embeddings
            return np.zeros(ANSWER_EMBED_DIM)

    def construct_feature_vector(self, base_features, step_i, failed_answers, failed_llm_ids, model_name_to_index):
        """
        Construct the augmented feature vector for LinUCB

        Parameters:
        - base_features: Features extracted from the question
        - step_i: Current step in the cascade (1-indexed)
        - failed_answers: List of previous failed answer strings
        - failed_llm_ids: List of previous failed LLM IDs
        - model_name_to_index: Mapping from model names to indices

        Returns:
        - Augmented feature vector
        """
        # Normalize the step (1-indexed -> [0,1] range)
        normalized_step = np.array([step_i / CASCADE_LENGTH])

        # Initialize context features - last answer embedding
        if step_i == 1 or not failed_answers:
            # No prior context for the first step
            last_answer_features = np.zeros(ANSWER_EMBED_DIM)
        else:
            # Extract features from the most recent failed answer
            last_answer = failed_answers[-1]
            last_answer_features = self.extract_answer_features(last_answer)

        # Calculate Last Failed LLM ID One-Hot encoding
        last_llm_onehot = np.zeros(LLM_ID_DIM)
        if step_i > 1 and failed_llm_ids:
            last_llm_name = failed_llm_ids[-1]
            if last_llm_name in model_name_to_index:
                last_llm_index = model_name_to_index[last_llm_name]
                last_llm_onehot[last_llm_index] = 1.0

        # Calculate Average Failed Answer Embedding
        avg_answer_features = np.zeros(ANSWER_EMBED_DIM)
        if step_i > 1 and failed_answers:
            all_answer_features = [self.extract_answer_features(ans) for ans in failed_answers]
            if all_answer_features:
                avg_answer_features = np.mean(all_answer_features, axis=0)

        # Handle case where avg_answer_features might be a scalar
        if avg_answer_features.shape == ():
            avg_answer_features = np.zeros(ANSWER_EMBED_DIM)

        # Concatenate all context features
        context_features = np.concatenate([
            last_answer_features,    # Dim: ANSWER_EMBED_DIM
            last_llm_onehot,         # Dim: LLM_ID_DIM
            avg_answer_features      # Dim: ANSWER_EMBED_DIM
        ])

        # Final feature vector
        augmented_features = np.concatenate([
            base_features,           # Dim: BASE_FEATURE_DIM
            normalized_step,         # Dim: 1
            context_features         # Dim: CONTEXT_FEATURE_DIM
        ])

        # Ensure the final dimension matches expectations
        if augmented_features.shape[0] != TOTAL_FEATURE_DIM:
            # This should not happen if dimensions are calculated correctly
            raise ValueError(f"Constructed feature vector dimension {augmented_features.shape[0]} != expected {TOTAL_FEATURE_DIM}")

        return augmented_features


class LinUCBModel:
    def __init__(self, model_names, feature_dim=TOTAL_FEATURE_DIM, alpha=ALPHA,
                 lambda_reg=LAMBDA_REG, beta_cost=BETA_COST,
                 default_fallback_cost=DEFAULT_FALLBACK_COST):
        self.model_names = model_names
        self.feature_dim = feature_dim
        self.alpha = alpha
        self.lambda_reg = lambda_reg
        self.beta_cost = beta_cost
        self.default_fallback_cost = default_fallback_cost

        self.model_name_to_index = {name: i for i, name in enumerate(model_names)}

        # Initialize cost estimator
        self.cost_estimator = ModelCostEstimator(model_names)

        # Initialize model parameters
        self.models = {}
        for model_name in model_names:
            self.models[model_name] = {
                'A': np.identity(feature_dim) * lambda_reg,
                'b': np.zeros(feature_dim),
                'last_call_time': 0
            }

    def update_reward_only(self, model_name, feature_vector, reward):
        """
        Update the LinUCB model parameters based on observed reward
        Parameters:
        - model_name: Name of the model to update
        - feature_vector: Feature vector used for selection
        - reward: Binary reward (1 for success, 0 for failure)
        """
        model = self.models[model_name]
        # Update A matrix
        model['A'] += np.outer(feature_vector, feature_vector)
        # Update b vector
        model['b'] += feature_vector * reward

    def get_predicted_reward(self, model_name, feature_vector):
        """
        Get just the predicted reward (p_ia) for a given model and feature vector

        Parameters:
        - model_name: Name of the model
        - feature_vector: Feature vector for prediction

        Returns:
        - Predicted reward (p_ia)
        """
        model = self.models[model_name]
        try:
            # Calculate theta using the model's A matrix and b vector
            theta = np.linalg.solve(model['A'], model['b'])
            # Calculate predicted reward
            p_ia = feature_vector.dot(theta)
            return float(p_ia)
        except np.linalg.LinAlgError:
            try:
                # Alternative way to calculate theta if direct solve fails
                A_inv = np.linalg.inv(model['A'])
                theta = A_inv.dot(model['b'])
                p_ia = feature_vector.dot(theta)
                return float(p_ia)
            except:
                # Return 0.0 if calculation completely fails
                return 0.0

    def calculate_ucb_scores(self, feature_vector):
        """
        Calculate UCB scores for model selection

        Parameters:
        - feature_vector: Feature vector used for model selection

        Returns a dictionary with scores for each model
        """
        scores = {}

        for model_name in self.model_names:
            model = self.models[model_name]
            try:
                L = np.linalg.cholesky(model['A'])
                theta = np.linalg.solve(model['A'], model['b'])
                z = np.linalg.solve(L, feature_vector)
                ucb_term = self.alpha * np.sqrt(np.sum(z**2))
                expected_reward = feature_vector.dot(theta)
                ucb_score = expected_reward + ucb_term
                scores[model_name] = {
                    "p_ia": float(expected_reward),
                    "e_ia": float(ucb_term),
                    "ucb_score": float(ucb_score)
                }
            except np.linalg.LinAlgError:
                try:
                    A_inv = np.linalg.inv(model['A'])
                    theta = A_inv.dot(model['b'])
                    ucb_term = self.alpha * np.sqrt(feature_vector.dot(A_inv).dot(feature_vector))
                    expected_reward = feature_vector.dot(theta)
                    ucb_score = expected_reward + ucb_term
                    scores[model_name] = {
                        "p_ia": float(expected_reward),
                        "e_ia": float(ucb_term),
                        "ucb_score": float(ucb_score)
                    }
                except:
                    scores[model_name] = {
                        "p_ia": 0.0,
                        "e_ia": 0.0,
                        "ucb_score": 0.0
                    }
        return scores

    def register_model_call(self, model_name):
        """Register that a model was called and update its last call time"""
        self.models[model_name]['last_call_time'] = time.time()
    def respect_rate_limit(self, model_name):
        """Wait if necessary to respect the model's rate limit."""
        # MODELS_CONFIG now refers to the OpenRouter models config, which doesn't have 'rpm'
        model_cfg = MODELS_CONFIG.get(model_name)

        # Only apply client-side RPM-based waiting if 'rpm' is defined for the model.
        # For the specified OpenRouter models, 'rpm' is not defined, so this block will be skipped.
        if model_cfg and "rpm" in model_cfg:
            model_state = self.models[model_name] # 'model' is a local var, 'self.models' holds states
            rpm = model_cfg["rpm"]

            min_seconds_between_calls = 60.0 / rpm

            time_since_last_call = time.time() - model_state['last_call_time']
            if time_since_last_call < min_seconds_between_calls:
                sleep_time = min_seconds_between_calls - time_since_last_call
                time.sleep(sleep_time)
    def save_model_state(self, file_path):
        """Save the model state to a file using numpy's compressed format"""
        save_dict = {}
        for model_name, model in self.models.items():
            save_dict[f'A_{model_name}'] = model['A']
            save_dict[f'b_{model_name}'] = model['b']

        # Save to compressed numpy format
        np.savez_compressed(file_path, **save_dict)

        # Save cost estimator state
        cost_estimator_path = file_path.replace('.npz', '_cost_estimator.npz')
        self.cost_estimator.save_state(cost_estimator_path)

    def load_model_state(self, file_path):
        """Load the model state from a file"""
        try:
            loaded = np.load(file_path)

            for model_name in self.models.keys():
                if f'A_{model_name}' in loaded and f'b_{model_name}' in loaded:
                    self.models[model_name]['A'] = loaded[f'A_{model_name}']
                    self.models[model_name]['b'] = loaded[f'b_{model_name}']

            print(f"Loaded LinUCB model state from {file_path}")

            # Load cost estimator state
            cost_estimator_path = file_path.replace('.npz', '_cost_estimator.npz')
            if os.path.exists(cost_estimator_path):
                self.cost_estimator.load_state(cost_estimator_path)

            return True

        except Exception as e:
            print(f"Error loading model state: {e}")
            return False



class BatchBudgetCascade:
    def __init__(self, feature_extractor, linucb_model, cascade_length=CASCADE_LENGTH):
        self.feature_extractor = feature_extractor
        self.linucb_model = linucb_model
        self.cascade_length = cascade_length
        self.used_models_for_this_question = []

    def grade_with_gemma12b(self, llm_raw_answer, ground_truth_latex):
        """Grade an answer against the ground truth using the grader model via OpenRouter."""
        if not llm_raw_answer or not llm_raw_answer.strip():
            return False
        if llm_raw_answer == ground_truth_latex:
            return True

        prompt = f"Expression 1: {llm_raw_answer}\nExpression 2: {ground_truth_latex}\n\n"
        prompt += "Expression 2 is the answer and expression is attempt by student, look at their final answer only which might be boxed, does student get the final expected answer？ Respond with only the word 'True' or 'False'."
        try:
            time.sleep(0.5) # Simple rate limiting for the grader
            api_response = openrouter_client.chat.completions.create(
                model=GRADER_MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0, max_tokens=10
            )
            grader_response_text = api_response.choices[0].message.content.strip().lower()
            if "true" in grader_response_text: return True
            elif "false" in grader_response_text: return False
            else:
                print(f"Warning: Grader returned ambiguous response: {grader_response_text}")
                return False
        except Exception as e:
            print(f"Error calling grader model {GRADER_MODEL_NAME} via OpenRouter: {e}")
            return False

    def format_prompt(self, question, failed_answers=None, failed_llm_ids=None):
        prompt = f"Solve the following math problem: {question['problem']}\n\n"
        prompt += "Also provide an explnation in one/serveral very short yet consice complete sentence within 75 words in total.\n"
        prompt += "At the end, clearly state your final answer in LaTeX format, enclosed within \\boxed{}.\n"
        prompt += "For example: 'The final answer is \\boxed{x=5}'."
        if failed_answers and failed_llm_ids:
            prompt += "\nNote: The following previous attempts were incorrect. Please provide a different solution:\n"
            for i in range(min(len(failed_answers), len(failed_llm_ids))):
                prompt += f"- Attempt {i+1} (by {failed_llm_ids[i]}) led to: {failed_answers[i]}\n"
        return prompt

    def calculate_token_cost(self, model_name, prompt, response_text, usage_info=None):
        """Calculate the actual cost. For OpenRouter, relies on usage_info from API response."""
        total_cost = 0.0
        error_message = None
        if model_name in MODELS_CONFIG:
            model_cfg = MODELS_CONFIG[model_name]
            if usage_info:
                input_tokens = usage_info.get("prompt_tokens", 0)
                output_tokens = usage_info.get("completion_tokens", 0)
                total_cost = (input_tokens * model_cfg["input_cost"]) + (output_tokens * model_cfg["output_cost"])
            else:
                error_message = f"Usage info not available for {model_name}. Cost is a rough estimate."
                input_tokens = len(prompt) // 4
                output_tokens = len(response_text) // 4 if response_text else 0
                total_cost = (input_tokens * model_cfg["input_cost"]) + (output_tokens * model_cfg["output_cost"])
        else:
            error_message = f"Model {model_name} not found in config for cost calculation."
        result = {"total_cost": total_cost}
        if error_message:
            result["error"] = error_message
            print(f"Cost calculation warning for {model_name}: {error_message}")
        return result

    def query_llm(self, model_name, prompt):
        """Query the specified LLM and return its response and cost."""
        answer_text, cost_data = "", {}
        try:
            self.linucb_model.respect_rate_limit(model_name)
            if model_name in MODELS_CONFIG:
                api_response = openrouter_client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                )
                answer_text = api_response.choices[0].message.content.strip()
                usage_info = api_response.usage.model_dump() if api_response.usage else None
                cost_data = self.calculate_token_cost(model_name, prompt, answer_text, usage_info=usage_info)
            else:
                raise ValueError(f"Model {model_name} is not configured in MODELS_CONFIG.")
            self.linucb_model.register_model_call(model_name)
            return answer_text, answer_text, cost_data # Returns (raw_response, parsed_answer, cost_data)
        except Exception as e:
            print(f"Error querying LLM {model_name}: {e}")
            self.linucb_model.register_model_call(model_name)
            cost_data = self.calculate_token_cost(model_name, prompt, "", usage_info=None)
            return "", None, cost_data

    def run_cascade_single_question(self, question, question_cost_estimator, is_training=False, initial_question_budget=None):
        """
        Run the cascade with knapsack-based model selection for a single question.
        - is_training: If True, use infinite budget and no LLM reselection.
        - initial_question_budget: Budget for this question (used in testing).
        """
        base_features = self.feature_extractor.extract_features(question)
        failed_answers, failed_llm_ids, self.used_models_for_this_question = [], [], []
        question_total_cost, current_attempts_log = 0.0, []
        final_status = "Failure"

        # Set budget based on training/testing mode
        if is_training:
            current_remaining_budget = float('inf')
            print(f"Training Mode - Unlimited Budget")
        else:
            current_remaining_budget = initial_question_budget if initial_question_budget is not None else question_cost_estimator.estimate_question_lcb_cost()
            print(f"Test Mode - Initial Budget: ${current_remaining_budget:.8f}")

        # Knapsack-based Cascade Loop
        for i in range(1, self.cascade_length + 1):
            print(f"Step {i}")
            if not is_training and current_remaining_budget <= EPSILON_VD:
                print(f"  Insufficient budget remaining (${current_remaining_budget:.8f}). Aborting cascade.")
                break

            x_i = self.feature_extractor.construct_feature_vector(base_features, i, failed_answers, failed_llm_ids, self.linucb_model.model_name_to_index)
            available_llms_for_step = [m for m in AVAILABLE_LLMS if m not in self.used_models_for_this_question] if is_training else list(AVAILABLE_LLMS)

            # Prepare items for knapsack by estimating value (UCB) and weight (LCB cost)
            knapsack_items, llm_info_for_step = [], {}
            ucb_scores = self.linucb_model.calculate_ucb_scores(x_i)
            knapsack_capacity = float('inf') if is_training else current_remaining_budget

            for model_name in available_llms_for_step:
                if model_name not in ucb_scores: continue
                lcb_cost = self.linucb_model.cost_estimator.estimate_lcb_cost(model_name, i)
                if is_training or lcb_cost <= knapsack_capacity:
                    knapsack_items.append({"name": model_name, "value": ucb_scores[model_name]["ucb_score"], "weight": lcb_cost})
                llm_info_for_step[model_name] = {"ucb_score": ucb_scores[model_name]["ucb_score"], "lcb_cost": lcb_cost}

            if not knapsack_items:
                print("  No models fit within the remaining budget for this step.")
                continue

            # In training, select best UCB; in testing, use knapsack solver
            if is_training:
                chosen_llm_names = [max(knapsack_items, key=lambda x: x['value'])['name']]
            else:
                chosen_llm_names = solve_0_1_knapsack(knapsack_items, knapsack_capacity)

            if not chosen_llm_names:
                print("  Knapsack solution is empty. No combination fits within budget.")
                continue

            # Execute chosen subset sequentially, ordered by UCB score
            step_successful = False
            chosen_llms_sorted = sorted(chosen_llm_names, key=lambda name: llm_info_for_step[name]["ucb_score"], reverse=True)
            print(f"  Attempting LLMs in UCB score order: {chosen_llms_sorted}")

            for model_name in chosen_llms_sorted:
                lcb_cost = llm_info_for_step[model_name]["lcb_cost"]
                if not is_training and current_remaining_budget < lcb_cost:
                    print(f"    Skipping {model_name}: Estimated cost (${lcb_cost:.8f}) exceeds remaining budget (${current_remaining_budget:.8f}).")
                    continue

                prompt = self.format_prompt(question, failed_answers, failed_llm_ids)
                raw_response, parsed_answer, cost_data = self.query_llm(model_name, prompt)
                actual_cost = cost_data.get("total_cost", 0.0)
                question_total_cost += actual_cost
                if not is_training:
                    current_remaining_budget -= actual_cost

                self.linucb_model.cost_estimator.update_stats(model_name, i, actual_cost)
                is_correct = self.grade_with_gemma12b(raw_response, question['answer'])
                reward = 1 if is_correct else 0
                self.linucb_model.update_reward_only(model_name, x_i, reward)
                self.used_models_for_this_question.append(model_name)
                print(f"    {model_name}: Correct={is_correct}, Cost=${actual_cost:.8f}")

                if is_correct:
                    step_successful = True
                    final_status = "Success"
                    break # Success, break from inner sequential loop

            current_attempts_log.append({"step": i, "chosen_subset": chosen_llm_names, "is_successful": step_successful})
            if step_successful:
                break # Success, break from outer cascade loop

        question_cost_estimator.update_stats(question_total_cost)
        return {
            "problem": question["problem"], "ground_truth_answer": question["answer"],
            "final_status": final_status, "total_cost": question_total_cost,
            "initial_question_budget": initial_question_budget if not is_training else "Unlimited",
            "steps_taken": len(current_attempts_log), "knapsack_steps": current_attempts_log,
            "training_mode": is_training
        }

In [None]:

def load_dataset(json_path):
    """Load the MCQ dataset from a JSON file with robust error handling"""
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            raw_data = json.load(f) # Load the raw JSON object

        if "test" in raw_data and isinstance(raw_data["test"], list):
            data = raw_data["test"] # Extract the list of questions from the "test" key
        else:
            print(f"Warning: 'test' key not found or not a list in {json_path}. Assuming flat list structure or empty.")
            data = raw_data if isinstance(raw_data, list) else [] # Fallback or handle as error
        return data
    except FileNotFoundError:
        print(f"Dataset file {json_path} not found.")
        return []
    except json.JSONDecodeError:
        print(f"Error decoding JSON from {json_path}. File may be corrupted.")
        try:
            with open(json_path, 'r') as f:
                content = f.read()

            # Simple recovery attempt - find all complete JSON objects
            import re
            pattern = r'\{[^{}]*\}'
            matches = re.findall(pattern, content)

            if matches:
                print(f"Attempted to recover {len(matches)} JSON objects.")
                recovered_data = []
                for match in matches:
                    try:
                        obj = json.loads(match)
                        recovered_data.append(obj)
                    except:
                        pass
                return recovered_data
        except:
            pass

        return []

def save_records_with_backup(records, json_path):
    """Save records to a JSON file with backup of previous file"""
    # Create backup of existing file if it exists
    if os.path.exists(json_path):
        backup_path = json_path + BACKUP_SUFFIX
        try:
            os.replace(json_path, backup_path)
        except Exception as e:
            print(f"Failed to create backup: {e}")

    # Save new data
    try:
        with open(json_path, 'w') as f:
            json.dump(records, f, indent=4)
        return True
    except Exception as e:
        print(f"Error saving records: {e}")

        # Try to restore from backup if save failed
        if os.path.exists(json_path + BACKUP_SUFFIX):
            try:
                os.replace(json_path + BACKUP_SUFFIX, json_path)
            except:
                pass

        return False

def initialize_json_files():
    """Initialize the JSON files for records with validation"""
    if not os.path.exists(RECORDS_PATH):
        with open(RECORDS_PATH, 'w') as f:
            json.dump([], f)  # Empty array
    else:
        # Validate existing file
        try:
            with open(RECORDS_PATH, 'r') as f:
                data = json.load(f)
            if not isinstance(data, list):
                os.rename(RECORDS_PATH, RECORDS_PATH + BACKUP_SUFFIX)
                with open(RECORDS_PATH, 'w') as f:
                    json.dump([], f)
        except json.JSONDecodeError:
            os.rename(RECORDS_PATH, RECORDS_PATH + BACKUP_SUFFIX)
            with open(RECORDS_PATH, 'w') as f:
                json.dump([], f)

def analyze_results(records_to_analyze, question_cost_estimator=None):
    """Analyze and print results from the records"""
    if not records_to_analyze:
        print("No records to analyze")
        return

    # Extract data for analysis
    total_questions = len(records_to_analyze)
    successful_questions = sum(1 for r in records_to_analyze if r["final_status"] == "Success")
    success_rate = successful_questions / total_questions if total_questions > 0 else 0

    # Calculate success by position (knapsack step)
    successes_by_position = [0] * CASCADE_LENGTH
    for record in records_to_analyze:
        if record["final_status"] == "Success":
            # Find the successful step (step is 1-indexed)
            for i, step in enumerate(record["knapsack_steps"]):
                if step["is_successful"]:
                    # Convert to 0-indexed for array
                    position = i
                    if position < CASCADE_LENGTH:
                        successes_by_position[position] += 1
                    break

    per_position_success = [count/total_questions for count in successes_by_position]

    # Calculate average steps and costs
    total_steps = sum(r["steps_taken"] for r in records_to_analyze)
    total_cost = sum(r["total_cost"] for r in records_to_analyze)

    avg_steps = total_steps / total_questions if total_questions > 0 else 0
    avg_cost = total_cost / total_questions if total_questions > 0 else 0

    # Calculate average cost for successful questions only
    if successful_questions > 0:
        success_cost = sum(r["total_cost"] for r in records_to_analyze if r["final_status"] == "Success")
        avg_cost_success = success_cost / successful_questions
    else:
        avg_cost_success = 0

    # Analyze model performance - modified for knapsack approach
    model_metrics = {model: {"calls": 0, "successes": 0, "total_cost": 0.0, "by_position": defaultdict(lambda: {"calls": 0, "successes": 0, "total_cost": 0.0})}
                    for model in AVAILABLE_LLMS}

    # Track LCB cost estimation accuracy
    lcb_cost_accuracy = {model: {"total_lcb_estimate": 0.0, "total_actual_cost": 0.0, "count": 0}
                         for model in AVAILABLE_LLMS}

    # Track knapsack metrics
    knapsack_metrics = {
        "total_steps": 0,
        "avg_models_per_step": 0,
        "models_selected_count": 0,
        "budget_utilization": 0.0,
        "total_budget_allocated": 0.0,
        "by_position": defaultdict(lambda: {"steps": 0, "models_selected": 0, "budget_allocated": 0.0, "actual_cost": 0.0})
    }

    for record in records_to_analyze:
        for step in record["knapsack_steps"]:
            position = step["step"]  # 1-indexed
            knapsack_metrics["total_steps"] += 1
            knapsack_metrics["models_selected_count"] += len(step["chosen_subset_by_knapsack"])

            # Handle the case where step_budget can be "Unlimited" string
            if step["knapsack_capacity_for_step"] != "Unlimited":
                step_budget = float(step["knapsack_capacity_for_step"])
                knapsack_metrics["total_budget_allocated"] += step_budget

                # Update position-specific budget stats
                pos_stats = knapsack_metrics["by_position"][position]
                pos_stats["budget_allocated"] += step_budget
            else:
                # Skip budget calculations for unlimited budget steps
                step_budget = float('inf')

            # Update position-specific knapsack stats
            pos_stats = knapsack_metrics["by_position"][position]
            pos_stats["steps"] += 1
            pos_stats["models_selected"] += len(step["chosen_subset_by_knapsack"])
            pos_stats["actual_cost"] += step["actual_cost_for_step"]

            # Process individual LLM attempts
            for attempt in step["llm_attempts"]:
                model = attempt["knapsack_model"]
                is_correct = attempt["is_correct"]
                cost = attempt["actual_cost"]

                # Track LCB cost estimation accuracy
                if "lcb_cost_estimate" in attempt:
                    lcb_cost_accuracy[model]["total_lcb_estimate"] += attempt["lcb_cost_estimate"]
                    lcb_cost_accuracy[model]["total_actual_cost"] += cost
                    lcb_cost_accuracy[model]["count"] += 1

                # Update overall model stats
                model_metrics[model]["calls"] += 1
                model_metrics[model]["total_cost"] += cost
                if is_correct:
                    model_metrics[model]["successes"] += 1

                # Update position-specific stats
                model_metrics[model]["by_position"][position]["calls"] += 1
                model_metrics[model]["by_position"][position]["total_cost"] += cost
                if is_correct:
                    model_metrics[model]["by_position"][position]["successes"] += 1

    # Calculate average models per step
    if knapsack_metrics["total_steps"] > 0:
        knapsack_metrics["avg_models_per_step"] = knapsack_metrics["models_selected_count"] / knapsack_metrics["total_steps"]
        if knapsack_metrics["total_budget_allocated"] > 0:
            knapsack_metrics["budget_utilization"] = sum(r["total_cost"] for r in records_to_analyze) / knapsack_metrics["total_budget_allocated"]

    # Calculate success rates and average costs for models
    for model, data in model_metrics.items():
        if data["calls"] > 0:
            data["success_rate"] = data["successes"] / data["calls"]
            data["avg_cost"] = data["total_cost"] / data["calls"]
        else:
            data["success_rate"] = 0
            data["avg_cost"] = 0

        for position, pos_data in data["by_position"].items():
            if pos_data["calls"] > 0:
                pos_data["success_rate"] = pos_data["successes"] / pos_data["calls"]
                pos_data["avg_cost"] = pos_data["total_cost"] / pos_data["calls"]
            else:
                pos_data["success_rate"] = 0
                pos_data["avg_cost"] = 0

    # Calculate LCB estimation accuracy
    for model, data in lcb_cost_accuracy.items():
        if data["count"] > 0:
            data["avg_lcb_estimate"] = data["total_lcb_estimate"] / data["count"]
            data["avg_actual_cost"] = data["total_actual_cost"] / data["count"]
            data["lcb_accuracy"] = data["avg_lcb_estimate"] / data["avg_actual_cost"] if data["avg_actual_cost"] > 0 else 0
        else:
            data["avg_lcb_estimate"] = 0
            data["avg_actual_cost"] = 0
            data["lcb_accuracy"] = 0

    # Calculate position-specific knapsack metrics
    for position, data in knapsack_metrics["by_position"].items():
        if data["steps"] > 0:
            data["avg_models_per_step"] = data["models_selected"] / data["steps"]

            if data["budget_allocated"] > 0:
                data["avg_budget"] = data["budget_allocated"] / data["steps"]
                data["budget_utilization"] = data["actual_cost"] / data["budget_allocated"]
            else:
                data["avg_budget"] = "Unlimited"
                data["budget_utilization"] = 0

            data["avg_cost"] = data["actual_cost"] / data["steps"]

    # Analyze batch performance
    batch_metrics = []
    for i in range(0, len(records_to_analyze), BATCH_SIZE):
        batch = records_to_analyze[i:i+BATCH_SIZE]
        batch_success = sum(1 for r in batch if r["final_status"] == "Success")
        batch_success_rate = batch_success / len(batch) if batch else 0
        batch_cost = sum(r["total_cost"] for r in batch)
        batch_avg_cost = batch_cost / len(batch) if batch else 0

        batch_metrics.append({
            "batch_idx": i // BATCH_SIZE,
            "batch_size": len(batch),
            "success_count": batch_success,
            "success_rate": batch_success_rate,
            "total_cost": batch_cost,
            "avg_cost": batch_avg_cost
        })

    # Generate summary text
    dataset_type = "TRAIN" if records_to_analyze and records_to_analyze[0].get("dataset") == "train" else "TEST" if records_to_analyze and records_to_analyze[0].get("dataset") == "test" else "OVERALL"
    summary = f"=== LinUCB CASCADE WITH KNAPSACK SELECTION - {dataset_type} SET RESULTS ===\n\n"
    summary += f"Batch Size: {BATCH_SIZE}\n"
    summary += f"LLM Cost Beta: {BETA_COST}\n"
    summary += f"Question Cost Beta: {BETA_COST_QUESTION}\n"
    summary += f"Cascade Length (K): {CASCADE_LENGTH}\n\n"

    summary += "=== OVERALL RESULTS ===\n"
    summary += f"Total Questions: {total_questions}\n"
    summary += f"Success Rate: {success_rate:.4f}\n"
    summary += f"Average Steps: {avg_steps:.4f}\n"
    summary += f"Average Cost per Question: ${avg_cost:.8f}\n"
    summary += f"Average Cost per Successful Question: ${avg_cost_success:.8f}\n"
    summary += "Success Rate by Position:\n"
    for i, rate in enumerate(per_position_success):
        summary += f"  Position {i+1}: {rate:.4f}\n"

    # Add question cost estimation info if available
    if question_cost_estimator and question_cost_estimator.stats["count"] > 0:
        avg_question_cost = question_cost_estimator.stats["total_cost"] / question_cost_estimator.stats["count"]
        estimated_lcb = question_cost_estimator.estimate_question_lcb_cost()

        summary += "\n=== QUESTION COST ESTIMATION ===\n"
        summary += f"Total Questions Processed: {question_cost_estimator.stats['count']}\n"
        summary += f"Average Question Cost: ${avg_question_cost:.8f}\n"
        summary += f"Current LCB Question Cost Estimate: ${estimated_lcb:.8f}\n"

    # Add knapsack metrics
    summary += "\n=== KNAPSACK METRICS ===\n"
    summary += f"Average Models Selected per Step: {knapsack_metrics['avg_models_per_step']:.4f}\n"
    if knapsack_metrics["total_budget_allocated"] > 0:
        summary += f"Overall Budget Utilization: {knapsack_metrics['budget_utilization']:.4f}\n"
    summary += "By Position:\n"
    for position, data in sorted(knapsack_metrics["by_position"].items()):
        if data["steps"] > 0:
            summary += f"  Position {position}: {data['avg_models_per_step']:.2f} models/step, "
            if data["avg_budget"] != "Unlimited":
                summary += f"Budget=${data['avg_budget']:.8f}, "
                summary += f"Utilization={data['budget_utilization']:.4f}, "
            else:
                summary += f"Budget=Unlimited, "
            summary += f"Cost=${data['avg_cost']:.8f}\n"

    summary += "\n=== MODEL PERFORMANCE ===\n"
    for model, metrics in model_metrics.items():
        if metrics["calls"] > 0:
            summary += f"\n{model}:\n"
            summary += f"  Overall: {metrics['successes']}/{metrics['calls']} = {metrics['success_rate']:.4f}\n"
            summary += f"  Average Cost: ${metrics['avg_cost']:.8f}\n"
            summary += "  By Position:\n"
            for position, pos_data in sorted(metrics["by_position"].items()):
                if pos_data["calls"] > 0:
                    summary += f"    Pos {position}: {pos_data['successes']}/{pos_data['calls']} = {pos_data['success_rate']:.4f}, Avg Cost: ${pos_data['avg_cost']:.8f}\n"

    summary += "\n=== LLM LCB COST ESTIMATION ACCURACY ===\n"
    for model, data in lcb_cost_accuracy.items():
        if data["count"] > 0:
            summary += f"{model}: Avg LCB Est: ${data['avg_lcb_estimate']:.8f}, Avg Actual: ${data['avg_actual_cost']:.8f}, Ratio: {data['lcb_accuracy']:.4f}\n"

    summary += "\n=== BATCH PERFORMANCE ===\n"
    for batch in batch_metrics:
        summary += f"Batch {batch['batch_idx']}: {batch['success_count']}/{batch['batch_size']} = {batch['success_rate']:.4f}, Total Cost: ${batch['total_cost']:.8f}, Avg Cost: ${batch['avg_cost']:.8f}\n"

    summary += f"\nTotal Overall Cost (All Questions): ${total_cost:.8f}\n"

    # Print and save summary
    print(summary)

    # Save to appropriate file based on dataset type
    summary_file_path = SUMMARY_STATS_PATH.replace(".txt", f"_{dataset_type}.txt")
    with open(summary_file_path, 'w') as f:
        f.write(summary)


def select_questions_by_lcb_cost(remaining_questions, question_cost_estimator, batch_size, pool_multiplier=QUESTION_POOL_MULTIPLIER):
    """
    Select the next batch of questions based on LCB cost estimates

    Parameters:
    - remaining_questions: List of unprocessed questions
    - question_cost_estimator: The estimator for question costs
    - batch_size: Number of questions to select
    - pool_multiplier: How many times batch_size to consider

    Returns:
    - Selected questions for the next batch
    """
    # Determine the pool size - consider at least batch_size questions
    pool_size = min(len(remaining_questions), batch_size * pool_multiplier)
    if pool_size <= batch_size:
        # If we don't have enough questions for a meaningful pool, return the first batch_size
        return remaining_questions[:batch_size]

    # Get the candidate pool
    candidate_pool = remaining_questions[:pool_size]

    # Calculate LCB cost estimate for each question
    question_costs = []
    for i, question in enumerate(candidate_pool):
        lcb_cost = question_cost_estimator.estimate_question_lcb_cost()
        question_costs.append((i, lcb_cost))

    # Sort by LCB cost (ascending)
    question_costs.sort(key=lambda x: x[1])

    # Get the indices of the lowest-cost questions
    selected_indices = [idx for idx, _ in question_costs[:batch_size]]
    selected_indices.sort()  # Sort indices to preserve original order

    # Select the questions
    selected_questions = [candidate_pool[idx] for idx in selected_indices]

    # Get the remaining questions (those not in the pool + those in pool but not selected)
    non_selected_pool_indices = set(range(pool_size)) - set(selected_indices)
    non_selected_pool = [candidate_pool[idx] for idx in non_selected_pool_indices]
    remaining_after_pool = remaining_questions[pool_size:]

    # Update the remaining_questions list (in-place)
    remaining_questions.clear()
    remaining_questions.extend(selected_questions)  # Put selected questions first
    remaining_questions.extend(non_selected_pool)   # Then non-selected from pool
    remaining_questions.extend(remaining_after_pool)  # Then the rest

    return selected_questions


In [None]:
def main():
    print("Starting LinUCB Cascade with Knapsack Selection")
    initialize_json_files()
    dataset = load_dataset(INPUT_JSON)
    if not dataset:
        print("Dataset is empty or could not be loaded. Exiting.")
        return

    feature_extractor = FeatureExtractor()
    feature_extractor.initialize(dataset)
    linucb_model = LinUCBModel(model_names=AVAILABLE_LLMS)
    question_cost_estimator = QuestionCostEstimator()
    if os.path.exists(QUESTION_COST_PATH):
        question_cost_estimator.load_state(QUESTION_COST_PATH)
    cascade = BatchBudgetCascade(feature_extractor, linucb_model)

    all_records = []
    if os.path.exists(RECORDS_PATH):
        try:
            with open(RECORDS_PATH, 'r') as f: all_records = json.load(f)
            processed_questions = {r["problem"] for r in all_records}
            dataset = [q for q in dataset if q["problem"] not in processed_questions]
            print(f"Loaded {len(all_records)} existing records. {len(dataset)} new questions to process.")
        except Exception as e:
            print(f"Could not load existing records: {e}")

    if os.path.exists(LINUCB_MODEL_PATH):
        linucb_model.load_model_state(LINUCB_MODEL_PATH)

    train_size = int(len(dataset) * 0.2)
    train_dataset, test_dataset = dataset[:train_size], dataset[train_size:]
    print(f"Split dataset into {len(train_dataset)} training and {len(test_dataset)} testing questions.")

    try:
        # --- Training Phase ---
        print("\n=== PROCESSING TRAINING SET ===")
        for idx, question in enumerate(train_dataset):
            print(f"\nProcessing Training Question #{idx + 1}/{len(train_dataset)}")
            question_record = cascade.run_cascade_single_question(question, question_cost_estimator, is_training=True)
            question_record["dataset"] = "train"
            all_records.append(question_record)
            if (idx + 1) % UPDATE_FREQUENCY == 0:
                save_records_with_backup(all_records, RECORDS_PATH)
                linucb_model.save_model_state(LINUCB_MODEL_PATH)
                question_cost_estimator.save_state(QUESTION_COST_PATH)

        # --- Testing Phase ---
        print("\n=== PROCESSING TEST SET ===")
        for idx, question in enumerate(test_dataset):
            print(f"\nProcessing Test Question #{idx + 1}/{len(test_dataset)}")
            # Using a fixed, tuned budget for reproducibility in the test phase
            test_budget = 0.00014217
            question_record = cascade.run_cascade_single_question(question, question_cost_estimator, is_training=False, initial_question_budget=test_budget)
            question_record["dataset"] = "test"
            all_records.append(question_record)
            if (idx + 1) % UPDATE_FREQUENCY == 0:
                save_records_with_backup(all_records, RECORDS_PATH)
                linucb_model.save_model_state(LINUCB_MODEL_PATH)
                question_cost_estimator.save_state(QUESTION_COST_PATH)

    except KeyboardInterrupt:
        print("\nProcessing interrupted. Saving progress...")
    finally:
        save_records_with_backup(all_records, RECORDS_PATH)
        linucb_model.save_model_state(LINUCB_MODEL_PATH)
        question_cost_estimator.save_state(QUESTION_COST_PATH)

        # Analyze and report results
        train_records = [r for r in all_records if r.get("dataset") == "train"]
        test_records = [r for r in all_records if r.get("dataset") == "test"]
        print("\n--- FINAL ANALYSIS ---")
        analyze_results(train_records, question_cost_estimator)
        analyze_results(test_records, question_cost_estimator)
        print("\nLinUCB Cascade with Knapsack Selection completed!")

if __name__ == "__main__":
    main()