# Entity extraction using Fasttext and LSTM

## Import everything important

In [2]:
import joblib
import torch
import torch.nn as nn
import transformers
import nltk

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn import model_selection

# from gensim.models import FastText as ft
from gensim.models import fasttext as ft
from torch.utils.tensorboard import SummaryWriter
from nltk.tokenize import word_tokenize

from tqdm import tqdm

## Some config

In [3]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 128 #
VALID_BATCH_SIZE = 128 #
EPOCHS = 5
EMBED_DIM = 300

MODEL_PATH = "./state_dict.pt"
TRAINING_FILE = 'ner_dataset.csv' #

In [4]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
) #
device

'cuda:0'

https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz

In [5]:
# fasttext_model = ft.load_fasttext_format("cc.en.300.bin")
fasttext_model = ft.load_facebook_vectors("cc.en.300.bin")

## Dataset

In [6]:
class EntityDataset:
    def __init__(self, texts, tags):
        self.texts = texts
        self.tags = tags
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        text = self.texts[item]
        tags = np.array(self.tags[item])
    
        # ids = [fasttext_model.wv[s] for s in text]
        ids = [fasttext_model[s] for s in text]
        
# реализуем паддинг: если текст меньше MAX_LEN, 
# забиваем матрицу нулями так, чтобы матрицы были одинакового размера для каждого текста   
# возвращаем матрицу с нулями, список тегов с нулями, и маску -- разметку, какие элементы 
# являются словами, а какие -- пустые поля (нули). Маска нужна, чтобы правильно считать лосс
# то есть, не учитывать в нем "пустые" части с нулями
        
        if len(ids) < MAX_LEN:
            ids_pad = np.array(ids + [[0]*len(ids[0])]*(MAX_LEN - len(ids)))
            tags_pad = list(tags) + [0]*(MAX_LEN - len(tags))
            mask = [1]*len(ids) + [0]*(MAX_LEN - len(ids))
        else:
            ids_pad = np.array(ids[:MAX_LEN])
            tags_pad = list(tags)[:MAX_LEN]
            mask = [1]*MAX_LEN
            
        return (torch.tensor(ids_pad, dtype=torch.float32),
                torch.tensor(tags_pad, dtype=torch.long),
                torch.tensor(mask, dtype=torch.long))

## Training and evaluation functions

In [7]:
def loss_fn(output, target, mask, num_labels):
    lfn = nn.CrossEntropyLoss()
    active_loss = mask.view(-1) == 1
    active_logits = output.view(-1, num_labels)
    active_labels = torch.where(
        active_loss,
        target.view(-1),
        torch.tensor(lfn.ignore_index).type_as(target)
    )
    loss = lfn(active_logits, active_labels)
    return loss

In [8]:
def acc_stat(pred, target, mask):
    mask = mask.bool()
    pred = torch.masked_select(pred, mask)
    target = torch.masked_select(target, mask)
    correct = torch.tensor(torch.eq(pred, target).sum().item(),dtype=torch.float32) # сколько элементов угадано корректно
    total = torch.tensor(len(pred), dtype=torch.float32) # сколько элементов было всего, не считая "пустых" с нулями
    return correct, total

пример того как должно работать

In [9]:
acc_stat(torch.tensor([1,2,3,4,0,0,0,0]), torch.tensor([1,2,3,4,5,5,5,5]), torch.tensor([1,1,1,1,0,0,0,0]))

(tensor(4.), tensor(4.))

In [10]:
acc_stat(torch.tensor([1,2,3,4,0,0,0,0]), torch.tensor([1,2,3,4,5,5,5,5]), torch.tensor([0,0,0,0,1,1,1,1]))

(tensor(0.), tensor(4.))

## Loss function and model

In [11]:
class EntityModel(nn.Module):
    def __init__(self, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5, bidirectional=False):
        super().__init__()
        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=drop_prob,
                            batch_first=True, bidirectional=bidirectional)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(hidden_dim, output_size)
    
    def forward(
        self, 
        embeds,
        hidden
    ):
        lstm_out, hidden = self.lstm(embeds, hidden)
        lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
        
# здесь пропустим через дропаут и линейный слой
        dropout_out = self.dropout(lstm_out)
        out = self.fc(dropout_out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        num_directions = 2 if self.lstm.bidirectional else 1
        print("num_directions", num_directions)
        h_zeros = torch.zeros(self.n_layers * num_directions,
                              batch_size, self.hidden_dim,
                              dtype=torch.float32, device=device)
        c_zeros = torch.zeros(self.n_layers * num_directions,
                              batch_size, self.hidden_dim,
                              dtype=torch.float32, device=device)

        return (h_zeros, c_zeros)

## Data processing

In [12]:
def process_data(data_path):
    df = pd.read_csv(data_path, encoding="latin-1")
    df.loc[:, "Sentence #"] = df["Sentence #"].fillna(method="ffill")

    enc_tag = preprocessing.LabelEncoder()

    # df.loc[:, "Tag"] = enc_tag.fit_transform(df["Tag"])
    df[df.columns[df.columns.get_loc("Tag")]] = enc_tag.fit_transform(df["Tag"])

    sentences = df.groupby("Sentence #")["Word"].apply(list).values
    tag = df.groupby("Sentence #")["Tag"].apply(list).values
    return sentences, tag, enc_tag

## Training

In [13]:
sentences, tag, enc_tag = process_data(TRAINING_FILE)

In [14]:
from sklearn.model_selection import train_test_split #

meta_data = {
    "enc_tag": enc_tag
}

joblib.dump(meta_data, "meta.bin")

num_tag = len(list(enc_tag.classes_))

(
    train_sentences,
    test_sentences,
    train_tag,
    test_tag
) = train_test_split(sentences, tag, test_size=0.2) # делим на трейн и тест с помощью train_test_split

In [15]:
train_dataset = EntityDataset(
    texts=train_sentences, tags=train_tag
)

train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, num_workers=0, #
    shuffle=True, drop_last=True
)

valid_dataset = EntityDataset(
    texts=test_sentences, tags=test_tag
)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=VALID_BATCH_SIZE, num_workers=0, #
    shuffle=False, drop_last=True
)

In [16]:
def eval_model(model, valid_data_loader):
    h = model.init_hidden(VALID_BATCH_SIZE)
    losses = []
    
    correct_sum, total_sum = 0, 0
    
    for inputs, labels, mask in valid_data_loader:
        h = tuple([each.data for each in h])
        # отправим inputs, labels и mask на GPU
        if torch.cuda.is_available():
            inputs = inputs.cuda()
            labels = labels.cuda()
            mask = mask.cuda()
            
        model.zero_grad()
        output, h = model(inputs, h)
        loss = loss_fn(output, labels.flatten(), mask, num_tag)
        losses.append(loss.item())
        
        correct, total = acc_stat(torch.argmax(output, dim=-1).flatten(), labels.flatten(), mask.flatten())
        correct_sum += correct
        total_sum += total
    return losses, correct_sum / total_sum

In [17]:
hidden_dim = 512
n_layers = 2

model = EntityModel(num_tag, EMBED_DIM, hidden_dim, n_layers, drop_prob=0.5, bidirectional=False)
model.to(device)

lr=0.005
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [18]:
counter = 0
print_every = 10
clip = 5
valid_loss_min = np.Inf
writer = SummaryWriter('logs')


model.train()
for i in range(EPOCHS):
    h = model.init_hidden(TRAIN_BATCH_SIZE)
    
    correct_sum, total_sum = 0, 0
    
    for inputs, labels, mask in tqdm(train_data_loader): #
        counter += 1
        h = tuple([e.data for e in h])

        inputs = inputs.to(device)
        labels = labels.to(device)
        mask = mask.to(device)
        model.zero_grad()
        output, h = model(inputs, h)
        loss = loss_fn(output, labels.flatten(), mask, num_tag) # вызываем функцию для подсчета лосса
        loss.backward() # и делаем обратное распространение ошибки
        correct, total = acc_stat(torch.argmax(output, dim=-1).flatten(), labels.flatten(), mask.flatten()) # вызываем функцию acc_stat
        correct_sum += correct
        total_sum += total

        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step() # градиентный спуск
        
        if counter % print_every == 0:
            model.eval()
            val_losses, val_acc = eval_model(model, valid_data_loader)
            model.train()
            
            val_loss = np.mean(val_losses)
            writer.add_scalar('train/loss', loss.item(), counter)
            writer.add_scalar('val/loss', val_loss, counter)
            writer.add_scalar('train/acc', correct_sum / total_sum, counter)
            writer.add_scalar('val/acc', val_acc, counter)

            print("Epoch: {}/{}...".format(i+1, EPOCHS),
                  "Step: {}...".format(counter),
                  "Loss: {:.6f}...".format(loss.item()),
                  "Val Loss: {:.6f}".format(val_loss),
                  "Train Acc: {:.6f}".format(correct_sum / total_sum),
                  "Val Acc: {:.6f}".format(val_acc))
                
            if np.mean(val_losses) <= valid_loss_min:
                torch.save(model.state_dict(), MODEL_PATH)
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,np.mean(val_losses)))
                valid_loss_min = np.mean(val_losses)

