In [1]:
import string

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import numpy as np
import os

# Utils

In [2]:
def find_files(path, ext):
    files = []
    for file in os.listdir(path):
        if file.endswith(ext):
            files.append(file)
    return files

In [3]:
from typing import Iterable, Union

def get_most_impacting(report: dict, n=2):
    """
    Get the most impacting tokens from experiments.

    Args:
        report (dict): The input dictionary containing experiment data >> dict_keys(['idx', 'tokens', 'importance', 'logits', 'predicted_label', 'true_label'])
        n (int, optional): The number of tokens to consider. Defaults to 2.

    Returns:
        dict: A dictionary containing analysis results.
    """
    _report = {
        "predict": dict(),
        "true": dict(),
        "at_all": dict(),
    }
    for pred, true, importance, tokens in zip(
        report["predicted_label"],
        report["true_label"],
        report["importance"],
        report["tokens"]
    ):

        _iter_pred = zip(np.array(tokens), importance[pred])
        data_pred = sorted([(token, importance) for token, importance in _iter_pred],
                           key= lambda x: x[1],
                           reverse=True)
        data_pred = data_pred[:n]

        if pred==true:
            data_true = data_pred.copy()
            data_all = data_pred.copy()
        else:
            _iter_true = zip(np.array(tokens), importance[true])
            data_true = sorted([(token, importance) for token, importance in _iter_true],
                               key= lambda x: x[1],
                               reverse=True)
            data_true = data_true[:n]

            data_all = sorted(data_true + data_pred,
                              key= lambda x: x[1],
                              reverse=True)
            data_all = data_all[:n]

        for p, t, a in zip(data_pred, data_true, data_all):
            if p[1] > 0:
                _report["predict"][p[0]] = _report["predict"].get(p[0], 0) + p[1]
            if t[1] > 0:
                _report["true"][t[0]] = _report["true"].get(t[0], 0) + t[1]
            if a[1] > 0:
                _report["at_all"][a[0]] = _report["at_all"].get(a[0], 0) + a[1]
            
    return {
        "predict": sorted(_report["predict"].items(), key=lambda x: x[1], reverse=True)[:n],
        "true": sorted(_report["true"].items(), key=lambda x: x[1], reverse=True)[:n],
        "at_all": sorted(_report["at_all"].items(), key=lambda x: x[1], reverse=True)[:n],
    }


def get_importance(report: dict, d_tokens: Iterable):
    _report = {}
    for dt in d_tokens:
        _report[dt] = {
            "tp": list(),
            "tn": list(),
            "fp": list(),
            "fn": list(),
            "pl": list(),
            "nl": list(),
        }
    for pred, true, importance, tokens in zip(
        report["predicted_label"],
        report["true_label"],
        report["importance"],
        report["tokens"]
    ):
        for token, importance in zip(tokens, importance):
            token = token.replace('##', '').replace('Ġ', '')
            for dt in d_tokens:
                if dt == token:
                    _report[dt]["pl"].append(importance[1])
                    _report[dt]["nl"].append(importance[0])

            if pred==true:
                if pred > 0:
                    for dt in d_tokens:
                        if dt == token:
                            _report[dt]["tp"].append(importance[1])
                else:
                    for dt in d_tokens:
                        if dt == token:
                            _report[dt]["tn"].append(importance[0])
            else:
                if pred > 0:
                    for dt in d_tokens:
                        if dt == token:
                            _report[dt]["fp"].append(importance[1])
                else:
                    for dt in d_tokens:
                        if dt == token:
                            _report[dt]["fn"].append(importance[0])

    return _report

