In [6]:
import numpy as np
import pandas as pd

In [7]:
import random

In [8]:
from tqdm import tqdm

In [9]:
import re
import os

In [152]:
import time

In [10]:
from collections import Counter

In [11]:
from sklearn.model_selection import train_test_split

#### PyTorch

In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, Subset
from torch.utils.data.sampler import SequentialSampler
from torch.utils.data import DataLoader

import torch.optim as optim

In [179]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [13]:
import torchinfo

#### Embeddings

In [14]:
# Should be used by default. Shows best results on intrinsic evaluations.
# Model was trained on large corpus of an literature (~150GB).

# !wget https://storage.yandexcloud.net/natasha-navec/packs/navec_hudlit_v1_12B_500K_300d_100q.tar

In [15]:
from navec import Navec

#### Metrics

In [16]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

In [17]:
from Levenshtein import distance as levenshtein_distance

#### Visualisation

In [18]:
import matplotlib
import matplotlib.pyplot as plt

In [19]:
import scienceplots

plt.style.use('science')
%config InlineBackend.figure_format = 'retina'

lables_fs = 16
ticks_fs = 12

In [20]:
import seaborn as sns

## Load data

In [21]:
prepared_dir = '../data/prepared'
filename_csv = '02_punct_pushkin.csv'

In [22]:
# load saved dataset
data_df = pd.read_csv(os.path.join(prepared_dir, filename_csv), index_col=0)
data_df.shape

(4456, 3)

In [23]:
pd.options.display.max_colwidth = 150
data_df.sample(5)

Unnamed: 0,input,input_lemma,new_target
2373,там они были схвачены 16 казаками и выданы победителю который отослал их скованных в уфу,там они быть схватить 16 казак и выдать победитель который отослать они сковать в уфа,S S S S S S S S C S S S S S F
4447,как досадно подумал алексей,как досадно подумать алексей,S C S F
3798,я сам отвечал граф с видом чрезвычайно расстроенным а простреленная картина есть памятник последней нашей встречи,я сам отвечать граф с вид чрезвычайно расстроить а прострелить картина быть памятник последний наш встреча,S C S S S S S C S S S S S S S F
4067,кто в минуту гнева не требовал от них роковой книги дабы вписать в оную свою бесполезную жалобу на притеснение грубость и неисправность,кто в минута гнев не требовать от они роковой книга дабы вписать в оную свой бесполезный жалоба на притеснение грубость и неисправность,C S S C S S S S S C S S S S S S S S C S S F
2350,в берде найдено осьмнадцать пушек семнадцать бочек медных денег 13 и множество хлеба,в берда найти осьмнадцать пушка семнадцать бочка медный деньга 13 и множество хлеб,S S S S C S S S S S S S F


## Pretrained Embeddings

