# Entity extraction using Fasttext and LSTM

## Import everything important

In [7]:
import joblib
import torch
import torch.nn as nn
import transformers
import nltk

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn import model_selection

from gensim.models import fasttext as ft

from torch.utils.tensorboard import SummaryWriter
import numpy as np
from nltk.tokenize import word_tokenize

from tqdm import tqdm

## Some config

In [8]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 128 #
VALID_BATCH_SIZE = 128 #
EPOCHS = 5
EMBED_DIM = 300

MODEL_PATH = "./state_dict.pt"
TRAINING_FILE = 'ner_dataset.csv' #

In [9]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
) #
device

'cuda:0'

In [11]:
fasttext_model = ft.load_facebook_vectors("../cc.en.300.bin")

## Dataset

In [12]:
class EntityDataset:
    def __init__(self, texts, tags):
        self.texts = texts
        self.tags = tags
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        text = self.texts[item]
        tags = np.array(self.tags[item])
    
        ids = [fasttext_model[s] for s in text if str(s) != 'nan']
        
# реализуем паддинг: если текст меньше MAX_LEN, 
# забиваем матрицу нулями так, чтобы матрицы были одинакового размера для каждого текста   
# возвращаем матрицу с нулями, список тегов с нулями, и маску -- разметку, какие элементы 
# являются словами, а какие -- пустые поля (нули). Маска нужна, чтобы правильно считать лосс
# то есть, не учитывать в нем "пустые" части с нулями
        
        if len(ids) < MAX_LEN:
            ids_pad = np.array(ids + [[0]*len(ids[0])]*(MAX_LEN - len(ids)))
            tags_pad = list(tags) + [0]*(MAX_LEN - len(tags))
            mask = [1]*len(ids) + [0]*(MAX_LEN - len(ids))
        else:
            ids_pad = np.array(ids[:MAX_LEN])
            tags_pad = list(tags)[:MAX_LEN]
            mask = [1]*MAX_LEN
            
        return (torch.tensor(ids_pad, dtype=torch.float32),
                torch.tensor(tags_pad, dtype=torch.long),
                torch.tensor(mask, dtype=torch.long))

## Training and evaluation functions

In [13]:
def loss_fn(output, target, mask, num_labels):
    lfn = nn.CrossEntropyLoss()
    active_loss = mask.view(-1) == 1
    active_logits = output.view(-1, num_labels)
    active_labels = torch.where(
        active_loss,
        target.view(-1),
        torch.tensor(lfn.ignore_index).type_as(target)
    )
    loss = lfn(active_logits, active_labels)
    return loss

In [14]:
def acc_stat(pred, target, mask):
    mask = mask.bool()
    pred = torch.masked_select(pred, mask)
    target = torch.masked_select(target, mask)
    correct = torch.tensor(torch.eq(pred, target).sum().item(),dtype=torch.float32) # сколько элементов угадано корректно
    total = torch.tensor(len(pred), dtype=torch.float32) # сколько элементов было всего, не считая "пустых" с нулями
    return correct, total

пример того как должно работать

In [15]:
acc_stat(torch.tensor([1,2,3,4,0,0,0,0]), torch.tensor([1,2,3,4,5,5,5,5]), torch.tensor([1,1,1,1,0,0,0,0]))

(tensor(4.), tensor(4.))

In [16]:
acc_stat(torch.tensor([1,2,3,4,0,0,0,0]), torch.tensor([1,2,3,4,5,5,5,5]), torch.tensor([0,0,0,0,1,1,1,1]))

(tensor(0.), tensor(4.))

## Loss function and model

In [17]:
class EntityModel(nn.Module):
    def __init__(self, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5, bidirectional=False):
        super().__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob,
                            batch_first=True, bidirectional=bidirectional)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), output_size)
    
    def forward(
        self, 
        embeds,
        hidden
    ):
        lstm_out, hidden = self.lstm(embeds, hidden)
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim * (2 if self.bidirectional else 1))  # Учитываем bidirectional в размерности
        
# здесь пропустим через дропаут и линейный слой
        dropout_out = self.dropout(lstm_out)
        out = self.fc(dropout_out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        num_directions = 2 if self.lstm.bidirectional else 1
        print("num_directions", num_directions)
        h_zeros = torch.zeros(self.n_layers * num_directions,
                              batch_size, self.hidden_dim,
                              dtype=torch.float32, device=device)
        c_zeros = torch.zeros(self.n_layers * num_directions,
                              batch_size, self.hidden_dim,
                              dtype=torch.float32, device=device)

        return (h_zeros, c_zeros)

## Data processing

In [18]:
def process_data(data_path):
    df = pd.read_csv(data_path, encoding="latin-1")
    df.loc[:, "Sentence #"] = df["Sentence #"].fillna(method="ffill")

    enc_tag = preprocessing.LabelEncoder()

    df[df.columns[df.columns.get_loc("Tag")]] = enc_tag.fit_transform(df["Tag"])

    sentences = df.groupby("Sentence #")["Word"].apply(list).values
    tag = df.groupby("Sentence #")["Tag"].apply(list).values
    return sentences, tag, enc_tag

## Training

In [19]:
sentences, tag, enc_tag = process_data(TRAINING_FILE)

In [20]:
from sklearn.model_selection import train_test_split #

meta_data = {
    "enc_tag": enc_tag
}

joblib.dump(meta_data, "meta.bin")

num_tag = len(list(enc_tag.classes_))

(
    train_sentences,
    test_sentences,
    train_tag,
    test_tag
) = train_test_split(sentences, tag, test_size=0.2) # делим на трейн и тест с помощью train_test_split

In [21]:
train_dataset = EntityDataset(
    texts=train_sentences, tags=train_tag
)

train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, num_workers=0, #
    shuffle=True, drop_last=True
)

valid_dataset = EntityDataset(
    texts=test_sentences, tags=test_tag
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=VALID_BATCH_SIZE, num_workers=0, #
    shuffle=False, drop_last=True
)

In [22]:
def eval_model(model, valid_data_loader):
    h = model.init_hidden(VALID_BATCH_SIZE)
    losses = []
    
    correct_sum, total_sum = 0, 0
    
    for inputs, labels, mask in valid_data_loader:
        h = tuple([each.data for each in h])
        # отправим inputs, labels и mask на GPU
        if torch.cuda.is_available():
            inputs = inputs.cuda()
            labels = labels.cuda()
            mask = mask.cuda()
            
        model.zero_grad()
        output, h = model(inputs, h)
        loss = loss_fn(output, labels.flatten(), mask, num_tag)
        losses.append(loss.item())
        
        correct, total = acc_stat(torch.argmax(output, dim=-1).flatten(), labels.flatten(), mask.flatten())
        correct_sum += correct
        total_sum += total
    return losses, correct_sum / total_sum

In [23]:
import gc
hidden_dim = 512
n_layers = 2

model = EntityModel(num_tag, EMBED_DIM, hidden_dim, n_layers, drop_prob=0.5, bidirectional=True)
model.to(device)

lr=0.005
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

scaler = torch.cuda.amp.GradScaler()
torch.cuda.empty_cache()
gc.collect()

0

In [None]:
counter = 0
print_every = 10
clip = 5
valid_loss_min = np.Inf
writer = SummaryWriter('logs')


