In [44]:
'''
TODO: This is to check what each model outputs as logprobs when given a question - and if the behaviour is different for:
1. Quants vs float16 variants
2. Between model families (Llama vs Qwen vs Gemma)
3. Between model sizes (14B vs 8B vs 3B vs 1B)
'''
import os
from datasets import load_dataset
import requests
import time
import json
import traceback
import numpy as np

TRUE_SYNONYMS = [
    "true",
    "correct",
    "truth",
    "yes",
    "right",
    "verdade",
]

FALSE_SYNONYMS = [
    "false",
    "incorrect",
    "wrong",
    "fake",
    "no",
    "not",
    "none",
]

In [39]:

def send_request(payload, url=None, max_retries=3, retry_delay=1):
    if url is None:
        raise ValueError("URL is not set")
    
    attempt = 0
    while attempt < max_retries:
        try:
            headers = {"Content-Type": "application/json"}
            response = requests.post(url, headers=headers, data=json.dumps(payload))
            response.raise_for_status()
            if attempt > 0:  # Only print success if we had to retry
                print("\033[92mRequest succeeded after retry\033[0m")  # Green text
            return response.json()
        except requests.exceptions.ConnectionError as e:
            attempt += 1
            if attempt < max_retries:
                print(f"\033[91mConnection error on attempt {attempt}/{max_retries}. Retrying in {retry_delay}s...\033[0m")  # Red text
                time.sleep(retry_delay)
            else:
                print(f"\033[91mFinal connection error after {max_retries} attempts: {e}\033[0m")  # Red text
                print("Connection error occurred. Continuing with available results.")
                exit(1)
        except Exception as e:
            attempt += 1
            if attempt < max_retries:
                print(f"\033[91mRequest error on attempt {attempt}/{max_retries}. Retrying in {retry_delay}s...\033[0m")  # Red text
                time.sleep(retry_delay)
            else:
                print(f"\033[91mFinal request error after {max_retries} attempts: {e}\033[0m")  # Red text
                print(payload)
                return None


def get_initial_response(task):
    try:
        url = task["url"]
        temp = task["temp"]
        max_tokens = task["max_tokens"]
        question_str = task["question_str"]

        payload = {
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": question_str},
            ],
            "temperature": temp,
            "max_tokens": max_tokens,
        }

        response = send_request(payload, url)
        if not response:
            return None

        response_text = response["choices"][0]["message"]["content"]
        task["response_text"] = response_text
        return task
    except Exception as e:
        print(f"Error getting initial response for task: {e}")
        traceback.print_exc()
        return None

def get_confidence_score(task):
    """Ask the model to self-evaluate and extract token-level log-probs."""
    try:
        url = task["url"]
        temp = task["temp"]
        question_str = task["question_str"]
        response_text = task["response_text"]
        exp_type = task["exp_type"]
        model_name = task["model_name"]

        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": question_str},
            {"role": "assistant", "content": response_text},
        ]

        if exp_type == "cot_exp":
            cot_text = task["cot_text"]
            messages.extend([
                {
                    "role": "user",
                    "content": (
                        "Before answering whether your above answer is correct, "
                        "please provide a detailed chain-of-thought explanation of "
                        "your reasoning. Explain step-by-step how you arrived at "
                        "your answer and why you think it is correct or might be "
                        "incorrect."
                    ),
                },
                {"role": "assistant", "content": cot_text},
            ])

        messages.append({
            "role": "user",
            "content": "Is the above answer correct? Answer only with the single word 'true' or 'false'.",
        })

        # Format request according to vLLM OpenAI-compatible API
        payload = {
            # "model": model_name,
            "messages": messages,
            "temperature": temp,
            "max_tokens": 1,
            "logprobs": True,
            "top_logprobs": 20
        }

        response = send_request(payload, url)
        if not response:
            return None

        # Extract logprobs from the response
        if "choices" in response and len(response["choices"]) > 0:
            choice = response["choices"][0]
            if "logprobs" in choice and "content" in choice["logprobs"]:
                # Get the logprobs for the tokens
                logprobs = []
                for token_info in choice["logprobs"]["content"]:
                    if "top_logprobs" in token_info:
                        for top_prob in token_info["top_logprobs"]:
                            token = top_prob["token"].lower()
                            if any(syn in token for syn in TRUE_SYNONYMS):
                                print(f"\033[92m{token}\033[0m (true)")
                                logprobs.append({
                                    "token": "true",
                                    "logprob": top_prob["logprob"]
                                })
                            elif any(syn in token for syn in FALSE_SYNONYMS):
                                print(f"\033[94m{token}\033[0m (false)")
                                logprobs.append({
                                    "token": "false", 
                                    "logprob": top_prob["logprob"]
                                })
                            else:
                                print(f"\033[91mShit that did not fall in the true or false category: {token}\033[0m")
                
                task["logprobs"] = logprobs
                return task

        print(f"Unexpected response structure: {response}")
        return None

    except Exception as e:
        print(f"Error getting confidence score: {e}")
        traceback.print_exc()
        return None

