In [None]:
%load_ext lab_black
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import ticker as tck
import seaborn as sns
import ipywidgets
import pydantic
import typing
import datetime
import itertools
import collections
import json
import gzip
import tqdm
import multiprocessing
from efficient_apriori import apriori
import math
import locale
from scipy.stats import linregress
from matplotlib.offsetbox import AnchoredText
from datetime import date
from IPython.display import HTML

locale.setlocale(locale.LC_ALL, "de_DE")
locale._override_localeconv["thousands_sep"] = "."
locale._override_localeconv["grouping"] = [3, 3, 0]
plt.rcParams["axes.formatter.use_locale"] = True
sns.set_theme(style="ticks")
cm = 1 / 2.54
a4 = 29.7, 42


class InfoboxProperty(pydantic.BaseModel):
    propertyType: typing.Optional[str]
    name: str


class InfoboxChange(pydantic.BaseModel):
    property: InfoboxProperty
    valueValidTo: typing.Optional[datetime.datetime] = None
    currentValue: typing.Optional[str] = None
    previousValue: typing.Optional[str] = None


class User(pydantic.BaseModel):
    username: typing.Optional[str]
    id: typing.Optional[int]


class InfoboxRevision(pydantic.BaseModel):
    revisionId: int
    pageTitle: str
    changes: typing.Sequence[InfoboxChange]
    validFrom: datetime.datetime
    attributes: typing.Optional[typing.Dict[str, str]]
    pageID: int
    revisionType: typing.Optional[str]
    key: str
    template: typing.Optional[str] = None
    position: typing.Optional[int] = None
    user: typing.Optional[User] = None
    validTo: typing.Optional[datetime.datetime] = None


class ChangeBuckets(pydantic.BaseModel):
    filename: str
    changes: typing.Dict[str, typing.Sequence[typing.Hashable]]


def sliding(seq, window_size):
    for i in range(len(seq) - window_size + 1):
        yield seq[i : i + window_size]


def overlapping_groups(groups, window_size):
    return {
        keys[0]: set().union(*(groups[key] for key in keys))
        for keys in sliding(tuple(groups.keys()), window_size)
    }

# Creating Buckets

In [None]:
def process_pageid(file):
    groups = collections.defaultdict(set)
    with open(file, "rb") as f:
        for ibc in pickle.load(f):
            groups[ibc.value_valid_from.date().isoformat()].add(ibc.page_id)
    return ChangeBuckets(
        filename=file.name,
        changes={k: tuple(sorted(groups[k])) for k in sorted(groups.keys())},
    )


def process_property(file):
    groups = collections.defaultdict(set)
    with open(file) as f:
        revisions = (InfoboxRevision.parse_raw(line) for line in f)
        for revision in revisions:
            groups[revision.validFrom.date().isoformat()].update(
                change.property.name for change in revision.changes
            )
    return ChangeBuckets(
        filename=file.name,
        changes={k: tuple(sorted(groups[k])) for k in sorted(groups.keys())},
    )


def process_template_property(file):
    groups = collections.defaultdict(set)
    with open(file) as f:
        revisions = (InfoboxRevision.parse_raw(line) for line in f)
        for revision in revisions:
            groups[revision.validFrom.date().isoformat()].update(
                (str(revision.template), change.property.name)
                for change in revision.changes
            )
    return ChangeBuckets(
        filename=file.name,
        changes={k: tuple(sorted(groups[k])) for k in sorted(groups.keys())},
    )


def process_page_property(file):
    groups = collections.defaultdict(set)
    with open(file) as f:
        revisions = (InfoboxRevision.parse_raw(line) for line in f)
        for revision in revisions:
            groups[revision.validFrom.date().isoformat()].update(
                (revision.pageID, change.property.name) for change in revision.changes
            )
    return ChangeBuckets(
        filename=file.name,
        changes={k: tuple(sorted(groups[k])) for k in sorted(groups.keys())},
    )


