# Visualization of F1 score as a function of span separation distance

In [1]:
import h5py
import numpy as np
import bokeh
from keras.models import load_model
from sklearn.metrics import f1_score
from bokeh.plotting import figure, output_file, show, output_notebook
from tensorflow import Graph, Session
from keras_self_attention import SeqWeightedAttention

Using TensorFlow backend.


In [2]:
test_data_joshi = "/home/users/pkahardipraja/data/bert_coref/data-bert-base-128/test_span_reps_joshi.128_1.h5"
test_data_baseline = "/home/users/pkahardipraja/data/bert_coref/data-bert-base-128/test_span_reps_baseline.128_1.h5"
with h5py.File(test_data_joshi, 'r') as f:
    test_data = f.get('span_representations').value
    x_test = test_data[:, :-2]
    y_test = test_data[:, -1].astype(int)
    x_dist = test_data[:, -2].astype(int)
    
with h5py.File(test_data_baseline, 'r') as f:
    test_data_base = f.get('span_representations').value
    x_test_base = test_data_base[:, :-2]
    y_test_base = test_data_base[:, -1].astype(int)



In [3]:
# Use 20 bucket, with range of 0-25, 26-50, etc.. This way each bucket contains more than 50 examples.
def get_output(model_path, data):
    graph = Graph()
    with graph.as_default():
        session = Session()
        with session.as_default():
            #  load model
            model = load_model(model_path)
            
            lower_dist = 0
            upper_dist = 25
            bucket_dist = 25
            token_dist = data[:, -2].astype(int)
            acc_list = []
            f1_list = []
            example_list = []

            while (upper_dist <= 500):
                idx =np.intersect1d(np.where(data[:, -2].astype(int) >= lower_dist), np.where(data[:, -2].astype(int) <= upper_dist))
                filtered_data = data[idx]
                features = filtered_data[:, :-2]
                label = filtered_data[:, -1].astype(int)
                _, score = model.evaluate(features, label, batch_size=features.shape[0])
                prediction = (np.asarray(model.predict(features))).round()
                f1 = f1_score(label, prediction)
                acc_list.append(score)
                f1_list.append(f1)
                example_list.append(len(idx))
                lower_dist = upper_dist + 1
                upper_dist += bucket_dist
    return acc_list, f1_list, example_list

def get_output_base(model_path, data, bert_type, max_span_width=30):
    if bert_type == "bert_base":
        embed_dim = 768
    elif bert_type == "bert_large":
        embed_dim = 1024
    
    graph = Graph()
    with graph.as_default():
        session = Session()
        with session.as_default():
            #  load model
            model = load_model(model_path, custom_objects=SeqWeightedAttention.get_custom_objects())
            
            lower_dist = 0
            upper_dist = 25
            bucket_dist = 25
            token_dist = data[:, -2].astype(int)
            acc_list = []
            f1_list = []
            example_list = []

            while (upper_dist <= 500):
                idx =np.intersect1d(np.where(data[:, -2].astype(int) >= lower_dist), np.where(data[:, -2].astype(int) <= upper_dist))
                filtered_data = data[idx]
                features = filtered_data[:, :-2]
                parent_span_emb = features[:, :max_span_width*embed_dim].reshape(features.shape[0], max_span_width, embed_dim)
                child_span_emb = features[:, max_span_width*embed_dim:].reshape(features.shape[0], max_span_width, embed_dim)
                
                label = filtered_data[:, -1].astype(int)
                _, score = model.evaluate([parent_span_emb, child_span_emb], label, batch_size=features.shape[0])
                prediction = (np.asarray(model.predict([parent_span_emb, child_span_emb]))).round()
                f1 = f1_score(label, prediction)
                acc_list.append(score)
                f1_list.append(f1)
                example_list.append(len(idx))
                lower_dist = upper_dist + 1
                upper_dist += bucket_dist
    return acc_list, f1_list, example_list

In [4]:
model_path = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_base_joshi_128.h5"
model_path_base_kernel_3 = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_base_baseline_128_kernel_3.h5"
model_path_base_kernel_5 = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_base_baseline_128_kernel_5.h5"
model_path_baseline = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_base_baseline_128.h5"

acc_list, f1_list, example_list = get_output(model_path, test_data)
acc_list_base_kernel_3, f1_list_base_kernel_3, example_list_base_kernel_3 = get_output_base(model_path_base_kernel_3, test_data_base, "bert_base")
acc_list_base_kernel_5, f1_list_base_kernel_5, example_list_base_kernel_5 = get_output_base(model_path_base_kernel_5, test_data_base, "bert_base")
acc_list_baseline, f1_list_baseline, example_list_baseline = get_output_base(model_path_baseline, test_data_base, "bert_base")

W0408 10:19:38.950574 139664974870272 deprecation.py:323] From /home/users/pkahardipraja/.local/lib/python3.5/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0408 10:19:45.250853 139664974870272 deprecation_wrapper.py:119] From /home/users/pkahardipraja/.local/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.