In [4]:
def get_importance_all(report: dict, metric_function=np.mean, split=True):
    _report = {}

    for pred, true, importance, tokens in zip(
        report["predicted_label"],
        report["true_label"],
        report["importance"],
        report["tokens"]
    ):
        for token, importance in zip(tokens, importance):
            if token not in _report:
                _report[token] = {
                    "pl": list(),
                    "nl": list(),
                }

            _report[token]["pl"].append(importance[1])
            _report[token]["nl"].append(importance[0])

    # Calculate the metric for each token and store them in a list of tuples
    pl_metrics, nl_metrics = [], []
    for token, values in _report.items():
        pl_metrics.append((token, metric_function(values["pl"])))
        nl_metrics.append((token, metric_function(values["nl"])))

    # Sort the list of tuples by the metric value
    pl_metrics = sorted(pl_metrics, key=lambda x: x[1], reverse=True)
    nl_metrics = sorted(nl_metrics, key=lambda x: x[1], reverse=True)

    if split:
        pl_metrics = pl_metrics[:NUMBER_OF_IMPORTANT_BAR]
        nl_metrics = nl_metrics[:NUMBER_OF_IMPORTANT_BAR]

    return {
        "pl": pl_metrics,
        "nl": nl_metrics,
    }


In [5]:
def get_most_popular(report: dict, n=2):
    _report = {
        "predict": dict(),
        "true": dict(),
        "at_all": dict(),
    }
    for pred, true, importance, tokens in zip(
        report["predicted_label"],
        report["true_label"],
        report["importance"],
        report["tokens"]
    ):

        _iter_pred = zip(np.array(tokens), importance[pred])
        data_pred = sorted([(token, importance) for token, importance in _iter_pred],
                           key= lambda x: x[1],
                           reverse=True)
        data_pred = data_pred[:n]

        if pred==true:
            data_true = data_pred.copy()
            data_all = data_pred.copy()
        else:
            _iter_true = zip(np.array(tokens), importance[true])
            data_true = sorted([(token, importance) for token, importance in _iter_true],
                               key= lambda x: x[1],
                               reverse=True)
            data_true = data_true[:n]

            data_all = sorted(data_true + data_pred,
                              key= lambda x: x[1],
                              reverse=True)
            data_all = data_all[:n]

        for p, t, a in zip(data_pred, data_true, data_all):
            if p[1] > 0:
                _report["predict"][p[0]] = _report["predict"].get(p[0], 0) + 1
            if t[1] > 0:
                _report["true"][t[0]] = _report["true"].get(t[0], 0) + 1
            if a[1] > 0:
                _report["at_all"][a[0]] = _report["at_all"].get(a[0], 0) + 1

    return {
        "predict": sorted(_report["predict"].items(), key=lambda x: x[1], reverse=True)[:n],
        "true": sorted(_report["true"].items(), key=lambda x: x[1], reverse=True)[:n],
        "at_all": sorted(_report["at_all"].items(), key=lambda x: x[1], reverse=True)[:n],
    }


In [6]:
def read_report(path):
    report = np.load(path, allow_pickle=True).item()
    report["tokens"] = [[token.replace('##', '').replace('Ġ', '') for token in tokens ]for tokens in report["tokens"]]
    return report


# Analysis

In [7]:
BASE_PATH = "."
DATASET_NAME = "sst2"
EXPERIMENTS_PATH = os.path.join(BASE_PATH, DATASET_NAME)

FILES = find_files(EXPERIMENTS_PATH, ".npy")

NUMBER_OF_IMPORTANT_WC = 300
NUMBER_OF_IMPORTANT_BAR = 15

DEMO = False

In [8]:
reports = {
    file.split('.')[0]: read_report(os.path.join(EXPERIMENTS_PATH, file)) for file in FILES
}

## Word cloud

### Most important

In [9]:
save_wc_path = os.path.join(EXPERIMENTS_PATH, "most_important")
if os.path.exists(save_wc_path):
    print("Path already exists")
else:
    os.mkdir(save_wc_path)
    print("Path created")

Path already exists


In [10]:
for exp_name, report in reports.items():
    rep_imp = get_most_impacting(report, n=NUMBER_OF_IMPORTANT_WC)
    _name = f"{exp_name}_imp.png"
    path = os.path.join(save_wc_path, _name)
    fig, axs = plt.subplots(1, len(rep_imp), figsize=(4*len(rep_imp), 4))
    for i, k in enumerate(rep_imp.keys()):
        wordcloud = WordCloud(width=800, height=800, background_color="white")
        wordcloud.generate_from_frequencies(dict(rep_imp[k]))

        axs[i].imshow(wordcloud, interpolation="bilinear")
        axs[i].axis("off")

        axs[i].set_title(k)

    fig.suptitle(exp_name, fontsize=16)
    plt.tight_layout(pad=1.0)
    if DEMO:
        plt.show()
    else:
        plt.savefig(path, format="png")
    plt.close()

