[Original Post](https://towardsdatascience.com/ocr-typo-detection-9dd6e396ecac)

In [None]:
# download dataset
!wget https://zenodo.org/record/3515403/files/ICDAR2019-POCR-ground-truth.zip?download=1
!mv 'ICDAR2019-POCR-ground-truth.zip?download=1' IDCAR2019-POCR-ground-truth.zip''
!unzip IDCAR2019-POCR-ground-truth.zip
!rm IDCAR2019-POCR-ground-truth.zip 

In [None]:
# import all our dependencies
!pip install transformers
!pip install pyyaml==5.4.1 # reverting to this for plotly express to work; 6.0.0 is not compatbile at the moment
import numpy as np
import pandas as pd
import re
import plotly.express as px
import pickle
from transformers import BertTokenizer
from transformers import BertForTokenClassification
import torch
import nltk.data
nltk.download('punkt')
import glob
import os
from os import path
import editdistance
import random
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import transformers
from transformers import BertTokenizer, BertConfig, BertModel, AdamW, get_linear_schedule_with_warmup

from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm, trange

torch.__version__
transformers.__version__
import matplotlib.pyplot as plt
%matplotlib inline


import seaborn as sns

In [None]:
"""
Each of the files contains OCR'd text, aligned OCR'd text and aligned GS (gold standard) text on the new line.
Aligned OCR'd and GS texts always have the same length, and missing characters are defined by '@' sign.

Folder structure:
    - training_18M_without_Finnish
        - Languages: EN/DE, etc..
            - EN1, DE1 etc..
                - 0.txt, 1.txt, etc..
                    - shape (3, len_each_tect) in order below: 

OCR_toInput: Raw OCR output
OCR_aligned: OCR output with @ inplace of where length is changed, so is the same length
GS_aligned: Ground truth
"""
# load the dataset
file = path.join(r"./training_18M_without_Finnish/EN", "EN1", "0.txt")

with open(file, "r", encoding="utf-8") as f:
    raw_text = f.readlines()

[s[:100] for s in raw_text]

In [None]:
# use NLTK package for sentence tokenization (english tokenizer)
sentence_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

In [None]:
"""
Split sentences into subtokens in places where space sign matches in both
ocr and GS sentences. So avoid situations where incorrectly split OCR'd word 
splits also the GS word
"""

def get_common_space_ids(sentence, gs_sentence):
    """Return indices of the spaces in the text which match in the both OCR and GS sentences."""
    
    # get indexes of all the spaces in ocr and gs sentences
    ocr_space_ids = [match.span()[0] for match in re.finditer(" ", sentence)]
    gs_space_ids = [match.span()[0] for match in re.finditer(" ", gs_sentence)]

    gs_cursor = 0
    new_ocr_space_ids = []
    # loop through ocr space ids
    for space_id in ocr_space_ids:
        
        # while counter less than total gs len and len less than space id; add 1 to catch up to
        while (gs_cursor < len(gs_space_ids)) and (gs_space_ids[gs_cursor] < space_id):
            gs_cursor += 1

        # if gs cursor less than len, and the two match; then it is the same spacing; append
        # only if the spacing id's are equal for the given word; then we append
        if (gs_cursor < len(gs_space_ids)) and (gs_space_ids[gs_cursor] == space_id):
            new_ocr_space_ids.append(space_id)
    
    # append length of sentence as the last element in list
    new_ocr_space_ids.append(len(sentence))
    return new_ocr_space_ids

In [None]:
sample_sentence = [s[:100] for s in raw_text]
sample_ocr = sample_sentence[1] # ocr with @
sample_gt = sample_sentence[2] # gs
get_common_space_ids(sample_ocr, sample_gt)

In [None]:
"""
Create ground truth by traversing all files, and collecting split sentences into
subtokens (variable words). Every subtoken receives the corresponding binary labels,
where 1 means that  it contains a typo (label).

Also collect algined OCR'd words and GS words for further analysis,
aligned_words, aligned_gs_words
"""
def extract_dataset(files, ocr_last_breakline_remove=False):

    # initialise variables
    words = [] # subtokens
    labels = [] # labels; 1 means has a type, 0 means no typo
    aligned_words = [] 
    aligned_gs_words = []

    print(f'There are {len(files):,} files in total')

    # loop through all the files that we have
    for file in files:
        print(file)

        # open file and parse raw text
        with open(file, "r", encoding="utf-8") as f:
            raw_text = f.readlines()

        assert len(raw_text) == 3
                    
        # Omit first 14 symbols, which contain structure defintion [OCR_toInput]
        # optional arg to remove last breakline
        if ocr_last_breakline_remove:
            aligned_ocr = raw_text[1][14:-1]
        else:
            aligned_ocr = raw_text[1][14:]
        aligned_gs = raw_text[2][14:]

        # length of the aligned texts should be the same
        assert len(aligned_ocr) == len(aligned_gs)

        # initialise variables that we keep on a file level
        # we will aggregate on a file level then add to the larger list
        file_aligned_words = []
        file_words = []
        file_aligned_gs_words = []
        file_labels = []
        
        # Getting sentence spans from the aligned OCRed text
        sentence_spans = sentence_tokenizer.span_tokenize(aligned_ocr)

        # loop through the sentence spans; has a (start_index, end_index) for each         
        for sentence_start, sentence_end in sentence_spans:

            # get corresponding sentences in aligned ocr and gs
            sentence = aligned_ocr[sentence_start: sentence_end]
            gs_sentence = aligned_gs[sentence_start: sentence_end]
            # these should be the same length
            assert len(gs_sentence) == len(sentence)
            
            # initialise variables for a sentence level; will aggregate these to the file level
            sentence_aligned_words = []
            sentence_words = []
            sentence_aligned_gs_words = []
            sentence_labels = []
            
            # Getting space indices between the tokens, which are used to get the 
            # word spans. Here a word may contain more than one token, if the 
            # GS on this span does not have the same split words
            new_ocr_space_ids = get_common_space_ids(sentence, gs_sentence)
            
            word_start = 0
            # Traversing through the space ids (indexes of where common spaces are)
            for space_id in new_ocr_space_ids:
                word = sentence[word_start: space_id] # get word up to space id

    #           If there was a sequence of spaces, the resulting word is empty. We omit it
                if len(word) == 0:
                    word_start = space_id + 1 
                    continue
                
                # trim the word; remove '@' symbol
                trimmed_word = word.replace("@", "")
                gs_word = gs_sentence[word_start: space_id] # get corresponding gs word
                label = int(word != gs_word) # see if the ocr word is infact the word we are looking for
                
                # append word to sentence 
                sentence_labels.append(label)
                sentence_aligned_words.append(word)
                sentence_words.append(trimmed_word)
                sentence_aligned_gs_words.append(gs_word) 
                
                word_start = space_id + 1

            # append sentences to file  
            file_aligned_words.append(sentence_aligned_words)
            file_words.append(sentence_words)
            file_aligned_gs_words.append(sentence_aligned_gs_words)
            file_labels.append(sentence_labels)
            
        # extend file to global; several sentences of individual files
        aligned_words.extend(file_aligned_words)
        words.extend(file_words)
        aligned_gs_words.extend(file_aligned_gs_words)
        labels.extend(file_labels)

    return aligned_words, words, aligned_gs_words, labels



In [None]:
train_files = glob.glob(path.join(r"./training_18M_without_Finnish/*", "*", "*.txt"))
aligned_words, words, aligned_gs_words, labels = extract_dataset(train_files, False)

In [None]:
"""
After some exploration; we realize some pieces of text are too noisy for learning.
Contain too many typos, unrecognized words or false alarms. Need to eliminate.
In order to be suitable, enough t omeasure edit distance between OCR'd word and
GS sentences. If very large, then sentence is a bad material to learn.
"""

sent_stat = pd.DataFrame({"ocr_sentence": aligned_words, "gs_sentence": aligned_gs_words})
sent_stat.head()

In [None]:
def compute_sent_edit_distance(x):
    '''Compute sentence edit distance normalized by the length of the sentence.'''
    ocr_sent = "".join(x['ocr_sentence'])
    gs_sent = "".join(x['gs_sentence'])
    return editdistance.distance(ocr_sent, gs_sent) / max(len(ocr_sent), len(gs_sent))

sent_stat["sent_edit_distance"] = sent_stat.apply(compute_sent_edit_distance, axis=1)

In [None]:
# majority have edit distance less than 0.4
sent_stat["sent_edit_distance"].hist()

In [None]:
# see which proportion of sentences fall into good and bad with threshold 0.4
MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE = 0.4
total_sent = sent_stat.shape[0]
good_sent = (sent_stat["sent_edit_distance"] <= MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE).sum()
good_sent_ratio = good_sent / total_sent
print("good sentences: %s\ntotal sentences: %s\ngood sentences ratio: %s" % (good_sent, total_sent, good_sent_ratio))

In [None]:
# some examples of good sentences
good_sentences_stat = sent_stat[sent_stat["sent_edit_distance"] <= MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE]
i = random.randint(0, good_sentences_stat.shape[0] - 1)

good_sentences_stat[i: i + 5].sort_values("sent_edit_distance", ascending=False)

In [None]:
# look at bad sentences
noisy_sentences_stat = sent_stat[sent_stat["sent_edit_distance"] > MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE]
i = random.randint(0, noisy_sentences_stat.shape[0] - 1)

noisy_sentences_stat[i: i + 5].sort_values("sent_edit_distance", ascending=False)

In [None]:
# numpy array the results; we do not save to pickle file as we will use directly below
words = np.array(words, dtype=object)[good_sentences_stat.index.tolist()].tolist()
labels = np.array(labels, dtype=object)[good_sentences_stat.index.tolist()].tolist()
# pickle.dump(words, open("train_ed_filtered_words.pickle", "wb"))
# pickle.dump(labels, open("train_ed_filtered_labels.pickle", "wb"))

# do the same for test set
test_files = glob.glob(path.join(r"./evaluation_4M_without_Finnish/*", "*", "*.txt"))
test_aligned_words, test_words, test_aligned_gs_words, test_labels = extract_dataset(test_files, True)

test_sent_stat = pd.DataFrame({"ocr_sentence": test_aligned_words, "gs_sentence": test_aligned_gs_words})
test_sent_stat["sent_edit_distance"] = test_sent_stat.apply(compute_sent_edit_distance, axis=1)
test_sent_stat["sent_edit_distance"].hist()
print("good sentences: %s\ntotal sentences: %s\ngood sentences ratio: %s" % ((test_sent_stat["sent_edit_distance"] <= MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE).sum(), test_sent_stat.shape[0], (test_sent_stat["sent_edit_distance"] < MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE).sum() / test_sent_stat.shape[0]))
test_good_sentences_stat = test_sent_stat[test_sent_stat["sent_edit_distance"] <= MAXIMUM_AVERAGE_EDIT_DISTANCE_RATE]

eval_words = np.array(test_words, dtype=object)[test_good_sentences_stat.index.tolist()].tolist()
eval_labels = np.array(test_labels, dtype=object)[test_good_sentences_stat.index.tolist()].tolist()
# pickle.dump(eval_words, open("test_ed_filtered_words.pickle", "wb"))
# pickle.dump(eval_labels, open("test_ed_filtered_labels.pickle", "wb"))

In [None]:
"""
Training the Model!
Utilise pretrained multilingual BERT; BERT output of each sub-token is plugged into 
convolutional layers and fully-connected layers to be classified. Model predictions
of sub-tokens are merged into token-level predictions. If more than one sub-token of a 
token is predicted to be erroneous, then the toke nis erroneous.
"""
# check how many typos in the training dataset; approximately 50%. A balanced dataset
# which is good for training a model
pos_labels_count = sum(sum(sent_labels) for sent_labels in labels)
total_labels_count = sum(len(sent_labels) for sent_labels in labels)
print(f"typo words: {pos_labels_count}, total: {total_labels_count}, typo_rate: {pos_labels_count / total_labels_count}")


In [None]:
"""
We have 10 languages in our dataset; s othe most suitable BERT tokenizer and
pretrained model is bert-base-multilingual-cased.
"""
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

In [None]:
def tokenize_and_preserve_labels(sentence, text_labels):
    """ Tokenize sentence into BERT subtokens. As well as the text label"""
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)

    # Adding CLS and SEP tokens
    tokenized_sentence = ["[CLS]"] + tokenized_sentence + ["[SEP]"]
    labels = [0] + labels + [0] # add 0 on either side
    return tokenized_sentence, labels