num_directions 1


  3%|▎         | 9/299 [00:05<02:47,  1.74it/s]

num_directions 1


  3%|▎         | 10/299 [00:34<45:45,  9.50s/it]

Epoch: 1/5... Step: 10... Loss: 0.884185... Val Loss: 0.766111 Train Acc: 0.766869 Val Acc: 0.846548
Validation loss decreased (inf --> 0.766111).  Saving model ...


  6%|▋         | 19/299 [00:39<04:13,  1.10it/s]

num_directions 1


  7%|▋         | 20/299 [01:08<43:00,  9.25s/it]

Epoch: 1/5... Step: 20... Loss: 0.824187... Val Loss: 0.803389 Train Acc: 0.805924 Val Acc: 0.846548


 10%|▉         | 29/299 [01:13<04:00,  1.12it/s]

num_directions 1


 10%|█         | 30/299 [01:42<41:58,  9.36s/it]

Epoch: 1/5... Step: 30... Loss: 0.744897... Val Loss: 0.768151 Train Acc: 0.819299 Val Acc: 0.846548


 13%|█▎        | 39/299 [01:47<04:00,  1.08it/s]

num_directions 1


 13%|█▎        | 40/299 [02:17<41:00,  9.50s/it]

Epoch: 1/5... Step: 40... Loss: 0.753772... Val Loss: 0.766618 Train Acc: 0.826577 Val Acc: 0.846548


 16%|█▋        | 49/299 [02:22<03:51,  1.08it/s]

num_directions 1


 17%|█▋        | 50/299 [02:51<38:35,  9.30s/it]

Epoch: 1/5... Step: 50... Loss: 0.694289... Val Loss: 0.767170 Train Acc: 0.830635 Val Acc: 0.846548


 20%|█▉        | 59/299 [02:56<03:38,  1.10it/s]

num_directions 1


 20%|██        | 60/299 [03:26<38:36,  9.69s/it]

Epoch: 1/5... Step: 60... Loss: 0.737378... Val Loss: 0.764608 Train Acc: 0.833759 Val Acc: 0.846548
Validation loss decreased (0.766111 --> 0.764608).  Saving model ...


 23%|██▎       | 69/299 [03:31<03:42,  1.04it/s]

num_directions 1


 23%|██▎       | 70/299 [04:01<36:46,  9.64s/it]

Epoch: 1/5... Step: 70... Loss: 0.722528... Val Loss: 0.762232 Train Acc: 0.836501 Val Acc: 0.846548
Validation loss decreased (0.764608 --> 0.762232).  Saving model ...


 26%|██▋       | 79/299 [04:06<03:26,  1.06it/s]

num_directions 1


 27%|██▋       | 80/299 [04:36<35:34,  9.75s/it]

Epoch: 1/5... Step: 80... Loss: 0.797924... Val Loss: 0.760259 Train Acc: 0.838112 Val Acc: 0.846548
Validation loss decreased (0.762232 --> 0.760259).  Saving model ...


 30%|██▉       | 89/299 [04:42<03:15,  1.07it/s]

num_directions 1


 30%|███       | 90/299 [05:12<33:38,  9.66s/it]

Epoch: 1/5... Step: 90... Loss: 0.782418... Val Loss: 0.759452 Train Acc: 0.838764 Val Acc: 0.846548
Validation loss decreased (0.760259 --> 0.759452).  Saving model ...


 33%|███▎      | 99/299 [05:17<03:06,  1.07it/s]

num_directions 1


 33%|███▎      | 100/299 [05:46<31:07,  9.39s/it]

Epoch: 1/5... Step: 100... Loss: 0.788584... Val Loss: 0.762351 Train Acc: 0.839498 Val Acc: 0.846548


 36%|███▋      | 109/299 [05:51<02:54,  1.09it/s]

num_directions 1


 37%|███▋      | 110/299 [06:21<30:18,  9.62s/it]

Epoch: 1/5... Step: 110... Loss: 0.767350... Val Loss: 0.765425 Train Acc: 0.839863 Val Acc: 0.846548


 40%|███▉      | 119/299 [06:26<02:48,  1.07it/s]

num_directions 1


 40%|████      | 120/299 [06:56<28:35,  9.58s/it]

Epoch: 1/5... Step: 120... Loss: 0.737316... Val Loss: 0.759310 Train Acc: 0.840509 Val Acc: 0.846548
Validation loss decreased (0.759452 --> 0.759310).  Saving model ...


 43%|████▎     | 129/299 [07:01<02:37,  1.08it/s]

num_directions 1


 43%|████▎     | 130/299 [07:31<27:25,  9.74s/it]

Epoch: 1/5... Step: 130... Loss: 0.722497... Val Loss: 0.762860 Train Acc: 0.840915 Val Acc: 0.846548


 46%|████▋     | 139/299 [07:36<02:32,  1.05it/s]

num_directions 1


 47%|████▋     | 140/299 [08:06<25:22,  9.57s/it]

Epoch: 1/5... Step: 140... Loss: 0.701707... Val Loss: 0.763215 Train Acc: 0.841853 Val Acc: 0.846548


 50%|████▉     | 149/299 [08:11<02:21,  1.06it/s]

num_directions 1


 50%|█████     | 150/299 [08:41<23:51,  9.61s/it]

Epoch: 1/5... Step: 150... Loss: 0.748727... Val Loss: 0.762323 Train Acc: 0.842075 Val Acc: 0.846548


 53%|█████▎    | 159/299 [08:47<02:18,  1.01it/s]

num_directions 1


 54%|█████▎    | 160/299 [09:17<22:25,  9.68s/it]

Epoch: 1/5... Step: 160... Loss: 0.794355... Val Loss: 0.760412 Train Acc: 0.842494 Val Acc: 0.846548


 57%|█████▋    | 169/299 [09:22<02:00,  1.08it/s]

num_directions 1


 57%|█████▋    | 170/299 [09:51<20:33,  9.56s/it]

Epoch: 1/5... Step: 170... Loss: 0.740157... Val Loss: 0.761041 Train Acc: 0.842798 Val Acc: 0.846548


 60%|█████▉    | 179/299 [09:57<01:52,  1.07it/s]

num_directions 1


 60%|██████    | 180/299 [10:26<18:53,  9.53s/it]

Epoch: 1/5... Step: 180... Loss: 0.743759... Val Loss: 0.760837 Train Acc: 0.842719 Val Acc: 0.846548


 63%|██████▎   | 189/299 [10:31<01:40,  1.09it/s]

num_directions 1


 64%|██████▎   | 190/299 [11:00<16:47,  9.24s/it]

Epoch: 1/5... Step: 190... Loss: 0.837027... Val Loss: 0.762841 Train Acc: 0.842603 Val Acc: 0.846548


 67%|██████▋   | 199/299 [11:05<01:30,  1.11it/s]

num_directions 1


 67%|██████▋   | 200/299 [11:33<15:17,  9.27s/it]

Epoch: 1/5... Step: 200... Loss: 0.798396... Val Loss: 0.766254 Train Acc: 0.842740 Val Acc: 0.846548


 70%|██████▉   | 209/299 [11:38<01:21,  1.11it/s]

num_directions 1


 70%|███████   | 210/299 [12:07<13:39,  9.21s/it]

Epoch: 1/5... Step: 210... Loss: 0.808348... Val Loss: 0.761936 Train Acc: 0.842917 Val Acc: 0.846548


 73%|███████▎  | 219/299 [12:12<01:12,  1.10it/s]

num_directions 1


 74%|███████▎  | 220/299 [12:41<12:23,  9.42s/it]

