In [1]:
%%time
import sys
if not r'G:\PythonProjects\WineRecognition2' in sys.path:
    sys.path.insert(0, r'G:\PythonProjects\WineRecognition2')

import os
import json
import torch
import numpy as np
import pandas as pd
import sklearn_crfsuite
import eli5
from sklearn.model_selection import train_test_split

from data_master import DataGenerator, DataLoader, count_unk_foreach_tag, compute_model_confidence
from nn.utils import CustomDataset, generate_tag_to_ix
from mlflow_utils import log_mlflow_on_train
from features.features import sent2features

  from pyarrow import HadoopFileSystem


Wall time: 2.17 s


In [2]:
%%time
DATASET_PATH = r'G:\PythonProjects\WineRecognition2\data\text\halliday_winesearcher_menu_gen_samplesv2\Halliday_WineSearcher_MenuGenSamples.txt'
LSTM_MODEL_PATH = 'G:/PythonProjects/WineRecognition2/artifacts/train/BiLSTM_CRF_17022022_185854/model/data/model.pth'
VOCAB_PATH = 'G:/PythonProjects/WineRecognition2/data/vocabs/Words_Halliday_Wine_AU_WORD_NUMS.json'
CASE_SENSITIVE_VOCAB = False
DICTIONARY_PATH = r"G:\PythonProjects\WineRecognition2\data\dictionaries\Dict-byword_Halliday_Winesearcher_Wine_AU-only_completed_rows"
DATAINFO_PATH = 'G:/PythonProjects/WineRecognition2/data_info.json'
DEVICE = 'cuda'
MODEL_NAME = "CRF_with_LSTM_features"
ALGORITHM = 'lbfgs'
C1 = 0.1
C2 = 0.1
MAX_ITERATIONS = 5
ALL_POSSIBLE_TRANSITIONS = True
TEST_SIZE = 0.2
RUN_NAME = 'Train-LSTMfeatures'
OUTPUT_DIR = r"G:\PythonProjects\WineRecognition2\artifacts\train\test_exp"
START_TIME = ''
USE_NUM2WORDS = True

Wall time: 0 ns


In [3]:
%%time
lstm_model = torch.load(LSTM_MODEL_PATH).to(DEVICE).eval()
freq_dict = DataLoader.load_frequency_dictionary(DICTIONARY_PATH, to_lowercase=True)

with open(VOCAB_PATH, encoding='utf-8') as file:
    word_to_ix = json.load(file)

with open(DATAINFO_PATH) as file:
    keys = json.load(file)['keys']['all']
    
tag_to_ix = generate_tag_to_ix(keys)
ix_to_tag = {value: key for key, value in tag_to_ix.items()}

tag_to_ix

Wall time: 1.71 s


{'Add_TradeName': 0,
 'Add_Brand': 1,
 'Add_KeyWordTrue': 2,
 'Add_KeyWordFalse': 3,
 'Add_GrapeVarieties': 4,
 'Add_GeoIndication': 5,
 'Add_WineType': 6,
 'Add_BottleSize': 7,
 'Add_Sweetness': 8,
 'Add_WineColor': 9,
 'Add_ClosureType': 10,
 'Add_Certificate': 11,
 'Add_Vintage': 12,
 'Add_Price': 13,
 'Punctuation': 14,
 'Other': 15}

In [4]:
sents = DataGenerator.generate_sents2(
    open(DATASET_PATH, encoding='utf-8').read().split('\n')
)
# sents = sents[:300]
dataset = CustomDataset(
    sents, tag_to_ix, word_to_ix, case_sensitive=CASE_SENSITIVE_VOCAB, prepare_dataset=False, convert_nums2words=USE_NUM2WORDS
)

In [5]:
def features_with_keys(sentence):
    return [{f'A{i}': feature for i, feature in enumerate(features)} for features in sentence]

def compute_features(x_sent, y_sent):
    our_features = sent2features(list(zip(x_sent, y_sent)), freq_dict)
    x_tensor = torch.tensor(dataset.sentence_to_indices(x_sent), dtype=torch.int64)
    final_features = lstm_model.get_lstm_features(x_tensor.to(DEVICE).unsqueeze(0))
    final_features = features_with_keys(final_features.squeeze(0).detach().cpu().numpy())

    for i in range(len(x_sent)):
        final_features[i].update(our_features[i])
    
    return final_features

In [6]:
%%time
X_tensors = (
    torch.tensor(dataset.sentence_to_indices(x_sent), dtype=torch.int64) for x_sent, _ in dataset.raw_data()
)
labels = [y_sent for _, y_sent in dataset.raw_data()]
len(labels)

Wall time: 0 ns


300

In [7]:
%%time
train_sents, val_sents = train_test_split(dataset.raw_data(), test_size=TEST_SIZE)
len(train_sents), len(val_sents)

Wall time: 0 ns


