# This notebook enables training and testing of Sherlock.
The procedure is:
- Load train, val, test datasets (should be preprocessed)
- Initialize model using the "pretrained" model or by training one from scratch.
- Evaluate and analyse the model predictions.

In [1]:
%env PYTHONHASHSEED=13
%load_ext autoreload
%autoreload 2

env: PYTHONHASHSEED=13


In [2]:
# This will be the ID for the retrained model,
#further down predictions can also be made with the original model: "sherlock"
model_id = 'retrained_sherlock'

In [3]:
from ast import literal_eval
from collections import Counter
from datetime import datetime

import numpy as np
import pandas as pd

from sklearn.metrics import f1_score, classification_report

from sherlock.deploy.model import SherlockModel

## Load datasets for training, validation, testing

In [4]:
start = datetime.now()
print(f'Started at {start}')

X_train = pd.read_parquet('../data/data/processed/train.parquet')
y_train = pd.read_parquet('../data/data/raw/train_labels.parquet').values.flatten()

y_train = np.array([x.lower() for x in y_train])

print(f'Load data (train) process took {datetime.now() - start} seconds.')

Started at 2022-11-28 06:29:28.593947
Load data (train) process took 0:00:00.604680 seconds.


In [5]:
len(np.unique(y_train))

35

In [6]:
print('Distinct types for columns in the Dataframe (should be all float32):')
print(set(X_train.dtypes))

Distinct types for columns in the Dataframe (should be all float32):
{dtype('float32')}


In [7]:
start = datetime.now()
print(f'Started at {start}')

X_validation = pd.read_parquet('../data/data/processed/validation.parquet')
y_validation = pd.read_parquet('../data/data/raw/val_labels.parquet').values.flatten()

y_validation = np.array([x.lower() for x in y_validation])

print(f'Load data (validation) process took {datetime.now() - start} seconds.')

Started at 2022-11-28 06:29:29.366814
Load data (validation) process took 0:00:00.116432 seconds.


In [8]:
start = datetime.now()
print(f'Started at {start}')

X_test = pd.read_parquet('../data/data/processed/test.parquet')
y_test = pd.read_parquet('../data/data/raw/test_labels.parquet').values.flatten()

y_test = np.array([x.lower() for x in y_test])

print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')

Started at 2022-11-28 06:29:29.525791
Finished at 2022-11-28 06:29:29.674027, took 0:00:00.148245 seconds


## Initialize the model
Two options:
- Load Sherlock model with pretrained weights
- Fit Sherlock model from scratch

In [9]:
model_id = "retrained_sherlock"

In [10]:
model = SherlockModel()
try:
    model.initialize_model_from_json(with_weights=True, model_id=model_id);
except:
    start = datetime.now()
    print(f'Started at {start}')
    # Model will be stored with ID `model_id`
    model.fit(X_train, y_train, X_validation, y_validation, model_id=model_id)

    print('Trained and saved new model.')
    print(f'Finished at {datetime.now()}, took {datetime.now() - start} seconds')
    model.store_weights(model_id=model_id)

Started at 2022-11-28 06:29:29.746098


2022-11-28 06:29:30.011197: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-28 06:29:30.014086: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
  super(Adam, self).__init__(name, **kwargs)


Epoch 1/10000


W1128 06:29:31.058537 46912499975424 ag_logging.py:142] AutoGraph could not transform <function Model.make_train_function.<locals>.train_function at 0x2aab55bbb9d0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'