In [None]:
"""
Convert data into tensor format; must adjust length of sequences 
"""
def truncate_or_pad(arr, max_sequence_length):
    """ Truncate or pad the `arr` according the maximum sequence length"""
    return arr[:max_sequence_length] + [tokenizer.pad_token_id] * (max_sequence_length - len(arr))


In [None]:
"""
Prepare dataset by:
1. Tokenize words using word piece tokenization from BERT
2. Convert Tokens into ids
3. align sequence lengths to max sequence length
4. Create attention masks, where 0 means padding
5. Convert token ids, labels and attention masks into tensors
"""
def prepare_dataset(words, labels, max_sequence_length=100):
    """Extract inputs, tags and masks tensors from the dataset"""

    # tokenize words and labels; add CLS to beginning and SEP to end
    tokenized_texts_and_labels = [tokenize_and_preserve_labels(sent, labs) for sent, labs in zip(words, labels)]
    
    # separate to tokenized texts and tokenized labels
    tokenized_texts, tokenized_labels = zip(*tokenized_texts_and_labels)

    # convert tokens to ids and truncate/pad if needed (pads with 0's)
    input_ids = np.array([truncate_or_pad(tokenizer.convert_tokens_to_ids(txt), max_sequence_length) for txt in tokenized_texts], dtype='long')
    subtoken_labels = np.array([truncate_or_pad(sentence_labels, max_sequence_length) for sentence_labels in tokenized_labels], dtype='long')
    
    # create attention mask; what to focus on. 0 if 0's, 1 otherwise
    attention_masks = [[int(i != 0) for i in ii] for ii in input_ids]

    # convert token ids, labels, masks into tensors
    inputs = torch.tensor(input_ids, dtype=torch.long)
    tags = torch.tensor(subtoken_labels, dtype=torch.long)
    masks = torch.tensor(attention_masks, dtype=torch.long)

    return inputs, tags, masks