In that implementation we will use [`navec`](https://github.com/natasha/navec#evaluation) library of pretrained word embeddings for Russian language.

In [37]:
navec_path = 'navec_hudlit_v1_12B_500K_300d_100q.tar'
navec_embed = Navec.load(navec_path)

In [42]:
EMBED_SIZE = 300

### Example

In [36]:
ex_sent = data_df.loc[1321]['input'] # data_df.sample(1)['input'].item()
ex_sent

'мы поцеловались горячо искренно и таким образом все было между нами решено'

In [60]:
sent_embed_tensor = None

for word in ex_sent.split():
    word_embed = navec_embed[word]
    
    assert len(word_embed) == EMBED_SIZE
    assert isinstance(word_embed, np.ndarray)

    word_embed_tensor = torch.tensor(word_embed)
    word_embed_tensor = word_embed_tensor.unsqueeze(dim=0)

    # size of word embedding: 1 * embed_size
    assert word_embed_tensor.size() == (1, EMBED_SIZE)

    if sent_embed_tensor is None:
        sent_embed_tensor = word_embed_tensor
    else:
        sent_embed_tensor = torch.cat(
            (sent_embed_tensor, word_embed_tensor), 
            dim=0
        )

# size of sentence tensor: n_words * embed_size
assert sent_embed_tensor.size() == (len(ex_sent.split()), EMBED_SIZE)

In [94]:
sent_embed_tensor.size()[0]

12

In [92]:
navec_embed['<pad>'][:10]  # padding: all zeros!

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [82]:
'чушь-чепуха' in navec_embed

False

## Dataset

**In that approach** we will use _original_ (not lemmatized) sentences as input (`input` column)

In [87]:
IGNORE_ID = -1  # token id to ignore

# punctuation vocab
PUNC_2_ID = {'S': 0, 'C': 1, 'F':2}
ID_2_PUNC = {v: k for k, v in PUNC_2_ID.items()}

In [88]:
PAD = '<pad>'
UNK = '<unk>'

In [99]:
class PuncDataset(Dataset):
    """Custom Dataset for punctuation prediction"""

    def __init__(self, df, sent_col, target_col, embed):
        self.sentences = df[sent_col]  # all sentences
        self.targets = df[target_col]  # all targets

        self.embed = embed  # navec embedding

    def __len__(self):
        """Return number of sentences"""
        return len(self.sentences)

    def __getitem__(self, index):
        """Return one Tensor pair of (input id sequence, punc id sequence)"""
        sentence = self.sentences[index]
        target = self.targets[index]
        
        word_id_seq, punc_id_seq = self._preprocess(sentence, target)
        return word_id_seq, punc_id_seq

    def _preprocess(self, sentence, target):
        """Convert txt sequence to word-id-seq and punc-id-seq"""
        # INPUT
        input_tensor = None
        
        for word in sentence.split():
            if word in self.embed:  # if word in vocab
                word_embed = navec_embed[word]
            else:
                word_embed = navec_embed[UNK]
    
            assert len(word_embed) == EMBED_SIZE
            assert isinstance(word_embed, np.ndarray)
        
            word_embed_tensor = torch.tensor(word_embed)  # size: [embed_size]
            word_embed_tensor = word_embed_tensor.unsqueeze(dim=0)  # size: [1, embed_size]
        
            assert word_embed_tensor.size() == (1, EMBED_SIZE)
        
            if input_tensor is None:
                input_tensor = word_embed_tensor
            else:
                input_tensor = torch.cat(
                    (input_tensor, word_embed_tensor), 
                    dim=0
                )
        # size: [len_sent, embed_size]
        assert input_tensor.size() == (len(sentence.split()), EMBED_SIZE)
        
        # OUTPUT
        output_seq = []
        for punc in target.split():
            output_seq.append(PUNC_2_ID.get(punc))

        assert input_tensor.size()[0] == len(output_seq)

        return input_tensor, torch.LongTensor(output_seq)

In [100]:
data_ds = PuncDataset(
    df=data_df, 
    sent_col='input', 
    target_col='new_target',
    embed=navec_embed
)

print(len(data_ds))  # dataset length

4456


In [101]:
data_ds[1111]  # example of dataset element

(tensor([[ 0.0693, -0.7137,  0.1012,  ..., -0.5967, -0.1622,  0.1446],
         [-0.5378, -0.6289, -0.3964,  ..., -0.4420, -0.1768,  0.5108],
         [-0.5378, -0.6289, -0.3964,  ..., -0.2754,  0.0252,  0.3528],
         ...,
         [-0.0273, -0.1946, -0.1695,  ..., -0.0571,  0.1050,  0.3030],
         [-0.2140,  0.2572, -0.2415,  ...,  0.4958,  0.5133,  0.0219],
         [-0.0460,  0.0078,  0.3391,  ...,  0.0026, -0.0379, -0.0747]]),
 tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

### Batches

In [181]:
def collate_fn(batch):
    """
   Process one mini-batch samples, such as sorting and padding.
    Args:
        batch: a list of (sentence tensor, targets tensor)
    Returns:
        input_padded_tensor
        output_padded_tensor
        lengths
    """
    # sort a list by sequence length (descending order) to use pack_padded_sequence
    batch.sort(key=lambda x: len(x[0]), reverse=True)

    # seperate inputs and labels
    input_tensors, label_seqs = zip(*batch)
    # padding
    lengths = [len(seq) for seq in label_seqs]
    bs = len(label_seqs)
    max_sent_len = max(lengths)
    
        # shape: [batch_size, max_sent_len, embed_size]
    input_padded_tensor = torch.zeros(bs, max_sent_len, EMBED_SIZE)  # zeros are the padding for navec!
        # shape: [batch_size, max_sent_len]
    output_padded_tensor = torch.zeros(bs, max_sent_len).fill_(IGNORE_ID).long()
    
    for i, (input_tensor, output_seq) in enumerate(zip(input_tensors, label_seqs)):
        end = lengths[i]
        input_padded_tensor[i, :end, :] = input_tensor[:end, :]
        output_padded_tensor[i, :end] = output_seq[:end]
        
    return input_padded_tensor, output_padded_tensor, torch.IntTensor(lengths)

In [182]:
data_loader = DataLoader(
    data_ds,
    batch_size=10, 
    drop_last=False,
    collate_fn=collate_fn,  # custom collate function6 defined above!
    num_workers=0
)

In [183]:
next(iter(data_loader))[0].size()

torch.Size([10, 24, 300])

## Model

In [242]:
class LstmPunctuator(nn.Module):
    def __init__(
        self,
        hidden_size, num_layers, bidirectional,
        num_class, dropout=0.0
    ):
        super(LstmPunctuator, self).__init__()
        
        # Hyper-parameters
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_class = num_class
        self.dropout = dropout
        
        # Components
        self.lstm = nn.LSTM(
            EMBED_SIZE, hidden_size, num_layers,
            batch_first=True,
            bidirectional=bool(bidirectional),
            dropout=self.dropout
        )
        fc_in_dim = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(fc_in_dim, num_class)

    def forward(self, padded_input, input_lengths):
        """
        Args:
            padded_input: [bs, max_sent_len, EMBED_SIZE]
            input_lengths: [bs]
        Returns:
            score: [bs, max_sent_len, num_classes]
        """
        # LSTM Layers
        total_length = padded_input.size(1)  # get the max sequence length
        packed_input = pack_padded_sequence(
            padded_input, input_lengths,
            batch_first=True
        )
        packed_output, _ = self.lstm(packed_input)
        output, _ = pad_packed_sequence(
            packed_output,
            batch_first=True,
            total_length=total_length
        )
        # Output Layer
        score = self.fc(output)
        return score

In [403]:
class GruPunctuator(nn.Module):
    def __init__(
        self,
        hidden_size, num_layers, bidirectional,
        num_class, dropout=0.0
    ):
        super(GruPunctuator, self).__init__()
        
        # Hyper-parameters
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_class = num_class
        self.dropout = dropout
        self.gru_h = None
        
        # Components
        self.gru = nn.GRU(
            EMBED_SIZE, hidden_size, num_layers,
            batch_first=True,
            bidirectional=bool(bidirectional),
            dropout=self.dropout
        )
        fc_in_dim = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(fc_in_dim, num_class)

    def forward(self, padded_input, input_lengths):
        """
        Args:
            padded_input: [bs, max_sent_len, EMBED_SIZE]
            input_lengths: [bs]
        Returns:
            score: [bs, max_sent_len, num_classes]
        """
        # LSTM Layers
        total_length = padded_input.size(1)  # get the max sequence length
        packed_input = pack_padded_sequence(
            padded_input, input_lengths,
            batch_first=True
        )

        self.reset_hidden(padded_input.size(0))
        packed_output, gru_h = self.gru(packed_input, self.gru_h)
        # self.gru_h = gru_h.detach()
        
        output, _ = pad_packed_sequence(
            packed_output,
            batch_first=True,
            total_length=total_length
        )
        # Output Layer
        score = self.fc(output)
        return score

    def reset_hidden(self, batch_size):
        size_0 = self.num_layers * 2 if self.bidirectional else self.num_layers
        self.gru_h = torch.zeros(size_0 , batch_size, self.hidden_size)  # .to(device)

## Training and Evaluationg loops

In [195]:
def train_fn(model, data_loader, loss_func, optimizer,
             device='cpu', show_process=False):
    '''
    Function to train `model`
    Args:
        model: torch.nn.Module - Neural Network
        data_loader: torch.utils.data.DataLoader - loader (by batches) for the train dataset
        loss_func - loss function
        optimizer: torch.optim
        device: str - device to computate on
        show_process: bool - flag to show (or not) a progress bar
    Returns:
        mean loss by batches
    '''
    model.train()  # activate 'train' mode of a model
    train_loss = []  # to store loss for each batch

    for X, y, y_lengths in tqdm(data_loader, total=len(data_loader),
                                desc='train', position=0,
                                leave=True, disable=not show_process):  # [X, y] - batch
        X, y = X.to(device), y.to(device)
        y_lengths = y_lengths.to(device)
        
        optimizer.zero_grad()

        y_hat = model(X, y_lengths)  # size: [bs, max_sent_length, num_classes]

        y_hat = y_hat.view(-1, y_hat.size(-1))
        loss = loss_func(y_hat, y.view(-1))  # loss calculation for the batch
        
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
        optimizer.step()

        train_loss.append(loss.item())  # accumulate losses for batches

    return np.mean(train_loss)  # return mean loss of the epoch

In [196]:
def validate_fn(model, data_loader, loss_func,
                device='cpu', show_process=False):
    '''
    Function to train `model`
    Args:
        model: torch.nn.Module - Neural Network
        data_loader: torch.utils.data.DataLoader - loader (by batches) for the validation dataset
        loss_func - loss function
        device: str - device to computate on
        show_process: bool - flag to show (or not) a progress bar
    Returns:
          mean loss by batches
    '''
    model.eval()  # activate 'eval' mode of a model
    val_loss = []  # to store loss for each batch

    for X, y, y_lengths in tqdm(data_loader, total=len(data_loader),
                                desc='validation', position=0,
                                leave=True, disable=not show_process):  # [X, y] - batch
        X, y = X.to(device), y.to(device)
        y_lengths = y_lengths.to(device)

        with torch.no_grad():
            y_hat = model(X, y_lengths)  # size: [bs, max_sent_length, num_classes]

            y_hat = y_hat.view(-1, y_hat.size(-1))
            loss = loss_func(y_hat, y.view(-1))  # loss calculation for the batch

        val_loss.append(loss.item())  # accumulate losses for batches

    return np.mean(val_loss)

## Model training

### DataLoaders

In [357]:
splitting_random_state = 78
test_ratio = 0.25

train_bs = 15
val_bs = 10

In [358]:
# data splitting
train_df, test_df = train_test_split(
    data_df, 
    test_size=test_ratio, 
    random_state=splitting_random_state
)

train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [359]:
# datasets
train_ds = PuncDataset(
    df=train_df, 
    sent_col='input', 
    target_col='new_target',
    embed=navec_embed
)

test_ds = PuncDataset(
    df=test_df, 
    sent_col='input', 
    target_col='new_target',
    embed=navec_embed
)

In [360]:
train_loader = DataLoader(
    train_ds,
    batch_size=train_bs, 
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    test_ds,
    batch_size=val_bs, 
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=0
)

### LSTM Model

In [361]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [373]:
# model parameters
hidden_sz = 32
num_layers = 2
bidir = 1
dropout = 0.2

In [362]:
# MODEL
model = LstmPunctuator(
    hidden_size=hidden_sz, num_layers=num_layers, bidirectional=bidir,
    dropout=dropout,
    num_class=len(PUNC_2_ID)
).to(device)

# criterion
loss_func = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_ID)
# optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
    weight_decay=0.0
)
# scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.5,  # default: 0.1
    patience=2,  # default: 10
)