In [6]:
# Use F1 score instead of accuracy as metrics (more interesting and can capture type I and type II error)
FONT_SIZE="13pt"
x = [i for i in range(20)]
output_notebook()
crs = []
col_keys = [[f1_list, "BERT-base c2f", "diamond", "green"],
                     [f1_list_base_kernel_3, "BERT-base + CNN (K=3)", "triangle", "blue"],
                     [f1_list_base_kernel_5, "BERT-base + CNN (K=5)", "circle", "red"],
                     [f1_list_baseline, "BERT-base baseline", "square", "purple"]
                    ]
p = figure(x_axis_label='Distance between mention pair (wordpiece tokens)', y_axis_label='F1 Score', y_range=[0.5, 1.05], width=900, height=400)

for i in col_keys:
    p.line(x, i[0], color=i[3])
    cr = p.scatter(x=x, y=i[0], size=10, hover_fill_color="Gray", legend_label=i[1], marker=i[2], color=i[3])
    crs.append(cr)

tooltips = [("Height", "@x"),("F1 score", "@y{0.00}"),]
p.add_tools(bokeh.models.HoverTool(tooltips=tooltips, renderers=crs))
p.xaxis.ticker = [i for i in range(20) if i%2==0]
p.yaxis.major_label_text_font_size = FONT_SIZE
p.xaxis.major_label_text_font_size = FONT_SIZE
p.xaxis.axis_label_text_font_size = FONT_SIZE
p.yaxis.axis_label_text_font_size = FONT_SIZE
p.legend.orientation = "horizontal"
p.legend.label_text_font_size = '10pt'
p.legend.location = "bottom_right"
p.xaxis.major_label_overrides = { 0: '0-25', 2: '51-75', 4: '101-125', 6:'151-175', 8:'201-225', 10:'251-275', 12:'301-325', 14:'351-375', 16:'401-425', 18:'451-475'}
show(p)

In [7]:
test_data_joshi_large = "/home/users/pkahardipraja/data/bert_coref/data-bert-large-384/test_span_reps_joshi.384_1.h5"
test_data_baseline_large = "/home/users/pkahardipraja/data/bert_coref/data-bert-large-384/test_span_reps_baseline.384_1.h5"
with h5py.File(test_data_joshi_large, 'r') as f:
    test_data_large = f.get('span_representations').value
    x_test = test_data_large[:, :-2]
    y_test = test_data_large[:, -1].astype(int)
    x_dist = test_data_large[:, -2].astype(int)
    
with h5py.File(test_data_baseline_large, 'r') as f:
    test_data_base_large = f.get('span_representations').value
    x_test_base = test_data_base_large[:, :-2]
    y_test_base = test_data_base_large[:, -1].astype(int)



In [8]:
model_path_large = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_large_joshi_384.h5"
model_path_base_kernel_3_large = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_large_baseline_384_kernel_3.h5"
model_path_base_kernel_5_large = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_large_baseline_384_kernel_5.h5"
model_path_baseline_large = "/project/kahardipraja/coref/bert_for_coreference_resolution/models/test_bert_large_baseline_384.h5"

acc_list_large, f1_list_large, example_list_large = get_output(model_path_large, test_data_large)
acc_list_base_kernel_3_large, f1_list_base_kernel_3_large, example_list_base_kernel_3_large = get_output_base(model_path_base_kernel_3_large, test_data_base_large, "bert_large")
acc_list_base_kernel_5_large, f1_list_base_kernel_5_large, example_list_base_kernel_5_large = get_output_base(model_path_base_kernel_5_large, test_data_base_large, "bert_large")
acc_list_baseline_large, f1_list_baseline_large, example_list_baseline_large = get_output_base(model_path_baseline_large, test_data_base_large, "bert_large")





In [9]:
output_notebook()
crs = []
col_keys2 = [[f1_list_large, "BERT-large c2f", "diamond", "green"],
                     [f1_list_base_kernel_3_large, "BERT-large + CNN (K=3)", "triangle", "blue"],
                     [f1_list_base_kernel_5_large, "BERT-large + CNN (K=5)", "circle", "red"],
                     [f1_list_baseline_large, "BERT-large baseline", "square", "purple"]
                    ]
p = figure(x_axis_label='Distance between mention pair (wordpiece tokens)', y_axis_label='F1 Score', y_range=[0.5, 1.05], width=900, height=400)

for i in col_keys2:
    p.line(x, i[0], color=i[3])
    cr = p.scatter(x=x, y=i[0], size=10, hover_fill_color="Gray", legend_label=i[1], marker=i[2], color=i[3])
    crs.append(cr)

tooltips = [("Height", "@x"),("F1 score", "@y{0.00}"),]
p.add_tools(bokeh.models.HoverTool(tooltips=tooltips, renderers=crs))
p.xaxis.ticker = [i for i in range(20) if i%2==0]
p.yaxis.major_label_text_font_size = FONT_SIZE
p.xaxis.major_label_text_font_size = FONT_SIZE
p.xaxis.axis_label_text_font_size = FONT_SIZE
p.yaxis.axis_label_text_font_size = FONT_SIZE
p.legend.orientation = "horizontal"
p.legend.label_text_font_size = '10pt'
p.legend.location = "bottom_right"
p.xaxis.major_label_overrides = { 0: '0-25', 2: '51-75', 4: '101-125', 6:'151-175', 8:'201-225', 10:'251-275', 12:'301-325', 14:'351-375', 16:'401-425', 18:'451-475'}
show(p)