Epoch: 1/5... Step: 220... Loss: 0.749248... Val Loss: 0.761247 Train Acc: 0.843019 Val Acc: 0.846548


 77%|███████▋  | 229/299 [12:46<01:04,  1.09it/s]

num_directions 1


 77%|███████▋  | 230/299 [13:16<10:46,  9.37s/it]

Epoch: 1/5... Step: 230... Loss: 0.773529... Val Loss: 0.762270 Train Acc: 0.843191 Val Acc: 0.846548


 80%|███████▉  | 239/299 [13:21<00:54,  1.10it/s]

num_directions 1


 80%|████████  | 240/299 [13:50<09:23,  9.56s/it]

Epoch: 1/5... Step: 240... Loss: 0.843500... Val Loss: 0.758984 Train Acc: 0.843343 Val Acc: 0.846548
Validation loss decreased (0.759310 --> 0.758984).  Saving model ...


 83%|████████▎ | 249/299 [13:55<00:45,  1.10it/s]

num_directions 1


 84%|████████▎ | 250/299 [14:24<07:43,  9.46s/it]

Epoch: 1/5... Step: 250... Loss: 0.782399... Val Loss: 0.752465 Train Acc: 0.843286 Val Acc: 0.846548
Validation loss decreased (0.758984 --> 0.752465).  Saving model ...


 87%|████████▋ | 259/299 [14:30<00:36,  1.09it/s]

num_directions 1


 87%|████████▋ | 260/299 [14:59<06:05,  9.36s/it]

Epoch: 1/5... Step: 260... Loss: 0.666156... Val Loss: 0.741250 Train Acc: 0.843666 Val Acc: 0.846548
Validation loss decreased (0.752465 --> 0.741250).  Saving model ...


 90%|████████▉ | 269/299 [15:04<00:27,  1.10it/s]

num_directions 1


 90%|█████████ | 270/299 [15:34<04:41,  9.72s/it]

Epoch: 1/5... Step: 270... Loss: 0.658285... Val Loss: 0.704249 Train Acc: 0.843966 Val Acc: 0.846548
Validation loss decreased (0.741250 --> 0.704249).  Saving model ...


 93%|█████████▎| 279/299 [15:39<00:18,  1.09it/s]

num_directions 1


 94%|█████████▎| 280/299 [16:08<02:59,  9.46s/it]

Epoch: 1/5... Step: 280... Loss: 0.655716... Val Loss: 0.647761 Train Acc: 0.843858 Val Acc: 0.846548
Validation loss decreased (0.704249 --> 0.647761).  Saving model ...


 97%|█████████▋| 289/299 [16:13<00:09,  1.08it/s]

num_directions 1


 97%|█████████▋| 290/299 [16:43<01:25,  9.52s/it]

Epoch: 1/5... Step: 290... Loss: 0.639154... Val Loss: 0.573106 Train Acc: 0.843902 Val Acc: 0.846548
Validation loss decreased (0.647761 --> 0.573106).  Saving model ...


100%|██████████| 299/299 [16:48<00:00,  3.37s/it]


num_directions 1


  0%|          | 0/299 [00:00<?, ?it/s]

num_directions 1


  0%|          | 1/299 [00:29<2:28:08, 29.83s/it]

Epoch: 2/5... Step: 300... Loss: 0.625279... Val Loss: 0.537562 Train Acc: 0.841162 Val Acc: 0.846548
Validation loss decreased (0.573106 --> 0.537562).  Saving model ...


  3%|▎         | 10/299 [00:34<04:28,  1.08it/s] 

num_directions 1


  4%|▎         | 11/299 [01:03<45:18,  9.44s/it]

Epoch: 2/5... Step: 310... Loss: 0.520204... Val Loss: 0.483966 Train Acc: 0.843931 Val Acc: 0.846548
Validation loss decreased (0.537562 --> 0.483966).  Saving model ...


  7%|▋         | 20/299 [01:08<04:21,  1.07it/s]

num_directions 1


  7%|▋         | 21/299 [01:37<43:26,  9.38s/it]

Epoch: 2/5... Step: 320... Loss: 0.557893... Val Loss: 0.456755 Train Acc: 0.844353 Val Acc: 0.846548
Validation loss decreased (0.483966 --> 0.456755).  Saving model ...


 10%|█         | 30/299 [01:42<04:07,  1.09it/s]

num_directions 1


 10%|█         | 31/299 [02:12<42:03,  9.42s/it]

Epoch: 2/5... Step: 330... Loss: 0.475179... Val Loss: 0.443997 Train Acc: 0.846331 Val Acc: 0.867792
Validation loss decreased (0.456755 --> 0.443997).  Saving model ...


 13%|█▎        | 40/299 [02:18<04:20,  1.00s/it]

num_directions 1


 14%|█▎        | 41/299 [02:47<41:25,  9.63s/it]

Epoch: 2/5... Step: 340... Loss: 0.472202... Val Loss: 0.440705 Train Acc: 0.848795 Val Acc: 0.868936
Validation loss decreased (0.443997 --> 0.440705).  Saving model ...


 17%|█▋        | 50/299 [02:53<03:55,  1.06it/s]

num_directions 1


 17%|█▋        | 51/299 [03:22<39:40,  9.60s/it]

Epoch: 2/5... Step: 350... Loss: 0.412432... Val Loss: 0.439720 Train Acc: 0.851144 Val Acc: 0.868582
Validation loss decreased (0.440705 --> 0.439720).  Saving model ...


 20%|██        | 60/299 [03:28<03:46,  1.05it/s]

num_directions 1


 20%|██        | 61/299 [03:58<38:39,  9.75s/it]

Epoch: 2/5... Step: 360... Loss: 0.487246... Val Loss: 0.424506 Train Acc: 0.853228 Val Acc: 0.869900
Validation loss decreased (0.439720 --> 0.424506).  Saving model ...


 23%|██▎       | 70/299 [04:03<03:41,  1.03it/s]

num_directions 1


 24%|██▎       | 71/299 [04:34<37:14,  9.80s/it]

Epoch: 2/5... Step: 370... Loss: 0.419353... Val Loss: 0.416661 Train Acc: 0.855274 Val Acc: 0.870258
Validation loss decreased (0.424506 --> 0.416661).  Saving model ...


 27%|██▋       | 80/299 [04:39<03:32,  1.03it/s]

num_directions 1


 27%|██▋       | 81/299 [05:10<36:20, 10.00s/it]

Epoch: 2/5... Step: 380... Loss: 0.414250... Val Loss: 0.405706 Train Acc: 0.857515 Val Acc: 0.870156
Validation loss decreased (0.416661 --> 0.405706).  Saving model ...


 30%|███       | 90/299 [05:15<03:16,  1.06it/s]

num_directions 1


 30%|███       | 91/299 [05:45<33:39,  9.71s/it]

Epoch: 2/5... Step: 390... Loss: 0.430643... Val Loss: 0.396411 Train Acc: 0.859347 Val Acc: 0.885461
Validation loss decreased (0.405706 --> 0.396411).  Saving model ...


 33%|███▎      | 100/299 [05:51<03:07,  1.06it/s]

num_directions 1


 34%|███▍      | 101/299 [06:20<31:32,  9.56s/it]

Epoch: 2/5... Step: 400... Loss: 0.425927... Val Loss: 0.383692 Train Acc: 0.861483 Val Acc: 0.885994
Validation loss decreased (0.396411 --> 0.383692).  Saving model ...


 37%|███▋      | 110/299 [06:25<02:54,  1.09it/s]

num_directions 1


 37%|███▋      | 111/299 [06:54<28:59,  9.25s/it]

Epoch: 2/5... Step: 410... Loss: 0.419762... Val Loss: 0.374988 Train Acc: 0.863261 Val Acc: 0.886207
Validation loss decreased (0.383692 --> 0.374988).  Saving model ...


 40%|████      | 120/299 [06:59<02:39,  1.12it/s]

num_directions 1


 40%|████      | 121/299 [07:28<28:15,  9.53s/it]

Epoch: 2/5... Step: 420... Loss: 0.395762... Val Loss: 0.368768 Train Acc: 0.864768 Val Acc: 0.887045
Validation loss decreased (0.374988 --> 0.368768).  Saving model ...


 43%|████▎     | 130/299 [07:34<02:39,  1.06it/s]