In [363]:
n_epochs = 20
print_each = 1

start_time = time.time()
prev_val_loss = 100
for epoch in range(n_epochs):
    start_epoch_time = time.time()
    if (epoch == 0) or ((epoch + 1) % print_each == 0):
        print(f'Epoch #{epoch + 1}: ', end='')

    # torch.manual_seed(48)  # for reproducibility
    mean_train_loss = train_fn(model, train_loader, loss_func,
                               optimizer,
                               device=device,
                               show_process=False
                              )  # train the model
    mean_val_loss = validate_fn(model, val_loader, loss_func,
                                device=device,
                                show_process=False
                               )  # evaluate the model
    
    if (epoch == 0) or ((epoch + 1) % print_each == 0):
        log_info = (f'\ttrain - {mean_train_loss:.6f}; ' +
                    f'\tval - {mean_val_loss:.6f}' + 
                    f'\t\ttime - {(time.time() - start_time):.3f} s'
                   )
        print(log_info)

    if prev_val_loss < mean_val_loss:
            break
    prev_val_loss = mean_val_loss

    if scheduler:
        scheduler.step(mean_val_loss)

Epoch #1: 	train - 0.522074; 	val - 0.371870		time - 3.032 s
Epoch #2: 	train - 0.353646; 	val - 0.326968		time - 5.998 s
Epoch #3: 	train - 0.308811; 	val - 0.296496		time - 8.954 s
Epoch #4: 	train - 0.275858; 	val - 0.278774		time - 11.877 s
Epoch #5: 	train - 0.248955; 	val - 0.268314		time - 14.804 s
Epoch #6: 	train - 0.225805; 	val - 0.261436		time - 17.758 s
Epoch #7: 	train - 0.205653; 	val - 0.259821		time - 20.778 s
Epoch #8: 	train - 0.187178; 	val - 0.260794		time - 23.706 s