In [None]:
MAX_SEQUENCE_LENGTH = 100
tr_inputs, tr_tags, tr_masks = prepare_dataset(words, labels, MAX_SEQUENCE_LENGTH)
val_inputs, val_tags, val_masks = prepare_dataset(eval_words, eval_labels, MAX_SEQUENCE_LENGTH)

In [None]:
"""
In our model, we use pretrained BERT embeddings model bert-base-multilingual-cased.
BERT embeddings are supplied to conv layers with 4 kernel sizes (2,3,4,5) each
with 32 filters. Then followed by max pooling, both conv and max pooling have stride=1.
Which has effect of information exchange within n-grams.
Finally, concatenated each kernel and fed to linear layer to produce final logits for binary 
classification. Equivalent technique to image segmentation.
"""
class CNNModel(nn.Module):
    def __init__(self, hidden_dropout_prob):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.kernel_1 = 2
        self.kernel_2 = 3
        self.kernel_3 = 4
        self.kernel_4 = 5
        self.embedding_dim = 768
        self.out_size = 32
        self.num_labels = 2

        # Different paddings are used in order to preserve the original sequence lengths.
        self.conv_1 = nn.Conv1d(self.embedding_dim, self.out_size, self.kernel_1, stride=1, padding=1)
        self.pool_1 = nn.MaxPool1d(kernel_size=self.kernel_1, stride=1)

        self.conv_2 = nn.Conv1d(self.embedding_dim, self.out_size, self.kernel_2, stride=1, padding=1)
        self.pool_2 = nn.MaxPool1d(kernel_size=self.kernel_2, stride=1, padding=1)

        self.conv_3 = nn.Conv1d(self.embedding_dim, self.out_size, self.kernel_3, stride=1, padding=2)
        self.pool_3 = nn.MaxPool1d(kernel_size=self.kernel_3, stride=1, padding=1)

        self.conv_4 = nn.Conv1d(self.embedding_dim, self.out_size, self.kernel_4, stride=1, padding=2)
        self.pool_4 = nn.MaxPool1d(kernel_size=self.kernel_4, stride=1, padding=2)
        
        self.classifier = nn.Linear(self.out_size * 4, self.num_labels)
    
    def forward(self, input, attention_mask):
        outputs = self.bert(input, attention_mask)
        sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
        conv_input = sequence_output.permute(0, 2, 1) # batch_size, hidden_size, sequence_length
        conv_output_1 = self.conv_1(conv_input)
        conv_output_1 = torch.relu(conv_output_1)
        pool_output_1 = self.pool_1(conv_output_1)
        conv_output_2 = self.conv_2(conv_input)
        conv_output_2 = torch.relu(conv_output_2)
        pool_output_2 = self.pool_2(conv_output_2)
        conv_output_3 = self.conv_3(conv_input)
        conv_output_3 = torch.relu(conv_output_3)
        pool_output_3 = self.pool_3(conv_output_3)
        conv_output_4 = self.conv_4(conv_input)
        conv_output_4 = torch.relu(conv_output_4)
        pool_output_4 = self.pool_4(conv_output_4)
        conc_pool_output = torch.cat((pool_output_1, pool_output_2, pool_output_3, pool_output_4), 1)  # batch_size, out_size * 4, sequence_length
        conc_pool_output = conc_pool_output.permute(0, 2, 1)  # batch_size, sequence_length, out_size * 4
        sequence_output = self.dropout(conc_pool_output)
        logits = self.classifier(sequence_output)  # batch_size, sequence_length, 2
        return logits