num_directions 1


 44%|████▍     | 131/299 [08:04<27:23,  9.79s/it]

Epoch: 2/5... Step: 430... Loss: 0.319314... Val Loss: 0.362796 Train Acc: 0.866551 Val Acc: 0.889405
Validation loss decreased (0.368768 --> 0.362796).  Saving model ...


 47%|████▋     | 140/299 [08:09<02:29,  1.06it/s]

num_directions 1


 47%|████▋     | 141/299 [08:41<26:28, 10.06s/it]

Epoch: 2/5... Step: 440... Loss: 0.413472... Val Loss: 0.356740 Train Acc: 0.867887 Val Acc: 0.889259
Validation loss decreased (0.362796 --> 0.356740).  Saving model ...


 50%|█████     | 150/299 [08:46<02:21,  1.06it/s]

num_directions 1


 51%|█████     | 151/299 [09:15<23:09,  9.39s/it]

Epoch: 2/5... Step: 450... Loss: 0.371833... Val Loss: 0.348167 Train Acc: 0.869138 Val Acc: 0.890814
Validation loss decreased (0.356740 --> 0.348167).  Saving model ...


 54%|█████▎    | 160/299 [09:20<02:08,  1.08it/s]

num_directions 1


 54%|█████▍    | 161/299 [09:49<21:38,  9.41s/it]

Epoch: 2/5... Step: 460... Loss: 0.311151... Val Loss: 0.347537 Train Acc: 0.870372 Val Acc: 0.892302
Validation loss decreased (0.348167 --> 0.347537).  Saving model ...


 57%|█████▋    | 170/299 [09:54<01:59,  1.08it/s]

num_directions 1


 57%|█████▋    | 171/299 [10:23<19:57,  9.35s/it]

Epoch: 2/5... Step: 470... Loss: 0.348414... Val Loss: 0.332836 Train Acc: 0.871625 Val Acc: 0.895281
Validation loss decreased (0.347537 --> 0.332836).  Saving model ...


 60%|██████    | 180/299 [10:28<01:48,  1.09it/s]

num_directions 1


 61%|██████    | 181/299 [11:00<19:42, 10.03s/it]

Epoch: 2/5... Step: 480... Loss: 0.346009... Val Loss: 0.319724 Train Acc: 0.872669 Val Acc: 0.900790
Validation loss decreased (0.332836 --> 0.319724).  Saving model ...


 64%|██████▎   | 190/299 [11:05<01:42,  1.07it/s]

num_directions 1


 64%|██████▍   | 191/299 [11:34<16:53,  9.38s/it]

Epoch: 2/5... Step: 490... Loss: 0.333327... Val Loss: 0.315229 Train Acc: 0.873753 Val Acc: 0.902214
Validation loss decreased (0.319724 --> 0.315229).  Saving model ...


 67%|██████▋   | 200/299 [11:39<01:31,  1.08it/s]

num_directions 1


 67%|██████▋   | 201/299 [12:08<15:19,  9.38s/it]

Epoch: 2/5... Step: 500... Loss: 0.353344... Val Loss: 0.299814 Train Acc: 0.875178 Val Acc: 0.910891
Validation loss decreased (0.315229 --> 0.299814).  Saving model ...


 70%|███████   | 210/299 [12:13<01:21,  1.09it/s]

num_directions 1


 71%|███████   | 211/299 [12:44<14:35,  9.95s/it]

Epoch: 2/5... Step: 510... Loss: 0.324409... Val Loss: 0.291334 Train Acc: 0.876705 Val Acc: 0.915401
Validation loss decreased (0.299814 --> 0.291334).  Saving model ...


 74%|███████▎  | 220/299 [12:49<01:16,  1.03it/s]

num_directions 1


 74%|███████▍  | 221/299 [13:20<12:44,  9.80s/it]

Epoch: 2/5... Step: 520... Loss: 0.278729... Val Loss: 0.277841 Train Acc: 0.878144 Val Acc: 0.919219
Validation loss decreased (0.291334 --> 0.277841).  Saving model ...


 77%|███████▋  | 230/299 [13:25<01:05,  1.05it/s]

num_directions 1


 77%|███████▋  | 231/299 [13:56<11:16,  9.95s/it]

Epoch: 2/5... Step: 530... Loss: 0.325833... Val Loss: 0.267562 Train Acc: 0.879601 Val Acc: 0.921283
Validation loss decreased (0.277841 --> 0.267562).  Saving model ...


 80%|████████  | 240/299 [14:01<00:56,  1.04it/s]

num_directions 1


 81%|████████  | 241/299 [14:32<09:28,  9.80s/it]

Epoch: 2/5... Step: 540... Loss: 0.258254... Val Loss: 0.259219 Train Acc: 0.881383 Val Acc: 0.922606
Validation loss decreased (0.267562 --> 0.259219).  Saving model ...


 84%|████████▎ | 250/299 [14:37<00:46,  1.07it/s]

num_directions 1


 84%|████████▍ | 251/299 [15:07<07:46,  9.72s/it]

Epoch: 2/5... Step: 550... Loss: 0.269749... Val Loss: 0.249279 Train Acc: 0.883030 Val Acc: 0.925469
Validation loss decreased (0.259219 --> 0.249279).  Saving model ...


 87%|████████▋ | 260/299 [15:12<00:37,  1.05it/s]

num_directions 1


 87%|████████▋ | 261/299 [15:43<06:12,  9.81s/it]

Epoch: 2/5... Step: 560... Loss: 0.265384... Val Loss: 0.241433 Train Acc: 0.884545 Val Acc: 0.927809
Validation loss decreased (0.249279 --> 0.241433).  Saving model ...


 90%|█████████ | 270/299 [15:48<00:27,  1.06it/s]

num_directions 1


 91%|█████████ | 271/299 [16:18<04:32,  9.73s/it]

Epoch: 2/5... Step: 570... Loss: 0.262672... Val Loss: 0.237375 Train Acc: 0.886028 Val Acc: 0.933414
Validation loss decreased (0.241433 --> 0.237375).  Saving model ...


 94%|█████████▎| 280/299 [16:23<00:17,  1.07it/s]

num_directions 1


 94%|█████████▍| 281/299 [16:54<02:55,  9.77s/it]

Epoch: 2/5... Step: 580... Loss: 0.255261... Val Loss: 0.224036 Train Acc: 0.887421 Val Acc: 0.934771
Validation loss decreased (0.237375 --> 0.224036).  Saving model ...


 97%|█████████▋| 290/299 [16:59<00:08,  1.04it/s]

num_directions 1


 97%|█████████▋| 291/299 [17:31<01:22, 10.28s/it]

Epoch: 2/5... Step: 590... Loss: 0.254666... Val Loss: 0.217352 Train Acc: 0.888785 Val Acc: 0.943520
Validation loss decreased (0.224036 --> 0.217352).  Saving model ...


100%|██████████| 299/299 [17:36<00:00,  3.53s/it]


num_directions 1


  0%|          | 1/299 [00:00<03:05,  1.60it/s]

num_directions 1


  1%|          | 2/299 [00:33<1:37:16, 19.65s/it]

Epoch: 3/5... Step: 600... Loss: 0.278075... Val Loss: 0.208542 Train Acc: 0.924823 Val Acc: 0.943709
Validation loss decreased (0.217352 --> 0.208542).  Saving model ...


  4%|▎         | 11/299 [00:39<04:49,  1.00s/it] 

num_directions 1


  4%|▍         | 12/299 [01:11<49:48, 10.41s/it]

Epoch: 3/5... Step: 610... Loss: 0.214375... Val Loss: 0.203299 Train Acc: 0.933811 Val Acc: 0.944034
Validation loss decreased (0.208542 --> 0.203299).  Saving model ...


  7%|▋         | 21/299 [01:16<04:36,  1.01it/s]

num_directions 1


  7%|▋         | 22/299 [01:48<46:58, 10.18s/it]

Epoch: 3/5... Step: 620... Loss: 0.218709... Val Loss: 0.201801 Train Acc: 0.936205 Val Acc: 0.945017
Validation loss decreased (0.203299 --> 0.201801).  Saving model ...


 10%|█         | 31/299 [01:53<04:30,  1.01s/it]

num_directions 1


 11%|█         | 32/299 [02:28<49:21, 11.09s/it]

