In [87]:
from datetime import datetime
from os.path import join
from os import listdir
import json
import re #for camel case conversion
from collections import Counter

from sklearn.metrics import classification_report
import pandas as pd

import altair as alt
alt.renderers.enable('default')

RendererRegistry.enable('default')

In [88]:
def sherlock_case(s):
    s = re.sub(r"(_|-)+", " ", s).title().replace(" ", "")
#     s = ''.join([s[0].lower(), s[1:]])
#     s = ''.join(map(lambda x: x if x.islower() else " "+x, s))
    return s.lower()

In [103]:
og_true_types = list(map(sherlock_case, pd.read_parquet("../../results/true_types/gittables_benchmark.parquet").values.flatten()))
true_types = list(map(sherlock_case, pd.read_parquet("../../results/true_types/gittables_benchmark_reannotated.parquet").values.flatten()))
prediction_sherlock = list(map(sherlock_case, pd.read_parquet("../../results/predictions/sherlock_gittables_benchmark_reannotated.parquet").values.flatten()))
prediction_sato = list(map(sherlock_case, pd.read_parquet("../../results/predictions/sato_gittables_benchmark_reannotated.parquet").values.flatten()))
# we cam actually use the prediction of non-reannotated since they are the same, for the predictions only the column/table values are
# taken into consideration, so only the true types differ, the predictions remain the same for both.
og_prediction_sherlock = list(map(sherlock_case, pd.read_parquet("../../results/predictions/sherlock_gittables_benchmark.parquet").values.flatten()))
og_prediction_sato = list(map(sherlock_case, pd.read_parquet("../../results/predictions/sato_gittables_benchmark.parquet").values.flatten()))

In [98]:
type_freq_df = pd.DataFrame(true_types, columns=['type'])
type_freq_df = pd.DataFrame(type_freq_df['type'].value_counts())
type_freq_df.columns = ['count']
type_freq_df.index.name = 'type'

In [99]:
alt.Chart(type_freq_df.reset_index()).mark_bar(size=15).encode(
    x = alt.X('type:O',
              title = 'Semantic Types',
              sort=alt.EncodingSortField(
                field="count",  
                order="descending")),
    y = alt.Y('count', title='Number of Samples')    
)

In [100]:
print(len(prediction_sherlock))
print(len(prediction_sato))
print(len(true_types))


801
801
801


In [77]:
# print(prediction_sato)
# print(prediction_sherlock)
# print(true_types)
# for idx, i in enumerate(true_types):
#     if type(i) != str:
#         print(idx)
# print(true_types[407])
# print(prediction_sherlock[407])

In [104]:
print(classification_report(og_true_types, og_prediction_sato))

                precision    recall  f1-score   support

       address       1.00      1.00      1.00         1
   affiliation       0.00      0.00      0.00         0
           age       0.00      0.00      0.00         1
         album       0.00      0.00      0.00         0
          area       0.00      0.00      0.00         0
        artist       0.00      0.00      0.00         0
         brand       0.00      0.00      0.00         0
      capacity       0.00      0.00      0.00         1
      category       0.01      0.20      0.02         5
          city       0.71      1.00      0.83         5
         class       0.33      0.02      0.03        64
classification       0.00      0.00      0.00         1
          code       0.71      0.29      0.42        17
    collection       0.00      0.00      0.00         0
       command       0.00      0.00      0.00         0
       company       0.00      0.00      0.00         2
     component       0.00      0.00      0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [102]:
print(classification_report(true_types, prediction_sato))

                precision    recall  f1-score   support

       address       1.00      0.50      0.67         2
   affiliation       0.00      0.00      0.00         0
           age       0.00      0.00      0.00         0
         album       0.00      0.00      0.00         0
          area       0.00      0.00      0.00         0
        artist       0.00      0.00      0.00         0
         brand       0.00      0.00      0.00         0
      capacity       0.00      0.00      0.00         1
      category       0.01      0.17      0.02         6
          city       0.71      1.00      0.83         5
         class       0.33      0.02      0.03        64
classification       0.00      0.00      0.00         1
          code       0.71      0.33      0.45        15
    collection       0.00      0.00      0.00         0
       command       0.00      0.00      0.00         0
       company       0.00      0.00      0.00         4
     component       0.00      0.00      0.00  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [105]:
