In [1]:
import os 
from pathlib import Path
from dotenv import load_dotenv
import pickle
import pandas as pd
import torch
from tqdm import tqdm
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()

True

In [3]:
base_path = Path.home()/Path(os.environ.get("SAVE_PATH"))

In [4]:
# loading slateq results
ALPHA = 0.25
SEEDS = [42, 5, 7, 97, 53]

In [5]:
def parse_data(data_dict: dict, alg_nam:str, seed: int)->pd.DataFrame:
    # convert from torch to float
    for k, v in data_dict.items():
        if isinstance(v, torch.Tensor):
            data_dict[k] = v.item()
        elif isinstance(v, list):
            if isinstance(v[0], torch.Tensor):
                data_dict[k] = [x.item() for x in v]
    data_df = pd.DataFrame(data_dict)
    data_df["alg_name"] = alg_nam
    data_df["seed"] = seed
    data_df = data_df.reset_index().rename(columns={"index": "episode"})
    
    return data_df

In [6]:
# load slateq results
res_df_list = []
for seed in tqdm(SEEDS):
    slateq_dict = pickle.load(open(base_path / Path(f"test_serving_observed_topic_slateq_{seed}")/ "logs_dict.pickle", "rb"))
    res_df = parse_data(slateq_dict, "SlateQ", seed)
    res_df_list.append(res_df)
# concat all results
final_df = pd.concat(res_df_list)

100%|██████████| 5/5 [00:00<00:00, 17.44it/s]


In [20]:
# load wp results
res_df_list = []

for seed in tqdm(SEEDS):
    for knearest in [20, 10, 5]:
        slateq_dict = pickle.load(open(base_path / Path(f"test_wa_{knearest}_serving_observed_topic_slateq_{seed}")/ "logs_dict.pickle", "rb"))
        res_df = parse_data(slateq_dict, f"Slate-Wolpertinger {knearest}%", seed)
        res_df_list.append(res_df)
# concat all results
wp_df = pd.concat(res_df_list)
final_df = pd.concat([final_df, wp_df])

100%|██████████| 5/5 [00:00<00:00,  5.78it/s]


In [21]:
final_df = final_df.reset_index(drop=True)

In [22]:
from scipy.stats import levene
from scipy.stats import shapiro
from scipy.stats import ttest_rel
from scipy.stats import wilcoxon
def stat_test(r1,r2):
    print(levene(r1, r2))
    differences = [a-b for a,b in zip(r1, r2)]
    print(shapiro(differences))
    print(ttest_rel(r1, r2))
    print(wilcoxon(r1, r2))

# Statistical tests

### SlateQ | WP20

In [23]:
print("===== SlateQ | WP20 =====")
sq_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["SlateQ"].values
wp_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["Slate-Wolpertinger 20%"].values
print("===== Return =====")
print("SlateQ",sq_return.mean(), sq_return.std())
print("WP",wp_return.mean(), wp_return.std())
stat_test(sq_return, wp_return)
print("===== Cum satisfaction =====")
sq_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["SlateQ"].values
wp_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["Slate-Wolpertinger 20%"].values
print("SlateQ",sq_cum_satisfaction.mean(), sq_cum_satisfaction.std())
print("WP",wp_cum_satisfaction.mean(), wp_cum_satisfaction.std())
stat_test(sq_cum_satisfaction, wp_cum_satisfaction)
print("===== Avg satisfaction =====")
sq_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["SlateQ"].values
wp_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["Slate-Wolpertinger 20%"].values
print("SlateQ",sq_avg_satisfaction.mean(), sq_avg_satisfaction.std())
print("WP",wp_avg_satisfaction.mean(), sq_avg_satisfaction.std())
stat_test(sq_avg_satisfaction, wp_avg_satisfaction)

