In [None]:
import json
import re
from analysis import Analysis, guesser_vs_oracle_update, stepwise_guesser_annotations
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as stats
from collections import defaultdict 
from IPython.display import HTML
import ipywidgets as widgets
from IPython.display import display
import nltk

plt.rcParams.update({'font.size': 13})
sns.set_style("darkgrid")

In [None]:
all_questions = []
summary_rows = []

for backend in backends:

    print(f"Processing {backend}...")

    with open(f"../data/generation/{backend}/{setting}/dialogues.txt") as f:
        dialogues = f.read()

    with open(f"../data/generation/{backend}/{setting}/oracle_annotations.json") as f:
        oracle_annotations = json.load(f)

    analysis = Analysis(dialogues, oracle_annotations, setting)

    avg_qs = analysis.average_questions()

    # Summary table row
    summary_rows.append({
        "Backend": backend,
        "Average Questions": round(avg_qs, 2)
    })

    # Store full distribution for boxplot
    for q in analysis.questions_dist:
        all_questions.append({
            "Backend": backend,
            "Num Questions": q
        })


# =========================
# Summary table
# =========================

summary_df = pd.DataFrame(summary_rows)

summary_df["Optimal"] = np.log2(analysis.num_candidates) + 0.5
summary_df["Baseline"] = analysis.num_candidates / 2

print(summary_df)


# =========================
# Boxplot
# =========================

plot_df = pd.DataFrame(all_questions)

plt.figure(figsize=(8,6))

sns.boxplot(
    data=plot_df,
    x="Backend",
    y="Num Questions",
    width=0.6
)

plt.axhline(
    np.log2(analysis.num_candidates)+0.5,
    color='tab:pink',
    linestyle='--',
    linewidth=2,
    label='Optimal'
)

plt.axhline(
    analysis.num_candidates/2,
    color='tab:gray',
    linestyle='--',
    linewidth=2,
    label='Baseline'
)

plt.title("Number of Questions per Dialogue (All Backends)")
plt.ylabel("Num Questions")
plt.legend()

plt.show()