model.train()
for i in range(EPOCHS):
    h = model.init_hidden(TRAIN_BATCH_SIZE)
    
    correct_sum, total_sum = 0, 0
    
    for inputs, labels, mask in tqdm(train_data_loader): #
        counter += 1
        h = tuple([e.data for e in h])

        inputs = inputs.to(device)
        labels = labels.to(device)
        mask = mask.to(device)
        model.zero_grad()

        with torch.cuda.amp.autocast(enabled=True):
            output, h = model(inputs, h)

        loss = loss_fn(output, labels.flatten(), mask, num_tag) # вызываем функцию для подсчета лосса
        correct, total = acc_stat(torch.argmax(output, dim=-1).flatten(), labels.flatten(), mask.flatten())# вызываем функцию acc_stat
        correct_sum += correct
        total_sum += total

        nn.utils.clip_grad_norm_(model.parameters(), clip)
        # градиентный спуск
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if counter % print_every == 0:
            model.eval()
            val_losses, val_acc = eval_model(model, valid_data_loader)
            model.train()
            
            val_loss = np.mean(val_losses)
            writer.add_scalar('train/loss', loss.item(), counter)
            writer.add_scalar('val/loss', val_loss, counter)
            writer.add_scalar('train/acc', correct_sum / total_sum, counter)
            writer.add_scalar('val/acc', val_acc, counter)

            print("Epoch: {}/{}...".format(i+1, EPOCHS),
                  "Step: {}...".format(counter),
                  "Loss: {:.6f}...".format(loss.item()),
                  "Val Loss: {:.6f}".format(val_loss),
                  "Train Acc: {:.6f}".format(correct_sum / total_sum),
                  "Val Acc: {:.6f}".format(val_acc))
                
            if np.mean(val_losses) <= valid_loss_min:
                torch.save(model.state_dict(), MODEL_PATH)
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,np.mean(val_losses)))
                valid_loss_min = np.mean(val_losses)

2023-11-26 10:09:10.619059: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-26 10:09:10.663076: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


num_directions 2


  3%|▎         | 9/299 [00:02<01:33,  3.10it/s]

num_directions 2


  3%|▎         | 10/299 [00:25<34:40,  7.20s/it]

Epoch: 1/5... Step: 10... Loss: 1.259766... Val Loss: 0.789167 Train Acc: 0.634344 Val Acc: 0.847605
Validation loss decreased (inf --> 0.789167).  Saving model ...


  6%|▋         | 19/299 [00:28<02:47,  1.67it/s]

num_directions 2


  7%|▋         | 20/299 [00:51<33:27,  7.19s/it]

Epoch: 1/5... Step: 20... Loss: 0.949219... Val Loss: 0.795295 Train Acc: 0.739861 Val Acc: 0.847605


 10%|▉         | 29/299 [00:53<02:42,  1.66it/s]

num_directions 2


 10%|█         | 30/299 [01:16<32:41,  7.29s/it]

Epoch: 1/5... Step: 30... Loss: 0.872070... Val Loss: 0.838794 Train Acc: 0.776051 Val Acc: 0.847605


 13%|█▎        | 39/299 [01:19<02:36,  1.66it/s]

num_directions 2


 13%|█▎        | 40/299 [01:42<31:21,  7.26s/it]

Epoch: 1/5... Step: 40... Loss: 0.873047... Val Loss: 0.797717 Train Acc: 0.792511 Val Acc: 0.847605


 16%|█▋        | 49/299 [01:45<02:30,  1.66it/s]

num_directions 2


 17%|█▋        | 50/299 [02:08<29:56,  7.22s/it]

Epoch: 1/5... Step: 50... Loss: 0.738281... Val Loss: 0.773464 Train Acc: 0.803903 Val Acc: 0.847605
Validation loss decreased (0.789167 --> 0.773464).  Saving model ...


 20%|█▉        | 59/299 [02:11<02:23,  1.67it/s]

num_directions 2


 20%|██        | 60/299 [02:34<29:10,  7.33s/it]

Epoch: 1/5... Step: 60... Loss: 0.809082... Val Loss: 0.759939 Train Acc: 0.811135 Val Acc: 0.847605
Validation loss decreased (0.773464 --> 0.759939).  Saving model ...


 23%|██▎       | 69/299 [02:37<02:23,  1.60it/s]

num_directions 2


 23%|██▎       | 70/299 [03:00<28:47,  7.54s/it]

Epoch: 1/5... Step: 70... Loss: 0.743164... Val Loss: 0.757215 Train Acc: 0.816309 Val Acc: 0.847605
Validation loss decreased (0.759939 --> 0.757215).  Saving model ...


 26%|██▋       | 79/299 [03:03<02:20,  1.57it/s]

num_directions 2


 27%|██▋       | 80/299 [03:27<27:38,  7.57s/it]

Epoch: 1/5... Step: 80... Loss: 0.754395... Val Loss: 0.758812 Train Acc: 0.820570 Val Acc: 0.847605


 30%|██▉       | 89/299 [03:30<02:11,  1.59it/s]

num_directions 2


 30%|███       | 90/299 [03:54<26:08,  7.51s/it]

Epoch: 1/5... Step: 90... Loss: 0.868164... Val Loss: 0.757948 Train Acc: 0.822433 Val Acc: 0.847605


 33%|███▎      | 99/299 [03:57<02:06,  1.58it/s]

num_directions 2


 33%|███▎      | 100/299 [04:21<25:02,  7.55s/it]

Epoch: 1/5... Step: 100... Loss: 0.746582... Val Loss: 0.755398 Train Acc: 0.824839 Val Acc: 0.847605
Validation loss decreased (0.757215 --> 0.755398).  Saving model ...


 36%|███▋      | 109/299 [04:24<02:00,  1.58it/s]

num_directions 2


 37%|███▋      | 110/299 [04:47<23:36,  7.50s/it]

Epoch: 1/5... Step: 110... Loss: 0.796875... Val Loss: 0.756041 Train Acc: 0.826509 Val Acc: 0.847605


 40%|███▉      | 119/299 [04:50<01:53,  1.59it/s]

num_directions 2


 40%|████      | 120/299 [05:14<22:23,  7.51s/it]

Epoch: 1/5... Step: 120... Loss: 0.741211... Val Loss: 0.756392 Train Acc: 0.828481 Val Acc: 0.847605


 43%|████▎     | 129/299 [05:17<01:47,  1.57it/s]

num_directions 2


 43%|████▎     | 130/299 [05:41<21:10,  7.52s/it]

Epoch: 1/5... Step: 130... Loss: 0.750488... Val Loss: 0.757277 Train Acc: 0.830462 Val Acc: 0.847605


 46%|████▋     | 139/299 [05:44<01:40,  1.60it/s]

num_directions 2


 47%|████▋     | 140/299 [06:07<19:59,  7.54s/it]

Epoch: 1/5... Step: 140... Loss: 0.843750... Val Loss: 0.757745 Train Acc: 0.831608 Val Acc: 0.847605


 50%|████▉     | 149/299 [06:10<01:35,  1.58it/s]

num_directions 2


 50%|█████     | 150/299 [06:33<18:10,  7.32s/it]

Epoch: 1/5... Step: 150... Loss: 0.741699... Val Loss: 0.759785 Train Acc: 0.832363 Val Acc: 0.847605


 53%|█████▎    | 159/299 [06:36<01:24,  1.65it/s]

num_directions 2


 54%|█████▎    | 160/299 [06:59<16:46,  7.24s/it]

Epoch: 1/5... Step: 160... Loss: 0.775879... Val Loss: 0.757621 Train Acc: 0.833183 Val Acc: 0.847605


 57%|█████▋    | 169/299 [07:02<01:18,  1.66it/s]

num_directions 2


 57%|█████▋    | 170/299 [07:25<15:35,  7.25s/it]

Epoch: 1/5... Step: 170... Loss: 0.751465... Val Loss: 0.759750 Train Acc: 0.834146 Val Acc: 0.847605


 60%|█████▉    | 179/299 [07:27<01:12,  1.66it/s]

num_directions 2


 60%|██████    | 180/299 [07:50<14:15,  7.19s/it]

Epoch: 1/5... Step: 180... Loss: 0.778809... Val Loss: 0.766641 Train Acc: 0.834827 Val Acc: 0.847605


 63%|██████▎   | 189/299 [07:53<01:05,  1.68it/s]

num_directions 2


 64%|██████▎   | 190/299 [08:16<13:07,  7.23s/it]

Epoch: 1/5... Step: 190... Loss: 0.705078... Val Loss: 0.757341 Train Acc: 0.835286 Val Acc: 0.847605


 67%|██████▋   | 199/299 [08:19<01:00,  1.64it/s]

num_directions 2


 67%|██████▋   | 200/299 [08:41<11:57,  7.25s/it]

Epoch: 1/5... Step: 200... Loss: 0.741211... Val Loss: 0.758323 Train Acc: 0.835955 Val Acc: 0.847605


 70%|██████▉   | 209/299 [08:44<00:54,  1.66it/s]

