# Text Denoising

Inspired by "Neural Networks for Text Correction and Completion in Keyboard Decoding" by Shaona Ghosh and Per Ola Kristensson. https://arxiv.org/pdf/1709.06429.pdf

In [1]:
from collections import defaultdict
import json
import os
import random
import string

In [2]:
from gluoncv.data.batchify import Tuple, Stack, Append, Pad
import gluonnlp as nlp
import mxboard
import mxnet as mx
from mxnet import gluon, autograd
import numpy as np
import re
from tqdm import tqdm

In [4]:
from ocr.utils.encoder_decoder import get_transformer_encoder_decoder, Denoiser, encode_char, decode_char, LabelSmoothing, SoftmaxCEMaskedLoss

In [5]:
ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
ctx

gpu(0)

## Data

In [6]:
if not os.path.isdir('dataset'):
    os.makedirs('dataset')
if not os.path.isdir('dataset/typo'):
    os.makedirs('dataset/typo')

In [7]:
text_filepath = 'dataset/typo/all.txt'
mx.test_utils.download('https://s3.us-east-2.amazonaws.com/gluon-ocr/models/all.txt', dirname='dataset/typo')

'dataset/typo/all.txt'

In [9]:
ALPHABET = ['<UNK>', '<PAD>', '<BOS>', '<EOS>']+list(' ' + string.ascii_letters + string.digits + string.punctuation)
ALPHABET_INDEX = {letter: index for index, letter in enumerate(ALPHABET)} # { a: 0, b: 1, etc}
FEATURE_LEN = 150 # max-length in characters for one document
NUM_WORKERS = 8 # number of workers used in the data loading
BATCH_SIZE = 64 # number of documents per batch
MAX_LEN_SENTENCE = 150
PAD = 1
BOS = 2
EOS = 3
UNK = 0

moses_detokenizer = nlp.data.SacreMosesDetokenizer()
moses_tokenizer = nlp.data.SacreMosesTokenizer()

### Generic Dataset

In [11]:
class NoisyTextDataset(mx.gluon.data.Dataset):
    def __init__(self, 
                 text_filepath=None, 
                 substitute_costs_filepath='models/substitute_probs.json', 
                 insert_weight=1, 
                 delete_weight=1, 
                 glue_prob=0.05, 
                 substitute_weight=2,
                 max_replace=0.3,
                 is_train=True, 
                 split=0.9, 
                 data_type='corpus', 
                 gbw_corpus=None
                ):
        self.max_replace = max_replace
        self.replace_weight = 0 #replace_prob  #Ignore typo dataset
        self.substitute_threshold = float(substitute_weight) / (insert_weight + delete_weight + substitute_weight)
        self.insert_threshold = self.substitute_threshold + float(insert_weight) / (insert_weight + delete_weight + substitute_weight)
        self.delete_threshold = self.insert_threshold + float(delete_weight) / (insert_weight + delete_weight + substitute_weight)
        self.glue_prob = glue_prob
        self.substitute_dict = json.load(open(substitute_costs_filepath,'r'))
        self.split = split
        self.data_type = data_type
        if self.data_type == 'corpus':
            self.text = self._process_text(text_filepath, is_train)
        elif self.data_type == 'GBW':
            self.gbw_corpus = gbw_corpus
        
    def _process_text(self, filename, is_train):
        with open(filename, 'r', encoding='Latin-1') as f:
            text = []
            for line in f.readlines():
                if line != '':
                    text.append(line.strip())
            
            split_index = int(self.split*len(text))
            if is_train:
                text = text[:split_index]
            else:
                text = text[split_index:]
        return text
    
    def _transform_line(self, line):
        """
        replace words that are in the typo dataset with a typo
        with a probability `self.replace_proba`
        """
        output = []
        
        processed_line = self._pre_process_line(line)
        
        # We get randomly the index of the modifications
        num_chars = len(''.join(processed_line))
        if num_chars:
            index_modifications = np.random.choice(num_chars, random.randint(0, int(self.max_replace*num_chars)), replace=False)
            substitute_letters = []
            insert_letters = []
            delete_letters = []
            # We randomly assign these indices to modifications based on precalculated thresholds
            for index in index_modifications:
                draw = random.random()
                if draw < self.substitute_threshold:
                    substitute_letters.append(index)
                    continue
                if draw < self.insert_threshold:
                    insert_letters.append(index)
                    continue
                else:
                    delete_letters.append(index)
                            
        
        j = 0
        for i, word in enumerate(processed_line):
            if word != '' and word not in string.punctuation:
                len_word = len(word)
                word_ = []
                k = j
                for letter in word:
                    if k in substitute_letters and letter in self.substitute_dict:
                        draw = random.random()
                        for replace, prob in self.substitute_dict[letter].items():
                            if draw < prob:
                                letter = replace
                                break
                    word_.append(letter)
                    k += 1
                word = ''.join(word_)
                                
                # Insert random letter
                k = j
                word_ = []
                for letter in word:
                    if k in insert_letters:
                        word_.append(ALPHABET[random.randint(4, len(ALPHABET)-1)])
                    word_.append(letter)
                    k += 1
                word = ''.join(word_)
                
                # Delete random letter
                k = j
                word_ = []
                for letter in word:
                    if k not in delete_letters:
                        word_.append(letter)
                    k += 1
                word = ''.join(word_)
                    
                output.append(word)
            else:
                output.append(word)
            j += len(word)

        output_ = [""]*len(output)
        j = 0
        for i, word in enumerate(output):
            output_[j] += word
            if random.random() > self.glue_prob:
                j += 1
        
        line = self._post_process_line(output_)
        return line.strip()
    
    def _pre_process_line(self, line):
        line = line.replace('\n','').replace('`',"'").replace('--',' -- ')
        return moses_tokenizer(line)
        
    def _post_process_line(self, words):
        output = ' '.join(moses_detokenizer(words))
        return output
    
    def _match_caps(self, original, typo):
        if original.isupper():
            return typo.upper()
        elif original.istitle():
            return typo.capitalize()
        else:
            return typo
    
    def __getitem__(self, idx):
        if self.data_type == 'GBW':
            tokens = moses_detokenizer(self.gbw_corpus[idx][:-1])
            if len(tokens) > 6:
                start = random.randint(0, len(tokens)-3)
                end = random.randint(start, len(tokens))
                tokens = tokens[start:end]
            line = ' '.join(tokens)
        else:
            line = self.text[idx]
        line_typo = self._transform_line(line)
        return line_typo, line

    def __len__(self):
        if self.data_type == 'GBW':
            return len(self.gbw_corpus)
        else:
            return len(self.text)