In [5]:
dataset = load_dataset("openai/gsm8k", "main")

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [47]:
from utils import format_gsm8k_question

gsm8k_question = format_gsm8k_question(dataset['test']['question'][327])

task = {
    "idx": "test",
    "dataset_type": "gsm8k",
    # "dataset": dataset,
    "dataset_split": "test",
    "exp_type": "zs_exp",
    "url": "http://localhost:8000/v1/chat/completions",
    "temp": 0.0,
    "max_tokens": 1024,
    "model_name": "llama 3.2 3b",
    "question_str": gsm8k_question,
}

task = get_initial_response(task)
print(task)

{'idx': 'test', 'dataset_type': 'gsm8k', 'dataset_split': 'test', 'exp_type': 'zs_exp', 'url': 'http://localhost:8000/v1/chat/completions', 'temp': 0.0, 'max_tokens': 1024, 'model_name': 'llama 3.2 3b', 'question_str': 'Given the following problem, reason and give a final answer to the problem.\nProblem: Oscar has 24 lollipops and eats 2 on his way to school.  He passes 14 out to his friends.  He buys twice as many lollipops on his way home as he gave to his friends.  He eats 3 more that night and 2 more in the morning.  How many lollipops does Oscar have?\nYour response should end with "The final answer is [answer]" where [answer] is the response to the problem.', 'response_text': "To solve this problem, let's break it down step by step.\n\n1. Oscar starts with 24 lollipops and eats 2 on his way to school. \n   24 - 2 = 22 lollipops\n\n2. He passes 14 out to his friends. \n   22 - 14 = 8 lollipops\n\n3. He buys twice as many lollipops on his way home as he gave to his friends. \n   Si

In [48]:
print(task['question_str'])
print(task["response_text"])

Given the following problem, reason and give a final answer to the problem.
Problem: Oscar has 24 lollipops and eats 2 on his way to school.  He passes 14 out to his friends.  He buys twice as many lollipops on his way home as he gave to his friends.  He eats 3 more that night and 2 more in the morning.  How many lollipops does Oscar have?
Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.
To solve this problem, let's break it down step by step.

1. Oscar starts with 24 lollipops and eats 2 on his way to school. 
   24 - 2 = 22 lollipops

2. He passes 14 out to his friends. 
   22 - 14 = 8 lollipops

3. He buys twice as many lollipops on his way home as he gave to his friends. 
   Since he gave 14 lollipops to his friends, he buys 2 * 14 = 28 lollipops.
   8 + 28 = 36 lollipops

4. He eats 3 more that night and 2 more in the morning. 
   36 - 3 - 2 = 31 lollipops

The final answer is 31.


In [49]:
task['exp_type'] = "zs_exp"
task.keys()

dict_keys(['idx', 'dataset_type', 'dataset_split', 'exp_type', 'url', 'temp', 'max_tokens', 'model_name', 'question_str', 'response_text'])

In [51]:
task = get_confidence_score(task)
probs = {"true": 0.0, "false": 0.0}

for item in task['logprobs']:
    token = item["token"].lower()
    is_true_synonym = any(synonym in token for synonym in TRUE_SYNONYMS)
    is_false_synonym = any(synonym in token for synonym in FALSE_SYNONYMS)

    if is_true_synonym:
        probs["true"] += np.exp(item["logprob"])
    elif is_false_synonym:
        probs["false"] += np.exp(item["logprob"])

p_true = (
    probs["true"] / (probs["true"] + probs["false"])
    if (probs["true"] + probs["false"]) > 0
    else 0.0
)
print(f"Probability of true: {p_true}")

[94mfalse[0m (false)
[92mtrue[0m (true)
[94mfalse[0m (false)
[92mtrue[0m (true)
[94mfalse[0m (false)
[92mġuntrue[0m (true)
[92mtrue[0m (true)
[94m_false[0m (false)
[94m(false[0m (false)
[94m/false[0m (false)
[94m.false[0m (false)
[94mġfalse[0m (false)
[92m_true[0m (true)
[92m(true[0m (true)
[94m=false[0m (false)
[92mincorrect[0m (true)
[94mno[0m (false)
[91mShit that did not fall in the true or false category: the[0m
[94mwrong[0m (false)
[94mĉfalse[0m (false)
Probability of true: 0.07492146896772653