In [372]:
# saving the model
lstm_model_path = f'serialized/lstm-model_hidden-{hidden_sz}_layers-{num_layers}_bidir-{bidir}_dropout-{dropout:.2f}_celoss-{mean_val_loss:.3f}.pth'
torch.save(model.state_dict(), lstm_model_path)

#### LSTM model metrics

In [374]:
model_test = LstmPunctuator(
    hidden_size=hidden_sz, num_layers=num_layers, bidirectional=bidir,
    dropout=dropout,
    num_class=len(PUNC_2_ID)
).to(device)

# LOAD MODEL
model_test.load_state_dict(torch.load(lstm_model_path))

<All keys matched successfully>

In [389]:
END_PUNC = ['F']
INTR_PUNC = ['S', 'C']

NAMES_PUNC = {
    'S': 'space (` `)',
    'C': 'comma (`,`)',
    'F': 'end of sent',
}

CLASSES = sorted(END_PUNC + INTR_PUNC)  # alphabetic order

In [396]:
def get_predictions_df(model, test_df):
    test_ds = PuncDataset(
        df=test_df, 
        sent_col='input', 
        target_col='new_target',
        embed=navec_embed
    )  # test dataset
    test_loader = DataLoader(
        test_ds,
        batch_size=1, 
        drop_last=False,
        collate_fn=collate_fn,
        num_workers=0
    )  # test DataLoader

    all_test_targets = []  # by markers
    all_test_preds = []
    
    model.eval()
    
    for i, (data) in enumerate(test_loader):
        padded_input, padded_target, input_lengths = data
        all_test_targets.append(' '.join([ID_2_PUNC[ix.item()] for ix in padded_target[0]]))
        
        pred = model(padded_input, input_lengths)
        pred = torch.argmax(pred.view(-1, pred.size(-1)), dim=1)
        all_test_preds.append(' '.join([ID_2_PUNC[ix.item()] for ix in pred]))

        assert len(pred) == len(padded_target[0])

    # DataFrame with results
    target_vs_pred_df = pd.DataFrame()

    target_vs_pred_df['target'] = all_test_targets
    target_vs_pred_df['predicted'] = all_test_preds

    return target_vs_pred_df


