In [1]:
import sys
import json

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

if r'G:\PythonProjects\WineRecognition2' not in sys.path:
    sys.path.insert(0, r'G:\PythonProjects\WineRecognition2')
from nn.utils import CustomDataset, train, plot_losses, generate_tag_to_ix
from nn.model import BiLSTM_CRF
from nn.mlflow_utils import log_mlflow_on_train
from data_master import DataGenerator

In [2]:
MODEL_NAME = 'BiLSTM_CRF'
RUN_NAME = ''
START_TIME = ''
OUTPUT_DIR = ''
DATASET_PATH = '../data/text/Halliday_Wine_AU-only_completed_rows-complex.txt'
VOCAB_PATH = '../data/vocabs/Words_Halliday_Wine_AU.json'
DATAINFO_PATH = '../data_info.json'
DEVICE = 'cuda'
BATCH_SIZE = 128
EMBEDDING_DIM = 256
HIDDEN_DIM = 64
NUM_EPOCHS = 3
LEARNING_RATE = 0.01
WEIGHT_DECAY = 0.0001
TEST_SIZE = 0.2

In [3]:
with open(DATASET_PATH) as file:
    sents = DataGenerator.generate_sents2(file.read().split('\n'))
len(sents)

123761

In [4]:
train_data, val_data = train_test_split(sents, test_size=TEST_SIZE)
len(train_data), len(val_data)

(99008, 24753)

In [5]:
with open(DATAINFO_PATH) as file:
    tag_to_ix = generate_tag_to_ix(json.load(file)['keys']['all'] + ['Punctuation'])
tag_to_ix

{'Add_TradeName': 0,
 'Add_Brand': 1,
 'Add_KeyWordTrue': 2,
 'Add_KeyWordFalse': 3,
 'Add_GrapeVarieties': 4,
 'Add_GeoIndication': 5,
 'Add_WineType': 6,
 'Add_BottleSize': 7,
 'Add_Sweetness': 8,
 'Add_WineColor': 9,
 'Add_ClosureType': 10,
 'Add_Certificate': 11,
 'Add_Vintage': 12,
 'Add_Price': 13,
 'Punctuation': 14}

In [6]:
with open(VOCAB_PATH, 'r', encoding='utf-8') as file:
    word_to_ix = json.load(file)
len(word_to_ix)

12138

In [7]:
train_dataset = CustomDataset(train_data, tag_to_ix, word_to_ix)
val_dataset = CustomDataset(val_data, tag_to_ix, word_to_ix)

In [8]:
dataloaders = {
    'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True),
    'val': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
}

In [9]:
vocab_size = len(word_to_ix)
model = BiLSTM_CRF(vocab_size, len(tag_to_ix), EMBEDDING_DIM, HIDDEN_DIM, padding_idx=word_to_ix['PAD']).to(DEVICE)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [10]:
model, losses = train(model, optimizer, dataloaders, DEVICE, NUM_EPOCHS, OUTPUT_DIR, tqdm)

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 1, train_loss: -10611.389903408215, val_loss: -21248.693521997335
Epoch: 2, train_loss: -31493.361210710245, val_loss: -41983.60658809033
Epoch: 3, train_loss: -52092.255842962186, val_loss: -62628.21545873227


In [12]:
y_val_true = [x[1] for x in val_data]

In [24]:
y_val_pred = []
tags = list(tag_to_ix.keys())
model.eval()
with torch.no_grad():
    for x_batch, y_batch, mask_batch in dataloaders['val']:
        y_batch_pred = model(x_batch.to(DEVICE), mask_batch.to(DEVICE))
        y_val_pred.extend(y_batch_pred)
y_val_pred = [[tags[tag] for tag in sentence] for sentence in y_val_pred]

In [40]:
test_eval = [list(zip(sentence, tags, y_val_pred[index])) for index, (sentence, tags) in enumerate(val_data)]

In [63]:
run_params = {
    'model_name': MODEL_NAME,
    'run_name': RUN_NAME,
    'start_time': START_TIME,
    'output_dir': OUTPUT_DIR,
    'dataset_path': DATASET_PATH,
    'vocab_path': VOCAB_PATH,
    'datainfo_path': DATAINFO_PATH,
    'device': DEVICE,
    'batch_size': BATCH_SIZE,
    'embedding_dim': EMBEDDING_DIM,
    'hidden_dim': HIDDEN_DIM,
    'vocab_size': vocab_size,
    'tags': ', '.join(tag_to_ix),
    'num_epochs': NUM_EPOCHS,
    'learning_rate': LEARNING_RATE,
    'weight_decay': WEIGHT_DECAY,
    'test_size': TEST_SIZE
}

In [None]:
log_mlflow_on_train(
    run_params=run_params,
    model=model,
    classes=list(tag_to_ix),
    losses=losses,
    y_true=y_val_true,
    y_pred=y_val_pred,
    test_eval=test_eval
)