# Text Denoising

Inspired by "Neural Networks for Text Correction and Completion in Keyboard Decoding" by Shaona Ghosh and Per Ola Kristensson. https://arxiv.org/pdf/1709.06429.pdf

In [1]:
from collections import defaultdict
import json
import os
import random
import string

In [2]:
from gluoncv.data.batchify import Tuple, Stack, Append, Pad
import gluonnlp as nlp
import mxboard
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import HybridBlock
from mxnet.gluon.loss import SoftmaxCELoss
import numpy as np
import re
from tqdm import tqdm

In [3]:
from utils.encoder_decoder import get_transformer_encoder_decoder, Denoiser, encode_char, decode_char

In [4]:
ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
ctx

gpu(0)

## Data

In [5]:
if not os.path.isdir('dataset'):
    os.makedirs('dataset')

In [6]:
typo_filepath = 'dataset/typo/typo-corpus-r1.txt'
text_filepath = 'dataset/typo/alicewonder.txt'
mx.test_utils.download('http://luululu.com/tweet/typo-corpus-r1.txt', dirname='dataset/typo')
mx.test_utils.download('http://textfiles.com/etext/FICTION/alicewonder.txt', dirname='dataset/typo')
# This needs to be made available somewhere
text_filepath = 'dataset/typo/all.txt'

In [7]:
sw = mxboard.SummaryWriter(logdir='logs', flush_secs=1)

In [8]:
ALPHABET = ['<UNK>', '<PAD>', '<BOS>', '<EOS>']+list(' ' + string.ascii_letters + string.digits + string.punctuation)
ALPHABET_INDEX = {letter: index for index, letter in enumerate(ALPHABET)} # { a: 0, b: 1, etc}
FEATURE_LEN = 150 # max-length in characters for one document
NUM_WORKERS = 8 # number of workers used in the data loading
BATCH_SIZE = 64 # number of documents per batch
MAX_LEN_SENTENCE = 150
PAD = 1
BOS = 2
EOS = 3
UNK = 0

#### Build the vocabulary for the target

In [9]:
text_lines = [l.replace('\n','').replace('`','"').replace('--',' -- ').strip() for l in open(text_filepath, 'r', encoding='Latin-1').readlines() if l != "''"]

In [10]:
use_words = False

In [11]:
#!python -m spacy download en

In [12]:
full_text = ' '.join(text_lines)
tokenizer = nlp.data.transforms.SpacyTokenizer()
tokens = tokenizer(full_text[:999999])
counter = nlp.data.Counter(tokens)

In [13]:
vocab = nlp.Vocab(counter, unknown_token='<UNK>', padding_token='<PAD>',
                  bos_token='<BOS>', eos_token='<EOS>', min_freq=1)

In [14]:
if use_words:
    import nltk
    nltk.download('perluniprops')
    nltk.download('nonbreaking_prefixes')
    glove_embed = nlp.embedding.create('glove', source='glove.6B.50d')
    
    tokenizer = nlp.data.transforms.NLTKMosesTokenizer()
    tokens = tokenizer(' '.join(text_lines))
    counter = nlp.data.Counter(tokens)
    vocab = nlp.Vocab(counter, unknown_token='<UNK>', padding_token='<PAD>',
                      bos_token='<BOS>', eos_token='<EOS>', min_freq=1)
    vocab.set_embedding(glove_embed)