W1128 06:29:36.107420 46912499975424 ag_logging.py:142] AutoGraph could not transform <function Model.make_test_function.<locals>.test_function at 0x2aad1bfb9e50> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'
Epoch 2/10000
Epoch 3/10000
Epoch 4/10000
Epoch 5/10000
Epoch 6/10000
Epoch 7/10000
Epoch 8/10000
Epoch 9/10000
Epoch 10/10000
Epoch 11/10000
Epoch 12/10000
Epoch 13/10000
Epoch 14/10000
Epoch 15/10000
Epoch 16/10000
Epoch 17/10000
Epoch 18/10000
Epoch 19/10000
Epoch 20/10000
Epoch 21/10000
Epoch 22/10000
Epoch 23/10000
Epoch 24/10000
Epoch 25/10000
Epoch 26/10000
Epoch 27/10000
Epoch 28/10000
Epoch 29/10000
Trained and saved new model.
Finished at 2022-11-28 06:31:25.928454, took 0:01:56.182366 seconds


In [11]:
predicted_labels = []

### Make prediction

In [12]:
predicted_labels = model.predict(X_test, model_id)
predicted_labels = np.array([x.lower() for x in predicted_labels])

W1128 06:31:27.972228 46912499975424 ag_logging.py:142] AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x2aad1ad248b0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Constant'


In [13]:
print(f'prediction count {len(predicted_labels)}, type = {type(predicted_labels)}')

size=len(y_test)

# Should be fully deterministic too.
f1_score(y_test[:size], predicted_labels[:size], average="weighted")

prediction count 9932, type = <class 'numpy.ndarray'>


0.8968930455827515

In [14]:
# If using the original model, model_id should be replaced with "sherlock"
#model_id = "sherlock"
classes = np.load(f"../model_files/classes_{model_id}.npy", allow_pickle=True)

report = classification_report(y_test, predicted_labels, output_dict=True)

class_scores = list(filter(lambda x: isinstance(x, tuple) and isinstance(x[1], dict) and 'f1-score' in x[1] and x[0] in classes, list(report.items())))

class_scores = sorted(class_scores, key=lambda item: item[1]['f1-score'], reverse=True)

### Top 5 Types

In [15]:
print(f"\t\tf1-score\tprecision\trecall\t\tsupport")

for key, value in class_scores[0:5]:
    if len(key) >= 8:
        tabs = '\t' * 1
    else:
        tabs = '\t' * 2

    print(f"{key}{tabs}{value['f1-score']:.3f}\t\t{value['precision']:.3f}\t\t{value['recall']:.3f}\t\t{value['support']}")

		f1-score	precision	recall		support
guuid		1.000		1.000		1.000		120
birth date	0.993		0.986		1.000		72
industry	0.982		0.969		0.995		444
currency	0.975		0.983		0.967		61
sex		0.968		0.969		0.967		450


### Bottom 5 Types

In [16]:
print(f"\t\tf1-score\tprecision\trecall\t\tsupport")

for key, value in class_scores[len(class_scores)-5:len(class_scores)]:
    if len(key) >= 8:
        tabs = '\t' * 1
    else:
        tabs = '\t' * 2

    print(f"{key}{tabs}{value['f1-score']:.3f}\t\t{value['precision']:.3f}\t\t{value['recall']:.3f}\t\t{value['support']}")

		f1-score	precision	recall		support
brand		0.761		0.805		0.721		86
rank		0.757		0.765		0.749		447
range		0.739		0.829		0.667		87
person		0.649		0.787		0.552		87
sales		0.533		0.741		0.417		48


### All Scores

In [17]:
print(classification_report(y_test, predicted_labels, digits=3))

                 precision    recall  f1-score   support

        address      0.912     0.971     0.941       450
            age      0.892     0.963     0.926       455
           area      0.900     0.785     0.839       298
     birth date      0.986     1.000     0.993        72
    birth place      0.982     0.873     0.924        63
          brand      0.805     0.721     0.761        86
           city      0.879     0.912     0.895       445
      continent      0.811     0.882     0.845        34
        country      0.917     0.941     0.929       456
         county      0.939     0.966     0.952       444
       currency      0.983     0.967     0.975        61
            day      0.894     0.884     0.889       456
       duration      0.918     0.940     0.929       450
          guuid      1.000     1.000     1.000       120
       industry      0.969     0.995     0.982       444
       language      0.938     0.955     0.946       221
       location      0.918    

