In [None]:
!pip install -qU transformers sentence-transformers
!pip install scipy spacy nltk textstat
!python -m spacy download en_core_web_sm

In [None]:
import requests
import json
from requests.exceptions import RequestException
from typing import Optional, Type
from pydantic import BaseModel

class SingletonMeta(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super().__call__(*args, **kwargs)
        return cls._instances[cls]

class APIManager(metaclass=SingletonMeta):
    def __init__(self):
        self.session = requests.Session()
        
    def _api_request(self, method, url, **kwargs):
        try:
            response = self.session.request(method, url, **kwargs)
            response.raise_for_status()
            return response.json()
        except RequestException as e:
            raise Exception(f'API request error: {e}')
            
    def run(self, prompt: str, temperature: float = 0.0, api_url: str = "http://localhost:8000",
            model_schema: Optional[Type[BaseModel]] = None): 
        """
        Send completion request to vLLM API server
        
        Parameters:
        - prompt: text prompt to generate from
        - temperature: controls randomness (0.0 = deterministic)
        - api_url: URL of the vLLM API server (on Kaggle use ngrok to get the reverse proxy URL)
        - output_model: optional pydantic model to validate the JSON output (e.g., ScoreModel)
        """
        request_payload = {
            'prompt': prompt,
            'max_tokens': 2048,
            'temperature': temperature,
            'top_p': 0.9,
            'min_p': 0.0,
            'top_k': 0,
            'typical_p': 1.0,
            'tfs': 1.0,
            'top_a': 0.0,
            'repetition_penalty': 1.0,
            'min_new_tokens': 200,
            'no_repeat_ngram_size': 0,
            'num_beams': 1,
            'seed': -1,
            'add_bos_token': True,
            'truncation_length': 8192,
            'ban_eos_token': False,
            'skip_special_tokens': True,
        }

        # Add output_schema only if model_schema is provided
        if model_schema:
            request_payload['output_schema'] = model_schema.model_json_schema()
        
        response = self._api_request('POST', f'{api_url}/v1/completions', json=request_payload)
        generated_text = response['choices'][0]['text']

        if model_schema:
            try:
                # When using model_schema, the response is guaranteed to be valid JSON
                return json.loads(generated_text)
            except json.JSONDecodeError as e:
                # This should only happen if there's an unexpected issue with the API
                raise ValueError(f"Failed to parse response as JSON: {e}. Response: {generated_text[:100]}...")
        else:
            return generated_text

api_manager = APIManager()

In [None]:
import os
import math
import json
import numpy as np
import re
import torch
import matplotlib.pyplot as plt
import statistics
from collections import deque, defaultdict
from difflib import SequenceMatcher
from IPython.display import display
import networkx as nx
from pydantic import BaseModel, Field
from typing import Union, List
%matplotlib inline

# For Metrics
from collections import Counter
from itertools import tee
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import cosine
import spacy
import nltk
nltk.download('punkt') # Not needed on Kaggle
from nltk.tokenize import sent_tokenize
from textstat import flesch_reading_ease
from textblob import TextBlob
from sentence_transformers import SentenceTransformer

## Pydantic models for guided decoding with vLLM
class ScoreModel(BaseModel):
    score: int = Field(..., ge=0, le=100, description="Score between 0 and 100")

class FeedbackModel(BaseModel):
    issues: Union[str, List[str]]

class LLMInterface:
    INSTRUCTION_SET = 'llama3'

    @classmethod
    def ask_llm_remote(cls, prompt, model_schema=None):
        topic = PromptManager.get_topic()
        temperature = 0.5
        if topic == 'creative':
            temperature = 1.0
        try:
            full_response = api_manager.run(prompt, temperature=temperature, model_schema=model_schema)
        except Exception as e:
            print(f"Error calling LLM API: {e}")
            return None

        response_delimiters = {
            'alpaca': "### Response:\n",
            'vicuna': "### ASSISTANT:",
            'llama3': "assistant\n\n",
            'chatml': "<|im_start|>assistant\n"
        }
        response_delimiter = response_delimiters[cls.INSTRUCTION_SET]
        response = full_response.split(response_delimiter)[-1].strip() if response_delimiter in full_response else full_response
        return response
    
    @classmethod
    def replace_placeholders(cls, prompt_parts):
        """ Replace placeholders in the prompt with the selected instruction set """
        instruction_sets = {
            "alpaca": {
                "instruction": "### Instruction:\n",
                "input": "### Input:\n",
                "response": "### Response:\n"
            },
            "vicuna": {
                "instruction": "### USER:\n",
                "response": "### ASSISTANT:\n"
                # No "input" for Vicuna
            },
            "llama3": {
                "instruction": "<|start_header_id|>system<|end_header_id|>\n\n",
                "input": "<|start_header_id|>user<|end_header_id|>\n\n",
                "response": "<|start_header_id|>assistant<|end_header_id|>\n\n",
                "end": "<|eot_id|>\n"
            },
            "chatml": {
                "instruction": "<|im_start|>system\n",
                "input": "<|im_start|>user\n",
                "response": "<|im_start|>assistant\n",
                "end": "<|im_end|>\n"
            }
        }

        selected_set = instruction_sets.get(cls.INSTRUCTION_SET)
        if not selected_set:
            raise ValueError(f"Instruction set {cls.INSTRUCTION_SET} is not defined.")

        if isinstance(prompt_parts, tuple):
            prompt_parts = [prompt_parts]

        final_prompt = ""
        prepend_response = ""
        for text, part in prompt_parts:
            if part == "response":
                prepend_response = text
            if part in selected_set:
                final_prompt += f"{selected_set[part]}{text}"
                if "end" in selected_set:
                    final_prompt += selected_set['end']
            elif part == "input" and "input" not in selected_set:
                # Merge input part with the instruction for sets that do not support "input"
                final_prompt += f"{text}"
                if "end" in selected_set:
                    final_prompt += selected_set['end']

        final_prompt += f"{selected_set['response']}{prepend_response}"
        return final_prompt

    @classmethod
    def set_instruction_set(cls, instruction_set):
        if instruction_set in ["alpaca", "vicuna", "llama3", "chatml"]:
            cls.INSTRUCTION_SET = instruction_set
        else:
            raise ValueError(f"Unsupported instruction set: {instruction_set}")

            
class Node:
    node_counter = 0
    
    def __init__(self, query, answer, feedback=None, refined_answer=None, parent=None):
        self.id = self.generate_id(parent)
        self.query = query
        self.answer = answer
        self.feedback = feedback
        self.refined_answer = refined_answer
        self.parent = parent
        self.children = set()
        self.visits = 0
        self.rewards = deque(maxlen=100)  # Store only the last 100 rewards
        self.reward_sum = 0
        self.reward_sum_squared = 0
        self.Q_value = 0
        self.previous_Q_value = 0
        self.max_children = 5
        self.importance_weight = 1.0
        self.depth = 0 if parent is None else parent.depth + 1

    @classmethod
    def generate_id(cls, parent):
        if parent is None:
            cls.node_counter += 1
            return str(cls.node_counter - 1)
        else:
            parent_id = parent.id
            child_number = len(parent.children) + 1
            return f"{parent_id}.{child_number}"

    def __str__(self):
        return f"Node_{self.id}"        
        
        
    def increment_visits(self):
        self.visits += 1

    def update_Q(self):
        # Update previous Q-value
        self.previous_Q_value = self.Q_value

        # Formula from the paper based on the minimum and average of the rewards
        if self.rewards:
            min_reward = min(self.rewards)
            avg_reward = sum(self.rewards) / len(self.rewards)
            self.Q_value = 0.5 * (min_reward + avg_reward)

        else:
            # If there are no rewards, and no children, default to Q-value of 0
            if not self.children:
                self.Q_value = 0
            else:
                # Optionally include children's Q-values
                self.Q_value = max(child.Q_value for child in self.children)

        # Ensure Q-value stays within the range
        self.Q_value = max(0, min(100, self.Q_value))

    def add_reward(self, reward):
        if len(self.rewards) == self.rewards.maxlen:
            old_reward = self.rewards[0]
            self.reward_sum -= old_reward
            self.reward_sum_squared -= old_reward ** 2
        
        self.rewards.append(reward)
        self.reward_sum += reward
        self.reward_sum_squared += reward ** 2 

    def get_ancestors(self):
        """ Collect all ancestor nodes for a given node. """
        ancestors = []
        current = self.parent
        while current:
            ancestors.append(current)
            current = current.parent
        return ancestors    

    def analyze_historical_performance(self):
        """ Analyze the feedback of ancestors to identify common issues using fuzzy matching. """
        ancestors = self.get_ancestors()[:5]  # Limit to 5 most recent ancestors
        issues = defaultdict(float)

        def normalize_text(text):
            return re.sub(r'[^\w\s]', '', text.lower())

        def fuzzy_match(a, b, threshold=0.8):
            return SequenceMatcher(None, normalize_text(a), normalize_text(b)).ratio() > threshold

        for i, ancestor in enumerate(ancestors):
            if ancestor.feedback and 'issues' in ancestor.feedback:
                weight = 1 / (i + 1)  # More recent ancestors have higher weight
                for issue in ancestor.feedback['issues']:
                    issue_text = issue if isinstance(issue, str) else str(issue)
                    issue_text = normalize_text(issue_text)
                    matched = False
                    for existing_issue in issues:
                        if fuzzy_match(issue_text, existing_issue):
                            issues[existing_issue] += weight
                            matched = True
                            break
                    if not matched:
                        issues[issue_text] = weight

        if issues:
            common_issues = sorted(issues.items(), key=lambda x: x[1], reverse=True)[:3]
            # Format in a structured way
            formatted_issues = "\n".join([f"- {issue}" for issue, _ in common_issues])
            return f"[Common Issues]\n{formatted_issues}"
        return None

    def generate_feedback(self):
        self.feedback = PromptManager.generate_feedback(self.query, self.answer)

    def refine_answer(self):
        historical_insights = self.analyze_historical_performance()
        historical_context = ""
        if historical_insights:
            historical_context = f"Note: Previous answers often struggled with:\n{historical_insights}."

        if self.parent:
            base_answer = self.parent.refined_answer if self.parent.refined_answer else self.parent.answer
        else:
            base_answer = self.answer if self.answer else "No previous answer available."
            
        refined_answer = PromptManager.refine_answer(self.query, base_answer, self.feedback, historical_context)
        if refined_answer is not None:
            self.refined_answer = refined_answer

    def self_evaluate(self, scoring_method, evaluator):
        """Evaluate the node using the provided evaluator instance."""
        return evaluator.evaluate_node(node=self, scoring_method=scoring_method)
            
    def create_child(self):
        """Create a new child node."""
        new_node = Node(
            query=self.query,
            answer=self.refined_answer if self.refined_answer else self.answer,
            feedback=None,
            refined_answer=None,
            parent=self
        )
        self.children.add(new_node)
        return new_node

## Preprocess the prompt before sending it to the LLM
class PromptManager:
    topic = "creative"  # Default topic
    terminology = {
        "math": {"task": "query", "response": "Answer"},
        "creative": {"task": "Task", "response": "Story"}
    }

    @classmethod
    def set_topic(cls, new_topic):
        if new_topic in cls.terminology:
            cls.topic = new_topic
        else:
            raise ValueError(f"Unsupported topic: {new_topic}")

    @classmethod
    def get_topic(cls):
        return cls.topic      
            
    @classmethod
    def get_terminology(cls):
        return cls.terminology[cls.topic]
    
    @classmethod
    def generate_feedback_prompt(cls, task, response):
        terms = cls.get_terminology()
        feedback_context = "Review the response critically and identify areas for improvement."
        prompt_parts = [
            (f"{terms['task']}: {task}\n{terms['response']}: {response}\n", "instruction"),
            (feedback_context + "\nProvide feedback in the following JSON format: {\"issues\": [\"issue1\", \"issue2\", ...]}\n", "input")
        ]
        return LLMInterface.replace_placeholders(prompt_parts).lstrip()
    
    @staticmethod
    def format_json_with_tags(json_data):
        """Format JSON data with markup tags based on the keys."""
        if not json_data:
            return ""
            
        formatted_text = ""
        for key, value in json_data.items():
            formatted_text += f"[{key.replace('_', ' ').title()}]\n"
            if isinstance(value, list):
                formatted_text += "\n".join([f"- {item}" for item in value]) + "\n"
            else:
                formatted_text += f"{value}\n"
        return formatted_text.strip()
    
    @classmethod
    def generate_refine_prompt(cls, task, base_response, feedback, historical_context=""):
        terms = cls.get_terminology()
        topic_instructions = {
            "math": "Based on the previous response, solve the query. Start by reasoning step by step within <think></think>, then verify with <verify></verify>, finally give the final answer <answer></answer>.",
            "creative": "Based on the previous rejected draft, improve your work. Start by reasoning step by step within <think></think> then output the full story <story></story>."
        }
        
        instruction = topic_instructions.get(cls.topic, "No specific instructions available for the chosen topic.")
        
        # Format feedback with markup tags instead of using raw JSON
        formatted_feedback = cls.format_json_with_tags(feedback)
        
        # historical context is put at the end of the context to emphasis the repeated changes to make
        if historical_context:
            prompt = [
                (f"{terms['task']}: {task}\n", "instruction"),
                (f"{terms['response']}:\n{base_response}\n[FEEDBACK]:\n{formatted_feedback}\n{historical_context}\n{instruction}\n", "input")
            ]
        else:
            prompt = [
                (f"{terms['task']}: {task}\n", "instruction"),
                (f"{terms['response']}:\n{base_response}\n[FEEDBACK]:\n{formatted_feedback}\n{instruction}\n", "input")
            ]
        return LLMInterface.replace_placeholders(prompt).lstrip()

    @staticmethod
    def extract_structured_data(response, expected_key, match_all=False):
        """
        Parse JSON response using Pydantic models based on the expected key
        """
        if isinstance(response, dict):
            # If response is already a dict (from using model_schema), extract directly
            return response.get(expected_key)
            
        if not isinstance(response, str):
            response = str(response)
        
        # Try to find JSON content
        json_pattern = r'\{[\s\S]+\}'
        match = re.search(json_pattern, response)
        
        if not match:
            return None
            
        json_str = match.group()
        
        try:
            # Use different Pydantic models based on the expected key
            if expected_key == 'issues':
                parsed_data = FeedbackModel.model_validate_json(json_str)
                return parsed_data.issues
            elif expected_key == 'score':
                parsed_data = ScoreModel.model_validate_json(json_str)
                return parsed_data.score
            else:
                # Fall back to basic JSON parsing for other keys
                data = json.loads(json_str)
                return data.get(expected_key)
        except Exception:
            # If parsing fails, return None
            return None

    @classmethod
    def generate_feedback(cls, task, response):
        feedback_prompt = cls.generate_feedback_prompt(task, response)
        print("--------------\nFeedback prompt\n--------------\n", feedback_prompt)
    
        attempts = 0
        valid_feedback = False
        feedback_answer = {}
    
        while not valid_feedback and attempts < 5:
            raw_feedback = LLMInterface.ask_llm_remote(feedback_prompt, FeedbackModel)
            print(f"Feedback (Attempt {attempts + 1}/5):\n{raw_feedback}")
    
            # When using FeedbackModel schema, raw_feedback is guaranteed to be valid
            # We can directly use it without additional validation
            if isinstance(raw_feedback, dict) and 'issues' in raw_feedback:
                feedback_answer = raw_feedback
                valid_feedback = True
            else:
                print("Unexpected feedback format. Retrying...")
    
            attempts += 1
    
        if not valid_feedback:
            print("Failed to obtain valid structured feedback after multiple attempts. Using default feedback.")
            feedback_answer = {'issues': ['Unable to generate specific feedback.']}
    
        print("--------------\nFeedback answer\n--------------\n", feedback_answer)
        return feedback_answer

    @classmethod
    def refine_answer(cls, task, base_response, feedback, historical_context=""):
        refine_prompt = cls.generate_refine_prompt(task, base_response, feedback, historical_context)
        print("--------------\nRefine prompting\n--------------\n", refine_prompt)
        refined_response = LLMInterface.ask_llm_remote(refine_prompt).lstrip()
        attempts = 0
    
        while attempts < 3:
            if cls.topic == "creative":
                # Look for think tags with possible surrounding characters
                think_pattern = re.compile(r'.*?<think>.*?</think>.*?', re.DOTALL)
                
                # Look for story tags with possible surrounding characters
                story_start_pattern = re.compile(r'.*?<story>(.*)', re.DOTALL)
                story_end_pattern = re.compile(r'(.*?)</story>.*', re.DOTALL)
                
                has_think = bool(think_pattern.match(refined_response))
                story_start_match = story_start_pattern.match(refined_response)
                
                if has_think and story_start_match:
                    # We found a story start, now find where it ends
                    partial_content = story_start_match.group(1)
                    story_end_match = story_end_pattern.match(partial_content)
                    
                    if story_end_match:
                        # We have both tags, extract the content
                        story_content = story_end_match.group(1).lstrip()
                        refined_response = story_content
                        break
            
            elif cls.topic == "maths":
                # Look for think, verify, and answer tags with possible surrounding characters
                think_pattern = re.compile(r'.*?<think>.*?</think>.*?', re.DOTALL)
                verify_pattern = re.compile(r'.*?<verify>.*?</verify>.*?', re.DOTALL)
                
                # Get answer content specifically
                answer_start_pattern = re.compile(r'.*?<answer>(.*)', re.DOTALL)
                answer_end_pattern = re.compile(r'(.*?)</answer>.*', re.DOTALL)
                
                has_think = bool(think_pattern.match(refined_response))
                has_verify = bool(verify_pattern.match(refined_response))
                answer_start_match = answer_start_pattern.match(refined_response)
                
                if has_think and has_verify and answer_start_match:
                    # We found an answer start, now find where it ends
                    partial_content = answer_start_match.group(1)
                    answer_end_match = answer_end_pattern.match(partial_content)
                    
                    if answer_end_match:
                        # We have all required tags, extract the content
                        answer_content = answer_end_match.group(1).lstrip()
                        refined_response = answer_content
                        break
            
            # If we get here, required markers weren't found
            attempts += 1
            print(f"Attempt n°{attempts} - Required markup not found. Resending reprompting for the refine answer")
            refined_response = LLMInterface.ask_llm_remote(refine_prompt).lstrip()
            print("--------------\nRefined response\n--------------\n", refined_response)
        
        print("--------------\nRefined response\n--------------\n", refined_response)
        return refined_response

    @classmethod
    def generate_evaluation_prompt(cls, task, refined_response):
        terms = cls.get_terminology()
        prompt = [
            (f"{terms['task']}:{task}\n{terms['response']}:{refined_response}\n", "instruction"),
            ("As an expert, analyze the output critically, then give it a score. The score must be a number between 0 and 100. JSON format: {\"score\": 0}\n", "input")
        ]
        return LLMInterface.replace_placeholders(prompt)
        
## Evaluate the answer through different metrics
class Evaluator:
    def __init__(self, topic="creative"):
        self.topic = topic
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self._gpt2_model = None
        self._gpt2_tokenizer = None
        self._tfidf = None
        self._nlp = None
        self._sentence_model = None

        # Set metrics configuration based on topic
        self.configure_metrics_for_topic()

    def configure_metrics_for_topic(self):
        """Configure metrics weights and flags based on the topic"""
        # Default configuration
        # Computation intensive : llm_autoeval > coherence > perplexity > diversity > entity_consistency > sentiment_consistency > readability
        self.metrics_config = {
            'perplexity': True,
            'readability': True,
            'coherence': True,
            'entity_consistency': True,
            'sentiment_consistency': True,
            'diversity': True,
            'llm_autoeval': True
        }
        
        # Default weights
        self.metric_weights = {
            'perplexity': 0.05,
            'readability': 0.10,
            'coherence': 0.25,
            'entity_consistency': 0.20,
            'sentiment_consistency': 0.05,
            'diversity': 0.15,
            'llm_autoeval': 0.20
        }
        
        # Adjust weights and enable/disable metrics based on topic
        if self.topic == "math":
            # For math, prioritize accuracy and clarity
            self.metrics_config.update({
                'perplexity': True,
                'readability': True,
                'coherence': True,
                'entity_consistency': True,
                'sentiment_consistency': False,  # Less relevant for math
                'diversity': False,             # Less relevant for math
                'llm_autoeval': True
            })
            
            self.metric_weights.update({
                'perplexity': 0.05,            
                'readability': 0.15,           # Higher for math (clarity)
                'coherence': 0.30,             # Higher for logical flow
                'entity_consistency': 0.20,    # Renamed to term_consistency for math
                'llm_autoeval': 0.30           # Higher weight on LLM evaluation
            })
            
        elif self.topic == "creative":
            # For creative, prioritize coherence, diversity, and readability
            self.metrics_config.update({
                'perplexity': True,
                'readability': True,
                'coherence': True,
                'entity_consistency': True,
                'sentiment_consistency': True,  # Important for creative
                'diversity': True,             # Important for creative
                'llm_autoeval': True
            })
            
            self.metric_weights.update({
                'perplexity': 0.05,
                'readability': 0.15,           # Higher for creative writing
                'coherence': 0.20,             # Still important but less than math
                'entity_consistency': 0.15,    # Slightly lower than default
                'sentiment_consistency': 0.10,  # Higher for creative writing
                'diversity': 0.20,             # Higher for creative writing
                'llm_autoeval': 0.15           # Moderate weight on LLM evaluation
            })

    @property
    def gpt2_model(self):
        if self._gpt2_model is None:
            print("Loading GPT-2 model for perplexity calculation...")
            self._gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
        return self._gpt2_model

    @property
    def gpt2_tokenizer(self):
        if self._gpt2_tokenizer is None:
            self._gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        return self._gpt2_tokenizer

    @property
    def tfidf(self):
        if self._tfidf is None:
            self._tfidf = TfidfVectorizer()
        return self._tfidf

    @property
    def nlp(self):
        if self._nlp is None:
            print("Loading spaCy model...")
            self._nlp = spacy.load("en_core_web_sm")
        return self._nlp

    @property
    def sentence_model(self):
        if self._sentence_model is None:
            print("Loading SentenceTransformer model for coherence...")
            self._sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
        return self._sentence_model
    
    def calculate_perplexity(self, text):
        max_length = 1024
        sentences = sent_tokenize(text)
        chunks = []
        current_chunk = []

        for sentence in sentences:
            sentence_tokens = self.gpt2_tokenizer.encode(sentence, add_special_tokens=False)
            if len(current_chunk) + len(sentence_tokens) > max_length:
                chunks.append(current_chunk)
                current_chunk = sentence_tokens
            else:
                current_chunk.extend(sentence_tokens)

        if current_chunk:
            chunks.append(current_chunk)

        total_loss = 0
        total_length = 0

        for chunk in chunks:
            inputs = torch.tensor(chunk, dtype=torch.long).unsqueeze(0).to(self.device)
            with torch.no_grad():
                outputs = self.gpt2_model(inputs, labels=inputs)
            total_loss += outputs.loss.item() * len(chunk)
            total_length += len(chunk)

        average_loss = total_loss / total_length
        raw_perplexity = np.exp(average_loss)

        # Normalize perplexity
        max_perplexity = 100
        log_perplexity = np.log1p(raw_perplexity)  # Log transformation
        normalized_perplexity = 100 * (1 - log_perplexity / np.log1p(max_perplexity))  # Normalize to 0-100

        return max(0, min(100, normalized_perplexity))  # Ensure it stays within [0, 100]

    def calculate_diversity(self, text):        
        doc = self.nlp(text.lower())
        tokens = [token.text for token in doc if not token.is_punct and not token.is_space]
        
        if not tokens:
            return 0
        
        # Distinct unigrams
        unigrams = set(tokens)
        unigram_diversity = len(unigrams) / len(tokens) if tokens else 0
        
        # Distinct bigrams
        bigrams = list(self.ngrams(tokens, 2))
        distinct_bigrams = set(bigrams)
        bigram_diversity = len(distinct_bigrams) / len(bigrams) if bigrams else 0
        
        # Distinct trigrams
        trigrams = list(self.ngrams(tokens, 3))
        distinct_trigrams = set(trigrams)
        trigram_diversity = len(distinct_trigrams) / len(trigrams) if trigrams else 0
        
        # Type-Token Ratio (TTR)
        ttr = len(set(tokens)) / len(tokens) if tokens else 0
        
        # Vocabulary richness (using hapax legomena)
        vocab = Counter(tokens)
        hapax = len([word for word, freq in vocab.items() if freq == 1])
        hapax_ratio = hapax / len(tokens) if tokens else 0
        
        # Shannon entropy for unigrams
        token_counts = Counter(tokens)
        token_probs = [count / len(tokens) for count in token_counts.values()]
        unigram_entropy = -sum(p * math.log2(p) for p in token_probs)
        
        # Shannon entropy for bigrams
        if bigrams:
            bigram_counts = Counter(bigrams)
            bigram_probs = [count / len(bigrams) for count in bigram_counts.values()]
            bigram_entropy = -sum(p * math.log2(p) for p in bigram_probs)
        else:
            bigram_entropy = 0
        
        # Shannon entropy for trigrams
        if trigrams:
            trigram_counts = Counter(trigrams)
            trigram_probs = [count / len(trigrams) for count in trigram_counts.values()]
            trigram_entropy = -sum(p * math.log2(p) for p in trigram_probs)
        else:
            trigram_entropy = 0
        
        # Max possible entropy is log2(num_unique_tokens)
        max_unigram_entropy = math.log2(len(unigrams)) if unigrams else 0
        max_bigram_entropy = math.log2(len(distinct_bigrams)) if distinct_bigrams else 0
        max_trigram_entropy = math.log2(len(distinct_trigrams)) if distinct_trigrams else 0
        
        # Normalized entropy (between 0 and 1)
        norm_unigram_entropy = unigram_entropy / max_unigram_entropy if max_unigram_entropy else 0
        norm_bigram_entropy = bigram_entropy / max_bigram_entropy if max_bigram_entropy else 0
        norm_trigram_entropy = trigram_entropy / max_trigram_entropy if max_trigram_entropy else 0
        
        # Calculate the diversity score including both original metrics and entropy
        diversity_score = (unigram_diversity + bigram_diversity + trigram_diversity + 
                            ttr + hapax_ratio + 
                            norm_unigram_entropy + norm_bigram_entropy + norm_trigram_entropy) / 8
        
        return diversity_score

    @staticmethod
    def ngrams(tokens, n):
        iterables = tee(tokens, n)
        for i, sub_iterable in enumerate(iterables):
            for _ in range(i):
                next(sub_iterable, None)
        return zip(*iterables)
    
    def calculate_readability(self, text):
        raw_score = flesch_reading_ease(text)
        # Clamp the score to a reasonable range before normalizing
        normalized_score = max(0, min(100, raw_score))
        return normalized_score
        
    def calculate_coherence(self, text):
        sentences = sent_tokenize(text)

        # Fall back to original method for very short texts
        if len(sentences) < 3:
            return self.calculate_coherence_original(text)

        # Use sentence embeddings instead of TF-IDF
        embeddings = self.sentence_model.encode(sentences)

        # Calculate pairwise cosine similarities
        similarities = cosine_similarity(embeddings)

        # Calculate coherence scores
        coherence_scores = []
        for i in range(len(sentences)):
            # Compare each sentence with all others, excluding self-comparison
            sentence_scores = similarities[i, [j for j in range(len(sentences)) if j != i]]
            avg_score = np.mean(sentence_scores)
            coherence_scores.append(avg_score)

        # Calculate overall coherence
        overall_coherence = np.mean(coherence_scores)

        # Check for NaN or infinite values
        if np.isnan(overall_coherence) or np.isinf(overall_coherence):
            print("Warning: Invalid overall coherence. Returning default score.")
            return 50  # Return a middle score as a fallback

        # Ensure overall_coherence is within [-1, 1] range
        overall_coherence = max(-1, min(1, overall_coherence))

        # Normalize to 0-1 range
        normalized_coherence = (overall_coherence + 1) / 2

        return normalized_coherence

    def calculate_coherence_original(self, text):
        # Keep the original method for fallback and comparison
        sentences = sent_tokenize(text)
        sentence_vectors = self.tfidf.fit_transform(sentences)
        coherence_scores = []
        for i in range(len(sentences) - 1):
            vec1 = sentence_vectors[i].toarray().flatten()
            vec2 = sentence_vectors[i+1].toarray().flatten()

            if np.all(vec1 == 0) or np.all(vec2 == 0):
                similarity = 0
            else:
                similarity = 1 - cosine(vec1, vec2)

            coherence_scores.append(similarity)
        return np.mean(coherence_scores) if coherence_scores else 0

    def calculate_entity_consistency(self, text):
        doc = self.nlp(text)
        entities = [ent.text.lower() for ent in doc.ents]
        unique_entities = set(entities)
        consistency = len(unique_entities) / len(entities) if entities else 0
        return 1 - consistency

    def calculate_sentiment_consistency(self, text):
        sentences = sent_tokenize(text)
        sentiments = [TextBlob(sentence).sentiment.polarity for sentence in sentences]
        return 1 - np.std(sentiments)

    def evaluate_node(self, node, base_evaluations=3, min_evaluations=3, attempts_per_evaluation=3, scoring_method='lowest'):
        """ Evaluate the node's refined answer using configurable objective quality metrics """
        num_evaluations = max(min_evaluations, base_evaluations - node.depth)
        valid_scores = []
        weights = []

        # Calculate metric scores once
        metric_scores = {}
        
        # Only consider weights of active metrics
        active_weights = {metric: weight 
                         for metric, weight in self.metric_weights.items() 
                         if self.metrics_config.get(metric, False)}
        total_weight = sum(active_weights.values())
    
        if self.metrics_config.get('perplexity', False):
            perplexity = self.calculate_perplexity(node.refined_answer)
            metric_scores['perplexity'] = 100 * (1 / (1 + perplexity))

        if self.metrics_config.get('readability', False):
            metric_scores['readability'] = self.calculate_readability(node.refined_answer)

        if self.metrics_config.get('coherence', False):
            metric_scores['coherence'] = self.calculate_coherence(node.refined_answer) * 100

        if self.metrics_config.get('entity_consistency', False):
            metric_scores['entity_consistency'] = self.calculate_entity_consistency(node.refined_answer) * 100

        if self.metrics_config.get('sentiment_consistency', False):
            metric_scores['sentiment_consistency'] = self.calculate_sentiment_consistency(node.refined_answer) * 100

        if self.metrics_config.get('diversity', False):
            metric_scores['diversity'] = self.calculate_diversity(node.refined_answer) * 100

        # LLM-based evaluations
        llm_final_score = None
        if self.metrics_config.get('llm_autoeval', False):
            for evaluation in range(num_evaluations):
                llm_score = self.get_llm_score(node, attempts_per_evaluation)
                if llm_score is not None:
                    valid_scores.append(llm_score)
                    weights.append(num_evaluations - evaluation)  # Weight based on evaluation number
                print(f"Evaluation {evaluation + 1}: LLM Score = {llm_score}")

            if valid_scores:
                if scoring_method == 'highest':
                    llm_final_score = max(valid_scores)
                elif scoring_method == 'lowest':
                    llm_final_score = min(valid_scores)
                elif scoring_method == 'average':
                    llm_final_score = sum(valid_scores) / len(valid_scores)
                elif scoring_method == 'weighted_average':
                    llm_final_score = sum(score * weight for score, weight in zip(valid_scores, weights)) / sum(weights)
                elif scoring_method == 'median':
                    llm_final_score = statistics.median(valid_scores)
                else:
                    raise ValueError(f"Unknown scoring method: {scoring_method}")
    
        # Combine all scores using only active metric weights
        all_scores = [score * active_weights[metric] for metric, score in metric_scores.items()]
        if llm_final_score is not None:
            all_scores.append(llm_final_score * active_weights['llm_autoeval'])
    
        if all_scores:
            final_score = sum(all_scores) / total_weight

            if final_score > 95:
                final_score -= 5  # Curb excessive scores

            node.add_reward(final_score)
            score_json = json.dumps({'score': final_score, 'metrics': metric_scores, 'llm_score': llm_final_score})
            print(f"Final score for Node_{node.id}: {score_json}")
            return score_json
        else:
            default_score = 0  # Maintaining the original default score
            node.add_reward(default_score)
            score_json = json.dumps({'score': default_score})
            print(f"Failed to obtain any valid scores for Node_{node.id}. Default score: {score_json}")
            return score_json

    def get_llm_score(self, node, attempts_per_evaluation):
        for attempt in range(attempts_per_evaluation):
            evaluation_prompt = PromptManager.generate_evaluation_prompt(node.query, node.refined_answer)
            response = LLMInterface.ask_llm_remote(evaluation_prompt, ScoreModel)
            print("Auto-eval:", response)
            
            # When using ScoreModel, the response is guaranteed to have a 'score' field
            if isinstance(response, dict) and 'score' in response:
                score = response['score']
                if 0 <= score <= 100:
                    return score
        return None
    
class ImportanceSampler:
    def __init__(self):
        pass

    def update_importance_weights(self, node):
        """
        Recursively traverse the tree from the given node and update the
        importance weights of its children based on their Q-values.
        """
        if not node.children:
            return

        # Collect the Q-values of all direct children of the current node.
        child_q_values = [child.Q_value for child in node.children]
        total_q_value = sum(child_q_values)

        # Update the importance weight for each child.
        for child in node.children:
            if total_q_value > 0:
                # Calculate the weight as the child's share of the total Q-value.
                child.importance_weight = child.Q_value / total_q_value
            else:
                # If all children have a Q-value of 0, they have equal importance.
                child.importance_weight = 1.0 / len(node.children)

            # Recursively call the update function for the child node to update its own children.
            self.update_importance_weights(child)    

    
# Use matplot to display the tree graph + nodes' Q_value
class MCTSVisualizer:
    def __init__(self):
        self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(20, 10))
        self.all_nodes = []
        self.iterations = []
        
    def update(self, iteration, tree):
        self.iterations.append(iteration)
        self.all_nodes = self._get_all_nodes(tree.root)

        self.ax1.clear()
        self.ax2.clear()

        self._plot_q_values()
        self._plot_tree(tree)
        
        plt.tight_layout()

        display(self.fig)

    def _get_all_nodes(self, node):
        nodes = [node]
        for child in node.children:
            nodes.extend(self._get_all_nodes(child))
        return nodes

    def _plot_q_values(self):
        for node in self.all_nodes:
            self.ax1.scatter(self.iterations[-1], node.Q_value, alpha=0.5)
            self.ax1.annotate(f"Node {node.id}", (self.iterations[-1], node.Q_value), 
                              xytext=(5, 5), textcoords='offset points', fontsize=8)
        self.ax1.set_title("Q-values of all nodes over iterations")
        self.ax1.set_xlabel("Iteration")
        self.ax1.set_ylabel("Q-value")

    def _plot_tree(self, tree):
        G = nx.Graph()
        self._build_graph(tree.root, G)
        pos = nx.spring_layout(G)
        nx.draw(G, pos, ax=self.ax2, with_labels=False, node_size=30)
        self.ax2.set_title("MCTS Tree Structure")

        # Add node labels
        labels = {node: f"{node.id}\n{node.Q_value:.2f}" for node in G.nodes()}
        nx.draw_networkx_labels(G, pos, labels, font_size=8, ax=self.ax2)

    def _build_graph(self, node, G):
        G.add_node(node)
        for child in node.children:
            G.add_edge(node, child)
            self._build_graph(child, G)
    
    def save_figure(self, filename="mcts_visualization.png"):
        self.fig.savefig(filename, bbox_inches='tight', dpi=300)
        print(f"Figure saved as {filename}")
            