def return_separate_punct(target_vs_pred_df):
    test_all_punc_target = []  # list of all punctuation
    test_all_punc_preds = []
    
    for target_this, predicted_this in zip(target_vs_pred_df['target'], target_vs_pred_df['predicted']):
        test_all_punc_target.extend(target_this.split(' '))
        test_all_punc_preds.extend(predicted_this.split(' '))
    
    assert len(test_all_punc_target) == len(test_all_punc_preds)
    
    return test_all_punc_target, test_all_punc_preds


def get_all_metrics(model, test_df):
    target_vs_pred_df = get_predictions_df(model, test_df)
    test_all_punc_target, test_all_punc_preds = return_separate_punct(target_vs_pred_df)

    cm = confusion_matrix(test_all_punc_target, test_all_punc_preds)
    # precision = TP / (TP + FP)
    precision = precision_score(test_all_punc_target, test_all_punc_preds, average=None, zero_division=np.nan)
    # recall = TP / (TP + FN)
    recall = recall_score(test_all_punc_target, test_all_punc_preds, average=None, zero_division=np.nan)
    # f1 = 2TP / (2TP + FP + FN)
    f1 = f1_score(test_all_punc_target, test_all_punc_preds, average=None)

    # PRINT
    metrics_names = ['Precision', 'Recall', 'F1 score']
    metrics = {'Precision': precision, 'Recall': recall, 'F1 score': f1}
    col_w = 18
    
    print(' ' * col_w + '|' + ''.join([f"{NAMES_PUNC[token] + (col_w - len(NAMES_PUNC[token])) * ' '}|" for token in CLASSES]))  # header
    print(''.join(['-' * col_w + '|' for _ in range(len(CLASSES) + 1)]) )
    for ind, metric_name in enumerate(metrics_names):
        row = f"{metric_name + (col_w - len(metric_name)) * ' '}|"
        for score in metrics[metric_name]:
            score_str = f'{score:.6f}'
            row += f"{score_str + (col_w - len(score_str)) * ' '}|"
        print(row)

    # Levenshtein distance
    print('\nLevenshtein distance:')
    target_vs_pred_df['levenshtein'] = target_vs_pred_df.apply(
        lambda row: levenshtein_distance(row.target, row.predicted),
        axis = 1
    )
    print(f"\tMean: {target_vs_pred_df['levenshtein'].mean()}")
    print(f"\tMIN : {target_vs_pred_df['levenshtein'].min()}")
    print(f"\tMAX : {target_vs_pred_df['levenshtein'].max()}\n")

