# Evaluation

In [None]:
import ast
import json
import pickle
import re
import sys
from collections import Counter, defaultdict
from functools import reduce
from operator import itemgetter
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import toolz
from rich import print
from tqdm.auto import tqdm

tqdm.pandas()
sns.set()

sys.path.insert(0, "../src")

from dataset import (
    TagAssociations,
    TagAugmenter,
    get_most_prevalent_tag,
    get_tag_ranking,
    normalise_wrt,
)
from evaluation import RelevanceAtK, RelevanceMethods

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

## Load data

In [None]:
dataset = pd.read_pickle(
    "../data/processed/ember_with_avclass_dataset.pkl"
)  # .set_index('sha256')

In [None]:
sim_unl_vs_train = pd.read_pickle(
    "../data/processed/xgb-sim-results/unlabelled_vs_train.pkl"
).set_index("needle_sha256")
sim_test_vs_traintest = pd.read_pickle(
    "../data/processed/xgb-sim-results/test_vs_train_test.pkl"
).set_index("needle_sha256")

In [None]:
unl_clean = set(
    dataset.query("label == -1 & avclass_prev.isna() & avclass_curr.isna()")["sha256"]
)
unl_dirty = set(
    dataset.query("label == -1 & (avclass_prev.notna() | avclass_curr.notna())")[
        "sha256"
    ]
)
assert 200_000 == len(unl_clean) + len(unl_dirty)

test_clean = set(dataset.query('label == 0 & subset == "test"')["sha256"])
test_dirty = set(dataset.query('label == 1 & subset == "test"')["sha256"])
assert 200_000 == len(test_clean) + len(test_dirty)

## Eval

### Label homogeneity

In [None]:
sim = sim_test_vs_traintest
clean, dirty = test_clean, test_dirty

clean_subset = sim.loc[list(clean)]
dirty_subset = sim.loc[list(dirty)]
ds = dataset.set_index("sha256")


def agg(hs):
    ys = {}
    for K in [1, 10, 50, 100]:
        ys[f"top-{K}"] = ds.loc[hs[:K], "label"].value_counts().to_dict()
    return ys


out = sim["hits_sha256"].progress_apply(agg).apply(pd.Series)
out["true"] = out.index.map(lambda h: int(h in dirty))

In [None]:
lh_file = "../data/processed/eval-results/label_homogeneity_test_vs_traintest.pkl"
# out.to_pickle(lh_file)
out = pd.read_pickle(lh_file)

#### Histogram

In [None]:
_, axs = plt.subplots(figsize=(14, 4), nrows=1, ncols=3)

meta = {
    0: {"color": "skyblue", "label": "benign", "alpha": 0.5},
    1: {"color": "salmon", "label": "malicious", "alpha": 1},
}
bins = np.linspace(0, 1, endpoint=True, num=21)

for i, k in enumerate([10, 50, 100]):
    ax = axs[i]
    for l in [1, 0]:
        xs = (
            out.query(f"true == {l}")[f"top-{k}"]
            .apply(lambda d: d.get(l, 0) / k)
            .to_list()
        )
        sns.histplot(xs, stat="percent", bins=bins, ax=ax, **meta[l])
        ax.set_xticks(bins[::2])
        ax.set_yscale("log")
        iy = [0.1, 1, 5, 10, 50, 90]
        ax.set_yticks(iy, labels=[f"{i}%" for i in iy])
        ax.set_title(f"Top {k} hits")
        ax.legend()
        # ax.set_xlabel(f"Fraction of hits matching labels with query")

plt.suptitle(f"Fraction of hits matching labels with query")
plt.tight_layout()
pass

#### ECDF

In [None]:
_, axs = plt.subplots(figsize=(12, 4), nrows=1, ncols=3)

meta = {
    0: {"c": "skyblue", "label": "benign"},
    1: {"c": "salmon", "label": "malicious"},
}

for i, k in enumerate([10, 50, 100]):
    ax = axs[i]
    for l in [0, 1]:
        xs = sorted(
            out.query(f"true == {l}")[f"top-{k}"]
            .apply(lambda d: d.get(l, 0) / k)
            .to_list()
        )
        ix = np.linspace(0, 1, endpoint=True, num=len(xs))
        ax.plot(xs, ix, **meta[l])
        ax.set_yscale("log")
        tk = np.linspace(0, 1, num=11)
        ax.set_xticks(tk)
        ax.set_xticklabels([f"{i:.1f}" for i in tk])
        ax.set_xlabel(f"Fraction of hits")
        ax.set_ylabel(f"Fraction of all samples")
        ax.set_title(f"Top {k} hits")
        # ax.grid(True, which='both')
        ax.legend()

plt.suptitle(f"Fraction of hits matching labels with query")
plt.tight_layout()
pass

### Relevance@K

In [None]:
def relevance_at_k(res):
    return pd.Series(
        {f"top-{K}": np.mean([r[2] for r in res[:K]]) for K in [1, 10, 50, 100]}
    )

#### Table summary

In [None]:
def summarize(df, qs):
    def summarize_single(xs, qs):
        qs = np.array(qs) * 100
        ps = np.round(np.percentile(xs, qs), 3)
        return {
            "mean": np.mean(xs),
            "std": np.std(xs),
            "percentiles": dict(zip(qs, ps)),
            "skew": stats.skew(xs),
            "kurtosis": stats.kurtosis(xs),
        }

    # ---

    out = {}

    for k in [1, 10, 50, 100]:
        t = f"top-{k}"
        out[t] = summarize_single(df[t].to_numpy(), qs)

    return pd.DataFrame(out).transpose()

In [None]:
%%time

