In [1]:
import torch
from sklearn.preprocessing import LabelEncoder
import os
from os.path import join
from datetime import datetime
from os import listdir
import re #for camel case conversion
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

os.environ['LDA_name'] = 'num-directstr_thr-0_tn-400'

import pandas as pd
import numpy as np
from extract.feature_extraction.topic_features_LDA import extract_topic_features
from extract.feature_extraction.sherlock_features import extract_sherlock_features
from utils import get_valid_types
from model import models_sherlock
from model.torchcrf import CRF

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.int):
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  method='lar', copy_X=True, eps=np.finfo(np.float).eps,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_Gram=True, verbose=0,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.float).eps, copy_X=True, fit_path=True,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  eps=np.finfo(np.floa

In [2]:
def camel_case(s):
  s = re.sub(r"(_|-)+", " ", s).title().replace(" ", "")
  return ''.join([s[0].lower(), s[1:]])

In [3]:
TYPENAME = os.environ['TYPENAME']
valid_types = get_valid_types(TYPENAME)
print(valid_types)
label_enc = LabelEncoder()
label_enc.fit(valid_types)

MAX_COL_COUNT = 10
topic_dim = 400
pre_trained_loc = './pretrained_sato'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print("PyTorch device={}".format(device))
device = torch.device('cpu')

['address', 'affiliate', 'affiliation', 'age', 'album', 'area', 'artist', 'birthDate', 'birthPlace', 'brand', 'capacity', 'category', 'city', 'class', 'classification', 'club', 'code', 'collection', 'command', 'company', 'component', 'continent', 'country', 'county', 'creator', 'credit', 'currency', 'day', 'depth', 'description', 'director', 'duration', 'education', 'elevation', 'family', 'fileSize', 'format', 'gender', 'genre', 'grades', 'isbn', 'industry', 'jockey', 'language', 'location', 'manufacturer', 'name', 'nationality', 'notes', 'operator', 'order', 'organisation', 'origin', 'owner', 'person', 'plays', 'position', 'product', 'publisher', 'range', 'rank', 'ranking', 'region', 'religion', 'requirement', 'result', 'sales', 'service', 'sex', 'species', 'state', 'status', 'symbol', 'team', 'teamName', 'type', 'weight', 'year']


In [4]:
feature_group_cols = {}
sherlock_feature_groups = ['char', 'word', 'par', 'rest']
for f_g in sherlock_feature_groups:
    feature_group_cols[f_g] = list(pd.read_csv(join(os.environ['BASEPATH'],
                                          'configs', 'feature_groups', 
                                          "{}_col.tsv".format(f_g)),
                                           sep='\t', header=None, 
                                           index_col=0)[1])

In [5]:
pad_vec = lambda x: np.pad(x, (0, topic_dim - len(x)),
                    'constant',
                    constant_values=(0.0, 1/topic_dim))

## Load models

In [6]:
classifier = models_sherlock.build_sherlock(sherlock_feature_groups, num_classes=len(valid_types), topic_dim=topic_dim, dropout_ratio=0.35)
#classifier.load_state_dict(torch.load(join(pre_trained_loc, 'sherlock_None.pt'), map_location=device))
model = CRF(len(valid_types) , batch_first=True).to(device)
#model.load_state_dict(torch.load(join(pre_trained_loc, 'model.pt'), map_location=device))

loaded_params = torch.load(join(pre_trained_loc, 'model.pt'), map_location=device)
classifier.load_state_dict(loaded_params['col_classifier'])
model.load_state_dict(loaded_params['CRF_model'])

classifier.eval()
model.eval()

CRF(num_tags=78)

In [7]:
def extract(df):

    df_dic = {'df':df, 'locator':'None', 'dataset_id':'None'}
    feature_dic = {}
    n = df.shape[1]

    # topic vectors
    topic_features = extract_topic_features(df_dic)
    topic_vec = pad_vec(topic_features.loc[0,'table_topic'])
    feature_dic['topic'] = torch.FloatTensor(np.vstack((np.tile(topic_vec,(n,1)), np.zeros((MAX_COL_COUNT - n, topic_dim)))))


    # sherlock vectors
    sherlock_features = extract_sherlock_features(df_dic)
    for f_g in feature_group_cols:
        temp = sherlock_features[feature_group_cols[f_g]].to_numpy()
        temp = np.vstack((temp, np.zeros((MAX_COL_COUNT - n, temp.shape[1])))).astype('float')
        temp = np.nan_to_num(temp)
        feature_dic[f_g] = torch.FloatTensor(temp)

    # dictionary of features, labels, masks
    return feature_dic, np.zeros(MAX_COL_COUNT), torch.tensor([1]*n + [0]*(MAX_COL_COUNT-n), dtype=torch.uint8)

In [8]:
def evaluate(df):

    feature_dic, labels, mask = extract(df)

    emissions = classifier(feature_dic).view(1, MAX_COL_COUNT, -1)
    mask = mask.view(1, MAX_COL_COUNT)
    pred = model.decode(emissions, mask)[0]

    return label_enc.inverse_transform(pred)

## Load gittables

In [9]:
# Uncomment dataset you want to predict types for
path_data = '../../gittables_benchmark/non_reannotated'
# path_data = '../../gittables_benchmark/reannotated'

In [10]:
tables = {}
table_true_types = {}
# col_ids = []
filepaths = [join(path_data, f) for f in listdir(path_data) if f.endswith('.csv')]

#Go over the tables in the dataset
for fp in filepaths:
    table_id = fp[fp.rfind('/')+1:-4]
    table_header = pd.read_csv(fp,header=None, nrows=1).values[0][1:]
    table_df = pd.read_csv(fp, index_col=0)
    tables[table_id] = table_df
    table_true_types[table_id] = table_header  

In [11]:
print(len(tables))
# print(tables)

477


In [12]:
predicted_types = {}
col_count = 0
for table_id, table_df in sorted(d.iteritems()):
    print(table_id)
#     print(table_df)
#     print(len(table_df.columns))
#     col_count += len(table_df.columns)
    predicted_types[table_id] = evaluate(table_df) 

GitTables_1548


  regex = re.compile(pat, flags=flags)
  regex = re.compile(pat, flags=flags)


GitTables_1946
GitTables_1578
GitTables_2803
GitTables_1544
GitTables_1865
GitTables_2712
GitTables_2928
GitTables_1561
GitTables_2004
GitTables_1799
GitTables_1898
GitTables_1577
GitTables_2745
GitTables_2853
GitTables_2821
GitTables_2324
GitTables_1830
GitTables_2309
GitTables_2637
GitTables_1636
GitTables_1593
GitTables_1637
GitTables_2663
GitTables_2508
GitTables_1542
GitTables_1770
GitTables_2317
GitTables_1973
GitTables_1651
GitTables_1966
GitTables_1563
GitTables_1809
GitTables_1704
GitTables_2010
GitTables_2961
GitTables_2903
GitTables_1611
GitTables_1614
GitTables_2891
GitTables_1974
GitTables_2151
GitTables_1585
GitTables_2111
GitTables_2752
GitTables_1989
GitTables_1994
GitTables_1950
GitTables_2102
GitTables_2022
GitTables_2240
GitTables_2347
GitTables_2658
GitTables_2446
GitTables_2402
GitTables_1617
GitTables_2088
GitTables_2301
GitTables_1920
GitTables_2322
GitTables_2516
GitTables_2213
GitTables_1816
GitTables_1706
GitTables_2660
GitTables_2837
GitTables_1783
GitTables_

In [13]:
# print(col_count)

In [14]:
# for v in table_true_types.values():
#     print(v)

In [15]:
table_true_types_array = [_type for typelist in table_true_types.values() for _type in typelist]
predicted_types_array = [_type for typelist in predicted_types.values() for _type in typelist]
# print(table_true_types_array)
print(classification_report(table_true_types_array, predicted_types_array))

                precision    recall  f1-score   support

       address       1.00      1.00      1.00         1
   affiliation       0.00      0.00      0.00         0
           age       0.00      0.00      0.00         1
         album       0.00      0.00      0.00         0
          area       0.00      0.00      0.00         0
        artist       0.00      0.00      0.00         0
         brand       0.00      0.00      0.00         0
      capacity       0.00      0.00      0.00         1
      category       0.01      0.20      0.02         5
          city       0.71      1.00      0.83         5
         class       0.33      0.02      0.03        64
classification       0.00      0.00      0.00         1
          code       0.71      0.29      0.42        17
    collection       0.00      0.00      0.00         0
       command       0.00      0.00      0.00         0
       company       0.00      0.00      0.00         2
     component       0.00      0.00      0.00  

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)