fname = "./changesets-pageid-dfe.json.gz"

if not Path(fname).exists():
    groups = collections.defaultdict(set)
    files = [
        x
        for x in sorted(
            Path("../../data/custom-format-default-filtered/").rglob("*.pickle")
        )
        if x.is_file()
    ]
    with multiprocessing.Pool(2) as p:
        imap = p.imap(process_pageid, files)
        for cb in tqdm.notebook.tqdm(imap, total=len(files)):
            for k, v in cb.changes.items():
                groups[k].update(v)
    del files
    groups = {k: tuple(sorted(groups[k])) for k in sorted(groups.keys())}
    with open(fname, "wb") as f:
        f.write(
            gzip.compress(
                ChangeBuckets(filename="all", changes=groups)
                .json(indent=None, separators=(",", ":"))
                .encode("utf-8")
            )
        )
else:
    with open(fname, "rb") as f:
        groups = ChangeBuckets.parse_raw(
            gzip.decompress(f.read()).decode("utf-8")
        ).changes

## Min/Max Support Filtering Analysis

In [None]:
@ipywidgets.interact(
    sma=ipywidgets.IntSlider(value=30, min=5, max=365, step=5, continuous_update=False)
)
def plot_freq(sma):
    fig = plt.figure(figsize=(20 * cm, 15 * cm), dpi=100)
    ax = plt.subplot(111)
    data = pd.DataFrame(
        ((key, len(value)) for key, value in groups.items()), columns=["Date", "Daily"]
    )
    data["Date"] = pd.to_datetime(data["Date"])
    data["Daily"] /= len(set().union(*groups.values()))
    data = data.set_index("Date").resample("D").sum().reset_index()
    data[f"SMA{sma}"] = data["Daily"].rolling(sma).mean()
    data = data.melt(id_vars=["Date"], var_name="Data", value_name="Percentage")

    sns.lineplot(x="Date", y="Percentage", hue="Data", data=data)
    ax.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    sns.despine(ax=ax)
    sns.move_legend(
        ax,
        "lower center",
        bbox_to_anchor=(0.5, 1),
        ncol=2,
        title=None,
        frameon=False,
    )
    plt.show()

In [None]:
def get_freq(g):
    freqs = collections.defaultdict(int)
    for bucket in g.values():
        for id in bucket:
            freqs[id] += 1

    freqs = pd.DataFrame(freqs.items(), columns=["ID", "Count"])
    freqs["Frequency"] = freqs["Count"] / len(groups)
    return freqs.sort_values(["Count", "ID"], ascending=[False, True])


freqs = get_freq(groups)
freqs

In [None]:
@ipywidgets.interact(
    bucket_size=ipywidgets.IntSlider(
        value=1, min=1, max=21, step=1, continuous_update=False
    ),
    bins=ipywidgets.IntSlider(
        value=50, min=5, max=100, step=5, continuous_update=False
    ),
    lower=ipywidgets.FloatSlider(
        value=0.05,
        min=0.0,
        max=1.0,
        step=0.001,
        readout_format=".1%",
        continuous_update=False,
    ),
    upper=ipywidgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=1.0,
        step=0.001,
        readout_format=".1%",
        continuous_update=False,
    ),
)
def plot_frequency_hist(bucket_size, bins, lower, upper):
    fig = plt.figure(figsize=(20 * cm, 15 * cm), dpi=100)
    ax = plt.subplot(111)
    data = get_freq(overlapping_groups(groups, bucket_size))
    data["Filtered"] = (data["Frequency"] < lower) | (data["Frequency"] > upper)
    sns.histplot(
        data=data,
        x="Frequency",
        hue="Filtered",
        stat="percent",
        bins=bins,
        multiple="stack",
        log_scale=(False, True),
        ax=ax,
    )
    sns.move_legend(
        ax,
        "center left",
        bbox_to_anchor=(1, 0.5),
        ncol=1,
        frameon=False,
    )
    ax.annotate(
        f"Filtered: {data['Filtered'].sum() / len(data):.2%} IDs",
        xy=(1, 1),
        xycoords="axes fraction",
        xytext=(0, 0),
        textcoords="offset points",
        ha="right",
        va="top",
    )
    ax.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    ax.yaxis.set_major_formatter(tck.PercentFormatter(xmax=100, decimals=4))
    sns.despine(ax=ax)
    plt.show()