Epoch: 3/5... Step: 630... Loss: 0.203943... Val Loss: 0.196055 Train Acc: 0.937944 Val Acc: 0.947522
Validation loss decreased (0.201801 --> 0.196055).  Saving model ...


 14%|█▎        | 41/299 [02:33<04:25,  1.03s/it]

num_directions 1


 14%|█▍        | 42/299 [03:04<42:21,  9.89s/it]

Epoch: 3/5... Step: 640... Loss: 0.200908... Val Loss: 0.190894 Train Acc: 0.938596 Val Acc: 0.948123
Validation loss decreased (0.196055 --> 0.190894).  Saving model ...


 17%|█▋        | 51/299 [03:09<03:59,  1.04it/s]

num_directions 1


 17%|█▋        | 52/299 [03:39<39:56,  9.70s/it]

Epoch: 3/5... Step: 650... Loss: 0.182049... Val Loss: 0.191727 Train Acc: 0.939098 Val Acc: 0.949276


 20%|██        | 61/299 [03:44<03:42,  1.07it/s]

num_directions 1


 21%|██        | 62/299 [04:15<38:34,  9.76s/it]

Epoch: 3/5... Step: 660... Loss: 0.192146... Val Loss: 0.185733 Train Acc: 0.939502 Val Acc: 0.949373
Validation loss decreased (0.190894 --> 0.185733).  Saving model ...


 24%|██▎       | 71/299 [04:20<03:34,  1.06it/s]

num_directions 1


 24%|██▍       | 72/299 [04:50<37:03,  9.79s/it]

Epoch: 3/5... Step: 670... Loss: 0.194377... Val Loss: 0.184420 Train Acc: 0.940298 Val Acc: 0.949605
Validation loss decreased (0.185733 --> 0.184420).  Saving model ...


 27%|██▋       | 81/299 [04:56<03:26,  1.05it/s]

num_directions 1


 27%|██▋       | 82/299 [05:26<35:21,  9.78s/it]

Epoch: 3/5... Step: 680... Loss: 0.222178... Val Loss: 0.180483 Train Acc: 0.940821 Val Acc: 0.951267
Validation loss decreased (0.184420 --> 0.180483).  Saving model ...


 30%|███       | 91/299 [05:31<03:16,  1.06it/s]

num_directions 1


 31%|███       | 92/299 [06:01<33:27,  9.70s/it]

Epoch: 3/5... Step: 690... Loss: 0.220496... Val Loss: 0.181102 Train Acc: 0.941315 Val Acc: 0.951078


 34%|███▍      | 101/299 [06:06<03:08,  1.05it/s]

num_directions 1


 34%|███▍      | 102/299 [06:36<31:14,  9.51s/it]

Epoch: 3/5... Step: 700... Loss: 0.194880... Val Loss: 0.180485 Train Acc: 0.941679 Val Acc: 0.951732


 37%|███▋      | 111/299 [06:41<02:51,  1.09it/s]

num_directions 1


 37%|███▋      | 112/299 [07:10<29:10,  9.36s/it]

Epoch: 3/5... Step: 710... Loss: 0.218361... Val Loss: 0.175282 Train Acc: 0.942189 Val Acc: 0.951325
Validation loss decreased (0.180483 --> 0.175282).  Saving model ...


 40%|████      | 121/299 [07:15<02:41,  1.10it/s]

num_directions 1


 41%|████      | 122/299 [07:44<27:10,  9.21s/it]

Epoch: 3/5... Step: 720... Loss: 0.226485... Val Loss: 0.173089 Train Acc: 0.942588 Val Acc: 0.952667
Validation loss decreased (0.175282 --> 0.173089).  Saving model ...


 44%|████▍     | 131/299 [07:49<02:33,  1.09it/s]

num_directions 1


 44%|████▍     | 132/299 [08:17<25:46,  9.26s/it]

Epoch: 3/5... Step: 730... Loss: 0.195914... Val Loss: 0.170802 Train Acc: 0.942982 Val Acc: 0.953442
Validation loss decreased (0.173089 --> 0.170802).  Saving model ...


 47%|████▋     | 141/299 [08:22<02:21,  1.11it/s]

num_directions 1


 47%|████▋     | 142/299 [08:51<24:20,  9.30s/it]

Epoch: 3/5... Step: 740... Loss: 0.152528... Val Loss: 0.169956 Train Acc: 0.943328 Val Acc: 0.953040
Validation loss decreased (0.170802 --> 0.169956).  Saving model ...


 51%|█████     | 151/299 [08:56<02:14,  1.10it/s]

num_directions 1


 51%|█████     | 152/299 [09:25<22:33,  9.21s/it]

Epoch: 3/5... Step: 750... Loss: 0.212233... Val Loss: 0.167555 Train Acc: 0.943430 Val Acc: 0.953588
Validation loss decreased (0.169956 --> 0.167555).  Saving model ...


 54%|█████▍    | 161/299 [09:30<02:04,  1.11it/s]

num_directions 1


 54%|█████▍    | 162/299 [09:59<21:11,  9.28s/it]

Epoch: 3/5... Step: 760... Loss: 0.176617... Val Loss: 0.166411 Train Acc: 0.943868 Val Acc: 0.954033
Validation loss decreased (0.167555 --> 0.166411).  Saving model ...


 57%|█████▋    | 171/299 [10:04<01:55,  1.11it/s]

num_directions 1


 58%|█████▊    | 172/299 [10:32<19:27,  9.19s/it]

Epoch: 3/5... Step: 770... Loss: 0.200446... Val Loss: 0.163571 Train Acc: 0.944132 Val Acc: 0.954348
Validation loss decreased (0.166411 --> 0.163571).  Saving model ...


 61%|██████    | 181/299 [10:37<01:46,  1.11it/s]

num_directions 1


 61%|██████    | 182/299 [11:07<18:18,  9.39s/it]

Epoch: 3/5... Step: 780... Loss: 0.169606... Val Loss: 0.165239 Train Acc: 0.944630 Val Acc: 0.954154


 64%|██████▍   | 191/299 [11:12<01:39,  1.08it/s]

num_directions 1


 64%|██████▍   | 192/299 [11:43<17:47,  9.98s/it]

Epoch: 3/5... Step: 790... Loss: 0.176097... Val Loss: 0.160780 Train Acc: 0.944811 Val Acc: 0.953898
Validation loss decreased (0.163571 --> 0.160780).  Saving model ...


 67%|██████▋   | 201/299 [11:48<01:36,  1.01it/s]

num_directions 1


 68%|██████▊   | 202/299 [12:20<16:43, 10.35s/it]

Epoch: 3/5... Step: 800... Loss: 0.149494... Val Loss: 0.160344 Train Acc: 0.945014 Val Acc: 0.954523
Validation loss decreased (0.160780 --> 0.160344).  Saving model ...


 71%|███████   | 211/299 [12:26<01:28,  1.01s/it]

num_directions 1


 71%|███████   | 212/299 [12:57<14:24,  9.93s/it]

Epoch: 3/5... Step: 810... Loss: 0.180275... Val Loss: 0.158768 Train Acc: 0.945340 Val Acc: 0.954217
Validation loss decreased (0.160344 --> 0.158768).  Saving model ...


 74%|███████▍  | 221/299 [13:02<01:13,  1.05it/s]

num_directions 1


 74%|███████▍  | 222/299 [13:31<12:07,  9.45s/it]

Epoch: 3/5... Step: 820... Loss: 0.201860... Val Loss: 0.160071 Train Acc: 0.945630 Val Acc: 0.954169


 77%|███████▋  | 231/299 [13:36<01:01,  1.10it/s]

num_directions 1


 78%|███████▊  | 232/299 [14:08<11:07,  9.96s/it]

Epoch: 3/5... Step: 830... Loss: 0.167013... Val Loss: 0.156519 Train Acc: 0.945910 Val Acc: 0.955109
Validation loss decreased (0.158768 --> 0.156519).  Saving model ...


 81%|████████  | 241/299 [14:13<00:54,  1.06it/s]

num_directions 1


 81%|████████  | 242/299 [14:42<08:57,  9.43s/it]

Epoch: 3/5... Step: 840... Loss: 0.183140... Val Loss: 0.156430 Train Acc: 0.946095 Val Acc: 0.955491
Validation loss decreased (0.156519 --> 0.156430).  Saving model ...


 84%|████████▍ | 251/299 [14:47<00:44,  1.08it/s]