In [14]:
def encode_char(text, src=True):
    encoded = np.ones(FEATURE_LEN, dtype='float32') * PAD
    text = text[:FEATURE_LEN-2]
    i = 0
    if not src:
        encoded[0] = BOS
        i = 1
    for letter in text:
        if letter in ALPHABET_INDEX:
            encoded[i] = ALPHABET_INDEX[letter]
        i += 1
    encoded[i] = EOS
    return encoded, np.array([i+1]).astype('float32')

def encode_word(text, src=True):
    tokens = tokenizer(text)
    indices = vocab[tokens]
    indices += [vocab['<EOS>']]
    indices = [vocab['<BOS>']]+indices
    return indices, np.array([len(indices)]).astype('float32')

def transform(data, label):
    src, src_valid_length = encode_char(data, src=True)
    tgt, tgt_valid_length = encode_char(label, src=False)
    return src, src_valid_length, tgt, tgt_valid_length, data, label

def decode_char(text):
    output = []
    for val in text:
        if val == EOS:
            break
        elif val == PAD or val == BOS:
            continue
        output.append(ALPHABET[int(val)])
    return "".join(output)


def decode_word(indices):
    return detokenizer([vocab.idx_to_token[int(i)] for i in indices], return_str=True).replace('<PAD>','')

In [92]:
dataset_train = NoisyTextDataset(text_filepath=text_filepath, glue_prob=0.2, is_train=True).transform(transform)
dataset_test = NoisyTextDataset(text_filepath=text_filepath, glue_prob=0.2, is_train=False).transform(transform)

# Finetuning on the text from the IAM dataset
dataset_train_ft = NoisyTextDataset(text_filepath='dataset/typo/text_train.txt', is_train=True, split=1.0).transform(transform)

In [17]:
dataset_train[random.randint(0, len(dataset_train)-1)]