In [None]:
tss = (
    pd.DataFrame(
        (
            (ts, item)
            for ts in sorted(
                ts
                for ts in groups.keys()
                if datetime.date(2018, 9, 1)
                < datetime.datetime.strptime(ts, "%Y-%m-%d").date()
                < datetime.date(2019, 9, 1)
            )
            for item in groups[ts]
        ),
        columns=["Date", "Page ID"],
    )
    .set_index("Date")
    .sort_index()
)
tss = tss.value_counts().reset_index().rename(columns={0: "Changes"})
tss["Page ID"] = tss["Page ID"].apply(
    lambda pid: f'<a href="https://en.wikipedia.org/?curid={pid}">{pid}</a>'
)

HTML(
    tss[tss["Changes"].between(80, 200)]
    .sample(15)
    .sort_values(["Changes", "Page ID"], ascending=[False, True])
    .to_html(escape=False)
)

# Apriori Association Rule Mining

In [None]:
%%time
n_days = 5
min_support = 0.18
min_confidence = 0.75
max_length = 2

data = tuple(overlapping_groups(groups, n_days).values())

val_size = math.ceil(len(data) * 0.2)
test_size = math.ceil(len(data) * 0.2)
train_data = data[: len(data) - (val_size + test_size)]
val_data = data[len(data) - (val_size + test_size) : len(data) - test_size]
test_data = data[len(data) - test_size :]
del data, val_size, test_size

itemsets, rules = apriori(
    train_data,
    min_support=min_support,
    min_confidence=min_confidence,
    max_length=max_length,
    verbosity=1,
)
del n_days

df = (
    pd.DataFrame(
        [
            (
                frozenset(rule.rhs),
                frozenset(rule.lhs),
                rule.confidence,
                rule.support,
                rule.lift,
                rule.conviction,
            )
            for rule in rules
        ],
        columns=["RHS", "LHS", "Confidence", "Support", "Lift", "Conviction"],
    )
    .set_index(["RHS", "LHS"])
    .sort_index()
)
display(df.describe().T.style.format("{:.2f}"))
display(df.sort_values("Lift", ascending=False))

In [None]:
occurences = collections.defaultdict(
    lambda: {
        "train_occurences": 0,
        "train_total": 0,
        "val_occurences": 0,
        "val_total": 0,
        "test_occurences": 0,
        "test_total": 0,
    }
)  # : Dict[ID, ((int, int), (int, int), (int, int))]
seen = set()
for i, group in tqdm.tqdm(
    enumerate(itertools.chain(train_data, val_data, test_data)),
    total=len(train_data) + len(val_data) + len(test_data),
):
    if i < len(train_data):
        s, j = "train", len(train_data)
    elif i < len(train_data) + len(val_data):
        s, j = "val", len(train_data) + len(val_data)
        if i == len(train_data):
            seen = set()
    elif i < len(train_data) + len(val_data) + len(test_data):
        s, j = "test", len(train_data) + len(val_data) + len(test_data)
        if i == len(train_data) + len(val_data):
            seen = set()
    for id in group - seen:
        occurences[id][f"{s}_total"] = j - i
    seen |= group
    for id in group:
        occurences[id][f"{s}_occurences"] += 1
del seen
rcdf = pd.DataFrame.from_dict(occurences, orient="index")
del occurences
rcdf["train_precision"] = (
    rcdf[rcdf["train_total"] > 0]["train_occurences"]
    / rcdf[rcdf["train_total"] > 0]["train_total"]
)