In [None]:
# activate GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# prepare data loader for training & evaluation
BATCH_SIZE = 32
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, pin_memory=True, num_workers=2, sampler=train_sampler, batch_size=BATCH_SIZE)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=BATCH_SIZE)

In [None]:
# initialize model and sent to available device
model = CNNModel(0.1)
model.to(device)

In [None]:
"""
Use AdamW optimizer with weight decay 0.01 for non-bias parameters.
Because biases are not influenced by input data, so does not make sense
to apply regulation to them.
"""
param_optimizer = list(model.named_parameters())
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if "bias" in n], 'weight_decay_rate': 0.01}, 
    {'params': [p for n, p in param_optimizer if "bias" not in n], 'weight_decay_rate': 0.0}
]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)

In [None]:
"""
The BERT authors suggested having no more than 3 epochs for training, and the linear
decay of learning rate. As our model always overfits on 3rd epoch.
Also use clipping to prevent gradient explosion
"""
epochs = 3
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the linear learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)
## Store the average loss and accuracy after each epoch so we can plot them.
loss_values, validation_loss_values = [], []
training_acc, val_acc = [], []
best_valid_loss = float('inf')
best_valid_acc = 0

best_acc_output_dir = "cnnmodel_best_acc_all"
best_loss_output_dir = "cnnmodel_best_loss_all"