(240, 60)

In [15]:
%%time
with torch.no_grad():
    X_train =(compute_features(x_sent, y_sent) for x_sent, y_sent in train_sents)
    y_train = [y_sent for _, y_sent in train_sents]

    X_val = [compute_features(x_sent, y_sent) for x_sent, y_sent in val_sents]
    y_val = [y_sent for _, y_sent in val_sents]

Wall time: 10.1 s


In [16]:
%%time
with torch.no_grad():
    crf_model = sklearn_crfsuite.CRF(
        algorithm=ALGORITHM,
        c1=C1,
        c2=C2,
        max_iterations=MAX_ITERATIONS,
        all_possible_transitions=ALL_POSSIBLE_TRANSITIONS
    )
    crf_model.fit(X_train, y_train)
eli5.show_weights(crf_model, top=(30, 30))

Wall time: 45.7 s


From \ To,Add_BottleSize,Add_Brand,Add_ClosureType,Add_GeoIndication,Add_GrapeVarieties,Add_KeyWordFalse,Add_KeyWordTrue,Add_Price,Add_Sweetness,Add_TradeName,Add_Vintage,Add_WineColor,Add_WineType,Punctuation
Add_BottleSize,0.862,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_Brand,-0.099,1.72,0.0,0.0,0.076,0.0,0.0,0.0,0.0,-0.588,-0.078,0.0,0.0,0.0
Add_ClosureType,0.0,0.0,0.648,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_GeoIndication,0.0,0.0,0.0,0.495,0.0,0.092,-0.331,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_GrapeVarieties,0.0,0.0,0.0,0.001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_KeyWordFalse,0.183,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_KeyWordTrue,-0.221,-0.451,0.0,0.0,0.0,0.0,1.043,0.0,0.0,0.0,0.0,0.0,0.0,0.266
Add_Price,-0.063,0.0,0.0,0.0,0.0,0.0,0.0,0.95,0.0,0.0,-0.402,0.0,0.0,0.0
Add_Sweetness,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Add_TradeName,0.0,0.701,0.0,0.0,0.0,0.0,-0.117,0.0,0.0,0.81,-0.008,0.0,0.0,0.336

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0,Unnamed: 13_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5,Unnamed: 9_level_5,Unnamed: 10_level_5,Unnamed: 11_level_5,Unnamed: 12_level_5,Unnamed: 13_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6,Unnamed: 9_level_6,Unnamed: 10_level_6,Unnamed: 11_level_6,Unnamed: 12_level_6,Unnamed: 13_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7,Unnamed: 9_level_7,Unnamed: 10_level_7,Unnamed: 11_level_7,Unnamed: 12_level_7,Unnamed: 13_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8,Unnamed: 9_level_8,Unnamed: 10_level_8,Unnamed: 11_level_8,Unnamed: 12_level_8,Unnamed: 13_level_8
Weight?,Feature,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,Unnamed: 8_level_9,Unnamed: 9_level_9,Unnamed: 10_level_9,Unnamed: 11_level_9,Unnamed: 12_level_9,Unnamed: 13_level_9
Weight?,Feature,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,Unnamed: 9_level_10,Unnamed: 10_level_10,Unnamed: 11_level_10,Unnamed: 12_level_10,Unnamed: 13_level_10
Weight?,Feature,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,Unnamed: 10_level_11,Unnamed: 11_level_11,Unnamed: 12_level_11,Unnamed: 13_level_11
Weight?,Feature,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,Unnamed: 11_level_12,Unnamed: 12_level_12,Unnamed: 13_level_12
Weight?,Feature,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,Unnamed: 6_level_13,Unnamed: 7_level_13,Unnamed: 8_level_13,Unnamed: 9_level_13,Unnamed: 10_level_13,Unnamed: 11_level_13,Unnamed: 12_level_13,Unnamed: 13_level_13
+0.944,A15,,,,,,,,,,,,
+0.847,A56,,,,,,,,,,,,
+0.774,A70,,,,,,,,,,,,
+0.721,A111,,,,,,,,,,,,
+0.712,A2,,,,,,,,,,,,
+0.703,A17,,,,,,,,,,,,
+0.629,A6,,,,,,,,,,,,
+0.591,A95,,,,,,,,,,,,
+0.578,A37,,,,,,,,,,,,
+0.558,A72,,,,,,,,,,,,

Weight?,Feature
+0.944,A15
+0.847,A56
+0.774,A70
+0.721,A111
+0.712,A2
+0.703,A17
+0.629,A6
+0.591,A95
+0.578,A37
+0.558,A72

Weight?,Feature
+1.322,A20
+1.185,A17
+1.031,A100
+1.001,A36
+0.983,A10
+0.957,A82
+0.909,A51
+0.853,A18
+0.798,A66
+0.747,Add_Brand