num_directions 2


 70%|███████   | 210/299 [09:07<10:41,  7.21s/it]

Epoch: 1/5... Step: 210... Loss: 0.763184... Val Loss: 0.759839 Train Acc: 0.836296 Val Acc: 0.847605


 73%|███████▎  | 219/299 [09:10<00:47,  1.67it/s]

num_directions 2


 74%|███████▎  | 220/299 [09:32<09:26,  7.17s/it]

Epoch: 1/5... Step: 220... Loss: 0.800293... Val Loss: 0.759400 Train Acc: 0.836722 Val Acc: 0.847605


 77%|███████▋  | 229/299 [09:35<00:42,  1.66it/s]

num_directions 2


 77%|███████▋  | 230/299 [09:58<08:16,  7.20s/it]

Epoch: 1/5... Step: 230... Loss: 0.808594... Val Loss: 0.758985 Train Acc: 0.836967 Val Acc: 0.847605


 80%|███████▉  | 239/299 [10:01<00:36,  1.66it/s]

num_directions 2


 80%|████████  | 240/299 [10:23<07:04,  7.19s/it]

Epoch: 1/5... Step: 240... Loss: 0.686035... Val Loss: 0.758724 Train Acc: 0.837463 Val Acc: 0.847605


 83%|████████▎ | 249/299 [10:26<00:30,  1.67it/s]

num_directions 2


 84%|████████▎ | 250/299 [10:49<05:51,  7.17s/it]

Epoch: 1/5... Step: 250... Loss: 0.732910... Val Loss: 0.758380 Train Acc: 0.837920 Val Acc: 0.847605


 87%|████████▋ | 259/299 [10:52<00:24,  1.67it/s]

num_directions 2


 87%|████████▋ | 260/299 [11:14<04:38,  7.14s/it]

Epoch: 1/5... Step: 260... Loss: 0.833496... Val Loss: 0.755470 Train Acc: 0.838148 Val Acc: 0.847605


 90%|████████▉ | 269/299 [11:17<00:17,  1.68it/s]

num_directions 2


 90%|█████████ | 270/299 [11:40<03:28,  7.19s/it]

Epoch: 1/5... Step: 270... Loss: 0.701660... Val Loss: 0.751327 Train Acc: 0.838631 Val Acc: 0.847605
Validation loss decreased (0.755398 --> 0.751327).  Saving model ...


 93%|█████████▎| 279/299 [11:42<00:11,  1.67it/s]

num_directions 2


 94%|█████████▎| 280/299 [12:05<02:18,  7.31s/it]

Epoch: 1/5... Step: 280... Loss: 0.782715... Val Loss: 0.743390 Train Acc: 0.838990 Val Acc: 0.847605
Validation loss decreased (0.751327 --> 0.743390).  Saving model ...


 97%|█████████▋| 289/299 [12:08<00:06,  1.63it/s]

num_directions 2


 97%|█████████▋| 290/299 [12:31<01:05,  7.31s/it]

Epoch: 1/5... Step: 290... Loss: 0.724609... Val Loss: 0.732039 Train Acc: 0.839312 Val Acc: 0.847605
Validation loss decreased (0.743390 --> 0.732039).  Saving model ...


100%|██████████| 299/299 [12:34<00:00,  2.52s/it]


num_directions 2


  0%|          | 0/299 [00:00<?, ?it/s]

num_directions 2


  0%|          | 1/299 [00:22<1:53:51, 22.93s/it]

Epoch: 2/5... Step: 300... Loss: 0.760254... Val Loss: 0.728026 Train Acc: 0.839681 Val Acc: 0.847605
Validation loss decreased (0.732039 --> 0.728026).  Saving model ...


  3%|▎         | 10/299 [00:25<02:54,  1.66it/s] 

num_directions 2


  4%|▎         | 11/299 [00:48<35:15,  7.34s/it]

Epoch: 2/5... Step: 310... Loss: 0.727539... Val Loss: 0.719634 Train Acc: 0.844051 Val Acc: 0.847605
Validation loss decreased (0.728026 --> 0.719634).  Saving model ...


  7%|▋         | 20/299 [00:51<02:46,  1.68it/s]

num_directions 2


  7%|▋         | 21/299 [01:14<33:34,  7.25s/it]

Epoch: 2/5... Step: 320... Loss: 0.787109... Val Loss: 0.699623 Train Acc: 0.845970 Val Acc: 0.847605
Validation loss decreased (0.719634 --> 0.699623).  Saving model ...


 10%|█         | 30/299 [01:17<02:42,  1.66it/s]

num_directions 2


 10%|█         | 31/299 [01:39<32:28,  7.27s/it]

Epoch: 2/5... Step: 330... Loss: 0.682129... Val Loss: 0.675351 Train Acc: 0.845197 Val Acc: 0.847605
Validation loss decreased (0.699623 --> 0.675351).  Saving model ...


 13%|█▎        | 40/299 [01:42<02:36,  1.65it/s]

num_directions 2


 14%|█▎        | 41/299 [02:05<31:07,  7.24s/it]

Epoch: 2/5... Step: 340... Loss: 0.644531... Val Loss: 0.650680 Train Acc: 0.845249 Val Acc: 0.847605
Validation loss decreased (0.675351 --> 0.650680).  Saving model ...


 17%|█▋        | 50/299 [02:08<02:29,  1.66it/s]

num_directions 2


 17%|█▋        | 51/299 [02:30<29:48,  7.21s/it]

Epoch: 2/5... Step: 350... Loss: 0.636719... Val Loss: 0.606958 Train Acc: 0.845520 Val Acc: 0.847605
Validation loss decreased (0.650680 --> 0.606958).  Saving model ...


 20%|██        | 60/299 [02:33<02:24,  1.66it/s]

num_directions 2


 20%|██        | 61/299 [02:56<28:36,  7.21s/it]

Epoch: 2/5... Step: 360... Loss: 0.632324... Val Loss: 0.568987 Train Acc: 0.845487 Val Acc: 0.847605
Validation loss decreased (0.606958 --> 0.568987).  Saving model ...


 23%|██▎       | 70/299 [02:59<02:16,  1.68it/s]

num_directions 2


 24%|██▎       | 71/299 [03:21<26:38,  7.01s/it]

Epoch: 2/5... Step: 370... Loss: 0.541992... Val Loss: 0.531456 Train Acc: 0.845523 Val Acc: 0.847605
Validation loss decreased (0.568987 --> 0.531456).  Saving model ...


 27%|██▋       | 80/299 [03:24<02:08,  1.71it/s]

num_directions 2


 27%|██▋       | 81/299 [03:46<25:29,  7.02s/it]

Epoch: 2/5... Step: 380... Loss: 0.508789... Val Loss: 0.506125 Train Acc: 0.845801 Val Acc: 0.847605
Validation loss decreased (0.531456 --> 0.506125).  Saving model ...


 30%|███       | 90/299 [03:49<02:01,  1.72it/s]

num_directions 2


 30%|███       | 91/299 [04:11<24:15,  7.00s/it]

Epoch: 2/5... Step: 390... Loss: 0.511719... Val Loss: 0.482517 Train Acc: 0.846238 Val Acc: 0.847605
Validation loss decreased (0.506125 --> 0.482517).  Saving model ...


 33%|███▎      | 100/299 [04:13<01:56,  1.71it/s]

num_directions 2


 34%|███▍      | 101/299 [04:35<23:08,  7.01s/it]

Epoch: 2/5... Step: 400... Loss: 0.521973... Val Loss: 0.470526 Train Acc: 0.846297 Val Acc: 0.847605
Validation loss decreased (0.482517 --> 0.470526).  Saving model ...


 37%|███▋      | 110/299 [04:38<01:50,  1.71it/s]

num_directions 2


 37%|███▋      | 111/299 [05:00<21:58,  7.01s/it]

Epoch: 2/5... Step: 410... Loss: 0.506348... Val Loss: 0.458237 Train Acc: 0.846727 Val Acc: 0.847605
Validation loss decreased (0.470526 --> 0.458237).  Saving model ...


 40%|████      | 120/299 [05:03<01:44,  1.72it/s]