## Review errors

In [18]:
size = len(y_test)
mismatches = list()

for idx, k1 in enumerate(y_test[:size]):
    k2 = predicted_labels[idx]

    if k1 != k2:
        mismatches.append(k1)
        
        # zoom in to specific errors. Use the index in the next step
        if k1 in ('address'):
            print(f'[{idx}] expected "{k1}" but predicted "{k2}"')
        
f1 = f1_score(y_test[:size], predicted_labels[:size], average="weighted")
print(f'Total mismatches: {len(mismatches)} (F1 score: {f1})')

data = Counter(mismatches)
data.most_common()   # Returns all unique items and their counts

[33] expected "address" but predicted "name"
[100] expected "address" but predicted "day"
[101] expected "address" but predicted "area"
[103] expected "address" but predicted "type"
[121] expected "address" but predicted "range"
[135] expected "address" but predicted "county"
[181] expected "address" but predicted "type"
[277] expected "address" but predicted "language"
[338] expected "address" but predicted "name"
[368] expected "address" but predicted "name"
[375] expected "address" but predicted "duration"
[402] expected "address" but predicted "product"
[408] expected "address" but predicted "duration"
Total mismatches: 1011 (F1 score: 0.8968930455827515)


[('rank', 112),
 ('location', 104),
 ('region', 75),
 ('name', 73),
 ('area', 64),
 ('day', 53),
 ('product', 46),
 ('city', 39),
 ('person', 39),
 ('status', 35),
 ('type', 34),
 ('range', 29),
 ('sales', 28),
 ('state', 28),
 ('country', 27),
 ('duration', 27),
 ('year', 26),
 ('brand', 24),
 ('manufacturer', 21),
 ('age', 17),
 ('county', 15),
 ('order', 15),
 ('sex', 15),
 ('address', 13),
 ('nationality', 11),
 ('language', 10),
 ('symbol', 9),
 ('birth place', 8),
 ('tax_id', 6),
 ('continent', 4),
 ('currency', 2),
 ('industry', 2)]

In [19]:
test_samples = pd.read_parquet('../data/data/raw/test_values.parquet')

In [20]:
idx = 57


original = test_samples.iloc[idx]
converted = original.apply(literal_eval).to_list()

print(f'Predicted "{predicted_labels[idx]}", actual label "{y_test[idx]}". Actual values:\n{converted}')

Predicted "address", actual label "address". Actual values:
[['Gillette, WY', 'Gillette, WY', '1720 W Warlow Drive, Gillette, WY', '2675 Ledoux Avenue, Gillette, WY', '2501 Ledoux Avenue, Gillette, WY', '4500 Running W Drive, Gillette, WY', '4500 Running W Drive, Gillette, WY']]


In [21]:
class_scores

[('guuid', {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 120}),
 ('birth date',
  {'precision': 0.9863013698630136,
   'recall': 1.0,
   'f1-score': 0.993103448275862,
   'support': 72}),
 ('industry',
  {'precision': 0.9692982456140351,
   'recall': 0.9954954954954955,
   'f1-score': 0.9822222222222222,
   'support': 444}),
 ('currency',
  {'precision': 0.9833333333333333,
   'recall': 0.9672131147540983,
   'f1-score': 0.9752066115702478,
   'support': 61}),
 ('sex',
  {'precision': 0.9688195991091314,
   'recall': 0.9666666666666667,
   'f1-score': 0.967741935483871,
   'support': 450}),
 ('symbol',
  {'precision': 0.9548872180451128,
   'recall': 0.9657794676806084,
   'f1-score': 0.9603024574669187,
   'support': 263}),
 ('year',
  {'precision': 0.9681818181818181,
   'recall': 0.9424778761061947,
   'f1-score': 0.9551569506726458,
   'support': 452}),
 ('county',
  {'precision': 0.9387308533916849,
   'recall': 0.9662162162162162,
   'f1-score': 0.9522752497225305