class MCTSTree:
    def __init__(self, query, min_depth1_nodes=3, importance_sampling=False):
        self.root = self.initialize_root(query, use_dummy=False)
        self.min_depth1_nodes = min_depth1_nodes
        self.use_importance_sampling = importance_sampling
        self.importance_sampler = ImportanceSampler() if importance_sampling else None
        
    def initialize_root(self, query, use_dummy=False):
        """ Initialize the tree with either a dummy or a model-generated answer at the root """
        answer = ""
        if use_dummy:
            answer = "I don't know how to solve this query."
        else:
            query_f = (f"{query}", "instruction")
            query_formatted = LLMInterface.replace_placeholders(query_f)
            answer = LLMInterface.ask_llm_remote(query_formatted)
        print("--------------\nRoot answer\n--------------\n", answer)
        root_node = Node(query, answer)
        return root_node                   

    def select_node(self, current_node):
        """Select the most promising node to explore next."""
        if current_node == self.root and len(current_node.children) < self.min_depth1_nodes:
            print(f"Expanding root node: Node_{current_node.id} (depth: {current_node.depth})")
            print(f"Q-value = {current_node.Q_value:.4f}, Previous Q-value = {current_node.previous_Q_value:.4f}, Visits = {current_node.visits}")
            print(f"Number of children: {len(current_node.children)}")
            print(f"Max children allowed: {current_node.max_children}")
            return current_node  # Always expand root until min_depth1_nodes is reached

        while current_node.children:
            if len(current_node.children) < current_node.max_children:
                unexpanded_child = self.select_candidate_nodes([current_node] + list(current_node.children))
                if unexpanded_child == current_node:
                    return current_node  # Expand this node
                else:
                    current_node = unexpanded_child  # Move to the selected child
            else:
                current_node = self.select_candidate_nodes(list(current_node.children))
        return current_node

    def select_candidate_nodes(self, nodes, exploration_weight=1.5):
        if not nodes:
            return None

        total_visits = sum(node.visits for node in nodes)
        ucb_values = []
        for node in nodes:
            if node.visits == 0:
                ucb = float('inf')  # Ensure unvisited nodes are selected first
            else:
                exploitation = node.Q_value
                parent_visits = node.parent.visits if node.parent else total_visits
                exploration = exploration_weight * math.sqrt(math.log(parent_visits + 1) / (node.visits))
                if self.use_importance_sampling:
                    importance_weight = node.importance_weight
                    ucb = (exploitation + exploration) * importance_weight
                else:
                    ucb = exploitation + exploration
            ucb_values.append((node, ucb))
            print(f"Node_{node.id} (depth: {node.depth}): UCB value = {ucb:.4f} (Q-value = {node.Q_value:.4f}, Previous Q-value = {node.previous_Q_value:.4f}, Visits = {node.visits})")

        selected_node, selected_ucb = max(ucb_values, key=lambda x: x[1])
        print(f"Selected Node_{selected_node.id} with UCB value {selected_ucb:.4f}")
        print(f"Number of children of the selected node: {len(selected_node.children)}")
        print(f"Max children allowed of the selected node: {selected_node.max_children}")
        return selected_node
    
    def expand_node(self, node):
        """Expand the selected node by creating a new child."""
        print("Creating new child node")
        return node.create_child()

    def simulate(self, node, evaluator):
        """Simulate the node's performance."""
        node.generate_feedback()
        node.refine_answer()
        node.self_evaluate('lowest', evaluator)

    def termination_check(self, node, current_iteration, max_iterations=8, threshold=80):
        if current_iteration >= max_iterations:
            return True, "Maximum iterations reached"
        elif node.Q_value >= threshold:  # Threshold for high-quality solution
            return True, "[Early Stopping] High quality solution found"
        return False, ""

    def backpropagate(self, node):
        while node is not None:
            node.increment_visits()
            old_Q = node.Q_value

            # Initially update Q-value based on its own rewards (handled in update_Q)
            node.update_Q()

            # Now adjust Q-value based on the formula Q'(a)
            if node.children:
                max_child_Q = max(child.Q_value for child in node.children)
                node.Q_value = 0.5 * (node.Q_value + max_child_Q)

            # If Q-value barely changed, stop propagation
            if abs(old_Q - node.Q_value) < 1e-3:
                break
            node = node.parent

    def iterate_tree(self, node):
        """Iterate through all nodes in the tree."""
        yield node
        for child in node.children:
            yield from self.iterate_tree(child)         
               
    def get_best_path(self):
        """Get the path to the node with the highest Q-value in the entire tree."""
        best_node = self.find_best_node(self.root)
        path = []
        current = best_node
        while current:
            path.append(current)
            current = current.parent
        return list(reversed(path))  # Reverse to get path from root to best node

    def find_best_node(self, node):
        """Recursively find the node with the highest Q-value in the subtree."""
        best_node = node
        for child in node.children:
            child_best = self.find_best_node(child)
            if child_best.Q_value > best_node.Q_value:
                best_node = child_best
        return best_node
    