num_directions 2


 40%|████      | 121/299 [05:25<20:46,  7.00s/it]

Epoch: 2/5... Step: 420... Loss: 0.513184... Val Loss: 0.455173 Train Acc: 0.846369 Val Acc: 0.847605
Validation loss decreased (0.458237 --> 0.455173).  Saving model ...


 43%|████▎     | 130/299 [05:28<01:38,  1.71it/s]

num_directions 2


 44%|████▍     | 131/299 [05:50<19:37,  7.01s/it]

Epoch: 2/5... Step: 430... Loss: 0.506348... Val Loss: 0.450697 Train Acc: 0.846532 Val Acc: 0.847605
Validation loss decreased (0.455173 --> 0.450697).  Saving model ...


 47%|████▋     | 140/299 [05:53<01:32,  1.71it/s]

num_directions 2


 47%|████▋     | 141/299 [06:16<19:43,  7.49s/it]

Epoch: 2/5... Step: 440... Loss: 0.447754... Val Loss: 0.446057 Train Acc: 0.846716 Val Acc: 0.847605
Validation loss decreased (0.450697 --> 0.446057).  Saving model ...


 50%|█████     | 150/299 [06:19<01:32,  1.62it/s]

num_directions 2


 51%|█████     | 151/299 [06:42<18:07,  7.35s/it]

Epoch: 2/5... Step: 450... Loss: 0.451660... Val Loss: 0.441455 Train Acc: 0.846719 Val Acc: 0.847605
Validation loss decreased (0.446057 --> 0.441455).  Saving model ...


 54%|█████▎    | 160/299 [06:45<01:25,  1.63it/s]

num_directions 2


 54%|█████▍    | 161/299 [07:08<16:51,  7.33s/it]

Epoch: 2/5... Step: 460... Loss: 0.435547... Val Loss: 0.436573 Train Acc: 0.847048 Val Acc: 0.847605
Validation loss decreased (0.441455 --> 0.436573).  Saving model ...


 57%|█████▋    | 170/299 [07:11<01:19,  1.62it/s]

num_directions 2


 57%|█████▋    | 171/299 [07:34<15:40,  7.35s/it]

Epoch: 2/5... Step: 470... Loss: 0.467285... Val Loss: 0.433872 Train Acc: 0.847450 Val Acc: 0.870511
Validation loss decreased (0.436573 --> 0.433872).  Saving model ...


 60%|██████    | 180/299 [07:37<01:13,  1.63it/s]

num_directions 2


 61%|██████    | 181/299 [08:00<14:16,  7.26s/it]

Epoch: 2/5... Step: 480... Loss: 0.468994... Val Loss: 0.436752 Train Acc: 0.847768 Val Acc: 0.870366


 64%|██████▎   | 190/299 [08:03<01:06,  1.64it/s]

num_directions 2


 64%|██████▍   | 191/299 [08:26<13:16,  7.37s/it]

Epoch: 2/5... Step: 490... Loss: 0.401123... Val Loss: 0.433114 Train Acc: 0.848240 Val Acc: 0.871846
Validation loss decreased (0.433872 --> 0.433114).  Saving model ...


 67%|██████▋   | 200/299 [08:29<01:01,  1.61it/s]

num_directions 2


 67%|██████▋   | 201/299 [08:52<11:57,  7.32s/it]

Epoch: 2/5... Step: 500... Loss: 0.453613... Val Loss: 0.426969 Train Acc: 0.848539 Val Acc: 0.872140
Validation loss decreased (0.433114 --> 0.426969).  Saving model ...


 70%|███████   | 210/299 [08:55<00:54,  1.63it/s]

num_directions 2


 71%|███████   | 211/299 [09:18<10:46,  7.35s/it]

Epoch: 2/5... Step: 510... Loss: 0.436523... Val Loss: 0.422104 Train Acc: 0.849255 Val Acc: 0.872304
Validation loss decreased (0.426969 --> 0.422104).  Saving model ...


 74%|███████▎  | 220/299 [09:21<00:48,  1.62it/s]

num_directions 2


 74%|███████▍  | 221/299 [09:44<09:25,  7.24s/it]

Epoch: 2/5... Step: 520... Loss: 0.479248... Val Loss: 0.417605 Train Acc: 0.849464 Val Acc: 0.872160
Validation loss decreased (0.422104 --> 0.417605).  Saving model ...


 77%|███████▋  | 230/299 [09:47<00:41,  1.67it/s]

num_directions 2


 77%|███████▋  | 231/299 [10:09<07:59,  7.05s/it]

Epoch: 2/5... Step: 530... Loss: 0.432373... Val Loss: 0.416509 Train Acc: 0.850353 Val Acc: 0.874141
Validation loss decreased (0.417605 --> 0.416509).  Saving model ...


 80%|████████  | 240/299 [10:12<00:34,  1.70it/s]

num_directions 2


 81%|████████  | 241/299 [10:34<06:49,  7.05s/it]

Epoch: 2/5... Step: 540... Loss: 0.424561... Val Loss: 0.414359 Train Acc: 0.851044 Val Acc: 0.871706
Validation loss decreased (0.416509 --> 0.414359).  Saving model ...


 84%|████████▎ | 250/299 [10:37<00:28,  1.70it/s]

num_directions 2


 84%|████████▍ | 251/299 [10:59<05:36,  7.02s/it]

Epoch: 2/5... Step: 550... Loss: 0.426270... Val Loss: 0.405923 Train Acc: 0.851716 Val Acc: 0.873606
Validation loss decreased (0.414359 --> 0.405923).  Saving model ...


 87%|████████▋ | 260/299 [11:02<00:22,  1.71it/s]

num_directions 2


 87%|████████▋ | 261/299 [11:24<04:26,  7.01s/it]

Epoch: 2/5... Step: 560... Loss: 0.439941... Val Loss: 0.399166 Train Acc: 0.852499 Val Acc: 0.873211
Validation loss decreased (0.405923 --> 0.399166).  Saving model ...


 90%|█████████ | 270/299 [11:27<00:16,  1.71it/s]

num_directions 2


 91%|█████████ | 271/299 [11:49<03:17,  7.07s/it]

Epoch: 2/5... Step: 570... Loss: 0.414795... Val Loss: 0.382269 Train Acc: 0.853126 Val Acc: 0.874966
Validation loss decreased (0.399166 --> 0.382269).  Saving model ...


 94%|█████████▎| 280/299 [11:52<00:11,  1.65it/s]

num_directions 2


 94%|█████████▍| 281/299 [12:15<02:11,  7.31s/it]

Epoch: 2/5... Step: 580... Loss: 0.382080... Val Loss: 0.383540 Train Acc: 0.854108 Val Acc: 0.876441


 97%|█████████▋| 290/299 [12:18<00:05,  1.63it/s]

num_directions 2


 97%|█████████▋| 291/299 [12:41<00:58,  7.36s/it]

Epoch: 2/5... Step: 590... Loss: 0.408691... Val Loss: 0.362906 Train Acc: 0.854759 Val Acc: 0.878702
Validation loss decreased (0.382269 --> 0.362906).  Saving model ...


100%|██████████| 299/299 [12:43<00:00,  2.56s/it]


num_directions 2


  0%|          | 1/299 [00:00<01:41,  2.95it/s]

num_directions 2


  1%|          | 2/299 [00:23<1:08:12, 13.78s/it]

Epoch: 3/5... Step: 600... Loss: 0.375244... Val Loss: 0.358139 Train Acc: 0.877543 Val Acc: 0.882656
Validation loss decreased (0.362906 --> 0.358139).  Saving model ...


  4%|▎         | 11/299 [00:26<02:56,  1.63it/s] 

num_directions 2


  4%|▍         | 12/299 [00:49<35:46,  7.48s/it]

Epoch: 3/5... Step: 610... Loss: 0.402588... Val Loss: 0.344400 Train Acc: 0.880618 Val Acc: 0.886325
Validation loss decreased (0.358139 --> 0.344400).  Saving model ...


  7%|▋         | 21/299 [00:52<02:51,  1.62it/s]

num_directions 2


  7%|▋         | 22/299 [01:15<33:51,  7.33s/it]

