In [None]:
import os
from pathlib import Path
DIR_HOME = Path(os.getcwd()).parent
DIR_CONVERSATION = DIR_HOME / "data" / "conversations"

import sys
sys.path.append(str(DIR_HOME))

import json
import pandas as pd
with open(DIR_CONVERSATION / "text-davinci-003-single-response.json") as f:
    responses = json.load(f)

from src.utils import cohen_d, norm_diff_stdev
from src.metrics import SentenceBERTDiversity, Length
grp_metrics = [SentenceBERTDiversity("paraphrase-MiniLM-L3-v2")]
ind_metrics = [Length()]

In [None]:
for response in responses:
    for metric in grp_metrics:
        response[metric.name] = metric(response["completion"])[0]
    for metric in ind_metrics:
        response[metric.name] = metric(response["completion"])

df_responses = pd.DataFrame(responses).set_index(["qid", "cid"])
df_responses.head()

In [None]:
base_cid = 0
last_cid = 10
last_qid = 23

print("\nDiversity comparison >>>")
stats = []
for cid in range((base_cid + 1), (last_cid + 1)):
    diversity_base = df_responses.xs(base_cid, level="cid").sentencebert_diversity.values
    diversity_test = df_responses.xs(cid, level="cid").sentencebert_diversity.values
    stats.append({"base_cfg": base_cid,
                  "test_cfg": cid,
                  "cohen_d": f"{cohen_d(diversity_base, diversity_test):.3f}",
                  "norm_diff_stdev": f"{norm_diff_stdev(diversity_base, diversity_test):.3f}"})
print(pd.DataFrame(stats))

print("\nLength comparison >>>")
stats = []
for cid in range((base_cid + 1), (last_cid + 1)):
    for qid in range(last_qid + 1):
        length_base = df_responses.xs((qid, base_cid), level=["qid", "cid"]).length.values[0]
        length_test = df_responses.xs((qid, cid), level=["qid", "cid"]).length.values[0]
        stats.append({"base_cfg": base_cid,
                      "test_cfg": cid,
                      "qid": qid,
                      "cohen_d": f"{cohen_d(length_base, length_test):.3f}",
                      "norm_diff_stdev": f"{norm_diff_stdev(length_base, length_test):.3f}"})
print(pd.DataFrame(stats))