# Analysis of Models predictions

In [1]:
import sys, os, json
from collections import defaultdict, Counter
from itertools import product, combinations, permutations
from importlib import reload
from IPython.display import HTML, display
import tabulate
import matplotlib.pyplot as plt
from typing import List, Any, Dict, Tuple, Callable, Iterable, Union, Set
Numeric = Union[int, float]
from graph import Graph, Node     # mtool graphs
import networkx as nx, numpy as np, pandas as pd
from networkx import NetworkXNoPath

# a hack for functional-like posfix list functions (map, filter, count, size)
import gc
postfix_map = lambda l,f: list(map(f,l))
postfix_filter = lambda l,f: list(filter(f,l))
postfix_count = lambda l,f: len(list(filter(f,l)))
gc.get_referents(list.__dict__)[0]['map'] = postfix_map
gc.get_referents(list.__dict__)[0]['filter'] = postfix_filter
gc.get_referents(list.__dict__)[0]['count'] = postfix_count
gc.get_referents(list.__dict__)[0]['size'] = lambda l:len(l)

# root of project in nlp-architext repo
libert_dir = "/data/home/ayalklei/nlp-architect/nlp_architect/models/libert"
analysis_dir = f"{libert_dir}/analysis"

# useful "meta" data - number of sentence per domain 
# (can be computed by #-lines in raw_sentences.txt files - `wc $liber_dir/analysis/raw_sentences/*.txt`)
num_sents = {"device": 3834, "restaurants": 5842, "laptops": 3846}
domains = list(num_sents.keys())
ud_enhancement_formalisms = ["eud", "eud_pp", "bart", "eud_pp_bart"]
# usefull general utils
def display_table(table):
    display(HTML(tabulate.tabulate(table, tablefmt='html')))

def display_ndict(nested_dict: Dict[str, Dict[str, Any]], 
                  with_mean=True, 
                  pprint: Callable = None,
                  precision: int = 4):
    """ Display two-level nested dict as a pretty table. """
    if not pprint:
        def pprint_f(x):
            if isinstance(x, float):
                return float(f"{x:.{precision}f}")
            else:
                return x
        pprint = pprint_f
    row_labels = list(nested_dict.keys())
    column_labels = list(list(nested_dict.values())[0].keys())
    as_tabular = [["-"] + column_labels] + \
                 [[row] + [pprint(nested_dict[row][col]) for col in column_labels]
                  for row in row_labels]
    if with_mean:
        as_tabular[0] += ["mean"]
        for i in range(1, len(as_tabular)):
            mean = np.mean(as_tabular[i][1:])
            as_tabular[i] += [pprint(mean)]
    display_table(as_tabular)

def display_absa_graph(graph: Graph, method="displacy"):
    # visualize
    print(f"Sentence: {graph.input}")
    if method == "dot":
        # visalize using dot
        dot_fn = "dot_example.dot"
        graph.dot(open(dot_fn, "w"))    # write dot file
        # see dot in jupyter
        def view_dot(fn):
            from graphviz import Source
            return Source.from_file(fn)
        return view_dot(dot_fn)
    elif method=="tikz":
        # visalize using tikz
        tikz_fn = "tikz_example.tex"
        graph.tikz(open(tikz_fn, "w"))    # write tikz latex file
        # I can't show it in notebook meantime since %load_ext tikzmagic not working
        return None
    else:
        print("opinions: ", graph.opinion_spans, 
              [' '.join(graph.input.split(" ")[i] 
                        for i in range(*span))
                        for span in graph.opinion_spans])
        print("aspects: ", graph.aspect_spans,
              [' '.join(graph.input.split(" ")[i] 
                        for i in range(*span))
                        for span in graph.aspect_spans])
        graph.displacy(jupyter=True, options={"compact":True, "distance":100})
            
def plot_hist_with_long_labels(array, bins=None, title=""):
    """ Display hostogram with rotated x-labels """
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    ax.hist(array, bins = bins or len(array), color = 'blue', edgecolor = 'black',)
    for tick in ax.get_xticklabels():
        tick.set_rotation(270)
    ax.set_title(title)