# Create output directory if needed
os.makedirs(best_acc_output_dir, exist_ok=True)
os.makedirs(best_loss_output_dir, exist_ok=True)

print("Best validation accuracy model location: %s" % best_acc_output_dir)
print("Best validation loss model location: %s" % best_loss_output_dir)
all_predictions = []
all_true_labels = []
all_masks = []

for epoch in range(epochs):
    print(f"Epoch:{epoch}")
    # ========================================
    #               Training
    # ========================================
    # Perform one full pass over the training set.

    # Put the model into training mode.
    model.train()
    # Reset the total loss for this epoch.
    total_loss = 0

    # Training loop
    true_labels = []
    logits_list = []
    label_ids_list = []
    for step, batch in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        # for step, batch in enumerate(train_dataloader):

        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # Always clear any previously calculated gradients before performing a backward pass.
        model.zero_grad()

        # forward pass
        # This will return the loss (rather than the model output)
        # because we have provided the `labels`.
        logits = model(b_input_ids, attention_mask=b_input_mask)
        loss_fct = nn.CrossEntropyLoss()

        # Only keep active parts of the loss
        active_loss = b_input_mask.view(-1) == 1
        active_logits = logits.view(-1, model.num_labels)
        active_labels = torch.where(
            active_loss, b_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(b_labels)
        )
        loss = loss_fct(active_logits, active_labels)

        logits = logits.detach()
        logits_list.append(logits)
        true_labels.extend(b_labels)

        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # track train loss
        total_loss += loss.item()

        # Clip the norm of the gradient
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)

        # update parameters
        optimizer.step()
        
        # Update the learning rate.
        scheduler.step()

    
    predictions = [list(p) for logits in logits_list for p in np.argmax(logits.to("cpu").numpy(), axis=2)]
    true_labels = [tl.to("cpu").numpy() for tl in true_labels]
    
    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)
    print("\nTraining loss: {}".format(avg_train_loss))
                                  
    b_input_mask_list = b_input_mask.cpu().numpy().tolist()
    pred_tags = [p_i for p, l, a in zip(predictions, true_labels, b_input_mask_list)
                                 for p_i, l_i, a_i in zip(p, l, a) if a_i]
    valid_tags = [l_i for l, a in zip(true_labels, b_input_mask_list)
                                  for l_i, a_i in zip(l, a) if a_i]

    acc = accuracy_score(pred_tags, valid_tags)
    print("Training Accuracy: {}".format(acc))
    training_acc.append(acc)

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)


    # ========================================
    #               Validation
    # ========================================
    # After the completion of each training epoch, measure our performance on
    # our validation set.

    model.eval()
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions, true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            logits = model(b_input_ids, b_input_mask)
        # Move logits and labels to CPU
        label_ids = b_labels.to('cpu').numpy()

        loss_fct = nn.CrossEntropyLoss()
        # Only keep active parts of the loss
        active_loss = b_input_mask.view(-1) == 1
        active_logits = logits.view(-1, model.num_labels)
        active_labels = torch.where(
            active_loss, b_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(b_labels)
        )
        loss = loss_fct(active_logits, active_labels).item()

        # Calculate the accuracy for this batch of test sentences.
        eval_loss += loss
        logits = logits.detach().cpu().numpy()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)

    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)
    print("Validation loss: {}".format(eval_loss))
    b_input_mask_list = b_input_mask.cpu().numpy().tolist()
    pred_tags = [p_i for p, l, a in zip(predictions, true_labels, b_input_mask_list)
                                 for p_i, l_i, a_i in zip(p, l, a) if a_i]
    valid_tags = [l_i for l, a in zip(true_labels, b_input_mask_list)
                                  for l_i, a_i in zip(l, a) if a_i]
    acc = accuracy_score(pred_tags, valid_tags)
    all_predictions.append(predictions)
    all_true_labels.append(true_labels)
    all_masks.append(b_input_mask_list)
    
    print("Validation Accuracy: {}".format(acc))
    val_acc.append(acc)
    
    
    if eval_loss < best_valid_loss:
        best_valid_loss = eval_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path.join(best_loss_output_dir, "model"))
        tokenizer.save_pretrained(best_loss_output_dir)
    
    if acc > best_valid_acc:
        best_valid_acc = acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, path.join(best_acc_output_dir, "model"))
        tokenizer.save_pretrained(best_acc_output_dir)

