In [None]:
import os
import time
import math
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from datasets import load_dataset
from dotenv import load_dotenv

load_dotenv()

rag_dataset = load_dataset("neural-bridge/rag-dataset-1200")

rag_prompt = PromptTemplate.from_template("""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question} 
Context: {context} 
Answer:""")

qa_prompt = PromptTemplate.from_template("""You are an assistant for question-answering tasks. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question} 
Answer:""")

In [3]:
def run_experiment(model, prompt, context, question, temperature, num_repeats=10, alpha=0.1, max_retries=5):
    data = []
    llm = ChatOpenAI(model=model, temperature=temperature).bind(logprobs=True)
    example_prompt = prompt.format(context=context, question=question) if context else prompt.format(question=question)
    
    for _ in range(num_repeats):
        full = None
        log_probs = []
        ema_log_prob = None
        skip = False

        retries = 0
        while retries < max_retries:
            try:
                start_time = time.time() 

                for chunk in llm.stream(example_prompt):
                    # Ensure streaming does not hang indefinitely
                    if time.time() - start_time > 60:  # Timeout after 60 seconds
                        raise TimeoutError("LLM streaming took too long.")

                    full = chunk if full is None else full + chunk
                    if "logprobs" in full.response_metadata:
                        for token in full.response_metadata["logprobs"]["content"]:
                            log_prob = token["logprob"]
                            log_probs.append(log_prob)
                            ema_log_prob = alpha * log_prob + (1 - alpha) * (ema_log_prob if ema_log_prob is not None else log_prob)

                break

            except Exception as e:
                retries += 1
                print(f"Attempt {retries} failed: {e}")

                if retries >= max_retries:
                    print(f"Max retries exceeded. Skipping.")
                    skip = True
                    break  

        if skip:
            continue  # Skip this iteration if max retries failed

        # Compute perplexities safely
        try:
            ppl = math.exp(-sum(log_probs) / len(log_probs)) if log_probs else None
        except OverflowError:
            ppl = float('inf')

        try:
            ema_ppl = math.exp(-ema_log_prob) if ema_log_prob else None
        except OverflowError:
            ema_ppl = float('inf')
        
        data.append({
            "Context": context,
            "Question": question,
            "Answer": full.content if full else "No response",
            "Perplexity": ppl,
            "EMA_Perplexity": ema_ppl,
            "Temperature": temperature,
            "Prompt_Type": "QA" if context is None else "RAG"
        })
    
    return data

def start():
    data = []
    num_qa = 5 # Using first 5 examples
    temperatures = [0, 1, 2]

    for i in range(num_qa):  
        context = rag_dataset['train'][i]['context']
        question = rag_dataset['train'][i]['question']
        
        for temp in temperatures:
            data.extend(run_experiment("gpt-4o-mini", rag_prompt, context, question, temp))
            data.extend(run_experiment("gpt-4o-mini", qa_prompt, None, question, temp))  # No context case

    df = pd.DataFrame(data)
    return df

Attempt 1 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: LLM streaming took too long.
Attempt 4 failed: The model produced invalid content. Consider modifying your prompt if you are seeing this error persistently.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: LLM streaming took too long.
Attempt 4 failed: LLM streaming took too long.
Attempt 5 failed: LLM streaming took too long.
Max retries exceeded. Skipping.
Attempt 1 failed: LLM streaming took too long.
Attempt 2 failed: LLM streaming took too long.
Attempt 3 failed: The model produced invalid content. Consider modifying your prompt if 

### Load existing results or start the experiment

In [None]:
if os.path.exists("results.csv"):
    df = pd.read_csv("results.csv")
else:
    df = start()
    df.to_csv("results.csv")

In [4]:
df.head()

Unnamed: 0,Context,Question,Answer,Perplexity,EMA_Perplexity,Temperature,Prompt_Type
0,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.095639,1.075593,0,RAG
1,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.103221,1.079131,0,RAG
2,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.103234,1.078639,0,RAG
3,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.113013,1.087556,0,RAG
4,Francisco Rogers found the answer to a search ...,Who found the answer to a search query collar ...,Francisco Rogers found the answer to a search ...,1.112027,1.086382,0,RAG


In [6]:
def custom_float_format(x):
    if abs(x) > 10:
        return '{:.4e}'.format(x)  # scientific notation for numbers > 10
    else:
        return '{:.4f}'.format(x)  # standard float format for numbers <= 10