In [11]:
for exp_name, report in reports.items():
    rep_imp = get_most_impacting(report, n=NUMBER_OF_IMPORTANT_BAR)
    _name = f"{exp_name}_imp_bar.png"
    path = os.path.join(save_wc_path, _name)
    fig, axs = plt.subplots(1, len(rep_imp), figsize=(6*len(rep_imp), 6))

    for i, k in enumerate(rep_imp.keys()):

        tokens = [item[0] for item in rep_imp[k]]
        freqs = [item[1] for item in rep_imp[k]]

        axs[i].bar(tokens, freqs)
        axs[i].set_title(k)
        axs[i].set_xticklabels(tokens, rotation=90, ha='center')


        axs[i].set_title(k)

    fig.suptitle(exp_name, fontsize=16)
    plt.tight_layout(pad=1.0)
    if DEMO:
        plt.show()
    else:
        plt.savefig(path, format="png")
    plt.close()


  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')


### Most popular

In [12]:
save_mp_path = os.path.join(EXPERIMENTS_PATH, "most_popular")
if os.path.exists(save_mp_path):
    print("Path already exists")
else:
    os.mkdir(save_mp_path)
    print("Path created")

Path already exists


In [13]:
for exp_name, report in reports.items():
    rep_mp = get_most_popular(report, n=NUMBER_OF_IMPORTANT_WC)
    _name = f"{exp_name}_mp.png"
    path = os.path.join(save_mp_path, _name)
    fig, axs = plt.subplots(1, len(rep_mp), figsize=(4*len(rep_mp), 4))
    for i, k in enumerate(rep_mp.keys()):
        wordcloud = WordCloud(width=800, height=800, background_color="white")
        wordcloud.generate_from_frequencies(dict(rep_mp[k]))

        axs[i].imshow(wordcloud, interpolation="bilinear")
        axs[i].axis("off")

        axs[i].set_title(k)

    fig.suptitle(exp_name, fontsize=16)
    plt.tight_layout(pad=1.0)
    if DEMO:
        plt.show()
    else:
        plt.savefig(path, format="png")
    plt.close()


## for particular set of tokens

In [14]:
EXP_and_NAME = {
    # "true positive": 'tp',
    # "true negative": 'tn',
    # "false positive": 'fp',
    # "false negative": 'fn',
    # "positive predicted": 'tp+fp',
    # "negative predicted": 'tn+fn',
    "real positive": 'tp+fn',
    "real negative": 'tn+fp',
    "effect on positivity": "pl",
    "effect on negativity": "nl",
}

In [15]:
save_mp_path = os.path.join(EXPERIMENTS_PATH, "token effect")
if os.path.exists(save_mp_path):
    print("Path already exists")
else:
    os.mkdir(save_mp_path)
    print("Path created")

Path already exists


In [16]:
token_group, token_set = [
    ("punctuation", list(string.punctuation)),
    ("neutral tokens", ["a", "an", "what", "the", "you", "one", "it", "this", "movie", "i", "with", "if"]),
    ('multy type', [
        "respect", "impact", "loved"
        "a", "an", "the", "one", "it", "this",
        ".", ":", ";", "(", ")", "?",
        "killing", "stupid",
    ])
][2]

_save_mp_path = os.path.join(save_mp_path, token_group)

if not os.path.exists(_save_mp_path):
    os.mkdir(_save_mp_path)
    print("Path created")