datasets = {
    "unl_vs_train": {
        "clean": unl_clean,
        "dirty": unl_dirty,
    },
    "test_vs_traintest": {
        "clean": test_clean,
        "dirty": test_dirty,
    },
}

qs = np.array([0.01, *np.linspace(0, 1, endpoint=True, num=21)[1:]])  # percentiles
out = []

for kind in ["class", "fam"]:
    for dataset, labels in datasets.items():
        for file in sorted(
            Path(f"../data/processed/eval-results-fix/{kind}-ranking").glob(
                f"{dataset}*.pkl"
            )
        ):
            # read file
            df = pd.read_pickle(file)
            df = pd.concat([df, df["results"].progress_apply(relevance_at_k)], axis=1)

            # extract info from file name
            m = re.match(
                f"{dataset}_top_100_rank_(?P<rank_by>\w+)_occur_(?P<thr>0\.\d+)_rel_(?P<rel>\w+).pkl",
                file.name,
            )
            assert m is not None
            attr = m.groupdict()
            assert attr["rank_by"].lower() == kind

            # summarise results for both clean/dirty
            for l, subset in enumerate(
                (labels["clean"], labels["dirty"], labels["clean"] | labels["dirty"])
            ):
                tmp_df = df.query("sha256.isin(@subset)")
                out.append(
                    summarize(tmp_df, qs)
                    .reset_index(names="Top-K")
                    .assign(
                        **{
                            "dataset": dataset,
                            "co_occur_thr": float(attr["thr"]),
                            "rank": attr["rank_by"],
                            "relevance_func": attr["rel"],
                            "label": "both" if l == 2 else l,
                            "subset_size": len(tmp_df),
                        }
                    )
                )

out = pd.concat(out, ignore_index=True)

In [None]:
# out.to_csv("../data/processed/eval-results-fix/paper_results_all_split.csv", index=False)

In [None]:
out = pd.read_csv("../data/processed/eval-results-fix/paper_results_all_split.csv")

ps = [1, 10, 50, 95]
out["percentiles"] = (
    out["percentiles"].apply(ast.literal_eval).apply(lambda d: [d[p] for p in ps])
)
out

In [None]:
DATASET = "unl_vs_train"
print("%", DATASET)

_base = ['(rank == "FAM")', "(co_occur_thr == 0.9)", '(top_k != "top-50")']
_extra = [f'(dataset == "{DATASET}")', ('label != "both"')]
cols = ["relevance_func", "top_k", "mean", "std", "percentiles", "label"]

cur = out.query("&".join(_base + _extra))[cols].reset_index(drop=True)
cur["mean (std)"] = cur.apply(
    lambda row: f'{row["mean"]:.3f} ({row["std"]:.3f})', axis=1
)
cur["skew"] = cur["skew"].round(3)
cur = (
    cur.drop(["mean", "std"], axis=1)
    # .set_index(['relevance_func', 'top_k'])
)

c0 = cur.query('label == "0"').drop("label", axis=1)[["mean (std)", "percentiles"]]
c1 = cur.query('label == "1"').drop("label", axis=1)[["mean (std)", "percentiles"]]
cur = pd.concat([c0, c1], axis=1)

# cur.query('label == "1"')
# print(cur.to_latex().replace('top-', '').replace('[', '').replace(']', ''))

#### Plots

In [None]:
def plot_results(xs: list[float], ax1):
    xs = sorted(xs)
    ys = np.linspace(0, 1, num=len(xs))

    ax1.plot(xs, ys, lw=2, color="skyblue")
    ax1.set_ylabel("ECDF")

    ax2 = ax1.twinx()
    sns.histplot(xs, stat="percent", bins=32, alpha=0.9, ax=ax2, color="salmon")
    ax2.set_ylabel("Percent")


def ecdf(data):
    xs = sorted(data)
    ix = np.linspace(0, 1, num=len(data))
    return {"x": xs[::100], "y": ix[::100]}


def show_exp(df, clean_hashes, dirty_hashes, K, ax):
    df["relevance"] = df["results"].apply(lambda res: np.mean([r[2] for r in res[:K]]))
    clean = df.query("sha256.isin(@clean_hashes)")["relevance"]
    dirty = df.query("sha256.isin(@dirty_hashes)")["relevance"]
    print(f"{len(clean)=:,d} | {len(dirty)=:,d}")

    bins = np.linspace(0, 1, 21)
    stat = "percent"

    fig, ax = plt.subplots(figsize=(12, 6), nrows=2, ncols=2, sharex=False)
    ax[1, 1].remove()
    ax[1, 0].invert_yaxis()

    ax[0, 0].set_xlabel(f"Precision@{K}")
    ax[1, 0].set_xlabel(f"Precision@{K}")
    sns.histplot(
        clean, ax=ax[0, 0], bins=bins, stat=stat, color="skyblue", label="~clean"
    )
    sns.histplot(
        dirty, ax=ax[1, 0], bins=bins, stat=stat, color="salmon", label="~dirty"
    )

    ax[0, 0].legend()
    ax[1, 0].legend()
    plt.subplots_adjust(hspace=0)
    ax[1, 0].set_xticks(bins[::2])
    ax[0, 0].set_title("Histogram of precision")

    sns.lineplot(**ecdf(clean), ax=ax[0, 1], lw=2, color="skyblue", label="~clean")
    sns.lineplot(**ecdf(dirty), ax=ax[0, 1], lw=2, color="salmon", label="~dirty")
    ax[0, 1].set_xticks(bins[::2])
    ax[0, 1].set_xlabel(f"Precision@{K}")
    ax[0, 1].set_ylabel("Fraction of samples")
    ax[0, 1].set_title("Empirical CDF")