(array([57., 58., 59., 82., 57., 57., 61.,  4., 31., 18.,  8.,  4., 24.,
        12.,  9.,  4., 42., 45., 48., 34.,  4., 23., 20.,  5., 15.,  9.,
         4., 23., 25.,  8.,  8.,  9., 18., 16., 11.,  4., 25.,  7., 24.,
        19.,  4., 63., 43., 19., 23.,  9., 23., 78.,  4.,  5., 18.,  8.,
         4., 25., 18., 24., 19.,  4., 31.,  5., 22., 19., 18., 78.,  4.,
         5., 18.,  8.,  3.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32),
 array([69.], dtype=float32),
 array([ 2., 57., 58., 59., 82., 57., 57., 61.,  4., 31., 18.,  8.,  4.,
        24., 12.,  9.,  4., 42., 4

### Validation data being the IAM Dataset prediction

In [18]:
data = json.load(open('dataset/typo/validating.json','r'))
data_ = []
for label, modified in data:
    if label.strip() != modified.strip():
        data_.append([label, modified])
val_dataset_ft = gluon.data.ArrayDataset(list(list(zip(*data_))[1]), list(list(zip(*data_))[0])).transform(transform)

In [19]:
val_dataset_ft[random.randint(0, len(val_dataset_ft)-1)]

(array([17., 19., 22.,  9.,  4., 24., 12.,  5., 24.,  4., 24., 27.,  9.,
        18., 24., 29., 79., 16., 19., 25., 22.,  4., 12., 19., 25., 23.,
         4.,  9.,  5., 22., 16., 13.,  9., 22., 80.,  4., 49., 13., 18.,
         7.,  9., 80.,  4., 24., 12.,  9., 18.,  3.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32),
 array([48.], dtype=float32),
 array([ 2., 17., 19., 22.,  9.,  4., 24., 12.,  5., 18.,  4., 24., 27.,
         9., 18., 24., 29., 79., 1

### Training on GBW

In [20]:
gbw_stream = nlp.data.GBWStream(segment='train', skip_empty=True, bos=None, eos='<EOS>')

In [21]:
for e, corpus in enumerate(gbw_stream):
    dataset_gbw = NoisyTextDataset(gbw_corpus=corpus, data_type='GBW').transform(transform)
    break

In [22]:
dataset_gbw[6]

(array([20., 22., 19., 23., 20.,  9.,  7., 24., 23.,  4., 24., 12., 22.,
        19., 25., 11., 12.,  4., 59., 57., 58., 57., 78.,  4., 27., 13.,
        24., 12.,  4.,  5.,  4., 23., 13., 11., 18., 13., 10., 13.,  7.,
         5., 18., 24.,  4.,  9.,  7., 19., 18., 19., 17., 13.,  7.,  3.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32),
 array([52.], dtype=float32),
 array([ 2., 20., 22., 19., 23., 20.,  9.,  7., 24., 23.,  4., 24., 12.,
        22., 19., 25., 11., 12.,  

#### DataLoaders

In [23]:
def batchify_list(elem):
    output = []
    for e in elem:
        output.append(elem)
    return output
    
batchify = Tuple(Stack(), Stack(), Stack(), Stack(), batchify_list, batchify_list)
batchify_word = Tuple(Stack(), Stack(), Pad(), Stack(), batchify_list, batchify_list)

In [None]:
train_data = gluon.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify, num_workers=5)
test_data = gluon.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True, last_batch='keep', batchify_fn=batchify, num_workers=5)
val_data_ft = gluon.data.DataLoader(val_dataset_ft, batch_size=BATCH_SIZE, shuffle=True, last_batch='keep', batchify_fn=batchify, num_workers=0)
train_data_ft = gluon.data.DataLoader(dataset_train_ft, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify, num_workers=5)

## Helper function to help train

In [82]:
def evaluate(net, iterator):
    loss = 0
    for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(iterator):
        src = src.as_in_context(ctx)
        tgt = tgt.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx).squeeze()
        tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()
        output = net(src, tgt[:,:-1], src_valid_length, tgt_valid_length-1)
        ls = loss_function_test(output, tgt[:,1:], tgt_valid_length).mean()
        loss += ls.asscalar()
    print("[Test Typo     ] {}".format(decode_char(src[0].asnumpy())))
    print("[Test Predicted] {}".format(get_sentence(net, decode_char(src[0].asnumpy()))))
    print("[Test Correct  ] {}".format(decode_char(tgt[0].asnumpy())))
    return loss / (i+1)

In [94]:
def run_epoch(net, epoch, train_iterator, test_iterator):
    loss = 0.
    for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(train_iterator):
        src = src.as_in_context(ctx)
        tgt = tgt.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx).squeeze()
        tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()
        
        with autograd.record():
            output = net(src, tgt[:,:-1], src_valid_length, tgt_valid_length-1)
            smoothed_label = label_smoothing(tgt[:,1:])
            ls = loss_function(output, smoothed_label, tgt_valid_length).mean()
        
        ls.backward()
        trainer.step(src.shape[0])
        loss += ls.asscalar()
        
        if i % send_every_n == 0:
            val_loss = evaluate(net, test_iterator)
            sw.add_scalar(tag='Val_Loss_it', value={key:val_loss}, global_step=i+e*len(train_iterator))
            sw.add_scalar(tag='Train_Loss_it', value={key:loss/(i+1)}, global_step=i+e*len(train_iterator))
            print("[Iteration {} Train] {}".format(i, loss / (i+1)))
            print("[Iteration {} Test ] {}".format(i, val_loss))
            print("[Train Typo        ] {}".format(decode_char(src[0].asnumpy())))
            print("[Train Predicted   ] {}".format(decode_char(output[0].asnumpy().argmax(axis=1))))
            print("[Train Correct     ] {}".format(decode_char(tgt[0].asnumpy())))
            print()
            sw.flush()

    test_loss = evaluate(denoiser, test_iterator)
    print("Epoch [{}], Train Loss {:.4f}, Test Loss {:.4f}".format(e, loss/(i+1), test_loss))
    sw.add_scalar(tag='Train_Loss', value={key:loss/(i+1)}, global_step=e)
    sw.add_scalar(tag='Test_Loss', value={key:test_loss}, global_step=e)
    print()
    return test_loss

## Network

In [44]:
num_heads = 16
embed_size = 256
num_layers = 2

epochs = 1
key = 'language_denoising'
best_test_loss = 10e20

learning_rate = 0.00004
send_every_n = 50
best_test_loss = 10e20

In [42]:
log_dir = './logs/text_denoising'
checkpoint_dir = "model_checkpoint"
checkpoint_name = key+".params"
sw = mxboard.SummaryWriter(logdir=log_dir, flush_secs=1)

In [26]:
net = Denoiser(alphabet_size=len(ALPHABET), max_src_length=FEATURE_LEN, max_tgt_length=FEATURE_LEN, num_heads=num_heads, embed_size=embed_size, num_layers=num_layers)
net.initialize(mx.init.Xavier(), ctx)

In [27]:
if (os.path.isfile(os.path.join(checkpoint_dir, checkpoint_name))):
    net.load_parameters(os.path.join(checkpoint_dir, checkpoint_name), ctx=ctx)

Preparing the loss

In [28]:
output_dim = len(ALPHABET)

In [29]:
label_smoothing = LabelSmoothing(epsilon=0.002, units=output_dim)
loss_function_test = SoftmaxCEMaskedLoss(sparse_label=True)
loss_function = SoftmaxCEMaskedLoss(sparse_label=False)

In [46]:
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate':learning_rate})

## Training the network

Training on the public novel dataset

In [None]:
for e in range(epochs):
    test_loss = run_epoch(net, e, train_data, val_data_ft)
    if test_loss < best_test_loss:
        print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss))
        denoiser.save_parameters(os.path.join(checkpoint_dir, checkpoint_name))
        best_test_loss = test_loss

Training on the GBW dataset

In [None]:
for e, corpus in enumerate(gbw_stream):
    dataset_gbw = NoisyTextDataset(gbw_corpus=corpus, data_type='GBW').transform(transform)
    train_data_gbw = gluon.data.DataLoader(dataset_gbw, batch_size=BATCH_SIZE, shuffle=True, last_batch='discard', batchify_fn=batchify, num_workers=5)
    test_loss = train_epoch(net, e, train_data_gbw, val_data_ft)
    if test_loss < best_test_loss:
        print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss))
        denoiser.save_parameters(os.path.join(checkpoint_dir, checkpoint_name))
        best_test_loss = test_loss

Fine-tuning on the IAM dataset text

In [None]:
for e in range(epochs):
    test_loss = train_epoch(net, e, train_data_ft, val_data_ft)
    if test_loss < best_test_loss:
        print("Saving network, previous best test loss {:.6f}, current test loss {:.6f}".format(best_test_loss, test_loss))
        denoiser.save_parameters(os.path.join(checkpoint_dir, checkpoint_name))
        best_test_loss = test_loss

## Manual Testing

In [79]:
def get_sentence(net, sentence):
    scorer = nlp.model.BeamSearchScorer(alpha=0, K=2, from_logits=False)
    beam_sampler = nlp.model.BeamSearchSampler(beam_size=5,
                                           decoder=net.decode_logprob,
                                           eos_id=EOS,
                                           scorer=scorer,
                                           max_length=150)
    src_seq, src_valid_length = encode_char(sentence)
    src_seq = mx.nd.array([src_seq], ctx=ctx)
    src_valid_length = mx.nd.array(src_valid_length, ctx=ctx)
    encoder_outputs, _ = net.encode(src_seq, valid_length=src_valid_length)
    states = net.decoder.init_state_from_encoder(encoder_outputs, 
                                                      encoder_valid_length=src_valid_length)
    inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)
    samples, scores, valid_lengths = beam_sampler(inputs, states)
    samples = samples[0].asnumpy()
    scores = scores[0].asnumpy()
    valid_lengths = valid_lengths[0].asnumpy()
    return decode_char(samples[0])

In [None]:
sentence = "This sentence contains an eror"

In [None]:
get_sentence(net, sentence)

## Appendix (maybe useful later)

### Create text file with all vocab words

In [None]:
model, vocab = nlp.model.big_rnn_lm_2048_512(dataset_name='gbw', pretrained=True, ctx=mx.cpu())
vocab = '\n'.join(vocab.idx_to_token)
with open('dataset/typo/vocab.txt', 'w') as f:
    f.write(vocab)