In [17]:
for exp_name, report in reports.items():
    rep_mp = get_importance(report, d_tokens=token_set)
    _name = f"{exp_name}_{token_group}.png"
    path = os.path.join(_save_mp_path, _name)
    fig, axs = plt.subplots(len(EXP_and_NAME)//2, 2, figsize=(14, 4*len(EXP_and_NAME)))

    for idx, exp in enumerate(EXP_and_NAME.keys()):
        i,j = idx//2, idx%2
        label_list = list(str(EXP_and_NAME[exp]).split('+'))
        temp_dict = dict()
        for ts in token_set:
            for l in label_list:
                if rep_mp[ts][l]:
                    temp_dict[ts] = list(temp_dict.get(ts, [])) + list(rep_mp[ts][l])

        axs[i][j].boxplot(temp_dict.values(), labels=temp_dict.keys())
        axs[i][j].set_title(exp)

    fig.suptitle(exp_name, fontsize=16)
    plt.tight_layout(pad=1.0)
    if DEMO:
        plt.show()
    else:
        plt.savefig(path, format="png")
    plt.close()

In this section, we would get importance with 'get_importance' function and then for [pl, nl], we would calculate mean for each token and plot 'NUMBER_OF_IMPORTANT_BAR' bar with most impact on that label

In [18]:
_save_path = os.path.join(save_mp_path, "most important aggregated")
if not os.path.exists(_save_path):
    os.mkdir(_save_path)
    print("Path created")

for exp_name, report in reports.items():
    rep_mp = get_importance_all(report, metric_function=np.mean)
    _name = f"{exp_name}_aggr_bar.png"
    path = os.path.join(_save_path, _name)
    fig, axs = plt.subplots(1, len(rep_mp), figsize=(6*len(rep_mp), 6))

    for i, k in enumerate(rep_mp.keys()):
        tokens = [item[0] for item in rep_mp[k]]
        freqs = [item[1] for item in rep_mp[k]]

        axs[i].bar(tokens, freqs)
        axs[i].set_title(k)
        axs[i].set_xticklabels(tokens, rotation=90, ha='center')

        axs[i].set_title(k)

    fig.suptitle(exp_name, fontsize=16)
    plt.tight_layout(pad=1.0)
    if DEMO:
        plt.show()
    else:
        plt.savefig(path, format="png")
    plt.close()


  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')
  axs[i].set_xticklabels(tokens, rotation=90, ha='center')


## generate [token-mean-variance] dataset

In [19]:
from IPython.display import display, HTML
from pandas.plotting import table

_save_path = os.path.join(save_mp_path, "trends")
if not os.path.exists(_save_path):
    os.mkdir(_save_path)
    print("Path created")

for exp_name, report in reports.items():
    rep_mp = get_importance_all(
        report,
        metric_function=np.mean,
        split=False
    )
    rep_vp = get_importance_all(
        report,
        metric_function=np.var,
        split=False
    )
    _name = f"{exp_name}.csv"
    path = os.path.join(_save_path, _name)
    img_path = os.path.join(_save_path, exp_name + "-desc.png")

    tokens = [token for token, _ in rep_mp['pl']]
    data = [
        {"token": token, "mean0": 0, "mean1": 0, "variance0": 0, "variance1": 0} \
        for token in tokens
    ]

    for mm1, vv1, mm0, vv0 in zip(rep_mp["pl"], rep_vp["pl"], rep_mp["nl"], rep_vp["nl"]):
        idx_v = tokens.index(vv1[0])
        idx_m = tokens.index(mm1[0])
        data[idx_v]["variance1"] = vv1[1]
        data[idx_m]["mean1"] = mm1[1]
        idx_v = tokens.index(vv0[0])
        idx_m = tokens.index(mm0[0])
        data[idx_v]["variance0"] = vv0[1]
        data[idx_m]["mean0"] = mm0[1]

    df = pd.DataFrame.from_dict(data)

    fig, ax = plt.subplots(figsize=(4, 3))
    ax.axis('off')
    table(ax, df.describe(), loc='center')

    if DEMO:
        print(exp_name)
        display((df.describe()))
        # display(HTML(df.to_html(index=False)))
    else:
        # df.to_csv(path, index=False)
        plt.savefig(img_path, dpi=300, bbox_inches='tight')
        plt.close()


Now, add variance to the dataset above