Evaluating models trained on v2.2 on v2.2 and v2.4

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# imports
import re
from copy import deepcopy
from pathlib import Path
from typing import *

import pandas as pd
import srsly
import swifter
import wandb
from datasets import load_from_disk
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tqdm.auto import tqdm

pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_columns", None)
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import ticker
from polyfuzz import PolyFuzz
from polyfuzz.models import TFIDF, EditDistance
from rapidfuzz import fuzz
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from src.data.utilities import (
    check_dict_equal,
    clean_slot_values,
    clean_time,
    complement_labels,
    diff_train,
    extract_domains,
    extract_slots,
    extract_values,
    remove_empty_slots,
)
from src.evaluation import compute_prf, diff, jga, prepare_states_eval, slot_metrics
from src.model import DSTTask

In [17]:
list_dfs = []
path = Path("../preds")

for exp_path in list(path.rglob("*experiment_1")):

    for p in tqdm(list(exp_path.rglob("*preds.parquet")), desc=exp_path.name):

        if p.parents[2].name != "v0":
            continue

        # load
        df = pd.read_parquet(p).assign(
            split=p.parents[0].name,
            epoch=int(p.parents[1].name.split("=")[1]),
            version=p.parents[2].name,
            model=p.parents[4].name,
            size=p.parents[5].name,
        )

        # extract slots from experiments with normal states
        df = df.assign(
            states=lambda df_: df_["states"].map(remove_empty_slots),
            previous_states=lambda df_: df_["previous_states"].map(remove_empty_slots),
        )

        list_dfs.append(df)

experiment_1:   0%|          | 0/2133 [00:00<?, ?it/s]

In [5]:
df = pd.concat(list_dfs).reset_index(drop=True)
df.groupby(["model", "split", "version"])[["epoch", "dialogue_id"]].nunique()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,epoch,dialogue_id
model,split,version,Unnamed: 3_level_1,Unnamed: 4_level_1
mwoz22_ops_nohist+prev_2022-11-24T04-26-31,test,v4,20,1000
mwoz22_ops_nohist+prev_2022-11-24T04-26-31,validation,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-05T14-00-06,test,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-05T14-00-06,validation,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-05T19-16-00,test,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-05T19-16-00,validation,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-06T00-36-24,test,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-06T00-36-24,validation,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-06T06-00-06,test,v4,20,1000
mwoz22_ops_nohist+prev_2022-12-06T06-00-06,validation,v4,20,1000


In [6]:
df["states"] = df["states"].swifter.apply(prepare_states_eval)

Pandas Apply:   0%|          | 0/1674600 [00:00<?, ?it/s]

In [7]:
# load gold annotations
list_dfs = []
for v in ("21", "22", "23", "24"):
    dataset_dict = load_from_disk(f"../data/processed/multiwoz_{v}")

    true_df = (
        pd.concat([dataset_dict[split].to_pandas() for split in ("test", "validation")])
        .sort_values(["dialogue_id", "turn_id"])
        .reset_index(drop=True)
        .assign(
            states=lambda df_: df_[f"states"].map(lambda ex: prepare_states_eval(remove_empty_slots(ex))), version=v
        )
    )

    list_dfs.append(true_df)


true_df = pd.concat(list_dfs).reset_index(drop=True)
print(len(true_df))
true_df = true_df.loc[true_df["usr_utt"] != "none", ["dialogue_id", "turn_id", "states", "version"]]
print(len(true_df))
true_df = true_df.set_index(["dialogue_id", "turn_id", "version"]).unstack(-1)
true_df.columns = [f"{i}_{j}" for i, j in true_df.columns]
true_df = true_df.reset_index()

66984
58984


In [8]:
regexs = []
for v in ["guest house", "swimming pool", "night club", "concert hall"]:
    regexs += [
        (re.compile(v, flags=re.IGNORECASE), v.replace(" ", "")),
        (re.compile(v.replace(" ", ""), flags=re.IGNORECASE), v),
    ]

for v in [("theater", "theatre"), ("center", "centre")]:
    regexs += [
        (re.compile(v[0], flags=re.IGNORECASE), v[1]),
        (re.compile(v[1], flags=re.IGNORECASE), v[0]),
    ]

regexs.append((re.compile("star", flags=re.IGNORECASE), ""))


def add_variations(state: Union[Dict, None], regexs) -> Union[Dict, None]:
    if state is None:
        return None

    new_state = deepcopy(state)
    for k, v_list in new_state.items():
        new_v_list = []
        for v in v_list:
            for pat, sub in regexs:
                new_v_list.append(pat.sub(sub, v).strip())

                if v.startswith("the"):
                    new_v_list.append(v.lstrip("the").strip())

        new_state[k] = list(set(new_v_list))

    return new_state

In [9]:
for v in ("21", "22", "23", "24"):
    true_df[f"new_states_{v}"] = true_df[f"states_{v}"].map(lambda ex: add_variations(ex, regexs))

In [10]:
# add gold annotations
pred_df = pd.merge(
    df,
    true_df,
    on=["dialogue_id", "turn_id"],
    how="inner",
)
# assert len(df) == len(pred_df)

In [11]:
# compute correct predictions
for v in ("21", "22", "23", "24"):
    pred_df[f"jga_turn_{v}"] = pred_df.swifter.apply(
        lambda row: jga(row["states"], row[f"states_{v}"]),
        axis=1,
    )

    pred_df[f"new_jga_turn_{v}"] = pred_df.swifter.apply(
        lambda row: jga(row["states"], row[f"new_states_{v}"]),
        axis=1,
    )

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/1474600 [00:00<?, ?it/s]

In [12]:
eval_cols = []
for v in ("21", "22", "23", "24"):
    eval_cols += [f"jga_turn_{v}", f"new_jga_turn_{v}"]
index_cols = ["model", "version", "split", "size", "epoch"]