num_directions 1


 84%|████████▍ | 252/299 [15:18<07:47,  9.94s/it]

Epoch: 3/5... Step: 850... Loss: 0.176247... Val Loss: 0.156198 Train Acc: 0.946444 Val Acc: 0.955375
Validation loss decreased (0.156430 --> 0.156198).  Saving model ...


 87%|████████▋ | 261/299 [15:23<00:37,  1.01it/s]

num_directions 1


 88%|████████▊ | 262/299 [15:55<06:12, 10.06s/it]

Epoch: 3/5... Step: 860... Loss: 0.163440... Val Loss: 0.155260 Train Acc: 0.946699 Val Acc: 0.954891
Validation loss decreased (0.156198 --> 0.155260).  Saving model ...


 91%|█████████ | 271/299 [16:00<00:26,  1.04it/s]

num_directions 1


 91%|█████████ | 272/299 [16:29<04:15,  9.48s/it]

Epoch: 3/5... Step: 870... Loss: 0.166914... Val Loss: 0.154892 Train Acc: 0.946867 Val Acc: 0.954639
Validation loss decreased (0.155260 --> 0.154892).  Saving model ...


 94%|█████████▍| 281/299 [16:34<00:16,  1.09it/s]

num_directions 1


 94%|█████████▍| 282/299 [17:04<02:41,  9.48s/it]

Epoch: 3/5... Step: 880... Loss: 0.155265... Val Loss: 0.157582 Train Acc: 0.947127 Val Acc: 0.955051


 97%|█████████▋| 291/299 [17:09<00:07,  1.09it/s]

num_directions 1


 98%|█████████▊| 292/299 [17:39<01:06,  9.53s/it]

Epoch: 3/5... Step: 890... Loss: 0.199504... Val Loss: 0.153395 Train Acc: 0.947261 Val Acc: 0.955211
Validation loss decreased (0.154892 --> 0.153395).  Saving model ...


100%|██████████| 299/299 [17:43<00:00,  3.56s/it]


num_directions 1


  1%|          | 2/299 [00:01<02:49,  1.75it/s]

num_directions 1


  1%|          | 3/299 [00:31<1:09:17, 14.05s/it]

Epoch: 4/5... Step: 900... Loss: 0.178170... Val Loss: 0.151237 Train Acc: 0.951978 Val Acc: 0.955787
Validation loss decreased (0.153395 --> 0.151237).  Saving model ...


  4%|▍         | 12/299 [00:36<04:27,  1.07it/s] 

num_directions 1


  4%|▍         | 13/299 [01:06<46:27,  9.75s/it]

Epoch: 4/5... Step: 910... Loss: 0.185551... Val Loss: 0.151599 Train Acc: 0.951492 Val Acc: 0.955816


  7%|▋         | 22/299 [01:11<04:19,  1.07it/s]

num_directions 1


  8%|▊         | 23/299 [01:40<43:04,  9.36s/it]

Epoch: 4/5... Step: 920... Loss: 0.170245... Val Loss: 0.150815 Train Acc: 0.951346 Val Acc: 0.956083
Validation loss decreased (0.151237 --> 0.150815).  Saving model ...


 11%|█         | 32/299 [01:45<04:07,  1.08it/s]

num_directions 1


 11%|█         | 33/299 [02:14<41:30,  9.36s/it]

Epoch: 4/5... Step: 930... Loss: 0.146847... Val Loss: 0.153069 Train Acc: 0.952404 Val Acc: 0.955559


 14%|█▍        | 42/299 [02:19<03:55,  1.09it/s]

num_directions 1


 14%|█▍        | 43/299 [02:49<40:44,  9.55s/it]

Epoch: 4/5... Step: 940... Loss: 0.174674... Val Loss: 0.150409 Train Acc: 0.952595 Val Acc: 0.956141
Validation loss decreased (0.150815 --> 0.150409).  Saving model ...


 17%|█▋        | 52/299 [02:54<03:46,  1.09it/s]

num_directions 1


 18%|█▊        | 53/299 [03:23<38:20,  9.35s/it]

Epoch: 4/5... Step: 950... Loss: 0.158567... Val Loss: 0.152818 Train Acc: 0.952692 Val Acc: 0.954852


 21%|██        | 62/299 [03:28<03:37,  1.09it/s]

num_directions 1


 21%|██        | 63/299 [03:58<37:12,  9.46s/it]

Epoch: 4/5... Step: 960... Loss: 0.155842... Val Loss: 0.148581 Train Acc: 0.952688 Val Acc: 0.956238
Validation loss decreased (0.150409 --> 0.148581).  Saving model ...


 24%|██▍       | 72/299 [04:03<03:28,  1.09it/s]

num_directions 1


 24%|██▍       | 73/299 [04:33<36:19,  9.65s/it]

Epoch: 4/5... Step: 970... Loss: 0.128716... Val Loss: 0.150797 Train Acc: 0.953295 Val Acc: 0.956863


 27%|██▋       | 82/299 [04:38<03:19,  1.09it/s]

num_directions 1


 28%|██▊       | 83/299 [05:07<33:58,  9.44s/it]

Epoch: 4/5... Step: 980... Loss: 0.164560... Val Loss: 0.147605 Train Acc: 0.953753 Val Acc: 0.956693
Validation loss decreased (0.148581 --> 0.147605).  Saving model ...


 31%|███       | 92/299 [05:12<03:11,  1.08it/s]

num_directions 1


 31%|███       | 93/299 [05:42<32:31,  9.47s/it]

Epoch: 4/5... Step: 990... Loss: 0.174319... Val Loss: 0.147679 Train Acc: 0.953715 Val Acc: 0.956625


 34%|███▍      | 102/299 [05:47<03:01,  1.08it/s]

num_directions 1


 34%|███▍      | 103/299 [06:17<31:39,  9.69s/it]

Epoch: 4/5... Step: 1000... Loss: 0.160095... Val Loss: 0.145996 Train Acc: 0.953747 Val Acc: 0.957144
Validation loss decreased (0.147605 --> 0.145996).  Saving model ...


 37%|███▋      | 112/299 [06:22<02:58,  1.05it/s]

num_directions 1


 38%|███▊      | 113/299 [06:53<30:22,  9.80s/it]

Epoch: 4/5... Step: 1010... Loss: 0.154619... Val Loss: 0.144942 Train Acc: 0.953754 Val Acc: 0.957512
Validation loss decreased (0.145996 --> 0.144942).  Saving model ...


 41%|████      | 122/299 [06:58<02:49,  1.04it/s]

num_directions 1


 41%|████      | 123/299 [07:29<28:56,  9.87s/it]

Epoch: 4/5... Step: 1020... Loss: 0.147270... Val Loss: 0.144506 Train Acc: 0.953935 Val Acc: 0.957478
Validation loss decreased (0.144942 --> 0.144506).  Saving model ...


 44%|████▍     | 132/299 [07:34<02:40,  1.04it/s]

num_directions 1


 44%|████▍     | 133/299 [08:05<27:53, 10.08s/it]

Epoch: 4/5... Step: 1030... Loss: 0.147503... Val Loss: 0.148016 Train Acc: 0.954000 Val Acc: 0.956514


 47%|████▋     | 142/299 [08:11<02:35,  1.01it/s]

num_directions 1


 48%|████▊     | 143/299 [08:42<25:56,  9.98s/it]

Epoch: 4/5... Step: 1040... Loss: 0.157316... Val Loss: 0.143430 Train Acc: 0.954089 Val Acc: 0.957366
Validation loss decreased (0.144506 --> 0.143430).  Saving model ...


 51%|█████     | 152/299 [08:47<02:18,  1.06it/s]

num_directions 1


 51%|█████     | 153/299 [09:17<23:21,  9.60s/it]

Epoch: 4/5... Step: 1050... Loss: 0.152920... Val Loss: 0.142699 Train Acc: 0.954131 Val Acc: 0.957100
Validation loss decreased (0.143430 --> 0.142699).  Saving model ...


 54%|█████▍    | 162/299 [09:22<02:08,  1.07it/s]

num_directions 1


 55%|█████▍    | 163/299 [09:52<21:58,  9.70s/it]

Epoch: 4/5... Step: 1060... Loss: 0.137240... Val Loss: 0.142344 Train Acc: 0.954314 Val Acc: 0.957507
Validation loss decreased (0.142699 --> 0.142344).  Saving model ...


 58%|█████▊    | 172/299 [09:57<02:00,  1.05it/s]