rcdf["val_precision"] = (
    rcdf[rcdf[["train_total", "val_total"]].sum(axis=1) > 0]["val_occurences"]
    / rcdf[rcdf[["train_total", "val_total"]].sum(axis=1) > 0]["val_total"]
)

rcdf["test_precision"] = (
    rcdf[rcdf[["train_total", "val_total", "test_total"]].sum(axis=1) > 0][
        "test_occurences"
    ]
    / rcdf[rcdf[["train_total", "val_total", "test_total"]].sum(axis=1) > 0][
        "test_total"
    ]
)

rcdf["precision"] = rcdf[["train_occurences", "val_occurences", "test_occurences"]].sum(
    axis=1
) / np.where(
    rcdf["train_total"] > 0,
    rcdf["train_total"] + len(val_data) + len(test_data),
    np.where(
        rcdf["val_total"] > 0, rcdf["val_total"] + len(test_data), rcdf["test_total"]
    ),
)
display(rcdf)
display(
    rcdf[["train_precision", "val_precision", "test_precision", "precision"]]
    .describe()
    .T.drop(columns=["count"])
    .style.format("{:.2%}")
)
for x in ("train_precision", "val_precision", "test_precision", "precision"):
    print(f"{x:>15s}: ", end="")
    x = rcdf[x]
    print(
        f"μ = {x.mean():.2%}, σ = {x.std():.2%}, median = {x.median():.2%}, mad = {(x.median() - x).abs().median():.2%}"
    )

In [None]:
for i in tqdm.tqdm(df.itertuples(), total=len(df)):
    d = {(False, False): 0, (False, True): 0, (True, False): 0, (True, True): 0}
    for s in train_data:
        d[(i.Index[0] <= s, i.Index[1] <= s)] += 1
    df.loc[i.Index, "TN (train)"] = d[(False, False)]
    df.loc[i.Index, "FP (train)"] = d[(False, True)]
    df.loc[i.Index, "FN (train)"] = d[(True, False)]
    df.loc[i.Index, "TP (train)"] = d[(True, True)]

df[["TP (train)", "FP (train)", "TN (train)", "FN (train)"]] = df[
    ["TP (train)", "FP (train)", "TN (train)", "FN (train)"]
].astype(int)
df["Recall (train)"] = df["TP (train)"] / df[["TP (train)", "FN (train)"]].sum(axis=1)
df["F1 (train)"] = (
    2
    * df[["Confidence", "Recall (train)"]].product(axis=1)
    / df[["Confidence", "Recall (train)"]].sum(axis=1)
)
df["Precision Random (train)"] = df[["TP (train)", "FN (train)"]].sum(axis=1) / df[
    ["TP (train)", "FP (train)", "TN (train)", "FN (train)"]
].sum(axis=1)

# Validation Set

In [None]:
for i in tqdm.tqdm(df.itertuples(), total=len(df)):
    d = {(False, False): 0, (False, True): 0, (True, False): 0, (True, True): 0}
    for s in val_data:
        d[(i.Index[0] <= s, i.Index[1] <= s)] += 1
    df.loc[i.Index, "TN (val)"] = d[(False, False)]
    df.loc[i.Index, "FP (val)"] = d[(False, True)]
    df.loc[i.Index, "FN (val)"] = d[(True, False)]
    df.loc[i.Index, "TP (val)"] = d[(True, True)]