In [13]:
results = pred_df.groupby(index_cols)[eval_cols].mean()
other = results.loc[~results.index.get_level_values("model").str.contains("mwoz22")].copy()
results = results.loc[results.index.get_level_values("model").str.contains("mwoz22")]

In [14]:
validation_results = []
for col in eval_cols:
    val_best_ids = results.groupby(index_cols[:-1])[col].idxmax().reset_index()
    val_best_ids = val_best_ids.loc[val_best_ids["split"] == "validation"].assign(
        best_ckpt=lambda df_: df_[col].map(lambda ex: (ex[0], ex[1], ex[3], ex[4]))
    )
    dd = results.loc[results.index.droplevel("split").isin(val_best_ids["best_ckpt"]), [col]].assign(
        metric_name=col, metric=lambda df_: df_[col] * 100
    )[["metric_name", "metric"]]

    validation_results.append(dd)


validation_results = pd.concat(validation_results).reset_index()
validation_results["exp"] = validation_results["model"].str.split("_").map(lambda ex: f"{ex[1]}_{ex[2]}")
validation_results = (
    validation_results.groupby(index_cols[1:-1] + ["exp", "metric_name"])["metric"].describe().reset_index()
)

In [16]:
validation_results_gold = validation_results.loc[(validation_results["version"] == "v4")].copy()

print(
    validation_results_gold.loc[(validation_results_gold["split"] == "test")]
    .drop(columns=["version", "size", "split", "exp", "count"])
    .round(2)
    .to_latex(index=False)
)

\begin{tabular}{lrrrrrrr}
\toprule
    metric\_name &  mean &  std &   min &   25\% &   50\% &   75\% &   max \\
\midrule
    jga\_turn\_21 & 53.02 & 0.20 & 52.69 & 53.02 & 53.02 & 53.15 & 53.21 \\
    jga\_turn\_22 & 56.02 & 0.15 & 55.78 & 55.98 & 56.09 & 56.12 & 56.13 \\
    jga\_turn\_23 & 50.74 & 0.19 & 50.42 & 50.72 & 50.81 & 50.83 & 50.90 \\
    jga\_turn\_24 & 86.26 & 0.34 & 85.85 & 86.03 & 86.25 & 86.44 & 86.73 \\
new\_jga\_turn\_21 & 57.59 & 0.10 & 57.49 & 57.54 & 57.58 & 57.58 & 57.75 \\
new\_jga\_turn\_22 & 60.71 & 0.14 & 60.47 & 60.72 & 60.73 & 60.78 & 60.85 \\
new\_jga\_turn\_23 & 56.20 & 0.13 & 56.05 & 56.09 & 56.21 & 56.31 & 56.36 \\
new\_jga\_turn\_24 & 88.29 & 0.35 & 87.85 & 88.10 & 88.24 & 88.52 & 88.73 \\
\bottomrule
\end{tabular}



In [52]:
validation_results = validation_results.loc[(validation_results["version"] != "v2")]

In [53]:
all_results = (
    validation_results.loc[(validation_results["split"] == "test") & (validation_results["size"] == "base")]
    .drop(columns=["version", "split", "count"])
    .copy()
    .assign(
        state_repr=lambda df_: df_["exp"]
        .str.split("_", expand=True)[0]
        .map({"cum": "Cumulative", "ops": "Operations"}),
        context=lambda df_: df_["exp"]
        .str.split("_", expand=True)[1]
        .map(
            {
                "fullhist+nostate": "Full-history",
                "fullhist+prev": "Full-history + State",
                "nohist+prev": "State",
                "partialhist+prev": "4 Turns + State",
            }
        ),
        metric_name=lambda df_: df_["metric_name"].map(
            {
                "jga_turn_21": "2.1",
                "new_jga_turn_21": "2.1 (fix labels)",
                "jga_turn_22": "2.2",
                "new_jga_turn_22": "2.2 (fix labels)",
                "jga_turn_23": "2.3",
                "new_jga_turn_23": "2.3 (fix labels)",
                "jga_turn_24": "2.4",
                "new_jga_turn_24": "2.4 (fix labels)",
            }
        ),
    )[["state_repr", "context", "metric_name", "mean", "std", "min", "25%", "50%", "75%", "max"]]
    .rename(
        columns={
            "state_repr": "State representation",
            "context": "Context",
            "size": "Model size",
            "metric_name": "Dataset version",
        }
    )
    .set_index(["State representation", "Context", "Dataset version"])
    .round(2)
)

In [54]:
# full table
print(all_results.reset_index().iloc[:, 1:].to_latex(index=False))

\begin{tabular}{llrrrrrrr}
\toprule
             Context &  Dataset version &  mean &  std &   min &   25\% &   50\% &   75\% &   max \\
\midrule
        Full-history &              2.1 & 49.47 & 0.22 & 49.12 & 49.39 & 49.55 & 49.63 & 49.66 \\
        Full-history &              2.2 & 56.72 & 0.49 & 56.17 & 56.38 & 56.61 & 57.18 & 57.28 \\
        Full-history &              2.3 & 47.67 & 0.59 & 47.10 & 47.19 & 47.48 & 48.14 & 48.44 \\
        Full-history &              2.4 & 56.89 & 0.51 & 56.16 & 56.65 & 57.04 & 57.08 & 57.50 \\
        Full-history & 2.1 (fix labels) & 50.69 & 0.37 & 50.45 & 50.47 & 50.60 & 50.61 & 51.34 \\
        Full-history & 2.2 (fix labels) & 57.01 & 0.45 & 56.43 & 56.81 & 56.89 & 57.42 & 57.51 \\
        Full-history & 2.3 (fix labels) & 49.51 & 0.63 & 48.74 & 48.94 & 49.81 & 49.85 & 50.19 \\
        Full-history & 2.4 (fix labels) & 63.11 & 0.83 & 61.98 & 62.89 & 63.13 & 63.27 & 64.30 \\
