In [1]:
%%time
import sys
if not r'G:\PythonProjects\WineRecognition2' in sys.path:
    sys.path.insert(0, r'G:\PythonProjects\WineRecognition2')

import os
import json
import torch
import numpy as np
import pandas as pd
import sklearn_crfsuite
import eli5
from sklearn.model_selection import train_test_split

from data_master import DataGenerator, DataLoader, count_unk_foreach_tag, compute_model_confidence
from mlflow_utils import log_mlflow_on_train
from features import features

Wall time: 1.95 s


In [2]:
%%time
DATASET_PATH = r"G:\PythonProjects\WineRecognition2\data\text\Halliday_Wine_AU-only_completed_rows-complex.txt"
LSTM_MODEL_PATH = r"G:\PythonProjects\WineRecognition2\artifacts\train\BiLSTM_CRF_10112021_030733\model\data\model.pth"
VOCAB_PATH = r"G:\PythonProjects\WineRecognition2\data\vocabs\Words_Halliday_Wine_AU.json"
CASE_SENSITIVE_VOCAB = False
DICTIONARY_PATH = r"G:\PythonProjects\WineRecognition2\data\dictionaries\Dict-byword_Halliday_Winesearcher_Wine_AU-only_completed_rows"
DATAINFO_PATH = 'G:/PythonProjects/WineRecognition2/data_info.json'
DEVICE = 'cuda'
MODEL_NAME = "CRF_with_LSTM_features"
ALGORITHM = 'lbfgs'
C1 = 0.1
C2 = 0.1
MAX_ITERATIONS = 5
ALL_POSSIBLE_TRANSITIONS = True
TEST_SIZE = 0.2
RUN_NAME = 'Train-LSTMfeatures'
OUTPUT_DIR = r"G:\PythonProjects\WineRecognition2\artifacts\train\test_exp"
START_TIME = ''

Wall time: 0 ns


In [3]:
%%time
model = torch.load(LSTM_MODEL_PATH).to(DEVICE).eval()
freq_dict = DataLoader.load_frequency_dictionary(DICTIONARY_PATH)
with open(VOCAB_PATH, encoding='utf-8') as file:
    word_to_ix = json.load(file)
if not CASE_SENSITIVE_VOCAB:
    word_to_ix = {word.lower(): index for word, index in word_to_ix.items()}

Wall time: 1.65 s


In [4]:
%%time
# getting features from lstm

def get_lstm_features(model, x):
    x = model.embedding(x)
    x, _ = model.lstm(x)
    return x

def features_with_keys(sentence):
    return [{f'A{i}': feature for i, feature in enumerate(features)} for features in sentence]

unk = 'UNK' if CASE_SENSITIVE_VOCAB else 'unk'
X_tensors = []

with torch.no_grad():
    sents = DataGenerator.generate_sents2(open(DATASET_PATH, encoding='utf-8').read().split('\n'))
    filename = os.path.splitext(os.path.split(DATASET_PATH)[-1])[0]
    metadata = {'features': [], 'labels': []}

    for x_sent, y_sent in sents:
        our_features = features.sent2features(list(zip(x_sent, y_sent)), freq_dict)
        
        if not CASE_SENSITIVE_VOCAB:
            x_sent = [word.lower() for word in x_sent]
            
        x_tensor = torch.tensor(
            [word_to_ix[word] if word in word_to_ix else word_to_ix[unk] for word in x_sent],
            dtype=torch.int64
        )
        X_tensors.append(x_tensor)
        
        final_features = get_lstm_features(model, x_tensor.to(DEVICE).unsqueeze(0))
        final_features = features_with_keys(final_features.squeeze(0).detach().cpu().numpy())
        for i in range(len(x_sent)):
            final_features[i].update(our_features[i])
        metadata['features'].append(final_features)
        metadata['labels'].append(y_sent)
        
labels = metadata['labels']
x = metadata['features']
len(x), len(labels)

Wall time: 26.4 s


(100, 100)

In [5]:
%%time
X_train, X_val, y_train, y_val, train_sents, val_sents = train_test_split(x, labels, sents, test_size=TEST_SIZE)
len(X_train), len(X_val), len(y_train), len(y_val), len(train_sents), len(val_sents)

Wall time: 0 ns


(80, 20, 80, 20, 80, 20)

In [6]:
%%time
model = sklearn_crfsuite.CRF(
    algorithm=ALGORITHM,
    c1=C1,
    c2=C2,
    max_iterations=MAX_ITERATIONS,
    all_possible_transitions=ALL_POSSIBLE_TRANSITIONS
)
model.fit(X_train, y_train)
eli5.show_weights(model, top=(30, 30))

Wall time: 263 ms