df[["TP (val)", "FP (val)", "TN (val)", "FN (val)"]] = df[
    ["TP (val)", "FP (val)", "TN (val)", "FN (val)"]
].astype(int)
df["Precision (val)"] = df["TP (val)"] / (df["TP (val)"] + df["FP (val)"])
df["Recall (val)"] = df["TP (val)"] / (df["TP (val)"] + df["FN (val)"])
df["F1 (val)"] = (
    2
    * (df["Precision (val)"] * df["Recall (val)"])
    / (df["Precision (val)"] + df["Recall (val)"])
)
df["Accuracy (val)"] = (df["TP (val)"] + df["TN (val)"]) / df[
    ["TP (val)", "FP (val)", "TN (val)", "FN (val)"]
].sum(axis=1)
df["FPR (val)"] = df["FP (val)"] / df[["FP (val)", "TN (val)"]].sum(axis=1)
df["Precision Random (val)"] = (df["TP (val)"] + df["FN (val)"]) / df[
    ["TP (val)", "FP (val)", "TN (val)", "FN (val)"]
].sum(axis=1)

df.sort_values(
    ["Precision (val)", "F1 (val)", "Recall (val)", "Accuracy (val)"], ascending=False
)

# Test Set

In [None]:
for i in tqdm.tqdm(df.itertuples(), total=len(df)):
    d = {(False, False): 0, (False, True): 0, (True, False): 0, (True, True): 0}
    for s in test_data:
        d[(i.Index[0] <= s, i.Index[1] <= s)] += 1
    df.loc[i.Index, "TN"] = d[(False, False)]
    df.loc[i.Index, "FP"] = d[(False, True)]
    df.loc[i.Index, "FN"] = d[(True, False)]
    df.loc[i.Index, "TP"] = d[(True, True)]
df[["TP", "FP", "TN", "FN"]] = df[["TP", "FP", "TN", "FN"]].astype(int)
df["Precision"] = df["TP"] / (df["TP"] + df["FP"])
df["Recall"] = df["TP"] / (df["TP"] + df["FN"])
df["F1"] = 2 * (df["Precision"] * df["Recall"]) / (df["Precision"] + df["Recall"])
df["Accuracy"] = (df["TP"] + df["TN"]) / df[["TP", "FP", "TN", "FN"]].sum(axis=1)
df["FPR"] = df["FP"] / df[["FP", "TN"]].sum(axis=1)
df["Precision Random"] = (df["TP"] + df["FN"]) / df[["TP", "FP", "TN", "FN"]].sum(
    axis=1
)

df.sort_values(["Precision", "F1", "Recall", "Accuracy"], ascending=False)

In [None]:
def plot_axis(x, y, xrange=(0, 1), yrange=(0, 1), oneone=True, linreg=True):
    fig = plt.figure(figsize=(12.5 * cm, 12.5 * cm), dpi=100)
    ax = plt.subplot(111)
    ax.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    ax.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    ax.set_xlim(*xrange)
    ax.set_ylim(*yrange)
    sns.despine(ax=ax)
    sns.histplot(x=x, y=y, data=df, ax=ax)
    if oneone:
        plt.plot(xrange, xrange, linestyle=":", color="grey")
    if linreg:
        sns.regplot(x=x, y=y, data=df, scatter=False, ci=None, truncate=False, ax=ax)
        regdata = df[[x, y]].dropna()
        regdata = linregress(regdata[x], regdata[y])
        a, b, r2 = regdata.intercept, regdata.slope, regdata.rvalue ** 2
        ax.add_artist(
            AnchoredText(
                f"$f(x)={a:.2f}+{b:.2f}x$\n$R^2 = {r2:.2f}$",
                frameon=False,
                loc="lower right",
            )
        )
    plt.show()


plot_axis(x="Precision Random", y="Precision")
plot_axis(x="Precision Random (val)", y="Precision (val)")
plot_axis(x="Precision Random (val)", y="Precision Random")

plot_axis(
    x="Confidence",
    y="Precision (val)",
    xrange=(df["Confidence"].min(), 1),
    oneone=False,
    linreg=False,
)
plot_axis(
    x="Confidence",
    y="Precision",
    xrange=(df["Confidence"].min(), 1),
    oneone=False,
    linreg=False,
)
plot_axis(x="Precision (val)", y="Precision")
del plot_axis

