In [None]:
%%time
import sys
if not r'G:\PythonProjects\WineRecognition2' in sys.path:
    sys.path.insert(0, r'G:\PythonProjects\WineRecognition2')

import os
import json
import torch
import numpy as np
import pandas as pd
import sklearn_crfsuite
import eli5
from sklearn.model_selection import train_test_split

from data_master import DataGenerator, DataLoader
from mlflow_utils import log_mlflow_on_train
from features import features

In [None]:
%%time
DATASET_PATH = r"G:\PythonProjects\WineRecognition2\data\text\Halliday_Wine_AU-only_completed_rows-complex.txt"
LSTM_MODEL_PATH = r"G:\PythonProjects\WineRecognition2\artifacts\train\BiLSTM_CRF_10112021_030733\model\data\model.pth"
VOCAB_PATH = r"G:\PythonProjects\WineRecognition2\data\vocabs\Words_Halliday_Wine_AU.json"
CASE_SENSITIVE_VOCAB = False
DICTIONARY_PATH = r"G:\PythonProjects\WineRecognition2\data\dictionaries\Dict-byword_Halliday_Winesearcher_Wine_AU-only_completed_rows"
DEVICE = 'cuda'
MODEL_NAME = "CRF_with_LSTM_features"
ALGORITHM = 'lbfgs'
C1 = 0.1
C2 = 0.1
MAX_ITERATIONS = 5
ALL_POSSIBLE_TRANSITIONS = True
TEST_SIZE = 0.2
RUN_NAME = 'Train-LSTMfeatures'
OUTPUT_DIR = r"G:\PythonProjects\WineRecognition2\artifacts\train\test_exp"
START_TIME = ''

In [None]:
%%time
model = torch.load(LSTM_MODEL_PATH).to(DEVICE).eval()
freq_dict = DataLoader.load_frequency_dictionary(DICTIONARY_PATH)
with open(VOCAB_PATH, 'r', encoding='utf-8') as file:
    word_to_ix = json.load(file)
if not CASE_SENSITIVE_VOCAB:
    word_to_ix = {word.lower(): index for word, index in word_to_ix.items()}

In [None]:
%%time
# getting features from lstm

def get_lstm_features(model, x):
    x = model.embedding(x)
    x, _ = model.lstm(x)
    return x

def features_with_keys(sentence):
    return [{f'A{i}': feature for i, feature in enumerate(features)} for features in sentence]

unk = 'UNK' if CASE_SENSITIVE_VOCAB else 'unk'

with torch.no_grad():
    sents = DataGenerator.generate_sents2(open(DATASET_PATH).read().split('\n'))
    filename = os.path.splitext(os.path.split(DATASET_PATH)[-1])[0]
    metadata = {'features': [], 'labels': []}

    for x_sent, y_sent in sents:
        our_features = features.sent2features(list(zip(x_sent, y_sent)), freq_dict)
        
        if not CASE_SENSITIVE_VOCAB:
            x_sent = [word.lower() for word in x_sent]
            
        x_tensor = torch.tensor(
            [word_to_ix[word] if word in word_to_ix else word_to_ix[unk] for word in x_sent],
            dtype=torch.int64
        ).to(DEVICE)
        
        final_features = get_lstm_features(model, x_tensor.unsqueeze(0))
        final_features = features_with_keys(final_features.squeeze(0).detach().cpu().numpy())
        for i in range(len(x_sent)):
            final_features[i].update(our_features[i])
        metadata['features'].append(final_features)
        metadata['labels'].append(y_sent)
        
labels = metadata['labels']
x = metadata['features']
len(x), len(labels)

In [None]:
%%time
X_train, X_val, y_train, y_val, train_sents, val_sents = train_test_split(x, labels, sents, test_size=TEST_SIZE)
len(X_train), len(X_val), len(y_train), len(y_val), len(train_sents), len(val_sents)

In [None]:
%%time
model = sklearn_crfsuite.CRF(
    algorithm=ALGORITHM,
    c1=C1,
    c2=C2,
    max_iterations=MAX_ITERATIONS,
    all_possible_transitions=ALL_POSSIBLE_TRANSITIONS
)
model.fit(X_train, y_train)
eli5.show_weights(model, top=(30, 30))

In [None]:
%%time
y_pred = model.predict(X_val)

In [None]:
%%time
test_eval = [list(zip(sentence, tags, y_pred[index])) for index, (sentence, tags) in enumerate(val_sents)]

In [None]:
%%time
run_params = {
    'dataset_path': DATASET_PATH,
    'lstm_model_path': LSTM_MODEL_PATH,
    'vocab_path': VOCAB_PATH,
    'case_sensitive_vocab': CASE_SENSITIVE_VOCAB,
    'dictionary_path': DICTIONARY_PATH,
    'device': DEVICE,
    'model_name': MODEL_NAME,
    'algorithm': ALGORITHM,
    'c1': C1,
    'c2': C2 ,
    'max_iterations': MAX_ITERATIONS,
    'all_possible_transitions': ALL_POSSIBLE_TRANSITIONS,
    'test_size': TEST_SIZE,
    'runname': RUN_NAME,
    'start_time': START_TIME,
    'output_dir': OUTPUT_DIR
}

In [None]:
%%time
log_mlflow_on_train(
    run_params=run_params,
    model=model,
    y_true=y_val,
    y_pred=y_pred,
    test_eval=test_eval
)