num_directions 1


 58%|█████▊    | 173/299 [10:28<20:32,  9.78s/it]

Epoch: 4/5... Step: 1070... Loss: 0.143793... Val Loss: 0.143588 Train Acc: 0.954315 Val Acc: 0.957085


 61%|██████    | 182/299 [10:33<01:51,  1.05it/s]

num_directions 1


 61%|██████    | 183/299 [11:03<18:29,  9.56s/it]

Epoch: 4/5... Step: 1080... Loss: 0.128265... Val Loss: 0.144566 Train Acc: 0.954313 Val Acc: 0.957245


 64%|██████▍   | 192/299 [11:08<01:38,  1.09it/s]

num_directions 1


 65%|██████▍   | 193/299 [11:37<16:31,  9.35s/it]

Epoch: 4/5... Step: 1090... Loss: 0.161126... Val Loss: 0.141608 Train Acc: 0.954244 Val Acc: 0.957739
Validation loss decreased (0.142344 --> 0.141608).  Saving model ...


 68%|██████▊   | 202/299 [11:42<01:32,  1.05it/s]

num_directions 1


 68%|██████▊   | 203/299 [12:12<15:29,  9.68s/it]

Epoch: 4/5... Step: 1100... Loss: 0.153502... Val Loss: 0.140003 Train Acc: 0.954321 Val Acc: 0.957812
Validation loss decreased (0.141608 --> 0.140003).  Saving model ...


 71%|███████   | 212/299 [12:17<01:23,  1.04it/s]

num_directions 1


 71%|███████   | 213/299 [12:49<14:25, 10.06s/it]

Epoch: 4/5... Step: 1110... Loss: 0.152605... Val Loss: 0.139581 Train Acc: 0.954440 Val Acc: 0.957817
Validation loss decreased (0.140003 --> 0.139581).  Saving model ...


 74%|███████▍  | 222/299 [12:54<01:13,  1.05it/s]

num_directions 1


 75%|███████▍  | 223/299 [13:24<12:16,  9.69s/it]

Epoch: 4/5... Step: 1120... Loss: 0.150399... Val Loss: 0.140635 Train Acc: 0.954608 Val Acc: 0.957575


 78%|███████▊  | 232/299 [13:29<01:02,  1.07it/s]

num_directions 1


 78%|███████▊  | 233/299 [13:58<10:22,  9.44s/it]

Epoch: 4/5... Step: 1130... Loss: 0.131598... Val Loss: 0.144113 Train Acc: 0.954570 Val Acc: 0.956809


 81%|████████  | 242/299 [14:04<00:54,  1.05it/s]

num_directions 1


 81%|████████▏ | 243/299 [14:35<09:21, 10.03s/it]

Epoch: 4/5... Step: 1140... Loss: 0.174596... Val Loss: 0.140527 Train Acc: 0.954641 Val Acc: 0.958127


 84%|████████▍ | 252/299 [14:40<00:45,  1.04it/s]

num_directions 1


 85%|████████▍ | 253/299 [15:10<07:28,  9.74s/it]

Epoch: 4/5... Step: 1150... Loss: 0.127369... Val Loss: 0.142455 Train Acc: 0.954704 Val Acc: 0.957851


 88%|████████▊ | 262/299 [15:16<00:34,  1.06it/s]

num_directions 1


 88%|████████▊ | 263/299 [15:46<05:47,  9.66s/it]

Epoch: 4/5... Step: 1160... Loss: 0.145951... Val Loss: 0.141188 Train Acc: 0.954762 Val Acc: 0.957890


 91%|█████████ | 272/299 [15:51<00:25,  1.06it/s]

num_directions 1


 91%|█████████▏| 273/299 [16:21<04:12,  9.70s/it]

Epoch: 4/5... Step: 1170... Loss: 0.188504... Val Loss: 0.141294 Train Acc: 0.954774 Val Acc: 0.957841


 94%|█████████▍| 282/299 [16:26<00:15,  1.07it/s]

num_directions 1


 95%|█████████▍| 283/299 [16:56<02:34,  9.68s/it]

Epoch: 4/5... Step: 1180... Loss: 0.175855... Val Loss: 0.141463 Train Acc: 0.954800 Val Acc: 0.957885


 98%|█████████▊| 292/299 [17:01<00:06,  1.07it/s]

num_directions 1


 98%|█████████▊| 293/299 [17:31<00:57,  9.59s/it]

Epoch: 4/5... Step: 1190... Loss: 0.115247... Val Loss: 0.139020 Train Acc: 0.954877 Val Acc: 0.958345
Validation loss decreased (0.139581 --> 0.139020).  Saving model ...


100%|██████████| 299/299 [17:35<00:00,  3.53s/it]


num_directions 1


  1%|          | 3/299 [00:01<02:46,  1.78it/s]

num_directions 1


  1%|▏         | 4/299 [00:31<58:55, 11.98s/it]

Epoch: 5/5... Step: 1200... Loss: 0.175332... Val Loss: 0.136589 Train Acc: 0.956909 Val Acc: 0.958830
Validation loss decreased (0.139020 --> 0.136589).  Saving model ...


  4%|▍         | 13/299 [00:36<04:30,  1.06it/s]

num_directions 1


  5%|▍         | 14/299 [01:06<46:26,  9.78s/it]

Epoch: 5/5... Step: 1210... Loss: 0.169176... Val Loss: 0.137355 Train Acc: 0.956734 Val Acc: 0.958127


  8%|▊         | 23/299 [01:12<04:23,  1.05it/s]

num_directions 1


  8%|▊         | 24/299 [01:42<44:39,  9.74s/it]

Epoch: 5/5... Step: 1220... Loss: 0.151267... Val Loss: 0.139722 Train Acc: 0.956574 Val Acc: 0.958011


 11%|█         | 33/299 [01:47<04:13,  1.05it/s]

num_directions 1


 11%|█▏        | 34/299 [02:18<44:00,  9.96s/it]

Epoch: 5/5... Step: 1230... Loss: 0.128257... Val Loss: 0.138706 Train Acc: 0.956884 Val Acc: 0.958326


 14%|█▍        | 43/299 [02:23<04:05,  1.04it/s]

num_directions 1


 15%|█▍        | 44/299 [02:54<42:10,  9.92s/it]

Epoch: 5/5... Step: 1240... Loss: 0.138746... Val Loss: 0.136813 Train Acc: 0.957171 Val Acc: 0.958360


 18%|█▊        | 53/299 [03:00<03:57,  1.04it/s]

num_directions 1


 18%|█▊        | 54/299 [03:30<40:23,  9.89s/it]

Epoch: 5/5... Step: 1250... Loss: 0.193137... Val Loss: 0.135482 Train Acc: 0.956644 Val Acc: 0.958660
Validation loss decreased (0.136589 --> 0.135482).  Saving model ...


 21%|██        | 63/299 [03:35<03:44,  1.05it/s]

num_directions 1


 21%|██▏       | 64/299 [04:06<38:02,  9.71s/it]

Epoch: 5/5... Step: 1260... Loss: 0.129727... Val Loss: 0.135797 Train Acc: 0.956974 Val Acc: 0.959266


 24%|██▍       | 73/299 [04:11<03:31,  1.07it/s]

num_directions 1


 25%|██▍       | 74/299 [04:42<37:15,  9.94s/it]

Epoch: 5/5... Step: 1270... Loss: 0.117309... Val Loss: 0.134334 Train Acc: 0.957372 Val Acc: 0.959251
Validation loss decreased (0.135482 --> 0.134334).  Saving model ...


 28%|██▊       | 83/299 [04:47<03:28,  1.03it/s]

num_directions 1


 28%|██▊       | 84/299 [05:18<36:14, 10.11s/it]

Epoch: 5/5... Step: 1280... Loss: 0.132435... Val Loss: 0.135414 Train Acc: 0.957483 Val Acc: 0.958699


 31%|███       | 93/299 [05:24<03:20,  1.03it/s]

num_directions 1


 31%|███▏      | 94/299 [05:54<33:37,  9.84s/it]

Epoch: 5/5... Step: 1290... Loss: 0.168406... Val Loss: 0.134215 Train Acc: 0.957325 Val Acc: 0.958951
Validation loss decreased (0.134334 --> 0.134215).  Saving model ...


 34%|███▍      | 103/299 [06:00<03:08,  1.04it/s]