In [None]:
fig = plt.figure(figsize=(12.5 * cm, 12.5 * cm), dpi=144)
ax = plt.subplot(111)
data_ = pd.DataFrame(
    data=(
        (
            p,
            len(df[df["Precision (val)"] > p]["Precision"]) / len(df),
            df[df["Precision (val)"] > p]["Precision"].mean(),
            df[df["Precision (val)"] > p]["Precision"].median(),
            df[df["Precision (val)"] > p]["Precision"].quantile(0.05),
        )
        for p in np.linspace(0, 1, 101)
    ),
    columns=[
        "Threshold Precision (val)",
        "% Rules left",
        "Mean Precision",
        "Median Precision",
        "5th Percentile Precision",
    ],
)
sns.lineplot(
    x="Threshold Precision (val)",
    y="Value",
    hue="Type",
    data=data_.melt(
        id_vars=["Threshold Precision (val)"], var_name="Type", value_name="Value"
    ),
    ax=ax,
)
ax.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
ax.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.plot([0.95, 0.95], [0, 1], linestyle=":", color="grey")
sns.despine(ax=ax)
sns.move_legend(
    ax,
    "center left",
    bbox_to_anchor=(1, 0.5),
    ncol=1,
    frameon=False,
)
plt.show()
display(data_.tail(25))

# Activity Change between train and val

In [None]:
df["Activity Change"] = (
    df["Precision Random (val)"] / df["Precision Random (train)"] - 1
)
fig = plt.figure(figsize=(12.5 * cm, 12.5 * cm), dpi=100)
ax = plt.subplot(111)
df["Activity Change"].hist(bins=15, ax=ax)
ax.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
sns.despine(ax=ax)

# Sample: Precision Advantage over Random in train data

In [None]:
tdf = pd.DataFrame(
    (
        (
            tuple(groups.keys())[len(train_data) + j],
            i.Index[0],
            i.Index[1],
            i.Index[0] <= s,
            i.Index[1] <= s,
        )
        for i in df[(df["Precision (val)"] - df["Precision Random (val)"]) > 0]
        .sample(5)
        .itertuples()
        for j, s in enumerate(val_data)
    ),
    columns=["Date", "RHS", "LHS", "RHS?", "LHS?"],
)
tdf["Date"] = pd.to_datetime(tdf["Date"])
tdf = (
    tdf.set_index("Date")
    .groupby(["RHS", "LHS", pd.Grouper(freq="D")])
    .sum()
    .astype(bool)
    .sort_index()
)
tdf["TP"] = tdf["LHS?"] & tdf["RHS?"]
tdf["TN"] = ~tdf["LHS?"] & ~tdf["RHS?"]
tdf["FP"] = tdf["LHS?"] & ~tdf["RHS?"]
tdf["FN"] = ~tdf["LHS?"] & tdf["RHS?"]

tdf = (
    tdf.drop(columns=["LHS?", "RHS?"])
    .reset_index(level=[0, 1])
    .groupby(["RHS", "LHS"])
    .expanding()
    .sum()
    .astype(int)
)

tdf["Precision"] = tdf["TP"] / tdf[["TP", "FP"]].sum(axis=1)
tdf["Precision Random"] = tdf[["TP", "FN"]].sum(axis=1) / tdf[
    ["TP", "FP", "TN", "FN"]
].sum(axis=1)

tdf["Precision - Precision Random"] = tdf["Precision"] - tdf["Precision Random"]

tdf.reset_index(inplace=True)
tdf["Rule"] = (
    tdf["LHS"].apply(lambda l: ", ".join(str(it) for it in sorted(l)))
    + " -> "
    + tdf["RHS"].apply(lambda l: ", ".join(str(it) for it in sorted(l)))
)