Epoch: 3/5... Step: 620... Loss: 0.371582... Val Loss: 0.325932 Train Acc: 0.883729 Val Acc: 0.903440
Validation loss decreased (0.344400 --> 0.325932).  Saving model ...


 10%|█         | 31/299 [01:18<02:43,  1.64it/s]

num_directions 2


 11%|█         | 32/299 [01:41<32:35,  7.32s/it]

Epoch: 3/5... Step: 630... Loss: 0.357666... Val Loss: 0.302375 Train Acc: 0.887102 Val Acc: 0.908300
Validation loss decreased (0.325932 --> 0.302375).  Saving model ...


 14%|█▎        | 41/299 [01:44<02:39,  1.62it/s]

num_directions 2


 14%|█▍        | 42/299 [02:07<31:12,  7.29s/it]

Epoch: 3/5... Step: 640... Loss: 0.309326... Val Loss: 0.287971 Train Acc: 0.890505 Val Acc: 0.918338
Validation loss decreased (0.302375 --> 0.287971).  Saving model ...


 17%|█▋        | 51/299 [02:10<02:31,  1.64it/s]

num_directions 2


 17%|█▋        | 52/299 [02:33<30:08,  7.32s/it]

Epoch: 3/5... Step: 650... Loss: 0.327148... Val Loss: 0.264489 Train Acc: 0.894126 Val Acc: 0.922098
Validation loss decreased (0.287971 --> 0.264489).  Saving model ...


 20%|██        | 61/299 [02:36<02:21,  1.69it/s]

num_directions 2


 21%|██        | 62/299 [02:58<27:41,  7.01s/it]

Epoch: 3/5... Step: 660... Loss: 0.239380... Val Loss: 0.251749 Train Acc: 0.897678 Val Acc: 0.930608
Validation loss decreased (0.264489 --> 0.251749).  Saving model ...


 24%|██▎       | 71/299 [03:01<02:13,  1.71it/s]

num_directions 2


 24%|██▍       | 72/299 [03:23<26:31,  7.01s/it]

Epoch: 3/5... Step: 670... Loss: 0.246704... Val Loss: 0.229101 Train Acc: 0.900756 Val Acc: 0.936576
Validation loss decreased (0.251749 --> 0.229101).  Saving model ...


 27%|██▋       | 81/299 [03:25<02:07,  1.70it/s]

num_directions 2


 27%|██▋       | 82/299 [03:48<25:33,  7.07s/it]

Epoch: 3/5... Step: 680... Loss: 0.231201... Val Loss: 0.216131 Train Acc: 0.903753 Val Acc: 0.939228
Validation loss decreased (0.229101 --> 0.216131).  Saving model ...


 30%|███       | 91/299 [03:51<02:02,  1.70it/s]

num_directions 2


 31%|███       | 92/299 [04:13<24:16,  7.03s/it]

Epoch: 3/5... Step: 690... Loss: 0.260498... Val Loss: 0.199283 Train Acc: 0.906693 Val Acc: 0.945476
Validation loss decreased (0.216131 --> 0.199283).  Saving model ...


 34%|███▍      | 101/299 [04:15<01:55,  1.71it/s]

num_directions 2


 34%|███▍      | 102/299 [04:38<23:05,  7.03s/it]

Epoch: 3/5... Step: 700... Loss: 0.223145... Val Loss: 0.187113 Train Acc: 0.910272 Val Acc: 0.950655
Validation loss decreased (0.199283 --> 0.187113).  Saving model ...


 37%|███▋      | 111/299 [04:40<01:50,  1.70it/s]

num_directions 2


 37%|███▋      | 112/299 [05:02<21:49,  7.00s/it]

Epoch: 3/5... Step: 710... Loss: 0.233521... Val Loss: 0.177401 Train Acc: 0.913184 Val Acc: 0.952535
Validation loss decreased (0.187113 --> 0.177401).  Saving model ...


 40%|████      | 121/299 [05:05<01:43,  1.71it/s]

num_directions 2


 41%|████      | 122/299 [05:27<20:45,  7.04s/it]

Epoch: 3/5... Step: 720... Loss: 0.185669... Val Loss: 0.168076 Train Acc: 0.915546 Val Acc: 0.953914
Validation loss decreased (0.177401 --> 0.168076).  Saving model ...


 44%|████▍     | 131/299 [05:30<01:37,  1.72it/s]

num_directions 2


 44%|████▍     | 132/299 [05:52<19:28,  7.00s/it]

Epoch: 3/5... Step: 730... Loss: 0.208008... Val Loss: 0.166336 Train Acc: 0.917849 Val Acc: 0.955027
Validation loss decreased (0.168076 --> 0.166336).  Saving model ...


 47%|████▋     | 141/299 [05:55<01:32,  1.71it/s]

num_directions 2


 47%|████▋     | 142/299 [06:17<18:17,  6.99s/it]

Epoch: 3/5... Step: 740... Loss: 0.180054... Val Loss: 0.165732 Train Acc: 0.920275 Val Acc: 0.955374
Validation loss decreased (0.166336 --> 0.165732).  Saving model ...


 51%|█████     | 151/299 [06:20<01:26,  1.71it/s]

num_directions 2


 51%|█████     | 152/299 [06:42<17:08,  6.99s/it]

Epoch: 3/5... Step: 750... Loss: 0.186646... Val Loss: 0.157205 Train Acc: 0.922201 Val Acc: 0.957120
Validation loss decreased (0.165732 --> 0.157205).  Saving model ...


 54%|█████▍    | 161/299 [06:44<01:20,  1.72it/s]

num_directions 2


 54%|█████▍    | 162/299 [07:06<15:58,  7.00s/it]

Epoch: 3/5... Step: 760... Loss: 0.164307... Val Loss: 0.156683 Train Acc: 0.924002 Val Acc: 0.955452
Validation loss decreased (0.157205 --> 0.156683).  Saving model ...


 57%|█████▋    | 171/299 [07:09<01:14,  1.71it/s]

num_directions 2


 58%|█████▊    | 172/299 [07:31<14:44,  6.96s/it]

Epoch: 3/5... Step: 770... Loss: 0.208862... Val Loss: 0.158105 Train Acc: 0.925632 Val Acc: 0.955100


 61%|██████    | 181/299 [07:34<01:08,  1.72it/s]

num_directions 2


 61%|██████    | 182/299 [07:56<13:37,  6.99s/it]

Epoch: 3/5... Step: 780... Loss: 0.150635... Val Loss: 0.150798 Train Acc: 0.927175 Val Acc: 0.957491
Validation loss decreased (0.156683 --> 0.150798).  Saving model ...


 64%|██████▍   | 191/299 [07:59<01:02,  1.72it/s]

num_directions 2


 64%|██████▍   | 192/299 [08:21<12:30,  7.01s/it]

Epoch: 3/5... Step: 790... Loss: 0.183350... Val Loss: 0.147285 Train Acc: 0.928616 Val Acc: 0.959145
Validation loss decreased (0.150798 --> 0.147285).  Saving model ...


 67%|██████▋   | 201/299 [08:23<00:57,  1.72it/s]

num_directions 2


 68%|██████▊   | 202/299 [08:45<11:17,  6.99s/it]

Epoch: 3/5... Step: 800... Loss: 0.193237... Val Loss: 0.142122 Train Acc: 0.929865 Val Acc: 0.960225
Validation loss decreased (0.147285 --> 0.142122).  Saving model ...


 71%|███████   | 211/299 [08:48<00:51,  1.71it/s]

num_directions 2


 71%|███████   | 212/299 [09:10<10:06,  6.97s/it]

Epoch: 3/5... Step: 810... Loss: 0.174316... Val Loss: 0.145896 Train Acc: 0.931091 Val Acc: 0.959135


 74%|███████▍  | 221/299 [09:13<00:45,  1.71it/s]

num_directions 2


 74%|███████▍  | 222/299 [09:35<08:58,  6.99s/it]

Epoch: 3/5... Step: 820... Loss: 0.185669... Val Loss: 0.144870 Train Acc: 0.932128 Val Acc: 0.960234


 77%|███████▋  | 231/299 [09:38<00:39,  1.71it/s]