Weight?,Feature
+1.033,Add_ClosureType
+0.810,A85
+0.800,A90
+0.659,A74
+0.559,A35
+0.549,A13
+0.516,A126
+0.467,A54
+0.390,A103
+0.329,A81

Weight?,Feature
+1.516,Add_GeoIndication
+0.964,A118
+0.943,A85
+0.757,A68
+0.752,A108
+0.635,A19
+0.588,A77
+0.587,A70
+0.573,A27
+0.519,+1:BGram.Add_GeoIndication

Weight?,Feature
+1.393,Add_GrapeVarieties
+1.006,A104
+0.948,A36
+0.787,A103
+0.743,A35
+0.716,A105
+0.641,A116
+0.625,A92
+0.596,A102
+0.589,A33

Weight?,Feature
+1.186,word[-3:]:ies
+1.186,word.lower():series
+0.872,Add_KeyWordFalse
+0.679,A6
+0.524,A120
+0.458,A99
+0.456,A1
+0.359,A116
+0.332,A28
+0.316,A44

Weight?,Feature
+1.104,A28
+0.959,+1:BGram.Add_KeyWordTrue
+0.955,A86
+0.924,A46
+0.873,A98
+0.804,A4
+0.756,A95
+0.751,A66
+0.704,A62
+0.625,A6

Weight?,Feature
+1.354,A46
+0.981,A54
+0.957,A62
+0.952,A122
+0.739,A29
+0.690,A45
+0.681,A0
+0.651,A30
+0.622,A50
+0.585,A111

Weight?,Feature
+0.727,Add_Sweetness
+0.727,word[-3:]:dry
+0.727,word.lower():dry
+0.616,A90
+0.590,A27
+0.566,A49
+0.535,A118
+0.475,A95
+0.449,A72
+0.435,A89

Weight?,Feature
+1.205,A48
+1.030,A27
+1.007,A11
+0.997,A41
+0.814,A96
+0.780,A26
+0.767,A73
+0.698,A65
+0.689,Add_TradeName
+0.665,A120

Weight?,Feature
+0.981,A15
+0.824,A1
+0.766,+1:word.lower():thousand
+0.743,A8
+0.705,A49
+0.631,A71
+0.562,A3
+0.548,A41
+0.536,A39
+0.520,A5

Weight?,Feature
+1.165,Add_WineColor
+0.747,A103
+0.656,A56
+0.616,A68
+0.604,A70
+0.591,A87
+0.508,A102
+0.499,A23
+0.435,A92
+0.435,A52

Weight?,Feature
+0.923,A63
+0.693,Add_WineType
+0.693,word[-3:]:ill
+0.693,word.lower():still
+0.618,A33
+0.466,A44
+0.438,A1
+0.434,A48
+0.430,A38
+0.403,A89

Weight?,Feature
+0.901,A74
+0.874,A39
+0.787,A94
+0.597,A101
+0.586,A52
+0.564,A45
+0.554,A122
+0.482,A14
+0.430,A21
+0.412,A114


In [17]:
%%time
with torch.no_grad():
    y_pred = crf_model.predict(X_val)
    marginals = crf_model.predict_marginals(X_val)

Wall time: 90 ms


In [20]:
unk_foreach_tag = count_unk_foreach_tag(X_tensors, labels, keys, dataset.word_to_ix[dataset.unk])

In [21]:
confs = compute_model_confidence(marginals)
prob_table = DataGenerator.generate_probability_table(marginals, val_sents)

In [22]:
%%time
test_eval = [list(zip(sentence, tags, y_pred[index])) for index, (sentence, tags) in enumerate(val_sents)]

Wall time: 0 ns


In [23]:
%%time
run_params = {
    'dataset_path': DATASET_PATH,
    'lstm_model_path': LSTM_MODEL_PATH,
    'vocab_path': VOCAB_PATH,
    'case_sensitive_vocab': CASE_SENSITIVE_VOCAB,
    'dictionary_path': DICTIONARY_PATH,
    'datainfo_path': DATAINFO_PATH,
    'device': DEVICE,
    'model_name': MODEL_NAME,
    'algorithm': ALGORITHM,
    'c1': C1,
    'c2': C2 ,
    'max_iterations': MAX_ITERATIONS,
    'all_possible_transitions': ALL_POSSIBLE_TRANSITIONS,
    'test_size': TEST_SIZE,
    'runname': RUN_NAME,
    'start_time': START_TIME,
    'output_dir': OUTPUT_DIR,
    'models_confidence': np.mean(confs),
    'unk_foreach_tag': json.dumps(unk_foreach_tag),
    'prob_table': prob_table,
    'use_num2words': USE_NUM2WORDS
}

Wall time: 0 ns


In [None]:
%%time
log_mlflow_on_train(
    run_params=run_params,
    model=crf_model,
    y_true=y_val,
    y_pred=y_pred,
    test_eval=test_eval
)