og_class_report = classification_report(og_true_types, og_prediction_sato, output_dict=True)
class_report = classification_report(true_types, prediction_sato, output_dict=True)
class_report = dict(list(class_report.items()))
class_report_diff = {}
#write class_report_diff in a for loop
i = 0
for k, v in class_report.items():
    if k == 'accuracy':
        class_report_diff[k] = v
        continue
    if v['f1-score'] != og_class_report[k]['f1-score']:
        class_report_diff[k] = v
print(class_report_diff)
# class_report_df = pd.DataFrame.from_dict(class_report)
# class_report_diff_df = pd.DataFrame.from_dict(class_report_diff)
# class_report_df.to_csv('csv_report_sato_benchmark_re-annotated.csv')
# class_report_diff_df.to_csv('csv_report_sato_benchmark_re-annotated_diff.csv')

0.6666666666666666
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.022988505747126436
0.023255813953488372
0.8333333333333333
0.8333333333333333
0.029850746268656716
0.029850746268656716
0.0
0.0
0.4545454545454545
0.4166666666666667
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.888888888888889
0.75
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.588235294117647
0.5714285714285714
0.0
0.0
0.0
0.0
1.0
1.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.37540453074433655
0.6699029126213591
0.07692307692307691
0.07692307692307691
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.07929515418502202
0.015037593984962407
0.21428571428571427
0.07692307692307693
0.4324324324324324
0.4324324324324324
0.0
0.0
0.0
0.0
0.0
0.0
0.2465753424657534
0.2523364485981308
0.25
0.25
0.9902912621359222
0.9902912621359222
0.13749454620653906
0.1394034327409373
0.29662733545106446
0.4022421591988421
{'address': {'precision': 1.0, 'recall': 0.5, 'f1-score': 0.6666666666666666, 'support': 2}, 'categ

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
# print(classification_report(true_types, prediction_sherlock))

In [32]:
og_class_report = classification_report(og_true_types, og_prediction_sherlock, output_dict=True)
class_report = classification_report(true_types, prediction_sherlock, output_dict=True)
class_report = dict(list(class_report.items()))
class_report_df = pd.DataFrame.from_dict(class_report)
class_report_df.to_csv('csv_report_sherlock_benchmark_re-annotated.csv')

In [10]:
report_sherlock = classification_report(true_types, prediction_sherlock, output_dict=True)
report_sherlock_df_input = {k: list(v.values()) for k, v in list(report_sherlock.items())[:-3]} #last 3 are total f1/macro/weigthed, these are not needed
report_sherlock_df = pd.DataFrame.from_dict(report_sherlock_df_input, orient='index', columns=['precision', 'recall', 'f1-score', 'support']).sort_values(by='f1-score', ascending = False)

report_sato = classification_report(true_types, prediction_sato, output_dict=True)
report_sato_df_input = {k: list(v.values()) for k, v in list(report_sato.items())[:-3]} #last 3 are total f1/macro/weigthed, these are not needed
report_sato_df = pd.DataFrame.from_dict(report_sato_df_input, orient='index', columns=['precision', 'recall', 'f1-score', 'support']).sort_values(by='f1-score', ascending = False)


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [11]:
combined_report_df = pd.merge(report_sato_df, report_sherlock_df, left_index=True, right_index=True)
combined_report_df.index.names = ['type']
combined_report_df.columns = ['precision_sato', 'recall_sato', 'f1-score_sato', 'support_sato', 'precision_sherlock', 'recall_sherlock', 'f1-score_sherlock', 'support_sherlock']

In [12]:
combined_report_df

Unnamed: 0_level_0,precision_sato,recall_sato,f1-score_sato,support_sato,precision_sherlock,recall_sherlock,f1-score_sherlock,support_sherlock
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
gender,1.0,1.0,1.0,1,0.0,0.0,0.0,1
year,0.980769,1.0,0.990291,102,1.0,0.911765,0.953846,102
country,1.0,0.8,0.888889,5,1.0,0.8,0.888889,5
city,0.714286,1.0,0.833333,5,0.833333,1.0,0.909091,5
address,1.0,0.5,0.666667,2,0.014388,1.0,0.028369,2
description,0.490196,0.735294,0.588235,34,0.448276,0.382353,0.412698,34
code,0.714286,0.333333,0.454545,15,0.8,0.266667,0.4,15
status,0.347826,0.571429,0.432432,14,0.588235,0.714286,0.645161,14
name,0.284314,0.552381,0.375405,105,0.7,0.133333,0.224,105
weight,0.153846,0.666667,0.25,3,0.166667,0.666667,0.266667,3


In [13]:
mismatches_sherlock = list()
mismatches_sherlock_idx = list()
mismatches_sato = list()
mismatches_sato_idx = list()

print_count = 0
for idx, true_type in enumerate(true_types):
    predicted_type_sherlock = prediction_sherlock[idx]
    predicted_type_sato = prediction_sato[idx]

    if true_type != predicted_type_sherlock:
        mismatches_sherlock.append(true_type)
        mismatches_sherlock_idx.append(idx)

    if true_type != predicted_type_sato:
        mismatches_sato.append(true_type)
        mismatches_sato_idx.append(idx)
        
        # zoom in to specific errors
        # if true_type in ('state') and print_count <= 6:
        #     print_count += 1
        #     print(f'Expected "{true_type}" but predicted "{predicted_type}"')
        #     print(f'{data[idx]}\n')
        

mismatch_sherlock_class_count = Counter(mismatches_sherlock)
print(mismatch_sherlock_class_count.most_common()[:10])

mismatches_sato_class_count = Counter(mismatches_sato)
print(mismatches_sato_class_count.most_common()[:10])

[('rank', 97), ('name', 91), ('class', 60), ('type', 58), ('species', 26), ('description', 21), ('state', 19), ('depth', 13), ('code', 11), ('duration', 11)]
[('type', 119), ('species', 112), ('rank', 100), ('class', 63), ('name', 47), ('state', 18), ('depth', 13), ('duration', 11), ('code', 10), ('description', 9)]


In [14]:
mismatch_sherlock_freq_df = pd.DataFrame(mismatches_sherlock, columns=['type'])
mismatch_sherlock_freq_df = pd.DataFrame(mismatch_sherlock_freq_df['type'].value_counts())
mismatch_sherlock_freq_df.columns = ['count']
mismatch_sherlock_freq_df.index.name = 'type'

mismatch_sato_freq_df = pd.DataFrame(mismatches_sato, columns=['type'])
mismatch_sato_freq_df = pd.DataFrame(mismatch_sato_freq_df['type'].value_counts())
mismatch_sato_freq_df.columns = ['count']
mismatch_sato_freq_df.index.name = 'type'

type_freq_df = pd.DataFrame(true_types, columns=['type'])
type_freq_df = pd.DataFrame(type_freq_df['type'].value_counts())
type_freq_df.columns = ['count']
type_freq_df.index.name = 'type'

In [15]:
combined_mismatch_freq_df = pd.merge(mismatch_sherlock_freq_df, mismatch_sato_freq_df, left_index=True, right_index=True, how='outer')
combined_mismatch_freq_df = pd.merge(combined_mismatch_freq_df, type_freq_df, left_index=True, right_index=True, how='outer')
combined_mismatch_freq_df.columns=['sherlock_mismatch_freq', 'sato_mismatch_freq', 'true_type_freq']

In [16]:
alt.Chart(combined_mismatch_freq_df.reset_index()).transform_fold(
      ['sherlock_mismatch_freq', 'sato_mismatch_freq', 'true_type_freq'],
      as_=['column', 'value']
    ).mark_bar(size=15).encode(
    column=alt.Column('type:O', sort=alt.EncodingSortField(
                field="sato_mismatch_freq",  
                order="descending")),
    x = alt.X('column:N',
              title = '',
              ),
    y=alt.Y('value:Q'),
    color='column:N'
)

In [17]:
# print(combined_mismatch_freq_df)

In [18]:
alt.Chart(combined_report_df.reset_index()).transform_fold(
      ['precision_sato', 'recall_sato', 'f1-score_sato', 'precision_sherlock', 'recall_sherlock', 'f1-score_sherlock'],
      as_=['column', 'value']
    ).mark_bar(size=15).encode(
    column=alt.Column('type:O', sort=alt.EncodingSortField(
                field="support_sherlock",  
                order="descending")),
    x = alt.X('column:N',
              title = '',
              ),
    y=alt.Y('value:Q'),
    color='column:N'
)