===== SlateQ | WP20 =====
===== Return =====
SlateQ 247.35999999999999 2.2637866507248403
WP 246.90900000000002 1.4865947665722512
LeveneResult(statistic=0.532552629151727, pvalue=0.48635139355725243)
ShapiroResult(statistic=0.9046893119812012, pvalue=0.4363434314727783)
Ttest_relResult(statistic=0.5531666966167407, pvalue=0.609609268612966)
WilcoxonResult(statistic=7.0, pvalue=1.0)
===== Cum satisfaction =====
SlateQ 42.994156853437424 2.2732131750562172
WP 42.58386329197884 1.425497232594593
LeveneResult(statistic=0.6073407830235915, pvalue=0.45822243316908684)
ShapiroResult(statistic=0.8972166180610657, pvalue=0.39469367265701294)
Ttest_relResult(statistic=0.49896060945606324, pvalue=0.6440000797701579)
WilcoxonResult(statistic=7.0, pvalue=1.0)
===== Avg satisfaction =====
SlateQ 0.6475132667906582 0.04269675827454641
WP 0.6299324461407959 0.04269675827454641
LeveneResult(statistic=0.48436659653185354, pvalue=0.50616309618521)
ShapiroResult(statistic=0.8875937461853027, pvalue=0.345

In [24]:
print("===== SlateQ | WP10 =====")
sq_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["SlateQ"].values
wp_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["Slate-Wolpertinger 10%"].values
print("===== Return =====")
print("SlateQ",sq_return.mean(), sq_return.std())
print("WP",wp_return.mean(), wp_return.std())
stat_test(sq_return, wp_return)
print("===== Cum satisfaction =====")
sq_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["SlateQ"].values
wp_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["Slate-Wolpertinger 10%"].values
print("SlateQ",sq_cum_satisfaction.mean(), sq_cum_satisfaction.std())
print("WP",wp_cum_satisfaction.mean(), wp_cum_satisfaction.std())
stat_test(sq_cum_satisfaction, wp_cum_satisfaction)
print("===== Avg satisfaction =====")
sq_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["SlateQ"].values
wp_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["Slate-Wolpertinger 10%"].values
print("SlateQ",sq_avg_satisfaction.mean(), sq_avg_satisfaction.std())
print("WP",wp_avg_satisfaction.mean(), sq_avg_satisfaction.std())
stat_test(sq_avg_satisfaction, wp_avg_satisfaction)

===== SlateQ | WP10 =====
===== Return =====
SlateQ 247.35999999999999 2.2637866507248403
WP 241.548 5.145801783978859
LeveneResult(statistic=1.6060896573162176, pvalue=0.24068633911610735)
ShapiroResult(statistic=0.863184928894043, pvalue=0.23991072177886963)
Ttest_relResult(statistic=3.4477882098482775, pvalue=0.02610716489372576)
WilcoxonResult(statistic=0.0, pvalue=0.0625)
===== Cum satisfaction =====
SlateQ 42.994156853437424 2.2732131750562172
WP 37.37873866891861 4.9499202480279445
LeveneResult(statistic=1.506594875109685, pvalue=0.25454979103348285)
ShapiroResult(statistic=0.8609022498130798, pvalue=0.23150594532489777)
Ttest_relResult(statistic=3.513304841453682, pvalue=0.024598324235586534)
WilcoxonResult(statistic=0.0, pvalue=0.0625)
===== Avg satisfaction =====
SlateQ 0.6475132667906582 0.04269675827454641
WP 0.5488295829081907 0.04269675827454641
LeveneResult(statistic=1.5169382720486726, pvalue=0.25305690972702344)
ShapiroResult(statistic=0.8474172353744507, pvalue=0.1864

In [25]:
print("===== SlateQ | WP5 =====")
sq_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["SlateQ"].values
wp_return = final_df.groupby(["alg_name", "seed"]).mean()["session_length"]["Slate-Wolpertinger 5%"].values
print("===== Return =====")
print("SlateQ",sq_return.mean(), sq_return.std())
print("WP",wp_return.mean(), wp_return.std())
stat_test(sq_return, wp_return)
print("===== Cum satisfaction =====")
sq_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["SlateQ"].values
wp_cum_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_cum_reward"]["Slate-Wolpertinger 5%"].values
print("SlateQ",sq_cum_satisfaction.mean(), sq_cum_satisfaction.std())
print("WP",wp_cum_satisfaction.mean(), wp_cum_satisfaction.std())
stat_test(sq_cum_satisfaction, wp_cum_satisfaction)
print("===== Avg satisfaction =====")
sq_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["SlateQ"].values
wp_avg_satisfaction = final_df.groupby(["alg_name", "seed"]).mean()["ep_avg_reward"]["Slate-Wolpertinger 5%"].values
print("SlateQ",sq_avg_satisfaction.mean(), sq_avg_satisfaction.std())
print("WP",wp_avg_satisfaction.mean(), sq_avg_satisfaction.std())
stat_test(sq_avg_satisfaction, wp_avg_satisfaction)

===== SlateQ | WP5 =====
===== Return =====
SlateQ 247.35999999999999 2.2637866507248403
WP 220.365 14.686274885075518
LeveneResult(statistic=3.4706744654915185, pvalue=0.0994750623089661)
ShapiroResult(statistic=0.8664183616638184, pvalue=0.25222185254096985)
Ttest_relResult(statistic=3.9869971796225854, pvalue=0.01630568089689091)
WilcoxonResult(statistic=0.0, pvalue=0.0625)
===== Cum satisfaction =====
SlateQ 42.994156853437424 2.2732131750562172
WP 17.240114057302474 14.16579164408293
LeveneResult(statistic=3.4653556089240296, pvalue=0.09969351573056447)
ShapiroResult(statistic=0.8645584583282471, pvalue=0.24508187174797058)
Ttest_relResult(statistic=3.959106480975237, pvalue=0.016690130200763213)
WilcoxonResult(statistic=0.0, pvalue=0.0625)
===== Avg satisfaction =====
SlateQ 0.6475132667906582 0.04269675827454641
WP 0.2426655572046293 0.04269675827454641
LeveneResult(statistic=2.753107181598788, pvalue=0.13564876475325738)
ShapiroResult(statistic=0.8323652744293213, pvalue=0.1448