pd.set_option('display.max_rows', None) 
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
pd.set_option('display.float_format', custom_float_format)

agg_metrics = df.groupby(["Question", "Temperature", "Prompt_Type"])[["Perplexity", "EMA_Perplexity"]].agg(['mean', 'median', 'std', 'min', 'max', 'count'])
print(agg_metrics)

                                                                           Perplexity                                                    EMA_Perplexity                                                     
                                                                                 mean     median        std         min        max count           mean     median         std         min         max count
Question                                           Temperature Prompt_Type                                                                                                                                  
What are some of the potential negative impacts... 0           QA              1.2214     1.2315     0.0168      1.1938     1.2338    10         1.3026     1.3322      0.0805      1.2107      1.4467    10
                                                               RAG             1.2458     1.2489     0.0143      1.2196     1.2626    10         1.3078     1.2865      0.0580      

### Median comparison

By comparing **median perplexity** between RAG and QA (baseline), we assess **model predictability**. Lower perplexity indicates more confident, predictable outputs. If RAG consistently shows **lower perplexity than QA**, it suggests that the added context helps the model generate more certain responses. Conversely, higher perplexity implies that RAG may be introducing confusion or noise. This differential helps evaluate whether RAG improves or degrades response reliability compared to the standalone LLM.

In [12]:
# Extract and group Perplexity and EMA_Perplexity medians
median_ppl = agg_metrics['Perplexity']['median']
median_ema = agg_metrics['EMA_Perplexity']['median']

grouped_ppl = median_ppl.groupby(['Temperature', 'Prompt_Type']).median().unstack()
grouped_ema = median_ema.groupby(['Temperature', 'Prompt_Type']).median().unstack()

# Compute the difference: RAG - QA
grouped_ppl['RAG_minus_QA'] = grouped_ppl['RAG'] - grouped_ppl['QA']
grouped_ema['RAG_minus_QA'] = grouped_ema['RAG'] - grouped_ema['QA']

print("Median Perplexity (RAG - QA):")
print(grouped_ppl)

print("\nMedian EMA_Perplexity (RAG - QA):")
print(grouped_ema)

Median Perplexity (RAG - QA):
Prompt_Type     QA        RAG  RAG_minus_QA
Temperature                                
0           1.0487     1.1035        0.0548
1           1.0487     1.1862        0.1376
2           1.6997 3.7677e+94    3.7677e+94

Median EMA_Perplexity (RAG - QA):
Prompt_Type     QA        RAG  RAG_minus_QA
Temperature                                
0           1.0590     1.0797        0.0207
1           1.0590     1.3563        0.2973
2           1.5231 1.5126e+25    1.5126e+25


### Standard deviation comparison

Standard deviation reflects how **consistent** the model's predictions are. A **lower std** means more stable outputs across generations. By comparing the std of RAG vs QA, we can assess if context improves or destabilizes predictability. A **negative RAG_minus_QA** value suggests RAG is **more consistent**, which is desirable for reliability in production systems.


In [15]:
# Analyze standard deviation of Perplexity and EMA_Perplexity
std_ppl = agg_metrics['Perplexity']['std']
std_ema = agg_metrics['EMA_Perplexity']['std']

grouped_std_ppl = std_ppl.groupby(['Temperature', 'Prompt_Type']).median().unstack()
grouped_std_ema = std_ema.groupby(['Temperature', 'Prompt_Type']).median().unstack()

# Compute difference in std: RAG - QA
grouped_std_ppl['RAG_minus_QA'] = grouped_std_ppl['RAG'] - grouped_std_ppl['QA']
grouped_std_ema['RAG_minus_QA'] = grouped_std_ema['RAG'] - grouped_std_ema['QA']

print("Mean Std of Perplexity (RAG - QA):")
print(grouped_std_ppl)

print("\nMean Std of EMA_Perplexity (RAG - QA):")
print(grouped_std_ema)

Mean Std of Perplexity (RAG - QA):
Prompt_Type     QA    RAG  RAG_minus_QA
Temperature                            
0           0.0039 0.0056        0.0017
1           0.1543 0.0721       -0.0823
2           0.8114    NaN           NaN

Mean Std of EMA_Perplexity (RAG - QA):
Prompt_Type     QA         RAG  RAG_minus_QA
Temperature                                 
0           0.0071      0.0168        0.0097
1           0.0430      0.2803        0.2373
2           0.4646 1.5979e+108   1.5979e+108