def plot_bar_with_long_labels(labels, values, title=""):
    """ Display hostogram with rotated x-labels """
    import matplotlib.pyplot as plt
    indexes = np.arange(len(labels))
    fig, ax = plt.subplots()
    ax.bar(indexes, values, color = 'blue', edgecolor = 'black')
    for tick in ax.get_xticklabels():
        tick.set_rotation(270)
    plt.xticks(indexes, labels)
    ax.set_title(title)

def plot_counter_with_long_labels(c: Counter):
    plot_bar_with_long_labels(*zip(*c.items()))

In [31]:
# Evaluation utils
def eval_p_r_f1(tp, fp, fn):
    p = float(tp) / (tp + fp) if tp + fp > 0 else None
    r = float(tp) / (tp + fn) if tp + fn > 0 else None
    f1 = (2*p*r)/(p+r) if p and r else None
    return p, r, f1
def evaluate_sets(gold, predicted) -> Tuple[float, float, float]: # precision, recall, F1
    tp = len(set(gold) & set(predicted))
    fp = len(set(predicted) - set(gold))
    fn = len(set(gold) - set(predicted))
    return eval_p_r_f1(tp, fp, fn)
def pretty_eval(p,r,f1):
    return f"P: {p:.2}   R: {r:.2}   F1: {f1:.2}"


In [24]:
# which model
exp_id = "li-biafpatt-T10-L11_Fri_Jan_01_11:27:38" # "bert-amtl-AT_Mon_Dec_28_00:09:59"
formalism = "dm"
exp_base_dir = f"{libert_dir}/logs/{exp_id}/{formalism}"

src_domain, tgt_domain = "laptops", "restaurants"
dataset = f"{src_domain}_to_{tgt_domain}"
dataset_dir = f"{exp_base_dir}/{dataset}"
predictions_fn = f"{dataset_dir}/prediction-test.txt"

with open(predictions_fn) as fin:
    blocks = [[line.split("\t") for line in block.split("\n")] 
              for block in fin.read().strip().split("\n\n")]
"""
blocks = 
[[['great', 'B-OP', 'O'],
  ['taste', 'O', 'O']],
 [['service', 'B-OP', 'B-ASP'],
  ['-', 'O', 'O'],
  ['friendly', 'B-OP', 'B-OP'],
  ['and', 'O', 'O'],
  ['attentive', 'B-OP', 'B-OP'],
  ['.', 'O', 'O']]
 ...]
"""

"\nblocks = \n[[['great', 'B-OP', 'O'],\n  ['taste', 'O', 'O']],\n [['service', 'B-OP', 'B-ASP'],\n  ['-', 'O', 'O'],\n  ['friendly', 'B-OP', 'B-OP'],\n  ['and', 'O', 'O'],\n  ['attentive', 'B-OP', 'B-OP'],\n  ['.', 'O', 'O']]\n ...]\n"

In [33]:
TPs, FPs, FNs = defaultdict(list), defaultdict(list), defaultdict(list)
for sent in blocks:
    tokens, pred_lbls, gold_lbls = zip(*sent)
    sent_str = ' '.join(tokens)
    # compute TP,FP,FN in token level - ignore B\I difference
    def lbl(bio_tag):
        return bio_tag.split("-")[1] if "-" in bio_tag else "O"
    for i, (pred, gold) in enumerate(zip(pred_lbls, gold_lbls)):
        pred, gold = lbl(pred), lbl(gold)
        if pred==gold and gold != "O":
            TPs[gold].append((sent_str, i))
        elif pred!=gold:
            if gold != "O":
                FNs[gold].append((sent_str, i))
            if pred != "O":
                FPs[pred].append((sent_str, i))
print(f'OP:  {len(TPs["OP"])} TPs,  {len(FPs["OP"])} FPs,  {len(FNs["OP"])} FNs;\t {pretty_eval(*eval_p_r_f1(len(TPs["OP"]), len(FPs["OP"]), len(FNs["OP"])))}')
print(f'ASP: {len(TPs["ASP"])} TPs,  {len(FPs["ASP"])} FPs,  {len(FNs["ASP"])} FNs;\t {pretty_eval(*eval_p_r_f1(len(TPs["ASP"]), len(FPs["ASP"]), len(FNs["ASP"])))}')

OP:  1136 TPs,  463 FPs,  463 FNs;	 P: 0.71   R: 0.71   F1: 0.71
ASP: 83 TPs,  44 FPs,  2110 FNs;	 P: 0.65   R: 0.038   F1: 0.072