num_directions 2


 78%|███████▊  | 232/299 [10:00<07:50,  7.03s/it]

Epoch: 3/5... Step: 830... Loss: 0.157227... Val Loss: 0.141650 Train Acc: 0.933224 Val Acc: 0.960114
Validation loss decreased (0.142122 --> 0.141650).  Saving model ...


 81%|████████  | 241/299 [10:03<00:33,  1.71it/s]

num_directions 2


 81%|████████  | 242/299 [10:25<06:39,  7.01s/it]

Epoch: 3/5... Step: 840... Loss: 0.148926... Val Loss: 0.137242 Train Acc: 0.934109 Val Acc: 0.961112
Validation loss decreased (0.141650 --> 0.137242).  Saving model ...


 84%|████████▍ | 251/299 [10:27<00:27,  1.71it/s]

num_directions 2


 84%|████████▍ | 252/299 [10:50<05:31,  7.04s/it]

Epoch: 3/5... Step: 850... Loss: 0.125732... Val Loss: 0.135709 Train Acc: 0.935087 Val Acc: 0.961690
Validation loss decreased (0.137242 --> 0.135709).  Saving model ...


 87%|████████▋ | 261/299 [10:52<00:22,  1.71it/s]

num_directions 2


 88%|████████▊ | 262/299 [11:15<04:23,  7.12s/it]

Epoch: 3/5... Step: 860... Loss: 0.125854... Val Loss: 0.136082 Train Acc: 0.935962 Val Acc: 0.961541


 91%|█████████ | 271/299 [11:18<00:16,  1.68it/s]

num_directions 2


 91%|█████████ | 272/299 [11:40<03:12,  7.11s/it]

Epoch: 3/5... Step: 870... Loss: 0.145020... Val Loss: 0.134252 Train Acc: 0.936716 Val Acc: 0.961165
Validation loss decreased (0.135709 --> 0.134252).  Saving model ...


 94%|█████████▍| 281/299 [11:43<00:10,  1.67it/s]

num_directions 2


 94%|█████████▍| 282/299 [12:05<01:59,  7.03s/it]

Epoch: 3/5... Step: 880... Loss: 0.173462... Val Loss: 0.135039 Train Acc: 0.937443 Val Acc: 0.960755


 97%|█████████▋| 291/299 [12:08<00:04,  1.70it/s]

num_directions 2


 98%|█████████▊| 292/299 [12:30<00:49,  7.06s/it]

Epoch: 3/5... Step: 890... Loss: 0.156128... Val Loss: 0.129218 Train Acc: 0.938137 Val Acc: 0.962486
Validation loss decreased (0.134252 --> 0.129218).  Saving model ...


100%|██████████| 299/299 [12:32<00:00,  2.52s/it]


num_directions 2


  1%|          | 2/299 [00:00<01:32,  3.20it/s]

num_directions 2


  1%|          | 3/299 [00:22<50:48, 10.30s/it]

Epoch: 4/5... Step: 900... Loss: 0.151978... Val Loss: 0.129628 Train Acc: 0.961185 Val Acc: 0.962751


  4%|▍         | 12/299 [00:25<02:48,  1.70it/s]

num_directions 2


  4%|▍         | 13/299 [00:47<34:04,  7.15s/it]

Epoch: 4/5... Step: 910... Loss: 0.142700... Val Loss: 0.128633 Train Acc: 0.959081 Val Acc: 0.962028
Validation loss decreased (0.129218 --> 0.128633).  Saving model ...


  7%|▋         | 22/299 [00:50<02:43,  1.69it/s]

num_directions 2


  8%|▊         | 23/299 [01:13<32:36,  7.09s/it]

Epoch: 4/5... Step: 920... Loss: 0.137451... Val Loss: 0.132808 Train Acc: 0.960508 Val Acc: 0.961256


 11%|█         | 32/299 [01:15<02:40,  1.66it/s]

num_directions 2


 11%|█         | 33/299 [01:38<31:30,  7.11s/it]

Epoch: 4/5... Step: 930... Loss: 0.149292... Val Loss: 0.128901 Train Acc: 0.961027 Val Acc: 0.962785


 14%|█▍        | 42/299 [01:41<02:31,  1.70it/s]

num_directions 2


 14%|█▍        | 43/299 [02:03<29:59,  7.03s/it]

Epoch: 4/5... Step: 940... Loss: 0.170532... Val Loss: 0.126183 Train Acc: 0.960386 Val Acc: 0.963170
Validation loss decreased (0.128633 --> 0.126183).  Saving model ...


 17%|█▋        | 52/299 [02:06<02:24,  1.71it/s]

num_directions 2


 18%|█▊        | 53/299 [02:28<28:49,  7.03s/it]

Epoch: 4/5... Step: 950... Loss: 0.141602... Val Loss: 0.125638 Train Acc: 0.960328 Val Acc: 0.963272
Validation loss decreased (0.126183 --> 0.125638).  Saving model ...


 21%|██        | 62/299 [02:30<02:18,  1.71it/s]

num_directions 2


 21%|██        | 63/299 [02:53<27:56,  7.10s/it]

Epoch: 4/5... Step: 960... Loss: 0.141724... Val Loss: 0.124661 Train Acc: 0.960084 Val Acc: 0.963291
Validation loss decreased (0.125638 --> 0.124661).  Saving model ...


 24%|██▍       | 72/299 [02:56<02:13,  1.70it/s]

num_directions 2


 24%|██▍       | 73/299 [03:18<26:26,  7.02s/it]

Epoch: 4/5... Step: 970... Loss: 0.125366... Val Loss: 0.124075 Train Acc: 0.960320 Val Acc: 0.963932
Validation loss decreased (0.124661 --> 0.124075).  Saving model ...


 27%|██▋       | 82/299 [03:20<02:06,  1.71it/s]

num_directions 2


 28%|██▊       | 83/299 [03:43<25:45,  7.15s/it]

Epoch: 4/5... Step: 980... Loss: 0.128052... Val Loss: 0.122032 Train Acc: 0.959894 Val Acc: 0.964120
Validation loss decreased (0.124075 --> 0.122032).  Saving model ...


 31%|███       | 92/299 [03:46<02:05,  1.66it/s]

num_directions 2


 31%|███       | 93/299 [04:08<24:05,  7.02s/it]

Epoch: 4/5... Step: 990... Loss: 0.129883... Val Loss: 0.122336 Train Acc: 0.959862 Val Acc: 0.964014


 34%|███▍      | 102/299 [04:11<01:54,  1.72it/s]

num_directions 2


 34%|███▍      | 103/299 [04:33<22:50,  6.99s/it]

Epoch: 4/5... Step: 1000... Loss: 0.129761... Val Loss: 0.125245 Train Acc: 0.960208 Val Acc: 0.963069


 37%|███▋      | 112/299 [04:36<01:50,  1.69it/s]

num_directions 2


 38%|███▊      | 113/299 [04:58<22:34,  7.28s/it]

Epoch: 4/5... Step: 1010... Loss: 0.124878... Val Loss: 0.120972 Train Acc: 0.960183 Val Acc: 0.964188
Validation loss decreased (0.122032 --> 0.120972).  Saving model ...


 41%|████      | 122/299 [05:01<01:47,  1.65it/s]

num_directions 2


 41%|████      | 123/299 [05:25<22:21,  7.62s/it]

Epoch: 4/5... Step: 1020... Loss: 0.136841... Val Loss: 0.119544 Train Acc: 0.960302 Val Acc: 0.964540
Validation loss decreased (0.120972 --> 0.119544).  Saving model ...


 44%|████▍     | 132/299 [05:29<01:46,  1.56it/s]

num_directions 2


 44%|████▍     | 133/299 [05:53<21:38,  7.82s/it]

Epoch: 4/5... Step: 1030... Loss: 0.142944... Val Loss: 0.119374 Train Acc: 0.960361 Val Acc: 0.964607
Validation loss decreased (0.119544 --> 0.119374).  Saving model ...


 47%|████▋     | 142/299 [05:56<01:43,  1.51it/s]

num_directions 2


 48%|████▊     | 143/299 [06:20<19:24,  7.46s/it]