In [None]:
# Use plot styling from seaborn.
sns.set(style='darkgrid')

# Increase the plot size and font size.
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)

# Plot the learning curve.
plt.plot(loss_values, 'b-o', label="training loss")
plt.plot(validation_loss_values, 'r-o', label="validation loss")

# Label the plot.
plt.title("Learning curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.show()

In [None]:
"""
Test Inference
"""
output_dir = "cnnmodel_best_loss_all"
tokenizer1 = BertTokenizer.from_pretrained(output_dir)

model1 = CNNModel(0.1)
checkpoint = torch.load(path.join(output_dir, "model"))
model1.load_state_dict(checkpoint['model_state_dict'])
# Copy the model to the GPU.
model1.to(device)
model1.eval()

In [None]:
test_sentence = """
Eroxena, Does my soul make an impreillon upon my ryes ?
"""

In [None]:
def inference(test_sentence):

    # prepare sentence; tokenize and create attention masks
    tokenized_sentence = tokenizer1.encode(test_sentence)
    attention_mask = np.zeros(100)
    attention_mask[:len(tokenized_sentence)] = 1
    tokenized_sentence += [0] * (100 - len(tokenized_sentence)) # zero pad end of tokenized sentence

    # turn to tensors and load to GPU
    input_ids = torch.tensor([tokenized_sentence]).cuda()
    attention_mask = torch.tensor([attention_mask]).cuda()

    # get model output
    with torch.no_grad():
        logits = model1(input_ids, attention_mask)
    label_indices = np.argmax(logits.to('cpu').numpy(), axis=2)

    # model predicts ids, convert these back to tokens
    tokens = tokenizer1.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])

    # join split tokens
    new_tokens, new_labels = [], []
    for token, label_idx in zip(tokens, label_indices[0]):

        # if starts with #, then belongs to past word; append to previous; and skip two hashtag
        if token.startswith("##"):
            new_tokens[-1] = new_tokens[-1] + token[2:] # add to previous; skip two hashtag
            if not new_labels[-1] and label_idx:
                new_labels[-1] = label_idx # set label accordingly
        # not unknown; add to label and token index
        else:
            new_labels.append(label_idx)
            new_tokens.append(token)

    # print these out; stop at PAD
    for token, label in zip(new_tokens, new_labels):
        if token == "[PAD]":
            break
        print("{}\t{}".format(label, token))

inference(test_sentence)