In [2]:
import pickle
import pandas as pd
import numpy as np

cached = {}

def load(data_name, fold, group_f, rating, ll, nth, topk):
    path = f"data_splits/{data_name}/{fold}/CEs/stats_{group_f}_{rating}_{ll}_{nth}_{topk}.pickle"
    if path in cached:
        return cached[path]
    try:
        with open(path, "rb") as f:
            allstats = pickle.load(f)
    except FileNotFoundError:
        return {}
    cached[path] = allstats
    return allstats


In [3]:
def get_val_for_confs(valnames, dataname=["yelp", "netflix", "amazon"], aggf=["sum", "mean", "disjunction"], rating=["rating", "binary"], ll=["median", "0.1", "no_spn"], ordering=["score", "nth"], topk=[5, 10]):
    all_vals, all_fails = [], 0
    if type(valnames) != list:
        valnames = [valnames]
    for dname in dataname:
        for af in aggf:
            for rat in rating:
                for l in ll:
                    for ord in ordering:
                        for tk in topk:
                            vals, fails = get_val_for_conf(valnames, dname, af, rat, l, ord, tk)
                            all_vals += vals
                            all_fails += fails
    return all_vals, all_fails

def get_val_for_conf(valnames, dataname, aggf, rating, ll, ordering, topk):
    vals = []
    fails = 0
    for fold in range(3):
        data = load(dataname, fold, aggf, rating, ll, ordering, topk)
        if len(data) == 0 and ll == "no_spn":
            data = load(dataname, fold, "mean", rating, ll, ordering, topk)
            if valnames == ["counterfactual_ll"]:
                ce_vecs = []
                for fid, CEs in data.items():
                    for ce in CEs.values():
                        if "counterfactual" in ce:
                            ce_vecs.append(ce["counterfactual"])
                        else:
                            fails += 1
                vals += eval_sample_lls(ce_vecs, dataname, fold, aggf, rating)
                continue
        for fid, CEs in data.items():
            for ce in CEs.values():
                if all(valname in ce for valname in valnames):
                    if len(valnames) == 1:
                        vals.append(ce[valnames[0]])
                    else:
                        vals.append([ce[valname] for valname in valnames])
                else:
                    fails += 1
    return vals, fails


In [4]:
spns_cache = {}

def eval_sample_lls(samples, data_name, fold, group_f, rating):
    path = f"data_splits/{data_name}/{fold}/models/spn_{group_f}_{rating}.pickle"
    if path in spns_cache:
        spn, groups = spns_cache[path]
    try:
        with open(path, "rb") as f:
            spn, _ = pickle.load(f)
        with open(f"data_splits/{data_name}/{fold}/groups.pickle", "rb") as f:
            (groups, _) = pickle.load(f)
        with open(f"data_splits/{data_name}/{fold}/items_{rating}.pickle", "rb") as f:
            item_map = pickle.load(f)
        groups = [np.array([item_map[item] for item in items]) for items in groups]
        spns_cache[path] = spn, groups
    except FileNotFoundError:
        print("No SPN found")
        return np.nan
    
    lls = []
    for sample in samples:
        cf = np.array(sample)
        if group_f == "sum":
            spn_cf = [cf[g].sum() for g in groups]
        elif group_f == "mean":
            spn_cf = [cf[g].mean() for g in groups]
        elif group_f == "disjunction":
            spn_cf = [(cf[g].sum() > 0).astype(int) for g in groups]
        lls.append(spn.compute_ll(np.array(spn_cf)))
    return lls

In [5]:
vals, fails = get_val_for_confs("time_solving")
print(np.median(vals))

11.075806282504345


In [4]:
for data in ["amazon", "yelp", "netflix"]:
    print(data)               
    for nth in ["nth", "score"]:
        for ll in ["0.1", "median", "no_spn"]:
            for rat in ["binary", "rating"]:
                for topk in [[5, 10]]:
                    all_fs = ["disjunction", "sum", "mean"]
                    if ll == "no_spn" or rat == "rating":
                        all_fs = ["mean"]
                    for af in all_fs:
                        vals, fails = get_val_for_confs("counterfactual", dataname=[data], aggf=[af], rating=[rat], ll=[ll], topk=topk, ordering=[nth])
                        print(f"{len(vals) / (len(vals) + fails):.2f}", end = " & ")
        print("\\\\")