Epoch: 4/5... Step: 1040... Loss: 0.138062... Val Loss: 0.118559 Train Acc: 0.960391 Val Acc: 0.964573
Validation loss decreased (0.119374 --> 0.118559).  Saving model ...


 51%|█████     | 152/299 [06:23<01:30,  1.62it/s]

num_directions 2


 51%|█████     | 153/299 [06:46<17:52,  7.34s/it]

Epoch: 4/5... Step: 1050... Loss: 0.086853... Val Loss: 0.119341 Train Acc: 0.960705 Val Acc: 0.964781


 54%|█████▍    | 162/299 [06:49<01:25,  1.61it/s]

num_directions 2


 55%|█████▍    | 163/299 [07:12<16:48,  7.41s/it]

Epoch: 4/5... Step: 1060... Loss: 0.122559... Val Loss: 0.117104 Train Acc: 0.960994 Val Acc: 0.965099
Validation loss decreased (0.118559 --> 0.117104).  Saving model ...


 58%|█████▊    | 172/299 [07:15<01:20,  1.57it/s]

num_directions 2


 58%|█████▊    | 173/299 [07:40<16:14,  7.74s/it]

Epoch: 4/5... Step: 1070... Loss: 0.159790... Val Loss: 0.116067 Train Acc: 0.960950 Val Acc: 0.964747
Validation loss decreased (0.117104 --> 0.116067).  Saving model ...


 61%|██████    | 182/299 [07:43<01:15,  1.54it/s]

num_directions 2


 61%|██████    | 183/299 [08:06<14:14,  7.36s/it]

Epoch: 4/5... Step: 1080... Loss: 0.104675... Val Loss: 0.117495 Train Acc: 0.961152 Val Acc: 0.964887


 64%|██████▍   | 192/299 [08:09<01:08,  1.57it/s]

num_directions 2


 65%|██████▍   | 193/299 [08:33<13:40,  7.74s/it]

Epoch: 4/5... Step: 1090... Loss: 0.145630... Val Loss: 0.117524 Train Acc: 0.961309 Val Acc: 0.964752


 68%|██████▊   | 202/299 [08:36<01:03,  1.53it/s]

num_directions 2


 68%|██████▊   | 203/299 [09:00<11:56,  7.46s/it]

Epoch: 4/5... Step: 1100... Loss: 0.172241... Val Loss: 0.114827 Train Acc: 0.961358 Val Acc: 0.965142
Validation loss decreased (0.116067 --> 0.114827).  Saving model ...


 71%|███████   | 212/299 [09:03<00:54,  1.59it/s]

num_directions 2


 71%|███████   | 213/299 [09:26<10:39,  7.44s/it]

Epoch: 4/5... Step: 1110... Loss: 0.115845... Val Loss: 0.113647 Train Acc: 0.961448 Val Acc: 0.965673
Validation loss decreased (0.114827 --> 0.113647).  Saving model ...


 74%|███████▍  | 222/299 [09:29<00:47,  1.62it/s]

num_directions 2


 75%|███████▍  | 223/299 [09:52<09:17,  7.34s/it]

Epoch: 4/5... Step: 1120... Loss: 0.117004... Val Loss: 0.115429 Train Acc: 0.961630 Val Acc: 0.965215


 78%|███████▊  | 232/299 [09:55<00:42,  1.58it/s]

num_directions 2


 78%|███████▊  | 233/299 [10:18<08:06,  7.37s/it]

Epoch: 4/5... Step: 1130... Loss: 0.107422... Val Loss: 0.115785 Train Acc: 0.961719 Val Acc: 0.965128


 81%|████████  | 242/299 [10:21<00:34,  1.63it/s]

num_directions 2


 81%|████████▏ | 243/299 [10:44<06:53,  7.38s/it]

Epoch: 4/5... Step: 1140... Loss: 0.122131... Val Loss: 0.117792 Train Acc: 0.961649 Val Acc: 0.965003


 84%|████████▍ | 252/299 [10:47<00:28,  1.64it/s]

num_directions 2


 85%|████████▍ | 253/299 [11:10<05:37,  7.34s/it]

Epoch: 4/5... Step: 1150... Loss: 0.158203... Val Loss: 0.114512 Train Acc: 0.961620 Val Acc: 0.965162


 88%|████████▊ | 262/299 [11:13<00:22,  1.64it/s]

num_directions 2


 88%|████████▊ | 263/299 [11:37<04:26,  7.41s/it]

Epoch: 4/5... Step: 1160... Loss: 0.163086... Val Loss: 0.111834 Train Acc: 0.961639 Val Acc: 0.965928
Validation loss decreased (0.113647 --> 0.111834).  Saving model ...


 91%|█████████ | 272/299 [11:40<00:16,  1.61it/s]

num_directions 2


 91%|█████████▏| 273/299 [12:03<03:10,  7.34s/it]

Epoch: 4/5... Step: 1170... Loss: 0.117981... Val Loss: 0.112088 Train Acc: 0.961650 Val Acc: 0.965750


 94%|█████████▍| 282/299 [12:06<00:10,  1.56it/s]

num_directions 2


 95%|█████████▍| 283/299 [12:30<02:02,  7.65s/it]

Epoch: 4/5... Step: 1180... Loss: 0.105896... Val Loss: 0.113030 Train Acc: 0.961763 Val Acc: 0.965735


 98%|█████████▊| 292/299 [12:33<00:04,  1.55it/s]

num_directions 2


 98%|█████████▊| 293/299 [12:57<00:45,  7.52s/it]

Epoch: 4/5... Step: 1190... Loss: 0.119019... Val Loss: 0.111398 Train Acc: 0.961890 Val Acc: 0.965894
Validation loss decreased (0.111834 --> 0.111398).  Saving model ...


100%|██████████| 299/299 [12:59<00:00,  2.61s/it]


num_directions 2


  1%|          | 3/299 [00:00<01:36,  3.08it/s]

num_directions 2


  1%|▏         | 4/299 [00:23<45:17,  9.21s/it]

Epoch: 5/5... Step: 1200... Loss: 0.116333... Val Loss: 0.112111 Train Acc: 0.966583 Val Acc: 0.965562


  4%|▍         | 13/299 [00:26<02:51,  1.67it/s]

num_directions 2


  5%|▍         | 14/299 [00:49<35:11,  7.41s/it]

Epoch: 5/5... Step: 1210... Loss: 0.148315... Val Loss: 0.112961 Train Acc: 0.965685 Val Acc: 0.965716


  8%|▊         | 23/299 [00:52<02:49,  1.63it/s]

num_directions 2


  8%|▊         | 24/299 [01:15<33:43,  7.36s/it]

Epoch: 5/5... Step: 1220... Loss: 0.109131... Val Loss: 0.109784 Train Acc: 0.966003 Val Acc: 0.966314
Validation loss decreased (0.111398 --> 0.109784).  Saving model ...


 11%|█         | 33/299 [01:19<02:44,  1.62it/s]

num_directions 2


 11%|█▏        | 34/299 [01:42<32:22,  7.33s/it]

Epoch: 5/5... Step: 1230... Loss: 0.151489... Val Loss: 0.113363 Train Acc: 0.965159 Val Acc: 0.965186


 14%|█▍        | 43/299 [01:45<02:40,  1.60it/s]

num_directions 2


 15%|█▍        | 44/299 [02:08<31:27,  7.40s/it]

Epoch: 5/5... Step: 1240... Loss: 0.120239... Val Loss: 0.114880 Train Acc: 0.965196 Val Acc: 0.965128


 18%|█▊        | 53/299 [02:11<02:32,  1.62it/s]

num_directions 2


 18%|█▊        | 54/299 [02:34<30:29,  7.47s/it]

Epoch: 5/5... Step: 1250... Loss: 0.095459... Val Loss: 0.110308 Train Acc: 0.964879 Val Acc: 0.966073


 21%|██        | 63/299 [02:37<02:27,  1.60it/s]

num_directions 2


 21%|██▏       | 64/299 [03:01<29:26,  7.52s/it]

Epoch: 5/5... Step: 1260... Loss: 0.089905... Val Loss: 0.111718 Train Acc: 0.964758 Val Acc: 0.965499


 24%|██▍       | 73/299 [03:04<02:21,  1.59it/s]