fig = plt.figure(figsize=(40 * cm, 30 * cm), dpi=100)
ax = plt.subplot(111)
sns.lineplot(x="Date", y="Precision - Precision Random", hue="Rule", data=tdf, ax=ax)
ax.set_ylim(-1, 1)
ax.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
sns.move_legend(
    ax,
    "center left",
    bbox_to_anchor=(1, 0.5),
    ncol=1,
    frameon=False,
)
sns.despine(ax=ax)

## Final Rules

In [None]:
t_val = 0.95
final_rules = df[df["Precision (val)"] > t_val]
display(final_rules[["Precision", "Recall", "F1", "Accuracy"]])
display(final_rules[["Precision", "Recall", "F1", "Accuracy"]].describe().T)

final_rules[["Precision", "Recall", "F1", "Accuracy"]].hist(figsize=(10, 5))
plt.suptitle(
    "Association Rules results histogram on test set with Val Precision > 0.95"
)
plt.tight_layout()

In [None]:
exported_rules = (
    final_rules.reset_index().groupby("RHS")["LHS"].apply(list).reset_index()
)
exported_rules.to_pickle("rules_dict.pickle")

In [None]:
exported_rules_eval = []
for date, group in tqdm.tqdm(groups.items(), total=len(groups)):
    for rule in exported_rules.itertuples():
        exported_rules_eval.append(
            (date, rule.RHS, tuple(lhs <= set(group) for lhs in rule.LHS))
        )
pd.DataFrame(exported_rules_eval, columns=["Date", "RHS", "LHSs"]).set_index(
    ["RHS", "Date"]
).sort_index().to_csv("rules_active.csv")

# Appendix

In [None]:
@ipywidgets.interact(
    x=ipywidgets.ToggleButtons(options=["Confidence", "Support", "Lift"]),
    y=ipywidgets.ToggleButtons(options=["Precision", "Recall", "F1", "Accuracy"]),
    support=ipywidgets.FloatRangeSlider(
        value=(0.0, 1.0),
        min=0.0,
        max=1.0,
        step=0.001,
        readout_format=".1%",
        continuous_update=False,
    ),
    confidence=ipywidgets.FloatRangeSlider(
        value=(0.0, 1.0),
        min=0.0,
        max=1.0,
        step=0.001,
        readout_format=".1%",
        continuous_update=False,
    ),
    q=ipywidgets.IntSlider(value=4, min=1, max=10, step=1, continuous_update=False),
)
def plot_correlation(x, y, support, confidence, q):
    temp = df[
        df["Support"].between(*support) & df["Confidence"].between(*confidence)
    ].copy()
    g = sns.jointplot(
        x=x,
        y=y,
        data=temp,
        kind="hist",
        height=20 * cm,
    )
    g.ax_joint.set_xlim(temp[x].min(), temp[x].max())
    if x in {"Confidence", "Support"}:
        g.ax_joint.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    g.ax_joint.set_ylim(0, 1)
    g.ax_joint.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    plt.show()
    temp[f"{y} Interval"] = pd.cut(
        temp[y],
        q,
        labels=[f"≤ {x[1]:.0%}" for x in pd.interval_range(0, 1, q).to_tuples()],
    )
    g = sns.jointplot(
        x=x,
        y=y,
        hue=f"{y} Interval",
        palette=sns.color_palette("mako", n_colors=q),
        data=temp,
        kind="scatter",
        height=20 * cm,
        marginal_kws={"common_norm": False},
    )
    g.ax_joint.set_xlim(temp[x].min(), temp[x].max())
    if x in {"Confidence", "Support"}:
        g.ax_joint.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    g.ax_joint.set_ylim(0, 1)
    g.ax_joint.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    plt.show()