From \ To,Add_BottleSize,Add_Brand,Add_ClosureType,Add_GeoIndication,Add_GrapeVarieties,Add_KeyWordTrue,Add_Price,Add_Sweetness,Add_TradeName,Add_Vintage,Add_WineColor,Add_WineType,Punctuation
Add_BottleSize,-0.048,-0.029,0.025,-0.027,-0.038,-0.025,0.066,0.024,-0.043,0.001,0.047,0.008,-0.022
Add_Brand,-0.033,0.144,-0.046,0.01,0.055,0.122,-0.036,-0.035,-0.194,-0.024,-0.018,-0.04,0.0
Add_ClosureType,-0.003,-0.033,0.196,-0.032,-0.054,-0.026,0.007,-0.025,-0.046,0.016,0.017,0.029,-0.024
Add_GeoIndication,-0.019,-0.062,-0.032,0.185,0.083,-0.033,-0.037,0.021,-0.072,-0.027,0.016,-0.017,-0.028
Add_GrapeVarieties,0.027,-0.09,-0.003,0.017,0.056,-0.002,-0.042,0.088,-0.078,-0.026,0.012,0.022,0.067
Add_KeyWordTrue,-0.0,-0.026,-0.004,0.035,0.034,0.158,-0.022,-0.02,-0.039,-0.016,-0.021,-0.026,-0.014
Add_Price,-0.026,-0.025,-0.044,-0.021,-0.035,-0.015,-0.031,-0.022,-0.039,-0.021,-0.025,-0.029,-0.015
Add_Sweetness,0.026,-0.027,0.003,-0.029,-0.041,-0.024,0.069,-0.034,-0.052,0.019,0.088,0.056,-0.021
Add_TradeName,-0.035,0.114,-0.039,-0.079,-0.107,-0.033,-0.037,-0.033,0.375,-0.026,-0.036,-0.037,0.13
Add_Vintage,-0.032,-0.032,-0.032,-0.022,-0.027,-0.021,0.266,-0.023,0.231,-0.026,-0.029,-0.033,-0.021

Weight?,Feature,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,Unnamed: 11_level_0,Unnamed: 12_level_0
Weight?,Feature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
Weight?,Feature,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2
Weight?,Feature,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3
Weight?,Feature,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4
Weight?,Feature,Unnamed: 2_level_5,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5,Unnamed: 7_level_5,Unnamed: 8_level_5,Unnamed: 9_level_5,Unnamed: 10_level_5,Unnamed: 11_level_5,Unnamed: 12_level_5
Weight?,Feature,Unnamed: 2_level_6,Unnamed: 3_level_6,Unnamed: 4_level_6,Unnamed: 5_level_6,Unnamed: 6_level_6,Unnamed: 7_level_6,Unnamed: 8_level_6,Unnamed: 9_level_6,Unnamed: 10_level_6,Unnamed: 11_level_6,Unnamed: 12_level_6
Weight?,Feature,Unnamed: 2_level_7,Unnamed: 3_level_7,Unnamed: 4_level_7,Unnamed: 5_level_7,Unnamed: 6_level_7,Unnamed: 7_level_7,Unnamed: 8_level_7,Unnamed: 9_level_7,Unnamed: 10_level_7,Unnamed: 11_level_7,Unnamed: 12_level_7
Weight?,Feature,Unnamed: 2_level_8,Unnamed: 3_level_8,Unnamed: 4_level_8,Unnamed: 5_level_8,Unnamed: 6_level_8,Unnamed: 7_level_8,Unnamed: 8_level_8,Unnamed: 9_level_8,Unnamed: 10_level_8,Unnamed: 11_level_8,Unnamed: 12_level_8
Weight?,Feature,Unnamed: 2_level_9,Unnamed: 3_level_9,Unnamed: 4_level_9,Unnamed: 5_level_9,Unnamed: 6_level_9,Unnamed: 7_level_9,Unnamed: 8_level_9,Unnamed: 9_level_9,Unnamed: 10_level_9,Unnamed: 11_level_9,Unnamed: 12_level_9
Weight?,Feature,Unnamed: 2_level_10,Unnamed: 3_level_10,Unnamed: 4_level_10,Unnamed: 5_level_10,Unnamed: 6_level_10,Unnamed: 7_level_10,Unnamed: 8_level_10,Unnamed: 9_level_10,Unnamed: 10_level_10,Unnamed: 11_level_10,Unnamed: 12_level_10
Weight?,Feature,Unnamed: 2_level_11,Unnamed: 3_level_11,Unnamed: 4_level_11,Unnamed: 5_level_11,Unnamed: 6_level_11,Unnamed: 7_level_11,Unnamed: 8_level_11,Unnamed: 9_level_11,Unnamed: 10_level_11,Unnamed: 11_level_11,Unnamed: 12_level_11
Weight?,Feature,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,Unnamed: 5_level_12,Unnamed: 6_level_12,Unnamed: 7_level_12,Unnamed: 8_level_12,Unnamed: 9_level_12,Unnamed: 10_level_12,Unnamed: 11_level_12,Unnamed: 12_level_12
+0.421,A8,,,,,,,,,,,
+0.360,A73,,,,,,,,,,,
+0.339,A6,,,,,,,,,,,
+0.296,A124,,,,,,,,,,,
+0.291,A93,,,,,,,,,,,
+0.287,A108,,,,,,,,,,,
+0.286,A105,,,,,,,,,,,
+0.285,Add_BottleSize,,,,,,,,,,,
+0.285,word.lower():750.0,,,,,,,,,,,
+0.264,word[-3:]:0.0,,,,,,,,,,,