In [15]:
class NoisyTextDataset(mx.gluon.data.Dataset):
    def __init__(self, text_filepath, typo_filepath, substitute_costs_filepath='models/substitute_probs.json', 
                 replace_weight=0, insert_weight=1, delete_weight=1, glue_prob=0.05, substitute_weight=2,
                 max_replace=0.3,
                 is_train=True, split=0.9,
                ):
        self.max_replace = max_replace
        self.replace_weight = 0 #replace_prob  #Ignore typo dataset
        self.substitute_threshold = float(substitute_weight) / (insert_weight + delete_weight + substitute_weight)
        self.insert_threshold = self.substitute_threshold + float(insert_weight) / (insert_weight + delete_weight + substitute_weight)
        self.delete_threshold = self.insert_threshold + float(delete_weight) / (insert_weight + delete_weight + substitute_weight)
        self.glue_prob = glue_prob
        self.typo_dict = self._process_typo(typo_filepath)
        self.substitute_dict = json.load(open(substitute_costs_filepath,'r'))
        self.split = split
        self.text = self._process_text(text_filepath, is_train)

        
    def _process_text(self, filename, is_train):
        with open(filename, 'r', encoding='Latin-1') as f:
            text = []
            for line in f.readlines():
                if line != '':
                    text.append(line.strip())
            
            split_index = int(self.split*len(text))
            if is_train:
                text = text[:split_index]
            else:
                text = text[split_index:]
        return text

    def _process_typo(self, filename):
        """
        This function loads the typo dataset and generate the 
        probability distribution of typos for each valid word
        """
        typo_dict = defaultdict(lambda : defaultdict(float))
        with open(filename, 'r') as f:
            lines = f.readlines()
            for line in lines:
                typo, correct = line.split('\t')[0:2]
                typo_dict[correct][typo] += 1
        for _, correct_word in typo_dict.items():
            total = 0
            for _, count in correct_word.items():
                total += count
            previous_value = 0.
            for wrong_word in correct_word:
                correct_word[wrong_word] = correct_word[wrong_word] / total + previous_value
                previous_value = correct_word[wrong_word]
        return typo_dict
    
    def _transform_line(self, line):
        """
        replace words that are in the typo dataset with a typo
        with a probability `self.replace_proba`
        """
        output = []
        
        processed_line = self._pre_process_line(line)
        
        # We get randomly the index of the modifications
        num_chars = len(''.join(processed_line))
        if num_chars:
            index_modifications = np.random.choice(num_chars, random.randint(0, int(self.max_replace*num_chars)), replace=False)
            substitute_letters = []
            insert_letters = []
            delete_letters = []
            # We randomly assign these indices to modifications based on precalculated thresholds
            for index in index_modifications:
                draw = random.random()
                if draw < self.substitute_threshold:
                    substitute_letters.append(index)
                    continue
                if draw < self.insert_threshold:
                    insert_letters.append(index)
                    continue
                else:
                    delete_letters.append(index)
                            
        
        j = 0
        for i, word in enumerate(processed_line):
            if word != '' and word not in string.punctuation:
                len_word = len(word)
                ###########################
                #          IGNORED        #
                ###########################
                #if word.lower() in self.typo_dict: 
                #    # Replace word with a typo based on probability distribution
                #    if random.random() < self.replace_prob:
                #        draw = random.random()
                #        for typo, value in self.typo_dict[word].items():
                #            if draw < value:
                #                word = self._match_caps(word, typo)
                #                break
                    
                # Replace letter with substitute based on probability distribution
                word_ = []
                k = j
                for letter in word:
                    if k in substitute_letters and letter in self.substitute_dict:
                        draw = random.random()
                        for replace, prob in self.substitute_dict[letter].items():
                            if draw < prob:
                                letter = replace
                                break
                    word_.append(letter)
                    k += 1
                word = ''.join(word_)
                                
                # Insert random letter
                k = j
                word_ = []
                for letter in word:
                    if k in insert_letters:
                        word_.append(ALPHABET[random.randint(4, len(ALPHABET)-1)])
                    word_.append(letter)
                    k += 1
                word = ''.join(word_)
                
                # Delete random letter
                k = j
                word_ = []
                for letter in word:
                    if k not in delete_letters:
                        word_.append(letter)
                    k += 1
                word = ''.join(word_)
                    
                output.append(word)
            else:
                output.append(word)
            j += len(word)

        output_ = [""]*len(output)
        j = 0
        for i, word in enumerate(output):
            output_[j] += word
            if random.random() > self.glue_prob:
                j += 1
        
        line = self._post_process_line(output_)
        return line.strip()
    
    def _pre_process_line(self, line):
        line = line.replace('\n','').replace('`',"'").replace('--',' -- ')
        for char in string.punctuation:
            if char in line:
                line = line.replace(char, ' '+char+' ')
        return line.split(' ')
        
    def _post_process_line(self, words):
        output = ' '.join(words)
        for char in string.punctuation:
            output = output.replace(' '+char+' ', char)
        return output
    
    def _match_caps(self, original, typo):
        if original.isupper():
            return typo.upper()
        elif original.istitle():
            return typo.capitalize()
        else:
            return typo
    
    def __getitem__(self, idx):
        line = self.text[idx]
        line_typo = self._transform_line(line)
        return line_typo, line

    def __len__(self):
        return len(self.text)