In [None]:
@ipywidgets.interact(
    x=ipywidgets.ToggleButtons(options=["Confidence", "Support", "Lift"]),
    y=ipywidgets.ToggleButtons(options=["Confidence", "Support", "Lift"]),
)
def plot_correlation(x, y):
    temp = df.copy()
    g = sns.jointplot(
        x=x,
        y=y,
        data=temp,
        kind="hist",
        height=20 * cm,
    )
    if x in {"Confidence", "Support"}:
        g.ax_joint.xaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    if y in {"Confidence", "Support"}:
        g.ax_joint.yaxis.set_major_formatter(tck.PercentFormatter(xmax=1))
    plt.show()

# Evaluation per Predictor

In [None]:
# Not used, keep for future if we don't care weighing rules
"""
df2 = (
    df.reset_index()[["LHS", "RHS"]]
    .apply(lambda x: pd.Series([x["LHS"]] + [(rhs,) for rhs in x["RHS"]]), axis=1)
    .melt(id_vars=[0])
    .dropna()[[0, "value"]]
    .rename(columns={0: "LHS", "value": "RHS"})
)
df2["RHS"] = df2["RHS"].apply(lambda t: t[0])
df2 = (
    df2.groupby("RHS")["LHS"]
    .apply(lambda x: frozenset(x))
    .reset_index()
    .set_index("RHS")
)
df2
"""

df2 = (
    df[df["Precision (val)"] > t_val]
    .reset_index()[["LHS", "RHS", "Lift"]]
    .apply(
        lambda x: pd.Series([x["LHS"], x["Lift"]] + [(rhs,) for rhs in x["RHS"]]),
        axis=1,
    )
    .melt(id_vars=[0, 1])
    .dropna()
    .drop(columns=["variable"])
    .rename(columns={0: "LHS", 1: "Lift", "value": "RHS"})
)
df2["RHS"] = df2["RHS"].apply(lambda t: t[0])
df2 = (
    df2.groupby("RHS")[["LHS", "Lift"]]
    .apply(lambda x, axis: dict(zip(x["LHS"], x["Lift"])), axis=1)
    .reset_index()
    .set_index("RHS")
    .rename(columns={0: "LHS"})
)

for i in tqdm.tqdm(df2.itertuples(index=True), total=len(df2)):
    d = {(False, False): 0, (False, True): 0, (True, False): 0, (True, True): 0}
    for s in test_data:
        t = [(lhs <= s, lift) for lhs, lift in i.LHS.items()]
        total = sum(lift for _, lift in t)
        signal = sum(lift for sig, lift in t if sig)
        d[(i.Index in s, (signal / total) > 0.8)] += 1
    df2.loc[i.Index, "TN"] = d[(False, False)]
    df2.loc[i.Index, "FP"] = d[(False, True)]
    df2.loc[i.Index, "FN"] = d[(True, False)]
    df2.loc[i.Index, "TP"] = d[(True, True)]
df2[["TP", "FP", "TN", "FN"]] = df2[["TP", "FP", "TN", "FN"]].astype(int)
df2["Precision"] = df2["TP"] / (df2["TP"] + df2["FP"])
df2["Recall"] = df2["TP"] / (df2["TP"] + df2["FN"])
df2["F1"] = 2 * (df2["Precision"] * df2["Recall"]) / (df2["Precision"] + df2["Recall"])
df2["Accuracy"] = (df2["TP"] + df2["TN"]) / df2[["TP", "FP", "TN", "FN"]].sum(axis=1)

df2["Precision Random"] = df2[["TP", "FN"]].sum(axis=1) / df2[
    ["TP", "FP", "TN", "FN"]
].sum(axis=1)

display(df2.sort_values(["Precision", "F1", "Recall", "Accuracy"], ascending=False))

prec = df2["TP"].sum() / df2[["TP", "FP"]].sum().sum()
rec = df2["TP"].sum() / df2[["TP", "FN"]].sum().sum()
acc = df2[["TP", "TN"]].sum().sum() / df2[["TP", "FP", "TN", "FN"]].sum().sum()
f1 = 2 * prec * rec / (prec + rec)

df2[["Precision", "Recall", "F1", "Accuracy"]].describe().T