amazon
1.00 & 0.99 & 0.97 & 0.97 & 0.92 & 1.00 & 0.00 & 0.00 & 1.00 & 1.00 & \\
1.00 & 0.95 & 1.00 & 1.00 & 0.94 & 0.95 & 0.00 & 0.00 & 1.00 & 1.00 & \\
yelp
1.00 & 1.00 & 0.30 & 0.07 & 0.86 & 0.00 & 0.00 & 0.00 & 1.00 & 1.00 & \\
1.00 & 1.00 & 0.95 & 0.88 & 1.00 & 0.00 & 0.00 & 0.00 & 1.00 & 1.00 & \\
netflix
1.00 & 1.00 & 0.70 & 0.51 & 0.99 & 0.00 & 0.00 & 0.33 & 1.00 & 1.00 & \\
1.00 & 1.00 & 0.89 & 0.62 & 1.00 & 0.00 & 0.00 & 0.33 & 1.00 & 1.00 & \\


In [5]:
val = ["timeout","counterfactual"]

vals, fails = get_val_for_confs(val)
if len(vals) != 0:
    touts, ces = [list(v) for v in zip(*vals)]
else:
    touts, ces = [], []

print(f"{np.mean(touts):.3f}", end = " & ")

0.086 & 

In [6]:
val = ["timeout","counterfactual"]

for ll in ["0.1", "median", "no_spn"]:
    all_touts, _ = get_val_for_confs("timeout", ll=[ll])
    vals, fails = get_val_for_confs(val, ll=[ll])
    if len(vals) != 0:
        touts, ces = [list(v) for v in zip(*vals)]
    else:
        touts, ces = [], []

    print(f"{(fails - np.sum(all_touts) + np.sum(touts)) / len(all_touts):.2f}", end = " & ")
    print(f"{(fails - np.sum(all_touts) + np.sum(touts)) / fails:.2f}", end = " & ")
    print(" \\\\")

0.03 & 0.23 &  \\
0.65 & 0.99 &  \\
0.00 & nan &  \\


  print(f"{(fails - np.sum(all_touts) + np.sum(touts)) / fails:.2f}", end = " & ")


In [7]:
for data in ["amazon", "yelp", "netflix"]:
    print(data)               
    for nth in ["nth", "score"]:
        print(f" & {nth} & ", end="")
        for val in ["counterfactual_ll", "distance"]:
            for rat in ["binary", "rating"]:
                for ll in (["0.1", "median", "no_spn"] if rat == "binary" else ["0.1", "no_spn"]):
                    for topk in [[5, 10]]:
                        af = "disjunction" if rat == "binary" else "mean"
                        vals, fails = get_val_for_confs(val, dataname=[data], aggf=[af], rating=[rat], ll=[ll], topk=topk, ordering=[nth])
                        if np.mean(vals) > 300:
                            print(f"${np.mean(vals):.0f} \\pm {np.std(vals):.0f}$", end = " & ")
                        else:
                            print(f"${np.mean(vals):.2f} \\pm {np.std(vals):.2f}$", end = " & ")
        print(" \\\\")

amazon
 & nth & $-18.44 \pm 12.23$ & $-12.08 \pm 3.16$ & $-19.57 \pm 13.45$ & $1695 \pm 31$ & $1693 \pm 34$ & $1.16 \pm 0.49$ & $2.60 \pm 3.24$ & $1.08 \pm 0.43$ & $2.11 \pm 2.16$ & $0.66 \pm 0.38$ &  \\
 & score & $-18.40 \pm 11.93$ & $-12.21 \pm 3.42$ & $-19.61 \pm 13.42$ & $1695 \pm 35$ & $1694 \pm 34$ & $1.19 \pm 0.46$ & $2.82 \pm 3.94$ & $1.11 \pm 0.37$ & $2.33 \pm 3.19$ & $0.54 \pm 0.37$ &  \\
yelp
 & nth & $-89.11 \pm 38.68$ & $-78.33 \pm 13.72$ & $-103.21 \pm 46.04$ & $3185 \pm 113$ & $3122 \pm 148$ & $2.71 \pm 2.04$ & $4.16 \pm 7.69$ & $1.65 \pm 0.98$ & $2.54 \pm 1.98$ & $1.07 \pm 0.86$ &  \\
 & score & $-89.23 \pm 39.80$ & $-82.42 \pm 14.69$ & $-103.06 \pm 45.99$ & $3184 \pm 84$ & $3122 \pm 148$ & $2.44 \pm 1.57$ & $5.02 \pm 10.89$ & $1.62 \pm 0.89$ & $6.92 \pm 5.17$ & $1.01 \pm 0.79$ &  \\
netflix
 & nth & $-23.44 \pm 9.10$ & $-19.67 \pm 2.89$ & $-23.61 \pm 9.25$ & $381 \pm 44$ & $367 \pm 60$ & $0.71 \pm 1.03$ & $4.07 \pm 8.34$ & $0.69 \pm 0.90$ & $0.53 \pm 1.13$ & $0.39 \pm