Full-history + State &              2.1 & 49.50 & 0.50 & 48.91 & 49.12

In [55]:
exp1_table = all_results.reset_index()
exp1_table = exp1_table.loc[exp1_table["Dataset version"].str.contains("fix")].assign(
    **{
        "Dataset version": lambda df_: df_["Dataset version"].str.rstrip("(fix labels)").str.strip(),
    }
)
exp1_table

Unnamed: 0,State representation,Context,Dataset version,mean,std,min,25%,50%,75%,max
4,Cumulative,Full-history,2.1,50.69,0.37,50.45,50.47,50.6,50.61,51.34
5,Cumulative,Full-history,2.2,57.01,0.45,56.43,56.81,56.89,57.42,57.51
6,Cumulative,Full-history,2.3,49.51,0.63,48.74,48.94,49.81,49.85,50.19
7,Cumulative,Full-history,2.4,63.11,0.83,61.98,62.89,63.13,63.27,64.3
12,Cumulative,Full-history + State,2.1,50.71,0.55,50.04,50.47,50.57,50.98,51.51
13,Cumulative,Full-history + State,2.2,56.89,0.57,56.19,56.65,56.67,57.33,57.6
14,Cumulative,Full-history + State,2.3,49.45,0.54,48.81,49.08,49.5,49.67,50.2
15,Cumulative,Full-history + State,2.4,63.27,0.68,62.11,63.27,63.48,63.74,63.75
20,Cumulative,State,2.1,50.9,0.11,50.77,50.83,50.91,50.95,51.06
21,Cumulative,State,2.2,56.5,0.47,55.93,56.13,56.57,56.81,57.08


---

In [49]:
print(
    validation_results.loc[
        (validation_results["size"] == "large") & (validation_results["split"] == "test"),
        ["exp", "metric_name", "mean"],
    ].to_latex(index=False)
)

\begin{tabular}{llr}
\toprule
                 exp &     metric\_name &      mean \\
\midrule
cum\_fullhist+nostate &     jga\_turn\_21 & 51.058058 \\
cum\_fullhist+nostate &     jga\_turn\_22 & 57.297884 \\
cum\_fullhist+nostate &     jga\_turn\_23 & 48.209441 \\
cum\_fullhist+nostate &     jga\_turn\_24 & 58.518719 \\
cum\_fullhist+nostate & new\_jga\_turn\_21 & 51.939772 \\
cum\_fullhist+nostate & new\_jga\_turn\_22 & 57.365708 \\
cum\_fullhist+nostate & new\_jga\_turn\_23 & 49.769398 \\
cum\_fullhist+nostate & new\_jga\_turn\_24 & 65.816603 \\
ops\_partialhist+prev &     jga\_turn\_21 & 49.864352 \\
ops\_partialhist+prev &     jga\_turn\_22 & 56.429734 \\
ops\_partialhist+prev &     jga\_turn\_23 & 47.843190 \\
ops\_partialhist+prev &     jga\_turn\_24 & 58.003256 \\
ops\_partialhist+prev & new\_jga\_turn\_21 & 50.922409 \\
ops\_partialhist+prev & new\_jga\_turn\_22 & 56.741725 \\
ops\_partialhist+prev & new\_jga\_turn\_23 & 49.701574 \\
ops\_partialhist+prev & new\_jga\_turn\_24 &

In [59]:
pred_df["domain"] = pred_df["states"].swifter.apply(extract_domains)
pred_df["slots"] = pred_df["states"].swifter.apply(extract_slots)

Pandas Apply:   0%|          | 0/14251996 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/14251996 [00:00<?, ?it/s]

In [79]:
val = []
for col in eval_cols:
    val_best_ids = results.groupby(index_cols[:-1])[col].idxmax().reset_index()
    val_best_ids = val_best_ids.loc[val_best_ids["split"] == "validation"].assign(
        best_ckpt=lambda df_: df_[col].map(lambda ex: (ex[0], ex[1], ex[3], ex[4]))
    )
    dd = results.loc[results.index.droplevel("split").isin(val_best_ids["best_ckpt"]), [col]].assign(
        metric_name=col, metric=lambda df_: df_[col] * 100
    )[["metric_name", "metric"]]

    val.append(dd)

val = pd.concat(val).reset_index()
val["exp"] = val["model"].str.split("_").map(lambda ex: f"{ex[1]}_{ex[2]}")
val = val.loc[(val["split"] == "test") & (val["size"] == "base") & (val["version"] == "v0")]
val["id"] = val["model"] + "+" + val["version"] + "+" + val["epoch"].astype(str)

In [78]:
pred_df["id"] = pred_df["model"] + "+" + pred_df["version"] + "+" + pred_df["epoch"].astype(str)

In [80]:
best = pred_df.loc[pred_df["id"].isin(val["id"])]

In [82]:
domain = best.explode("domain")

In [120]:
results_domain = domain.groupby(index_cols + ["domain"])[eval_cols].mean()
results_domain = results_domain.loc[results_domain.index.get_level_values("split") == "test"]
results_domain = results_domain.reset_index()
results_domain["exp"] = results_domain["model"].str.split("_").map(lambda ex: f"{ex[1]}_{ex[2]}")
results_domain = results_domain.groupby(["exp", "domain"])[eval_cols].describe().stack(0)
results_domain = results_domain.reset_index().rename(columns={"level_2": "metric_name"})