In [397]:
%%time
get_all_metrics(model_test, test_df)

                  |comma (`,`)       |end of sent       |space (` `)       |
------------------|------------------|------------------|------------------|
Precision         |0.687129          |1.000000          |0.918454          |
Recall            |0.601647          |0.999102          |0.942514          |
F1 score          |0.641553          |0.999551          |0.930329          |

Levenshtein distance:
	Mean: 1.3913824057450628
	MIN : 0
	MAX : 13

CPU times: user 1.33 s, sys: 3.96 ms, total: 1.33 s
Wall time: 1.33 s


**All metrics increased (except of small decreasing for `end of sentence` prediction) in contrast to metrics of X-Punctuator baseline!**

### GRU Model

In [448]:
train_bs = 15
val_bs = 10

In [449]:
train_loader = DataLoader(
    train_ds,
    batch_size=train_bs, 
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    test_ds,
    batch_size=val_bs, 
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=0
)

In [450]:
# model parameters
hidden_sz = 32
num_layers = 2
bidir = 1
dropout = 0.2

In [459]:
# MODEL
model = GruPunctuator(
    hidden_size=hidden_sz, num_layers=num_layers, bidirectional=bidir,
    dropout=dropout,
    num_class=len(PUNC_2_ID)
).to(device)

# criterion
loss_func = torch.nn.CrossEntropyLoss(ignore_index=IGNORE_ID)
# optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3,
    weight_decay=0.0
)
# scheduler
scheduler = None