num_directions 1


 35%|███▍      | 104/299 [06:30<31:33,  9.71s/it]

Epoch: 5/5... Step: 1300... Loss: 0.148332... Val Loss: 0.135545 Train Acc: 0.957392 Val Acc: 0.958786


 38%|███▊      | 113/299 [06:35<02:53,  1.07it/s]

num_directions 1


 38%|███▊      | 114/299 [07:06<30:42,  9.96s/it]

Epoch: 5/5... Step: 1310... Loss: 0.143814... Val Loss: 0.137626 Train Acc: 0.957295 Val Acc: 0.959091


 41%|████      | 123/299 [07:11<02:47,  1.05it/s]

num_directions 1


 41%|████▏     | 124/299 [07:42<28:46,  9.87s/it]

Epoch: 5/5... Step: 1320... Loss: 0.146819... Val Loss: 0.137258 Train Acc: 0.956973 Val Acc: 0.958975


 44%|████▍     | 133/299 [07:47<02:39,  1.04it/s]

num_directions 1


 45%|████▍     | 134/299 [08:17<27:00,  9.82s/it]

Epoch: 5/5... Step: 1330... Loss: 0.150645... Val Loss: 0.134719 Train Acc: 0.957015 Val Acc: 0.958505


 48%|████▊     | 143/299 [08:23<02:28,  1.05it/s]

num_directions 1


 48%|████▊     | 144/299 [08:53<25:24,  9.83s/it]

Epoch: 5/5... Step: 1340... Loss: 0.163162... Val Loss: 0.135050 Train Acc: 0.956971 Val Acc: 0.958607


 51%|█████     | 153/299 [08:59<02:20,  1.04it/s]

num_directions 1


 52%|█████▏    | 154/299 [09:30<24:08,  9.99s/it]

Epoch: 5/5... Step: 1350... Loss: 0.145833... Val Loss: 0.134200 Train Acc: 0.957094 Val Acc: 0.959120
Validation loss decreased (0.134215 --> 0.134200).  Saving model ...


 55%|█████▍    | 163/299 [09:35<02:11,  1.04it/s]

num_directions 1


 55%|█████▍    | 164/299 [10:05<21:58,  9.77s/it]

Epoch: 5/5... Step: 1360... Loss: 0.146663... Val Loss: 0.131837 Train Acc: 0.957108 Val Acc: 0.959440
Validation loss decreased (0.134200 --> 0.131837).  Saving model ...


 58%|█████▊    | 173/299 [10:10<01:58,  1.06it/s]

num_directions 1


 58%|█████▊    | 174/299 [10:40<19:52,  9.54s/it]

Epoch: 5/5... Step: 1370... Loss: 0.144638... Val Loss: 0.132947 Train Acc: 0.957145 Val Acc: 0.958965


 61%|██████    | 183/299 [10:45<01:48,  1.07it/s]

num_directions 1


 62%|██████▏   | 184/299 [11:15<18:24,  9.60s/it]

Epoch: 5/5... Step: 1380... Loss: 0.152796... Val Loss: 0.135529 Train Acc: 0.957136 Val Acc: 0.958863


 65%|██████▍   | 193/299 [11:20<01:38,  1.08it/s]

num_directions 1


 65%|██████▍   | 194/299 [11:50<16:42,  9.55s/it]

Epoch: 5/5... Step: 1390... Loss: 0.146846... Val Loss: 0.135007 Train Acc: 0.957285 Val Acc: 0.958747


 68%|██████▊   | 203/299 [11:55<01:27,  1.09it/s]

num_directions 1


 68%|██████▊   | 204/299 [12:24<15:04,  9.52s/it]

Epoch: 5/5... Step: 1400... Loss: 0.139885... Val Loss: 0.133494 Train Acc: 0.957214 Val Acc: 0.959314


 71%|███████   | 213/299 [12:30<01:21,  1.06it/s]

num_directions 1


 72%|███████▏  | 214/299 [13:00<13:38,  9.62s/it]

Epoch: 5/5... Step: 1410... Loss: 0.123967... Val Loss: 0.134138 Train Acc: 0.957158 Val Acc: 0.958549


 75%|███████▍  | 223/299 [13:05<01:13,  1.04it/s]

num_directions 1


 75%|███████▍  | 224/299 [13:36<12:40, 10.14s/it]

Epoch: 5/5... Step: 1420... Loss: 0.146178... Val Loss: 0.131404 Train Acc: 0.957033 Val Acc: 0.959789
Validation loss decreased (0.131837 --> 0.131404).  Saving model ...


 78%|███████▊  | 233/299 [13:42<01:03,  1.04it/s]

num_directions 1


 78%|███████▊  | 234/299 [14:13<10:48,  9.98s/it]

Epoch: 5/5... Step: 1430... Loss: 0.103013... Val Loss: 0.133723 Train Acc: 0.957063 Val Acc: 0.959338


 81%|████████▏ | 243/299 [14:18<00:53,  1.05it/s]

num_directions 1


 82%|████████▏ | 244/299 [14:48<09:01,  9.85s/it]

Epoch: 5/5... Step: 1440... Loss: 0.143341... Val Loss: 0.130834 Train Acc: 0.957127 Val Acc: 0.959881
Validation loss decreased (0.131404 --> 0.130834).  Saving model ...


 85%|████████▍ | 253/299 [14:54<00:43,  1.05it/s]

num_directions 1


 85%|████████▍ | 254/299 [15:26<07:41, 10.25s/it]

Epoch: 5/5... Step: 1450... Loss: 0.119487... Val Loss: 0.131383 Train Acc: 0.957187 Val Acc: 0.959682


 88%|████████▊ | 263/299 [15:31<00:36,  1.01s/it]

num_directions 1


 88%|████████▊ | 264/299 [16:02<05:46,  9.91s/it]

Epoch: 5/5... Step: 1460... Loss: 0.129265... Val Loss: 0.130719 Train Acc: 0.957195 Val Acc: 0.959769
Validation loss decreased (0.130834 --> 0.130719).  Saving model ...


 91%|█████████▏| 273/299 [16:07<00:25,  1.04it/s]

num_directions 1


 92%|█████████▏| 274/299 [16:37<04:04,  9.78s/it]

Epoch: 5/5... Step: 1470... Loss: 0.132637... Val Loss: 0.131133 Train Acc: 0.957193 Val Acc: 0.959503


 95%|█████████▍| 283/299 [16:43<00:15,  1.06it/s]

num_directions 1


 95%|█████████▍| 284/299 [17:13<02:26,  9.76s/it]

Epoch: 5/5... Step: 1480... Loss: 0.137535... Val Loss: 0.131571 Train Acc: 0.957187 Val Acc: 0.959348


 98%|█████████▊| 293/299 [17:18<00:05,  1.06it/s]

num_directions 1


 98%|█████████▊| 294/299 [17:49<00:49,  9.99s/it]

Epoch: 5/5... Step: 1490... Loss: 0.112944... Val Loss: 0.130492 Train Acc: 0.957220 Val Acc: 0.959828
Validation loss decreased (0.130719 --> 0.130492).  Saving model ...


100%|██████████| 299/299 [17:52<00:00,  3.59s/it]


## Inference

In [20]:
meta_data = joblib.load("meta.bin")
enc_tag = meta_data["enc_tag"]

num_tag = len(list(enc_tag.classes_))

text = """
Natasha is traveling to New York
"""

device = torch.device("cuda")
model.to(device)

# так как это инференс, выключаем расчет градиентов:
with torch.no_grad():
    # inputs = torch.tensor([fasttext_model.wv[s] for s in word_tokenize(text)], dtype=torch.float32)
    inputs = torch.tensor([fasttext_model[s] for s in word_tokenize(text)], dtype=torch.float32)
    inputs = inputs.unsqueeze(0).to(device)
    h = model.init_hidden(1)
    tag, h = model(inputs, h)

    print(
        enc_tag.inverse_transform(
            tag.argmax(-1).cpu().numpy().reshape(-1)
        )
    )

  inputs = torch.tensor([fasttext_model[s] for s in word_tokenize(text)], dtype=torch.float32)


num_directions 1
['B-per' 'O' 'O' 'O' 'B-geo' 'I-geo']