In [130]:
results_domain.loc[(results_domain["metric_name"].isin(["new_jga_turn_22", "new_jga_turn_24"])),].set_index(
    ["exp", "domain", "metric_name"]
)[["mean", "std", "min", "25%", "50%", "75%", "max"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,std,min,25%,50%,75%,max
exp,domain,metric_name,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
cum_fullhist+nostate,attraction,new_jga_turn_22,0.504243,0.007214,0.494210,0.498604,0.502585,0.508680,0.518152
cum_fullhist+nostate,attraction,new_jga_turn_24,0.582955,0.023106,0.535980,0.573464,0.584605,0.597156,0.618280
cum_fullhist+nostate,hotel,new_jga_turn_22,0.398062,0.009949,0.382136,0.389927,0.399691,0.402477,0.415277
cum_fullhist+nostate,hotel,new_jga_turn_24,0.437193,0.025453,0.393704,0.425919,0.441023,0.450499,0.483359
cum_fullhist+nostate,restaurant,new_jga_turn_22,0.543266,0.006537,0.530856,0.539142,0.542748,0.547721,0.552715
...,...,...,...,...,...,...,...,...,...
ops_partialhist+prev,restaurant,new_jga_turn_24,0.626825,0.014651,0.598258,0.619059,0.630791,0.633845,0.648310
ops_partialhist+prev,taxi,new_jga_turn_22,0.425136,0.014837,0.387247,0.419382,0.427907,0.434867,0.444618
ops_partialhist+prev,taxi,new_jga_turn_24,0.399982,0.025611,0.337481,0.388146,0.407752,0.421618,0.428351
ops_partialhist+prev,train,new_jga_turn_22,0.540713,0.009756,0.522136,0.535184,0.540989,0.545393,0.555217


In [108]:
results_domain.loc[
    results_domain["exp"].isin(["cum_fullhist+nostate", "ops_nohist+prev"]), ["exp", "domain", "mean", "std"]
]

Unnamed: 0,exp,domain,mean,std
0,cum_fullhist+nostate,attraction,0.413930,0.007095
1,cum_fullhist+nostate,attraction,0.500851,0.007243
2,cum_fullhist+nostate,attraction,0.370806,0.009352
3,cum_fullhist+nostate,attraction,0.530398,0.018316
4,cum_fullhist+nostate,attraction,0.438027,0.007722
...,...,...,...,...
275,ops_nohist+prev,train,0.614188,0.012501
276,ops_nohist+prev,train,0.467423,0.006496
277,ops_nohist+prev,train,0.538241,0.007788
278,ops_nohist+prev,train,0.453017,0.005484


---
### Run-time

In [None]:
tmp = pred_df["model"].str.split("_", expand=True)
pred_df["state_repr"] = tmp[1].str.strip()
pred_df["context"] = tmp[2].str.strip()
del tmp

In [None]:
pred_df["size"] = "base"
large = [p.name for p in Path("../preds/experiment_1/large/").iterdir()]
pred_df.loc[pred_df["model"].isin(large), "size"] = "large"

In [None]:
pred_df["runtime_instance"] = pred_df["runtime"] / pred_df["batch_size"]

In [None]:
t = pred_df.loc[pred_df["size"] == "base"]
rt = t.groupby(["state_repr", "context"])["runtime_instance"].agg(["median", "std"])

In [None]:
rt["relative"] = rt["median"] / rt["median"].min()

In [None]:
print(
    rt[["median", "relative"]]
    .reset_index()
    .assign(
        state_repr=lambda df_: df_["state_repr"].map({"cum": "Cumulative", "ops": "State operations"}),
        median=lambda df_: df_["median"] * 100,
    )
    .round(2)
    .to_latex(index=False)
)

---

### Model selection and results

In [None]:
eval_cols = [
    "jga_turn_22",
    "jga_turn_24",
]

index_cols = [
    "model",
    "epoch",
    "version",
    "size",
]

In [None]:
val_df = (
    pred_df.loc[pred_df["split"] == "validation", eval_cols + index_cols]
    .groupby(index_cols)[eval_cols]
    .mean()
    .reset_index()
)

val_df = (
    val_df.groupby(["model", "version"])
    .apply(lambda df_: df_.set_index(["epoch"])[eval_cols].idxmax())
    .reset_index()
    .melt(id_vars=["model", "version"], var_name="metric", value_name="epoch")
)

In [None]:
test_df = (
    pred_df.loc[pred_df["split"] == "test", eval_cols + index_cols]
    .groupby(index_cols)[eval_cols]
    .mean()
    .reset_index()
    .melt(id_vars=index_cols, var_name="metric")
)

In [None]:
best_test_val = pd.merge(
    test_df,
    val_df,
    on=["model", "version", "epoch", "metric"],
    how="inner",
).set_index(["model", "version", "epoch", "metric", "size"])

best_test = test_df.groupby(["model", "version", "metric", "size"]).max().drop(columns=["epoch"])

In [None]:
results = best_test_val.join(best_test, lsuffix="_val", rsuffix="_oracle").reset_index()
assert results["version"].nunique() == 1
del results["version"]

_, results["state_repr"], results["context"], _ = zip(*results["model"].str.split("_"))

results = (
    results.groupby(["state_repr", "context", "size", "metric"])["value_val", "value_oracle"].agg(["mean", "std"]) * 100
).reset_index()

results.columns = [f"{i}-{j}".rstrip("-") for i, j in results.columns]

In [None]:
results

In [None]:
pd.merge(results, runtime, on=["state_repr", "context", "size"], how="left")


# .to_latex(float_format="{:0.2f}".format)

In [None]:
results

In [None]:
t = results.reset_index()["metric"].str.split("_", expand=True).fillna("")
results["metric"] = (t[0] + "_" + t[2]).str.strip("_")
results["data_version"] = t[2]
del t

In [None]:
r = results.drop(columns=["data_version"]).set_index(["model", "version", "metric"]).unstack(-1)
r.columns = [f"{i}-{j}" for i, j in r.columns]

table = (
    r.reset_index()
    .drop(columns=["epoch-jga_22", "epoch-jga_24"])
    .assign(exp=lambda df_: df_["model"].str.split("_").map(lambda ex: f"{ex[1]}_{ex[2]}"))
    .groupby("exp")
    .describe()
    .stack(0)
    .assign(
        mean=lambda df_: df_["mean"] * 100,
        std=lambda df_: df_["std"] * 100,
        metric=lambda df_: df_.apply(lambda row: f"{row['mean']:,.2f} pm {row['std']:,.2f}", axis=1),
    )["metric"]
    .unstack(-1)
    .reset_index()
    .iloc[:, [0, 3, 1, 4, 2]]
)

table["state_repr"], table["dialogue_context"] = zip(*table["exp"].str.split("_"))
del table["exp"]

table = table.assign(
    state_repr=lambda df_: df_["state_repr"].map({"cum": "Cumulative", "ops": "State operations"}),
    dialogue_context=lambda df_: df_["dialogue_context"].map(
        {
            "fullhist+nostate": "Full history",
            "fullhist+prev": "Full history and Previous state",
            "partialhist+prev": "Last 4 turns and Previous state",
            "nohist+prev": "Previous state",
        }
    ),
).set_index(["state_repr", "dialogue_context"])[["value_val-jga_22", "value_val-jga_24"]]
table.columns = ["2.2", "2.4"]

In [None]:
print(table.to_latex())

In [None]:
t = results["model"].str.split("_", expand=True)
results["mode"] = t[1] + "_" + t[2]
del t

In [None]:
results.groupby(["mode", "metric"])[["value_val", "value_oracle"]].describe()

---
### JGA per dialogue

In [None]:
eval_cols = [
    "jga_turn_21",
    "jga_turn_21_clean",
    "jga_turn_21_clean_complement",
    "jga_turn_22",
    "jga_turn_22_clean",
    "jga_turn_23",
    "jga_turn_23_clean",
    "jga_turn_24",
    "jga_turn_24_clean",
    "jga_turn_23_clean_complement",
    "jga_turn_24_clean_complement",
]

cols = ["state_repr", "dialogue_context_repr", "domains_21", "domains_22", "domains_23", "domains_24"]

In [None]:
index_cols = [
    "dialogue_id",
    "model",
    "epoch",
    "split",
    "version",
]
diag_df = pred_df.groupby(index_cols)[eval_cols].mean()

In [None]:
diag_df.groupby(index_cols[1:]).mean().groupby(["model", "split", "version"]).max()

In [None]:
results = []
for v in tqdm(("21", "22", "23", "24")):
    eval_cols = [f"jga_turn_{v}", f"jga_turn_{v}_clean"]
    if v != "22":
        eval_cols += [f"jga_turn_{v}_clean_complement"]
    tmp = pred_df[["model", "split", "version", "epoch", f"domains_{v}"] + eval_cols]
    tmp = tmp.explode(f"domains_{v}")
    tmp = tmp.groupby(["model", "split", "version", "epoch", f"domains_{v}"])[eval_cols].mean().reset_index()
    results.append(tmp)

In [None]:
a = results[0]

In [None]:
a.set_index(["model", "split", "version", "epoch", "domains_21"]).unstack(-1).mean(1)

In [None]:
a = df.loc[
    df["model"] == "mwoz22_cum_fullhist+prev_2022-11-23T22-58-34",
    ["dialogue_id", "turn_id", "states", "predictions", "version", "epoch"],
]
a = a.set_index(["dialogue_id", "turn_id", "epoch", "version"]).unstack(-1).dropna()

In [None]:
a.columns = ["s1", "s2", "p1", "p2"]

In [None]:
a["diff"] = a.apply(lambda row: diff_train(row["s1"], row["s2"]), axis=1)
a["check"] = a["diff"].map(len) > 0

In [None]:
a = a.reset_index()

In [None]:
a.loc[(a["check"] == True) & (a["dialogue_id"] == "MUL0003.json") & (a["epoch"] == 0)]

In [None]:
b = a.groupby(["dialogue_id", "turn_id", "epoch"])["predictions"].nunique()
b = b.loc[b > 1]

In [None]:
b.reset_index()

In [None]:
# true_df["states_24_normalized"] = true_df.swifter.apply(lambda row: normalize_labels(row["states_22"], row["states_24"]), axis=1)

In [None]:
slots = list(set(i for v in true_df["states_22"].map(extract_slots) for i in v))

In [None]:
values_21 = list(
    set([i for v in true_df["states_21"].map(lambda ex: extract_values(clean_slot_values(ex))).tolist() for i in v])
)
values_22 = list(
    set([i for v in true_df["states_22"].map(lambda ex: extract_values(clean_slot_values(ex))).tolist() for i in v])
)
values_23 = list(
    set([i for v in true_df["states_23"].map(lambda ex: extract_values(clean_slot_values(ex))).tolist() for i in v])
)
values_24 = list(
    set([i for v in true_df["states_24"].map(lambda ex: extract_values(clean_slot_values(ex))).tolist() for i in v])
)
values_pred = list(set([i for v in df["states"].map(extract_values).tolist() for i in v]))

In [None]:
model = PolyFuzz([TFIDF(model_id="TFIDF"), EditDistance(model_id="EDIT")])
model.match(values_22, values_21)
# model.group()

d = (
    pd.concat([v.assign(dist=k) for k, v in model.get_matches().items()])
    .set_index(["From", "To", "dist"])
    .unstack(-1)
    .reset_index()
)
d.columns = ["From", "To", "edit", "tfidf"]
d = d.dropna(subset=["From", "To"])

d["time"] = d["From"].str.contains("\d{2,}", regex=True) | d["To"].str.contains("\d{2,}", regex=True)
d["exact"] = d["From"] == d["To"]
d["clean"] = d["From"].map(clean_text)
d["token_set"] = d.apply(lambda row: fuzz.partial_ratio(row["From"], row["To"]), axis=1) / 100

d = d.loc[
    (d["edit"] > 0)
    & (d["tfidf"] > 0)
    & (d["time"] == False)
    & (d["exact"] == False)
    # & (d["edit"] < 1)
    # & (d["tfidf"] < 1)
]

In [None]:
d.loc[(d["tfidf"] < 0.5) & (d["edit"] < 0.5)]

In [None]:
true_df["values"] = true_df["states_22"].map(extract_values)

In [None]:
s = "cotto?also,"
true_df.loc[true_df["values"].map(lambda ex: s in ex)]

In [None]:
s = "holiday inn"
true_df.loc[(true_df["sys_utt_22"].str.contains(s)) | (true_df["sys_utt_22"].str.contains(s))]

In [None]:
df.loc[df["dialogue_id"] == "SNG01735.json"]

In [None]:
model.get_matches("TFIDF").set_index("Group")

In [None]:
l = pd.DataFrame(values_22, columns=["labels"])

In [None]:
l.loc[l["labels"].str.contains("\d{2,}", regex=True), "labels"].unique()

In [None]:
t = d.loc[d["time"] == True].copy()

In [None]:
t["from"] = t["From"].map(clean_time)
t

In [None]:
model = PolyFuzz([TFIDF(model_id="TFIDF"), EditDistance(model_id="EDIT")])
model.match(values_22, values_pred)

dd = pd.concat([v.assign(dist=k) for k, v in model.get_matches().items()]).set_index(["From", "To", "dist"]).unstack(-1)
dd.columns = ["edit", "tfidf"]
dd = dd.reset_index()

dd["time"] = dd["From"].str.contains("\d{1,2}:\d{1,2}", regex=True) | dd["To"].str.contains(
    "\d{1,2}:\d{1,2}", regex=True
)

In [None]:
pd.merge(d, dd, on=["From", "To"], how="left")

In [None]:
dd.loc[dd["time"] == True]

In [None]:
t = d.loc[d["time"] == True]

In [None]:
t["from"] = t["From"].map(clean_time)

In [None]:
t

In [None]:
d.loc[(d["edit"] < 1) | (d["tfidf"] < 1)]

In [None]:
d.loc[d["EDIT"]]

In [None]:
# s = serializers["mwoz22_ops_2022-11-17T13-58-22"]
# pred_df_all["rec"] = pred_df_all.swifter.apply(lambda row: s.deserialize(row["predictions"], row["previous_states"])[1], axis=1)

In [None]:
# add gold annotations
pred_df = pd.merge(
    df,
    true_df,
    on=["dialogue_id", "turn_id"],
    how="inner",
)
assert len(df) == len(pred_df)

In [None]:
pred_df = pred_df.loc[pred_df["epoch"] == 11]

In [None]:
for col in ("states", "states_22", "states_24"):
    pred_df[f"{col}_clean"] = pred_df[col].swifter.apply(clean_slot_values)

In [None]:
# compute correct predictions
for version in (2, 4):
    # for version in (1, 2, 3, 4):
    pred_df[f"correct_2{version}"] = pred_df.swifter.apply(
        # lambda row: jga(row["rec"], row[f"states_2{version}"]),
        lambda row: jga(row["states"], row[f"states_2{version}"]),
        axis=1,
    )
    pred_df[f"correct_2{version}_clean"] = pred_df.swifter.apply(
        lambda row: jga(row["states_clean"], row[f"states_2{version}_clean"]),
        axis=1,
    )

In [None]:
pred_df["states_24_complement"] = pred_df.swifter.apply(
    lambda row: complement_labels(row["states_22_clean"], row["states_24_clean"]), axis=1
)

In [None]:
pred_df[f"correct_2{version}_clean_complement"] = pred_df.swifter.apply(
    lambda row: jga(row["states_clean"], row[f"states_2{version}_complement"]),
    axis=1,
)

In [None]:
pred_df.groupby(["model", "epoch", "version", "split"])[
    ["correct_22", "correct_22_clean", "correct_24", "correct_24_clean", "correct_24_clean_complement"]
].mean()

In [None]:
pred_df

In [None]:
# pred_df["slots"] = pred_df["states"].swifter.apply(extract_slots)

In [None]:
# compute correct predictions
for version in (2, 4):
    pred_df[f"states_metrics_2{version}"] = pred_df.swifter.apply(
        # lambda row: compute_prf(row["rec"], row[f"states_2{version}"], num_slots=30),
        lambda row: compute_prf(row["states"], row[f"states_2{version}"], num_slots=30),
        axis=1,
    )

In [None]:
for version in (2, 4):
    (
        pred_df[f"f1_2{version}"],
        pred_df[f"recall_2{version}"],
        pred_df[f"precision_2{version}"],
        pred_df[f"accuracy_2{version}"],
        pred_df[f"relative_accuracy_2{version}"],
    ) = zip(*list(pred_df[f"states_metrics_2{version}"].values))

In [None]:
jga_cols = ["correct_22", "correct_24"]
metrics = [
    f"{i}_2{version}" for i in ["f1", "recall", "precision", "accuracy", "relative_accuracy"] for version in [2, 4]
]
results = pred_df.groupby(["split", "epoch", "model", "version"])[jga_cols + metrics].mean().reset_index()

In [None]:
results_val = (
    results.loc[results["split"] == "validation"]
    .drop(columns=["split"])
    .set_index(["model", "epoch", "version"])
    .copy()
)

results_test = results.loc[results["split"] == "test"].drop(columns=["split"]).copy()

In [None]:
res = results_val.groupby(["model", "version"])[jga_cols].idxmax()
res = res.melt(var_name="data_version")
# res["data_version"] = res["data_version"].str.split("_", expand=True)[1]

In [None]:
res

In [None]:
list_dfs = []
for data_version, (model, epoch, version) in res.values:

    t = results_test.loc[
        (results_test["epoch"] == epoch) & (results_test["model"] == model) & (results_test["version"] == version),
        ["epoch", "model", "version", data_version] + metrics,
    ]
    t = t.rename(columns={data_version: "jga"})
    t["data_version"] = data_version.split("_")[1]

    list_dfs.append(t)

In [None]:
pd.concat(list_dfs)

---
### Errors Ops models

In [None]:
cols = [
    "model",
    "epoch",
    "version",
    "split",
    "dialogue_id",
    "turn_id",
    "sys_utt_24",
    "usr_utt_24",
    "previous_states",
    "predictions",
    "states_clean",
    "states_24_clean_complement",
    "jga_turn_24_clean_complement",
    "dialogue_context_repr",
]

In [None]:
ops_df = pred_df.loc[pred_df["model"].str.contains("ops"), cols].copy()

In [None]:
err_df = ops_df.loc[ops_df["jga_turn_24_clean_complement"] == False]
diffs = err_df.swifter.apply(lambda row: diff(row["states_clean"], row["states_24_clean_complement"]), axis=1)

ops_df["diffs"] = None
ops_df.loc[ops_df["jga_turn_24_clean_complement"] == False, "diffs"] = diffs
del err_df, diffs

In [None]:
ops_df["pred_true"] = ops_df["diffs"].swifter.apply(
    lambda ex: tuple(ex["pred-true"].items()) if ex is not None and "pred-true" in ex else None
)
ops_df["true_pred"] = ops_df["diffs"].swifter.apply(
    lambda ex: tuple((k, v) for k, v_list in ex["true-pred"].items() for v in v_list)
    if ex is not None and "true-pred" in ex
    else None
)

In [None]:
def wrongly_predicted(ex):
    if ex is None or not ("pred-true" in ex and "true-pred" in ex):
        return
    preds = []
    for k in ex["pred-true"]:
        if k in ex["true-pred"]:
            preds.append((k, ex["pred-true"][k], tuple(ex["true-pred"][k])))
    return tuple(preds)


def over_predicted(ex):
    if ex is None or "pred-true" not in ex:
        return

    preds = []
    for k in ex["pred-true"]:
        if "true-pred" not in ex or k not in ex["true-pred"]:
            preds.append((k, ex["pred-true"][k]))
    return tuple(preds)


def under_predicted(ex):
    if ex is None or "true-pred" not in ex:
        return

    preds = []
    for k in ex["true-pred"]:
        if "pred-true" not in ex or k not in ex["pred-true"]:
            preds.append((k, tuple(ex["true-pred"][k])))
    return tuple(preds)


ops_df["wrong"] = ops_df["diffs"].swifter.apply(wrongly_predicted)
ops_df["under"] = ops_df["diffs"].swifter.apply(under_predicted)
ops_df["over"] = ops_df["diffs"].swifter.apply(over_predicted)

In [None]:
# distribution of most frequent errors

slot_count = (
    ops_df.loc[ops_df["split"] == "validation", "states_24_clean_complement"]
    .map(extract_slots)
    .explode("states_24_clean_complement")
    .value_counts()
)
wrong_slots = (
    ops_df.loc[(ops_df["split"] == "validation") & (ops_df["wrong"] != ()) & (~ops_df["wrong"].isna()), "wrong"]
    .explode("wrong")
    .map(lambda ex: ex[0])
    .value_counts()
)
s = pd.DataFrame({"n": slot_count, "n_wrong": wrong_slots})
s["ratio"] = (s["n_wrong"] / s["n"]) * 100
v_values = s["ratio"].sort_values(ascending=False)
del slot_count, wrong_slots, s

slot_count = (
    ops_df.loc[ops_df["split"] == "test", "states_24_clean_complement"]
    .map(extract_slots)
    .explode("states_24_clean_complement")
    .value_counts()
)
wrong_slots = (
    ops_df.loc[(ops_df["split"] == "test") & (ops_df["wrong"] != ()) & (~ops_df["wrong"].isna()), "wrong"]
    .explode("wrong")
    .map(lambda ex: ex[0])
    .value_counts()
)
s = pd.DataFrame({"n": slot_count, "n_wrong": wrong_slots})
s["ratio"] = (s["n_wrong"] / s["n"]) * 100
t_values = s["ratio"].sort_values(ascending=False)
del slot_count, wrong_slots, s

plot_data = (
    pd.DataFrame({"test": t_values, "validation": v_values})
    .reset_index()
    .melt(id_vars=["index"], value_name="perc", var_name="split")
    .sort_values("perc", ascending=False)
)
sns.barplot(plot_data, y="index", x="perc", hue="split", orient="h")

In [None]:
ops_df.loc[ops_df["epoch"] > 10, ["wrong", "over", "under"]].applymap(lambda ex: len(ex) if ex is not None else 0).sum()

In [None]:
true_df.loc[true_df["dialogue_id"] == "PMUL2088.json"]

In [None]:
t = ops_df.loc[
    ((~ops_df["wrong"].isna()) | (~ops_df["under"].isna()) | (~ops_df["over"].isna()))
    #  & ops_df["wrong"] != ()) | (~ops_df["under"].isna() & ops_df["under"] != ()) | (ops_df["over"] != ()))
    & (ops_df["wrong"] != ()) & (ops_df["predictions"] != "none") & (ops_df["epoch"] > 3),
    # & (ops_df["model"] == "mwoz22_ops_nohist+prev_2022-11-24T04-26-31"),
    [
        "dialogue_context_repr",
        "dialogue_id",
        "turn_id",
        "sys_utt_24",
        "usr_utt_24",
        "previous_states",
        "predictions",
        "diffs",
        "wrong",
        "under",
        "over",
        "epoch",
        "model",
    ],
]

n_trials = t[["model", "epoch"]].drop_duplicates().shape[0]
t["turn_occurrence"] = t.groupby(["dialogue_id", "turn_id"]).transform("size") / n_trials * 100

t = t.explode("wrong")
t["wrong_slot"] = t["wrong"].map(lambda ex: ex[0] if ex is not None else None)

t = t.drop(columns=["model", "epoch"])
t = t.drop_duplicates(subset=["dialogue_id", "turn_id", "wrong", "under", "over", "dialogue_context_repr"])
t = t.sort_values(
    ["turn_occurrence", "dialogue_id", "dialogue_context_repr", "turn_id"], ascending=[False, True, True, True]
)
# t = t.loc[t.apply(lambda row: row["slot"] in row["predictions"], axis=1)]

In [None]:
t.head(50)

In [None]:
wt = t.loc[(~t["wrong"].isna()) & (t["wrong"] != ())].drop(columns=["over", "under"])
wt.loc[(wt.apply(lambda row: row["wrong_slot"] in row["predictions"], axis=1))].head(50)

#### Under-predicted

In [None]:
t = ops_df.loc[
    (~ops_df["under"].isna()) & (ops_df["under"] != ()) & (ops_df["predictions"] != "none") & (ops_df["epoch"] > 3),
    # & (ops_df["model"] == "mwoz22_ops_nohist+prev_2022-11-24T04-26-31"),
    [
        "dialogue_context_repr",
        "dialogue_id",
        "turn_id",
        "sys_utt_24",
        "usr_utt_24",
        "previous_states",
        "predictions",
        "diffs",
        "under",
        "epoch",
        "model",
    ],
]

n_trials = t[["model", "epoch"]].drop_duplicates().shape[0]
t["turn_occurrence"] = t.groupby(["dialogue_id", "turn_id"]).transform("size") / n_trials * 100

t = t.explode("under")
t["under_slot"] = t["under"].map(lambda ex: ex[0] if ex is not None else None)

t = t.drop(columns=["model", "epoch"])
t = t.drop_duplicates(subset=["dialogue_id", "turn_id", "under", "dialogue_context_repr"])
t = t.sort_values(
    ["turn_occurrence", "dialogue_id", "dialogue_context_repr", "turn_id"], ascending=[False, True, True, True]
)
# t = t.loc[t.apply(lambda row: row["slot"] in row["predictions"], axis=1)]

In [None]:
t["under_slot"].value_counts().plot.barh()

In [None]:
t.loc[(t["dialogue_context_repr"].str.contains("prev")) & (t["under_slot"] == "hotel-name")].head(50)

In [None]:
true_df.loc[true_df["dialogue_id"] == "MUL0039.json"]

---

In [None]:
tmp = ops_df.explode("pred_true").groupby(["pred_true", "epoch"])["model"].agg(["size", "nunique"]).reset_index()

In [None]:
tmp["slot"] = tmp["pred_true"].map(lambda ex: ex[0]).str.strip()

In [None]:
tmp["tot_per_slot"] = tmp.groupby("slot")["size"].transform("sum")
tmp["tot_per_error"] = tmp.groupby("pred_true")["size"].transform("sum")

In [None]:
schema = srsly.read_yaml("../data/processed/mwoz_all_versions/schema.yaml")

In [None]:
tmp["in_schema"] = tmp["slot"].isin(schema)

In [None]:
# tmp.to_excel("errors.xlsx", index=False)

In [None]:
tmp.loc[tmp["in_schema"] == False]

In [None]:
pred_slot = tmp.groupby("slot")["size"].sum().sort_values(ascending=False).to_frame("counts")

In [None]:
slots = [i for v in true_df["states_22"].map(extract_slots).tolist() for i in v]
slots = pd.Series(slots).value_counts().to_frame("counts_true")

In [None]:
pred_slot = pred_slot.join(slots)

In [None]:
pred_slot["counts"] = pred_slot["counts"].rank()

In [None]:
pred_slot["counts_true"] = pred_slot["counts_true"].rank()

In [None]:
pred_slot.sort_values(["counts", "counts_true"], ascending=False)

In [None]:
pred_slot.max()

In [None]:
error = "attraction-name"
ops_df.loc[
    (ops_df["pred_true"].map(lambda ex: error in ex[0] if ex is not None else False))
    & (ops_df["predictions"].str.contains(f"INSERT {error}"))
    & (ops_df["model"].str.contains("nohist"))
].drop_duplicates(subset=["dialogue_id", "turn_id"])

In [None]:
error = ("attraction-name", "adc")
ops_df.loc[
    (ops_df["pred_true"].map(lambda ex: error in ex if ex is not None else False))
    & (ops_df["predictions"].str.contains(f"INSERT {error[0]} = {error[1]}"))
]

In [None]:
tmp.loc[tmp["in_schema"] == False, "slot"].unique()

In [None]:
tmp["pred_true"].nunique()

In [None]:
set(i[0] for i in tmp["pred_true"].to_list())

In [None]:
tmp["size"].sum()

In [None]:
len(ops_df)

In [None]:
diag = "MUL0409.json"
turn = 1
error = (("attraction-area", "center"),)

ops_df.loc[(ops_df["pred_true"] == error) & (ops_df["dialogue_id"] == diag) & (ops_df["turn_id"] == 1)]

In [None]:
true_df.loc[true_df["dialogue_id"] == diag]

In [None]:
ops_df["states_clean_schema"] = ops_df["states_clean"].swifter.apply(
    lambda ex: {k: v for k, v in ex.items() if k in schema} if ex is not None else None
)

In [None]:
ops_df[f"jga_turn_22_clean_schema"] = ops_df.swifter.apply(
    lambda row: jga(row["states_clean_schema"], row[f"states_22_clean"]),
    axis=1,
)

In [None]:
a = (
    ops_df.groupby(["model", "epoch", "version", "split"])["jga_turn_22_clean", "jga_turn_22_clean_schema"]
    .mean()
    .reset_index()
)
a["diff"] = a["jga_turn_22_clean_schema"] - a["jga_turn_22_clean"]

In [None]:
a["diff"].describe()

In [None]:
a.loc[(a["diff"] < 0.003) & (a["split"] == "test")]