num_directions 2


 25%|██▍       | 74/299 [03:28<28:11,  7.52s/it]

Epoch: 5/5... Step: 1270... Loss: 0.133179... Val Loss: 0.109646 Train Acc: 0.964215 Val Acc: 0.965981
Validation loss decreased (0.109784 --> 0.109646).  Saving model ...


 28%|██▊       | 83/299 [03:31<02:12,  1.63it/s]

num_directions 2


 28%|██▊       | 84/299 [03:54<26:29,  7.40s/it]

Epoch: 5/5... Step: 1280... Loss: 0.109314... Val Loss: 0.109390 Train Acc: 0.964555 Val Acc: 0.966753
Validation loss decreased (0.109646 --> 0.109390).  Saving model ...


 31%|███       | 93/299 [03:57<02:06,  1.63it/s]

num_directions 2


 31%|███▏      | 94/299 [04:20<24:53,  7.29s/it]

Epoch: 5/5... Step: 1290... Loss: 0.102600... Val Loss: 0.107700 Train Acc: 0.964851 Val Acc: 0.967023
Validation loss decreased (0.109390 --> 0.107700).  Saving model ...


 34%|███▍      | 103/299 [04:23<02:04,  1.57it/s]

num_directions 2


 35%|███▍      | 104/299 [04:46<24:11,  7.44s/it]

Epoch: 5/5... Step: 1300... Loss: 0.121704... Val Loss: 0.106936 Train Acc: 0.964766 Val Acc: 0.967003
Validation loss decreased (0.107700 --> 0.106936).  Saving model ...


 38%|███▊      | 113/299 [04:49<01:58,  1.57it/s]

num_directions 2


 38%|███▊      | 114/299 [05:12<22:36,  7.33s/it]

Epoch: 5/5... Step: 1310... Loss: 0.096252... Val Loss: 0.110331 Train Acc: 0.965102 Val Acc: 0.966478


 41%|████      | 123/299 [05:15<01:52,  1.57it/s]

num_directions 2


 41%|████▏     | 124/299 [05:38<21:37,  7.42s/it]

Epoch: 5/5... Step: 1320... Loss: 0.159058... Val Loss: 0.108426 Train Acc: 0.965136 Val Acc: 0.966738


 44%|████▍     | 133/299 [05:41<01:43,  1.61it/s]

num_directions 2


 45%|████▍     | 134/299 [06:04<20:16,  7.37s/it]

Epoch: 5/5... Step: 1330... Loss: 0.150879... Val Loss: 0.107754 Train Acc: 0.965080 Val Acc: 0.967018


 48%|████▊     | 143/299 [06:08<01:38,  1.59it/s]

num_directions 2


 48%|████▊     | 144/299 [06:31<18:57,  7.34s/it]

Epoch: 5/5... Step: 1340... Loss: 0.107544... Val Loss: 0.107543 Train Acc: 0.965278 Val Acc: 0.966251


 51%|█████     | 153/299 [06:34<01:29,  1.63it/s]

num_directions 2


 52%|█████▏    | 154/299 [06:56<17:35,  7.28s/it]

Epoch: 5/5... Step: 1350... Loss: 0.101807... Val Loss: 0.108726 Train Acc: 0.965354 Val Acc: 0.966994


 55%|█████▍    | 163/299 [06:59<01:22,  1.65it/s]

num_directions 2


 55%|█████▍    | 164/299 [07:23<16:43,  7.43s/it]

Epoch: 5/5... Step: 1360... Loss: 0.136963... Val Loss: 0.105894 Train Acc: 0.965407 Val Acc: 0.967153
Validation loss decreased (0.106936 --> 0.105894).  Saving model ...


 58%|█████▊    | 173/299 [07:26<01:17,  1.63it/s]

num_directions 2


 58%|█████▊    | 174/299 [07:49<15:33,  7.47s/it]

Epoch: 5/5... Step: 1370... Loss: 0.110352... Val Loss: 0.105685 Train Acc: 0.965387 Val Acc: 0.967336
Validation loss decreased (0.105894 --> 0.105685).  Saving model ...


 61%|██████    | 183/299 [07:52<01:12,  1.61it/s]

num_directions 2


 62%|██████▏   | 184/299 [08:16<14:26,  7.53s/it]

Epoch: 5/5... Step: 1380... Loss: 0.102905... Val Loss: 0.104921 Train Acc: 0.965392 Val Acc: 0.967548
Validation loss decreased (0.105685 --> 0.104921).  Saving model ...


 65%|██████▍   | 193/299 [08:19<01:06,  1.60it/s]

num_directions 2


 65%|██████▍   | 194/299 [08:42<12:59,  7.43s/it]

Epoch: 5/5... Step: 1390... Loss: 0.093872... Val Loss: 0.106896 Train Acc: 0.965390 Val Acc: 0.967182


 68%|██████▊   | 203/299 [08:45<00:58,  1.63it/s]

num_directions 2


 68%|██████▊   | 204/299 [09:09<11:48,  7.45s/it]

Epoch: 5/5... Step: 1400... Loss: 0.118713... Val Loss: 0.105706 Train Acc: 0.965388 Val Acc: 0.967355


 71%|███████   | 213/299 [09:12<00:53,  1.61it/s]

num_directions 2


 72%|███████▏  | 214/299 [09:35<10:33,  7.46s/it]

Epoch: 5/5... Step: 1410... Loss: 0.099426... Val Loss: 0.110506 Train Acc: 0.965244 Val Acc: 0.966478


 75%|███████▍  | 223/299 [09:38<00:46,  1.62it/s]

num_directions 2


 75%|███████▍  | 224/299 [10:01<09:15,  7.41s/it]

Epoch: 5/5... Step: 1420... Loss: 0.116211... Val Loss: 0.106906 Train Acc: 0.965291 Val Acc: 0.967326


 78%|███████▊  | 233/299 [10:04<00:41,  1.59it/s]

num_directions 2


 78%|███████▊  | 234/299 [10:28<08:01,  7.41s/it]

Epoch: 5/5... Step: 1430... Loss: 0.119568... Val Loss: 0.105652 Train Acc: 0.965279 Val Acc: 0.967254


 81%|████████▏ | 243/299 [10:31<00:34,  1.61it/s]

num_directions 2


 82%|████████▏ | 244/299 [10:54<06:54,  7.53s/it]

Epoch: 5/5... Step: 1440... Loss: 0.090454... Val Loss: 0.103773 Train Acc: 0.965234 Val Acc: 0.967485
Validation loss decreased (0.104921 --> 0.103773).  Saving model ...


 85%|████████▍ | 253/299 [10:57<00:28,  1.60it/s]

num_directions 2


 85%|████████▍ | 254/299 [11:20<05:33,  7.41s/it]

Epoch: 5/5... Step: 1450... Loss: 0.101562... Val Loss: 0.104134 Train Acc: 0.965291 Val Acc: 0.967408


 88%|████████▊ | 263/299 [11:23<00:22,  1.62it/s]

num_directions 2


 88%|████████▊ | 264/299 [11:47<04:20,  7.45s/it]

Epoch: 5/5... Step: 1460... Loss: 0.146118... Val Loss: 0.103211 Train Acc: 0.965264 Val Acc: 0.967992
Validation loss decreased (0.103773 --> 0.103211).  Saving model ...


 91%|█████████▏| 273/299 [11:50<00:16,  1.62it/s]

num_directions 2


## Inference

In [None]:
meta_data = joblib.load("meta.bin")
enc_tag = meta_data["enc_tag"]

num_tag = len(list(enc_tag.classes_))

text = """
Natasha is traveling to New York
"""

device = torch.device("cuda")
model.to(device)

with torch.no_grad():
# так как это инференс, выключаем расчет градиентов:
    inputs = torch.tensor([fasttext_model.wv[s] for s in word_tokenize(text)], dtype=torch.float32)
    inputs = inputs.unsqueeze(0).to(device)
    h = model.init_hidden(1)
    tag, h = model(inputs, h)

    print(
        enc_tag.inverse_transform(
            tag.argmax(-1).cpu().numpy().reshape(-1)
        )
    )