In [16]:
print(len(table_true_types_array))
print(len(predicted_types_array))

801
801


In [17]:
# # Write out data for further investigation
path_out_true_types = '../../combined/results/true_types'
path_out_predictions = '../../combined/results/predictions'

# tt_df = pd.DataFrame(columns=['type'], data=table_true_types_array)
# tt_df.to_parquet(join(path_out_true_types, 'gittables_benchmark.parquet'))

# pick the code for the dataset you chose to predict
# original benchmark
pred_df = pd.DataFrame(columns=['type'], data=predicted_types_array)
pred_df.to_parquet(join(path_out_predictions, 'sato_gittables_benchmark.parquet'))
# reannotated benchmark
# pred_df = pd.DataFrame(columns=['type'], data=predicted_labels)
# pred_df.to_parquet(join(path_out_predictions, 'sato_gittables_benchmark_reannotated.parquet'))

In [18]:
for table_id, types in table_true_types.items():
    print(f'true type: {types}')
    print(f'sato: {predicted_types[table_id]}')

true type: ['type']
sato: ['name']
true type: ['type']
sato: ['category']
true type: ['type']
sato: ['category']
true type: ['name' 'rank' 'year']
sato: ['species' 'species' 'year']
true type: ['class' 'description']
sato: ['category' 'description']
true type: ['class']
sato: ['company']
true type: ['name']
sato: ['position']
true type: ['code' 'name']
sato: ['rank' 'description']
true type: ['type']
sato: ['name']
true type: ['name' 'rank' 'species' 'year']
sato: ['name' 'species' 'notes' 'year']
true type: ['type']
sato: ['category']
true type: ['type']
sato: ['category']
true type: ['type']
sato: ['category']
true type: ['description' 'name' 'type']
sato: ['name' 'name' 'type']
true type: ['name' 'rank' 'species' 'year']
sato: ['name' 'species' 'notes' 'year']
true type: ['year']
sato: ['year']
true type: ['name']
sato: ['position']
true type: ['state' 'type']
sato: ['type' 'type']
true type: ['name' 'rank' 'year']
sato: ['name' 'species' 'year']
true type: ['name' 'rank' 'species' 

sato: ['category']
true type: ['state' 'type']
sato: ['type' 'type']
true type: ['class']
sato: ['name']
true type: ['type']
sato: ['name']
true type: ['name' 'rank' 'year']
sato: ['species' 'species' 'year']
true type: ['class' 'description']
sato: ['brand' 'description']
true type: ['name']
sato: ['position']
true type: ['type']
sato: ['category']
true type: ['state' 'type']
sato: ['type' 'type']
true type: ['depth']
sato: ['weight']
true type: ['name']
sato: ['name']
true type: ['name' 'rank' 'species' 'year']
sato: ['name' 'species' 'notes' 'year']
true type: ['type']
sato: ['type']
true type: ['name' 'type']
sato: ['name' 'status']
true type: ['name' 'rank' 'year']
sato: ['name' 'species' 'year']
true type: ['type']
sato: ['category']
true type: ['name']
sato: ['name']
true type: ['address' 'city' 'county' 'name' 'state']
sato: ['address' 'city' 'city' 'name' 'state']
true type: ['duration']
sato: ['position']
true type: ['name' 'type']
sato: ['name' 'status']
true type: ['name' '

In [19]:
path_data = '../../gittables_benchmark/non_reannotated'

In [20]:
tables = {}
table_true_types = {}
# col_ids = []
filepaths2 = [join(path_data, f) for f in listdir(path_data) if f.endswith('.csv')]

#Go over the tables in the dataset
for fp in filepaths2:
    table_id = fp[fp.rfind('/')+1:-4]
    table_header = pd.read_csv(fp,header=None, nrows=1).values[0][1:]
    table_df = pd.read_csv(fp, index_col=0)
    if (len(table_df.columns)>=2):
#         print(table_header)
#         print(table_df)
        tables[table_id] = table_df
        table_true_types[table_id] = table_header  

In [21]:
print(len(tables))

178


In [22]:
predicted_types = {}
for table_id, table_df in tables.items():
    predicted_types[table_id] = evaluate(table_df) 

In [23]:
for table_id, types in table_true_types.items():
    print(types)
    print(predicted_types[table_id])

['name' 'rank' 'year']
['species' 'species' 'year']
['class' 'description']
['category' 'description']
['code' 'name']
['rank' 'description']
['name' 'rank' 'species' 'year']
['name' 'species' 'notes' 'year']
['description' 'name' 'type']
['name' 'name' 'type']
['name' 'rank' 'species' 'year']
['name' 'species' 'notes' 'year']
['state' 'type']
['type' 'type']
['name' 'rank' 'year']
['name' 'species' 'year']
['name' 'rank' 'species' 'year']
['type' 'species' 'notes' 'year']
['component' 'product' 'status']
['type' 'status' 'status']
['name' 'rank' 'species' 'year']
['name' 'species' 'notes' 'year']
['origin' 'type']
['product' 'type']
['name' 'rank' 'species' 'year']
['name' 'species' 'notes' 'year']
['state' 'type']
['type' 'type']
['name' 'rank' 'year']
['name' 'species' 'year']
['class' 'description']
['type' 'publisher']
['name' 'rank' 'year']
['name' 'species' 'year']
['name' 'rank' 'year']
['name' 'species' 'year']
['code' 'language' 'name']
['name' 'position' 'name']
['class' 'de

In [24]:
table_true_types_array = [_type for typelist in table_true_types.values() for _type in typelist]
predicted_types_array = [_type for typelist in predicted_types.values() for _type in typelist]
# print(table_true_types_array)
print(classification_report(table_true_types_array, predicted_types_array))

              precision    recall  f1-score   support

     address       1.00      1.00      1.00         1
 affiliation       0.00      0.00      0.00         0
         age       0.00      0.00      0.00         1
       album       0.00      0.00      0.00         0
        area       0.00      0.00      0.00         0
      artist       0.00      0.00      0.00         0
       brand       0.00      0.00      0.00         0
    category       0.20      0.25      0.22         4
        city       0.80      1.00      0.89         4
       class       0.00      0.00      0.00        19
        code       0.80      0.25      0.38        16
     company       0.00      0.00      0.00         2
   component       0.00      0.00      0.00         3
     country       0.00      0.00      0.00         1
      county       0.00      0.00      0.00         1
     creator       0.00      0.00      0.00         0
      credit       0.00      0.00      0.00         0
 description       0.74    

  'precision', 'predicted', average, warn_for)
  'recall', 'true', average, warn_for)