In [460]:
n_epochs = 20
print_each = 1

val_loss_th = 0.26

start_time = time.time()
prev_val_loss = 100
for epoch in range(n_epochs):
    start_epoch_time = time.time()
    if (epoch == 0) or ((epoch + 1) % print_each == 0):
        print(f'Epoch #{epoch + 1}: ', end='')

    # torch.manual_seed(48)  # for reproducibility
    mean_train_loss = train_fn(model, train_loader, loss_func,
                               optimizer,
                               device=device,
                               show_process=False
                              )  # train the model
    mean_val_loss = validate_fn(model, val_loader, loss_func,
                                device=device,
                                show_process=False
                               )  # evaluate the model
    
    if (epoch == 0) or ((epoch + 1) % print_each == 0):
        log_info = (f'\ttrain - {mean_train_loss:.6f}; ' +
                    f'\tval - {mean_val_loss:.6f}' + 
                    f'\t\ttime - {(time.time() - start_time):.3f} s'
                   )
        print(log_info)

    if prev_val_loss < mean_val_loss or mean_val_loss < val_loss_th:
            break
    prev_val_loss = mean_val_loss

    if scheduler:
        scheduler.step(mean_val_loss)

Epoch #1: 	train - 0.490619; 	val - 0.353615		time - 2.854 s
Epoch #2: 	train - 0.329946; 	val - 0.302494		time - 5.676 s
Epoch #3: 	train - 0.288781; 	val - 0.282932		time - 8.491 s
Epoch #4: 	train - 0.266346; 	val - 0.270868		time - 11.322 s
Epoch #5: 	train - 0.242723; 	val - 0.264213		time - 14.153 s
Epoch #6: 	train - 0.224372; 	val - 0.261755		time - 16.976 s
Epoch #7: 	train - 0.208505; 	val - 0.259991		time - 19.792 s


**There is noo significant difference with LSTM results...**

In [461]:
# saving the model
gru_model_path = f'serialized/gru-model_hidden-{hidden_sz}_layers-{num_layers}_bidir-{bidir}_dropout-{dropout:.2f}_celoss-{mean_val_loss:.3f}.pth'
torch.save(model.state_dict(), gru_model_path)

#### GRU model metrics

In [462]:
model_test = GruPunctuator(
    hidden_size=hidden_sz, num_layers=num_layers, bidirectional=bidir,
    dropout=dropout,
    num_class=len(PUNC_2_ID)
).to(device)

# LOAD MODEL
model_test.load_state_dict(torch.load(gru_model_path))

<All keys matched successfully>

In [463]:
%%time
get_all_metrics(model_test, test_df)

                  |comma (`,`)       |end of sent       |space (` `)       |
------------------|------------------|------------------|------------------|
Precision         |0.723112          |1.000000          |0.909720          |
Recall            |0.547898          |1.000000          |0.955976          |
F1 score          |0.623428          |1.000000          |0.932275          |

Levenshtein distance:
	Mean: 1.3680430879712746
	MIN : 0
	MAX : 10

CPU times: user 1.45 s, sys: 39.8 ms, total: 1.49 s
Wall time: 1.59 s


Better identification of `end of sentence` token! No bugs!