In [19]:
import time
import math
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from datasets import load_dataset
from dotenv import load_dotenv

load_dotenv()

rag_dataset = load_dataset("neural-bridge/rag-dataset-1200")

rag_prompt = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question} 
Context: {context} 
Answer:""")

qa_prompt = PromptTemplate.from_template("""You are an assistant for question-answering tasks. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question} 
Answer:""")

In [21]:
def run_experiment(model, prompt, context, question, temperature, num_repeats=5, alpha=0.1, max_retries=5):
    data = []
    llm = ChatOpenAI(model=model, temperature=temperature).bind(logprobs=True)
    example_prompt = prompt.format(context=context, question=question) if context else prompt.format(question=question)
    
    for _ in range(num_repeats):
        full = None
        log_probs = []
        ema_log_prob = None
        skip = False

        retries = 0
        while retries < max_retries:
            try:
                start_time = time.time()  # Track start time

                for chunk in llm.stream(example_prompt):
                    # Ensure streaming does not hang indefinitely
                    if time.time() - start_time > 15:  # Timeout after 15 seconds
                        raise TimeoutError("LLM streaming took too long.")

                    full = chunk if full is None else full + chunk
                    if "logprobs" in full.response_metadata:
                        for token in full.response_metadata["logprobs"]["content"]:
                            log_prob = token["logprob"]
                            log_probs.append(log_prob)
                            ema_log_prob = alpha * log_prob + (1 - alpha) * (ema_log_prob if ema_log_prob is not None else log_prob)

                break  # Success, exit retry loop

            except Exception as e:
                retries += 1  # Ensure retries always increments
                print(f"Attempt {retries} failed: {e}")

                if retries >= max_retries:
                    print(f"Max retries exceeded. Skipping.")
                    skip = True
                    break  # Ensure we exit retry loop

        if skip:
            continue  # Skip this iteration if max retries failed

        # Compute perplexities safely
        try:
            ppl = math.exp(-sum(log_probs) / len(log_probs)) if log_probs else None
        except OverflowError:
            ppl = float('inf')

        try:
            ema_ppl = math.exp(-ema_log_prob) if ema_log_prob else None
        except OverflowError:
            ema_ppl = float('inf')
        
        data.append({
            "Context": context,
            "Question": question,
            "Answer": full.content if full else "No response",
            "Perplexity": ppl,
            "EMA_Perplexity": ema_ppl,
            "Temperature": temperature,
            "Prompt_Type": "QA" if context is None else "RAG"
        })
    
    return data

data = []
num_qa = 5 # Using first 5 examples
temperatures = [0, 1, 2]

for i in range(num_qa):  
    context = rag_dataset['train'][i]['context']
    question = rag_dataset['train'][i]['question']
    
    for temp in temperatures:
        data.extend(run_experiment("gpt-4o-mini", rag_prompt, context, question, temp))
        data.extend(run_experiment("gpt-4o-mini", qa_prompt, None, question, temp))  # No context case

df = pd.DataFrame(data)

Attempt 1 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: LLM streaming took too long.
Attempt 4 failed: The model produced invalid content. Consider modifying your prompt if you are seeing this error persistently.
Attempt 5 failed: LLM streaming took too long.
Max retries exceeded. Skipping.
Attempt 1 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: LLM streaming took too long.
Attempt 4 failed: The model produced invalid content. Consider modifying your prompt if you are seeing this error persistently.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: LLM streaming took too long.
Attempt 4 failed: The model produced invalid content. Consider modifying your prompt if you are seein

In [22]:
df.head()

Unnamed: 0,Context,Question,Answer,Perplexity,EMA_Perplexity,Temperature,Prompt_Type
0,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to the searc...,1.078585,1.071768,0,RAG
1,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to the searc...,1.078421,1.071751,0,RAG
2,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.101121,1.07793,0,RAG
3,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to the searc...,1.066306,1.057514,0,RAG
4,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.103505,1.079985,0,RAG


In [29]:
pd.set_option('display.max_rows', None) 
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
pd.set_option('display.float_format', '{:.4f}'.format)

agg_metrics = df.groupby(["Question", "Temperature", "Prompt_Type"])[["Perplexity", "EMA_Perplexity"]].agg(['mean', 'std', 'min', 'max', 'median', 'count'])
print(agg_metrics)

                                                                                                                   Perplexity                                                                                                                                                                                                           EMA_Perplexity                                                                                                                                                                                                                  
                                                                                                                         mean                                                std                                                min                                                max median count                                               mean                                                std                                                min   