In [16]:
def encode_char(text, src=True):
    encoded = np.ones(FEATURE_LEN, dtype='float32') * PAD
    text = text[:FEATURE_LEN-2]
    i = 0
    if not src:
        encoded[0] = BOS
        i = 1
    for letter in text:
        if letter in ALPHABET_INDEX:
            encoded[i] = ALPHABET_INDEX[letter]
        i += 1
    encoded[i] = EOS
    return encoded, np.array([i+1]).astype('float32')

def encode_word(text, src=True):
    tokens = tokenizer(text)
    indices = vocab[tokens]
    indices += [vocab['<EOS>']]
    indices = [vocab['<BOS>']]+indices
    return indices, np.array([len(indices)]).astype('float32')

def transform(data, label):
    src, src_valid_length = encode_char(data, src=True)
    tgt, tgt_valid_length = encode_char(label, src=False)
    return src, src_valid_length, tgt, tgt_valid_length, data, label

In [17]:
dataset_train = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, glue_prob=0.2, is_train=True).transform(transform)
dataset_test = NoisyTextDataset(text_filepath=text_filepath, typo_filepath=typo_filepath, glue_prob=0.2, is_train=False).transform(transform)

In [18]:
dataset_train[random.randint(0, len(dataset_train)-1)]

(array([27., 19., 18., 73., 24.,  4.,  5., 15.,  9.,  4., 24., 12.,  9.,
        23.,  9.,  4., 24., 12., 13., 18., 11., 23., 78.,  4., 23., 19.,
         4., 12.,  9., 24., 11.,  4., 17.,  9., 78.,  5., 73., 16., 16.,
         4.,  8., 16., 16.,  4., 12.,  9., 22.,  3.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32),
 array([48.], dtype=float32),
 array([ 2., 27., 19., 18., 73., 24.,  4., 24.,  5., 15.,  9.,  4., 24.,
        12.,  9., 23.,  9.,  4., 2

#### Validation data being the IAM Dataset prediction

In [19]:
data = json.load(open('dataset/typo/finetuning.json','r'))
data_ = []
for label, modified in data:
    if label.strip() != modified.strip():
        data_.append([label, modified])
val_dataset_ft = gluon.data.ArrayDataset(list(list(zip(*data_))[1]), list(list(zip(*data_))[0])).transform(transform)

In [24]:
val_dataset_ft[random.randint(0, len(val_dataset_ft)-1)]

(array([24., 12.,  5., 18.,  4.,  5.,  4.,  7., 19., 25., 20., 16.,  9.,
         4., 19., 10.,  4., 24., 16.,  9., 20., 12., 19., 18.,  9.,  4.,
         7.,  5., 16., 16., 23.,  4., 12.,  9.,  4., 27.,  5., 23.,  3.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
         1.,  1.,  1.,  1.,  1.,  1.,  1.], dtype=float32),
 array([39.], dtype=float32),
 array([ 2., 24., 12.,  5., 18.,  4.,  5.,  4.,  7., 19., 25., 20., 16.,
         9.,  4., 19., 10.,  4., 2

#### Helper functions

In [25]:
def batchify_list(elem):
    output = []
    for e in elem:
        output.append(elem)
    return output
    
batchify = Tuple(Stack(), Stack(), Stack(), Stack(), batchify_list, batchify_list)
batchify_word = Tuple(Stack(), Stack(), Pad(), Stack(), batchify_list, batchify_list)

In [26]:
def decode_char(text):
    output = []
    for val in text:
        if val == EOS:
            break
        elif val == PAD or val == BOS:
            continue
        output.append(ALPHABET[int(val)])
    return "".join(output)


detokenizer = nlp.data.NLTKMosesDetokenizer()
def decode_word(indices):
    return detokenizer([vocab.idx_to_token[int(i)] for i in indices], return_str=True).replace('<PAD>','')

In [27]:
train_data = gluon.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify, num_workers=5)
test_data = gluon.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, last_batch='rollover', batchify_fn=batchify, num_workers=5)
val_data_ft = gluon.data.DataLoader(val_dataset_ft, batch_size=BATCH_SIZE, shuffle=True, last_batch='keep', batchify_fn=batchify, num_workers=0)

## Network

In [28]:
class SoftmaxCEMaskedLoss(SoftmaxCELoss):
    """Wrapper of the SoftmaxCELoss that supports valid_length as the input
    """
    def hybrid_forward(self, F, pred, label, valid_length): # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        F
        pred : Symbol or NDArray
            Shape (batch_size, length, V)
        label : Symbol or NDArray
            Shape (batch_size, length)
        valid_length : Symbol or NDArray
            Shape (batch_size, )
        Returns
        -------
        loss : Symbol or NDArray
            Shape (batch_size,)
        """
        if self._sparse_label:
            sample_weight = F.cast(F.expand_dims(F.ones_like(label), axis=-1), dtype=np.float32)
        else:
            sample_weight = F.ones_like(label)
        sample_weight = F.SequenceMask(sample_weight,
                                       sequence_length=valid_length,
                                       use_sequence_length=True,
                                       axis=1)
        return super(SoftmaxCEMaskedLoss, self).hybrid_forward(F, pred, label, sample_weight)

# pylint: disable=unused-argument
class _SmoothingWithDim(mx.operator.CustomOp):
    def __init__(self, epsilon=0.1, axis=-1):
        super(_SmoothingWithDim, self).__init__(True)
        self._epsilon = epsilon
        self._axis = axis

    def forward(self, is_train, req, in_data, out_data, aux):
        inputs = in_data[0]
        outputs = ((1 - self._epsilon) * inputs) + (self._epsilon / float(inputs.shape[self._axis]))
        self.assign(out_data[0], req[0], outputs)

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        self.assign(in_grad[0], req[0], (1 - self._epsilon) * out_grad[0])


@mx.operator.register('_smoothing_with_dim')
class _SmoothingWithDimProp(mx.operator.CustomOpProp):
    def __init__(self, epsilon=0.1, axis=-1):
        super(_SmoothingWithDimProp, self).__init__(True)
        self._epsilon = float(epsilon)
        self._axis = int(axis)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        data_shape = in_shape[0]
        output_shape = data_shape
        return (data_shape,), (output_shape,), ()

    def declare_backward_dependency(self, out_grad, in_data, out_data):
        return out_grad

    def create_operator(self, ctx, in_shapes, in_dtypes):
        #  create and return the CustomOp class.
        return _SmoothingWithDim(self._epsilon, self._axis)
# pylint: enable=unused-argument


class LabelSmoothing(HybridBlock):
    """Applies label smoothing. See https://arxiv.org/abs/1512.00567.
    Parameters
    ----------
    axis : int, default -1
        The axis to smooth.
    epsilon : float, default 0.1
        The epsilon parameter in label smoothing
    sparse_label : bool, default True
        Whether input is an integer array instead of one hot array.
    units : int or None
        Vocabulary size. If units is not given, it will be inferred from the input.
    prefix : str, default 'rnn_'
        Prefix for name of `Block`s
        (and name of weight if params is `None`).
    params : Parameter or None
        Container for weight sharing between cells.
        Created if `None`.
    """
    def __init__(self, axis=-1, epsilon=0.1, units=None,
                 sparse_label=True, prefix=None, params=None):
        super(LabelSmoothing, self).__init__(prefix=prefix, params=params)
        self._axis = axis
        self._epsilon = epsilon
        self._sparse_label = sparse_label
        self._units = units

    def hybrid_forward(self, F, inputs, units=None): # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        F
        inputs : Symbol or NDArray
            Shape (batch_size, length) or (batch_size, length, V)
        units : int or None
        Returns
        -------
        smoothed_label : Symbol or NDArray
            Shape (batch_size, length, V)
        """
        if self._sparse_label:
            assert units is not None or self._units is not None, \
                'units needs to be given in function call or ' \
                'instance initialization when sparse_label is False'
            if units is None:
                units = self._units
            inputs = F.one_hot(inputs, depth=units)
        if units is None and self._units is None:
            return F.Custom(inputs, epsilon=self._epsilon, axis=self._axis,
                            op_type='_smoothing_with_dim')
        else:
            if units is None:
                units = self._units
            return ((1 - self._epsilon) * inputs) + (self._epsilon / units)

In [29]:
denoiser = Denoiser(alphabet_size=len(ALPHABET), max_src_length=FEATURE_LEN, max_tgt_length=FEATURE_LEN, num_heads=16, embed_size=256, num_layers=2)

In [30]:
denoiser.load_parameters('model_checkpoint/denoiser_highhead.params', ctx=ctx)

In [None]:
denoiser.initialize(mx.init.Xavier(), ctx)
#denoiser.tgt_embedding[0].params.reset_ctx(ctx)

In [31]:
output_dim = len(ALPHABET)
#output_dim = len(vocab)

In [32]:
label_smoothing = LabelSmoothing(epsilon=0.01, units=output_dim)
loss_function_test = SoftmaxCEMaskedLoss(sparse_label=True)
loss_function = SoftmaxCEMaskedLoss(sparse_label=False)
trainer = gluon.Trainer(denoiser.collect_params(), 'adam', {'learning_rate':0.0001})

In [33]:
def evaluate(denoiser, iterator):
    loss = 0
    for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(iterator):
        src = src.as_in_context(ctx)
        tgt = tgt.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx).squeeze()
        tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()
        output = denoiser(src, tgt[:,:-1], src_valid_length, tgt_valid_length-1)
        ls = loss_function_test(output, tgt[:,1:], tgt_valid_length).mean()
        loss += ls.asscalar()
    print("[Test Typo     ] {}".format(decode_char(src[0].asnumpy())))
    print("[Test Predicted] {}".format(decode_char(output[0].asnumpy().argmax(axis=1))))
    print("[Test Correct  ] {}".format(decode_char(tgt[0].asnumpy())))
    return loss / (i+1)

#### Finetuning text being the IAM Dataset train text

In [34]:
dataset_train_ft = NoisyTextDataset(text_filepath='dataset/typo/text_train.txt', typo_filepath=typo_filepath, is_train=True, split=1.0).transform(transform)
train_data_ft = gluon.data.DataLoader(dataset_train_ft, batch_size=BATCH_SIZE, shuffle=True, last_batch='rollover', batchify_fn=batchify, num_workers=5)

In [35]:
evaluate(denoiser, val_data_ft)

[Test Typo     ] rwes. Iom JIan Bawley. Does that mean anything
[Test Predicted] ras. I m J n Bawley. Does that mean anything
[Test Correct  ] was. I'm Ian Bawley. Does that mean anything


0.10538775101304054

In [36]:
def train_epoch(denoiser, epoch, train_iterator, test_iterator):
    loss = 0.
    for i, (src, src_valid_length, tgt, tgt_valid_length, typo, label) in enumerate(train_iterator):
        src = src.as_in_context(ctx)
        tgt = tgt.as_in_context(ctx)
        src_valid_length = src_valid_length.as_in_context(ctx).squeeze()
        tgt_valid_length = tgt_valid_length.as_in_context(ctx).squeeze()
        with autograd.record():
            output = denoiser(src, tgt[:,:-1], src_valid_length, tgt_valid_length-1)
            smoothed_label = label_smoothing(tgt[:,1:])
            ls = loss_function(output, smoothed_label, tgt_valid_length).mean()
        ls.backward()
        trainer.step(src.shape[0])
        loss += ls.asscalar()
        
        if i % 300 == 0:
            val_loss = evaluate(denoiser, test_iterator)
            sw.add_scalar(tag='Val_Loss_it', value={key:val_loss}, global_step=i+e*len(train_iterator))
            sw.add_scalar(tag='Train_Loss_it', value={key:loss/(i+1)}, global_step=i+e*len(train_iterator))
            print("[Iteration {}   ] {}".format(i, loss / (i+1)))
            print("[Train Typo     ] {}".format(decode_char(src[0].asnumpy())))
            print("[Train Predicted] {}".format(decode_char(output[0].asnumpy().argmax(axis=1))))
            print("[Train Correct  ] {}".format(decode_char(tgt[0].asnumpy())))
            print()
            sw.flush()

    test_loss = evaluate(denoiser, test_iterator)
    print("Epoch [{}], Train Loss {:.4f}, Test Loss {:.4f}".format(e, loss/(i+1), test_loss))
    sw.add_scalar(tag='Train_Loss', value={key:loss/(i+1)}, global_step=e)
    sw.add_scalar(tag='Test_Loss', value={key:test_loss}, global_step=e)
    print()
    return test_loss

## Training the network

In [39]:
epochs = 10
key = 'big_dataset_high_head_0.0001'
best_test_loss = 10e20

In [None]:
for e in range(epochs):
    test_loss = train_epoch(denoiser, e, train_data, val_data_ft)
    if test_loss < best_test_loss:
        denoiser.save_parameters('model_checkpoint/denoiser_highhead_2.params')
        best_test_loss = test_loss

Fine-tuning

In [None]:
epochs = 1
key = 'fine_tuning'
for e in range(epochs):
    test_loss = train_epoch(denoiser, e, train_data_ft, val_data_ft)
    if test_loss < best_test_loss:
        denoiser.save_parameters('model_checkpoint/denoiser_ft.params')
        best_test_loss = test_loss

In [None]:
ideas:
    - per word loss
    - tokenize better the words

## Manual Testing

In [94]:
scorer = nlp.model.BeamSearchScorer(alpha=0, K=5, from_logits=False)

In [95]:
eos_id = EOS
beam_sampler = nlp.model.BeamSearchSampler(beam_size=30,
                                           decoder=denoiser.decode_logprob,
                                           eos_id=eos_id,
                                           scorer=scorer,
                                           max_length=150)

In [96]:
def scorer(data, scores, step):
    
    return 1

In [97]:
def generate_sequences(sampler, inputs, begin_states):
    samples, scores, valid_lengths = sampler(inputs, begin_states)
    samples = samples[0].asnumpy()
    scores = scores[0].asnumpy()
    valid_lengths = valid_lengths[0].asnumpy()
    print('Generation Result:')
    for sample in samples:
        print(decode_char(sample))

In [None]:
sentence = "of fact I'd ashed him last night to depurise"

In [131]:
src_seq, src_valid_length = encode_char(sentence)
src_seq = mx.nd.array([src_seq], ctx=ctx)
src_valid_length = mx.nd.array(src_valid_length, ctx=ctx)
encoder_outputs, _ = denoiser.encode(src_seq, valid_length=src_valid_length)
states = denoiser.decoder.init_state_from_encoder(encoder_outputs, 
                                                  encoder_valid_length=src_valid_length)
inputs = mx.nd.full(shape=(1,), ctx=src_seq.context, dtype=np.float32, val=BOS)
generate_sequences(beam_sampler, inputs, states,)

Generation Result:
of fact I'd asked him last night to deparite
of fact I'd asked him last night to deparise
of fact I'd asked him last night to deperite
of fact I'd asked him last night to deprivise
of fact I'd asked him last night to deprise
of fact I'd asked him last night to deposite
of fact I'd asked him last night to depurite
of fact I'd asked him last night to depurise
of fact I'd asked him last night to depirite
of fact I'd asked him last night to be purise
of fact I'd asked him last night to deprince
of fact I'd asked him last night to deparing
of fact I'd asked him last night to deprite
of fact I'd asked him last night to depinise
of fact I'd asked him last night to dephrise
of fact I'd asked him last night to deperise
of fact I'd asked him last night to depinite
of fact I'd asked him last night to deprive
of fact I'd asked him last night to depenise
of fact I'd asked him last night to dephrite
of fact I'd asked him last night to depenite
of fact I'd asked him last night to d