# Save the MCTS state at any iteration and restart from that point
class MCTSStateManager:
    @staticmethod
    def save_state(tree, iteration, filename):
        state = {
            'iteration': iteration,
            'tree': MCTSStateManager._serialize_tree(tree.root)
        }
        with open(filename, 'w') as f:
            json.dump(state, f, indent=2)
        print(f"State saved to {filename}")

    @staticmethod
    def load_state(filename):
        if not os.path.exists(filename):
            print(f"No saved state found at {filename}")
            return None, 0

        with open(filename, 'r') as f:
            state = json.load(f)
        
        tree = MCTSTree(state['tree']['query'])
        tree.root = MCTSStateManager._deserialize_tree(state['tree'], None)
        return tree, state['iteration']

    @staticmethod
    def _serialize_tree(node):
        serialized = {
            'id': node.id,
            'query': node.query,
            'answer': node.answer,
            'feedback': node.feedback,
            'refined_answer': node.refined_answer,
            'visits': node.visits,
            'rewards': list(node.rewards),
            'Q_value': node.Q_value,
            'previous_Q_value': node.previous_Q_value,
            'max_children': node.max_children,
            'importance_weight': node.importance_weight,
            'depth': node.depth,
            'children': [MCTSStateManager._serialize_tree(child) for child in node.children]
        }
        return json.loads(json.dumps(serialized, ensure_ascii=False))

    @staticmethod
    def _deserialize_tree(data, parent):
        node = Node(
            query=data['query'],
            answer=data['answer'],
            feedback=data['feedback'],
            refined_answer=data['refined_answer'],
            parent=parent
        )
        node.id = data['id']
        node.visits = data['visits']
        node.rewards = deque(data['rewards'], maxlen=100)
        node.Q_value = data['Q_value']
        node.previous_Q_value = data['previous_Q_value']
        node.max_children = data['max_children']
        node.importance_weight = data['importance_weight']
        node.depth = data['depth']

        for child_data in data['children']:
            child = MCTSStateManager._deserialize_tree(child_data, node)
            node.children.add(child)

        return node    
    
    