Weight?,Feature
+0.421,A8
+0.360,A73
+0.339,A6
+0.296,A124
+0.291,A93
+0.287,A108
+0.286,A105
+0.285,Add_BottleSize
+0.285,word.lower():750.0
+0.264,word[-3:]:0.0

Weight?,Feature
+0.481,A66
+0.412,A104
+0.357,A37
+0.351,A40
+0.350,A5
+0.348,A58
+0.335,A28
+0.312,A46
+0.306,A14
+0.295,A13

Weight?,Feature
+0.444,A125
+0.412,A121
+0.376,A89
+0.368,A7
+0.346,Add_ClosureType
+0.283,A63
+0.264,A91
+0.252,A112
+0.247,A116
+0.234,A102

Weight?,Feature
+0.505,A67
+0.422,Add_GeoIndication
+0.379,A38
+0.344,A11
+0.337,A25
+0.332,A81
+0.308,A32
+0.292,A92
+0.272,A86
+0.261,A63

Weight?,Feature
+0.526,Add_GrapeVarieties
+0.459,A23
+0.416,A100
+0.393,A74
+0.351,A1
+0.347,A109
+0.329,A77
+0.303,A13
+0.286,A126
+0.284,A37

Weight?,Feature
+0.303,A119
+0.293,A11
+0.274,A88
+0.258,A0
+0.239,A10
+0.234,A26
+0.233,A123
+0.232,A92
+0.218,A71
+0.214,A86

Weight?,Feature
+0.666,Add_Price
+0.609,isNumber(word)
+0.459,A14
+0.449,A57
+0.390,A65
+0.386,A124
+0.379,A31
+0.364,A78
+0.363,A0
+0.338,A119

Weight?,Feature
+0.390,word[-3:]:Dry
+0.390,Add_Sweetness
+0.390,word.lower():dry
+0.357,A108
+0.352,A119
+0.322,A113
+0.279,A9
+0.270,A56
+0.267,A67
+0.266,A10

Weight?,Feature
+0.562,A62
+0.481,A50
+0.462,A71
+0.448,Add_TradeName
+0.443,A44
+0.423,A25
+0.413,A83
+0.400,A95
+0.397,A36
+0.382,A102

Weight?,Feature
+0.556,Add_Vintage
+0.502,A118
+0.471,isNumber(word)
+0.453,A94
+0.431,A85
+0.395,A100
+0.390,A49
+0.388,A9
+0.345,A119
+0.328,A56

Weight?,Feature
+0.386,Add_WineColor
+0.379,A69
+0.341,A73
+0.293,A7
+0.288,A100
+0.286,A89
+0.277,Add_Brand
+0.274,A9
+0.265,A93
+0.254,Add_TradeName

Weight?,Feature
+0.407,A121
+0.352,A55
+0.332,A88
+0.326,A61
+0.321,A93
+0.320,word.lower():still
+0.320,word[-3:]:ill
+0.320,Add_WineType
+0.308,A75
+0.298,A39

Weight?,Feature
+0.373,A11
+0.347,A1
+0.326,A124
+0.319,A98
+0.316,A4
+0.313,A113
+0.300,A120
+0.294,A77
+0.291,A95
+0.274,A114


In [7]:
%%time
y_pred = model.predict(X_val)
marginals = model.predict_marginals(X_val)

Wall time: 10 ms


In [8]:
with open(DATAINFO_PATH) as file:
    unk_foreach_tag = count_unk_foreach_tag(X_tensors, labels, json.load(file)['keys']['all'], word_to_ix[unk])

In [11]:
confs = compute_model_confidence(marginals)
prob_table = DataGenerator.generate_probability_table(marginals, val_sents)

In [14]:
%%time
test_eval = [list(zip(sentence, tags, y_pred[index])) for index, (sentence, tags) in enumerate(val_sents)]

Wall time: 999 µs


In [15]:
%%time
run_params = {
    'dataset_path': DATASET_PATH,
    'lstm_model_path': LSTM_MODEL_PATH,
    'vocab_path': VOCAB_PATH,
    'case_sensitive_vocab': CASE_SENSITIVE_VOCAB,
    'dictionary_path': DICTIONARY_PATH,
    'datainfo_path': DATAINFO_PATH,
    'device': DEVICE,
    'model_name': MODEL_NAME,
    'algorithm': ALGORITHM,
    'c1': C1,
    'c2': C2 ,
    'max_iterations': MAX_ITERATIONS,
    'all_possible_transitions': ALL_POSSIBLE_TRANSITIONS,
    'test_size': TEST_SIZE,
    'runname': RUN_NAME,
    'start_time': START_TIME,
    'output_dir': OUTPUT_DIR,
    'models_confidence': np.mean(confs),
    'unk_foreach_tag': json.dumps(unk_foreach_tag),
    'prob_table': prob_table
}

Wall time: 0 ns


In [None]:
%%time
log_mlflow_on_train(
    run_params=run_params,
    model=model,
    y_true=y_val,
    y_pred=y_pred,
    test_eval=test_eval
)