def mcts_sr_algorithm(query, topic, min_depth1_nodes=3, iterations=8, qvalue_threshold=80, importance_sampling=False, save_interval=4, load_file=None):
    # Set the topic for PromptManager
    PromptManager.set_topic(topic)  
    
    # Single instance of Evaluator will be reused for all node evaluations.
    evaluator = Evaluator(topic=topic)
    
    # Load state if a load file is provided
    if load_file:
        tree, start_iteration = MCTSStateManager.load_state(load_file)
        print(f"Loading the MCTS from previous state from iteration {start_iteration}")
        if tree is None:
            print("Tree not found in the file. Starting over at Iteration #1")
            tree = MCTSTree(query, min_depth1_nodes=min_depth1_nodes, importance_sampling=importance_sampling)
            tree.simulate(tree.root, evaluator)
            tree.backpropagate(tree.root)
            start_iteration = 0
    else:
        tree = MCTSTree(query, min_depth1_nodes=min_depth1_nodes, importance_sampling=importance_sampling)
        start_iteration = 0
        # Simulate and backpropagate root node for new tree
        tree.simulate(tree.root, evaluator)
        tree.backpropagate(tree.root)
    
    visualizer = MCTSVisualizer()
    termination_reason = None

    for i in range(start_iteration, iterations):
        print(f"\n---- Iteration {i + 1} ----")
        
        # Selection
        selected_node = tree.select_node(tree.root)
        selected_node.increment_visits()
        
        # Expansion
        if len(selected_node.children) < selected_node.max_children:
            new_node = tree.expand_node(selected_node)
            
            # Simulation
            tree.simulate(new_node, evaluator)

            # Backpropagation
            tree.backpropagate(new_node)
        else:
            print("Node fully expanded, backtracking...")

        # Update importance weights if using importance sampling
        if importance_sampling:
            tree.importance_sampler.update_importance_weights(tree.root)

        # Check for termination conditions
        terminated, reason = tree.termination_check(selected_node, i + 1, max_iterations=iterations, threshold=qvalue_threshold)
        if terminated:
            termination_reason = reason
            print(f"Termination condition met: {reason}")
            break

        # Update visualization
        visualizer.update(i+1, tree)

        # Save state at specified intervals if save_interval is not set to 0
        if save_interval > 0 and (i + 1) % save_interval == 0:
            MCTSStateManager.save_state(tree, i + 1, f"mcts_state_iteration_{i+1}.json")
        
    # If loop completes without early termination
    if termination_reason is None:
        termination_reason = "Maximum iterations reached"

    # Final best path
    final_best_path = tree.get_best_path()
    best_node = final_best_path[-1]
    
    print("\n---- Final Results ----")
    print(f"Termination reason: {termination_reason}")
    print(f"Best node: Node_{best_node.id}")
    print(f"Best Q value: {best_node.Q_value}")
    print(f"Best answer: {best_node.refined_answer or best_node.answer}")
    
    visualizer.update(iterations, tree)
    visualizer.save_figure()
    
    # Save final state as JSON
    MCTSStateManager.save_state(tree, iterations, "mcts_state_final.json")
    
    return

# Topics available: math, creative
query = "Write a short story about a cat finding an owner."
print("Starting MCTS algorithm for: ", f"'{query}'")
mcts_sr_algorithm(query, topic="creative", min_depth1_nodes=3, iterations=5, qvalue_threshold=80, importance_sampling=False, save_interval=3, load_file=None)