A simple text correction example for ELECTRA using Google Colab.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
%tensorflow_version 2.x

import os
import warnings
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

os.chdir('./drive/My Drive/Python/Research/bert')
warnings.filterwarnings('ignore')

import mymodels as mm

In [2]:
class Corrector(keras.layers.Layer):
  def __init__(self, config, model, dim):
    super(Corrector, self).__init__()
    self.bert = mm.BERT(config, model)
    self.name1 = 'discriminator_predictions/dense'
    self.name2 = 'discriminator_predictions/dense_1'
    self.dense1 = keras.layers.Dense(dim, activation=mm.gelu_activating, name=self.name1)
    self.dense2 = keras.layers.Dense(1, activation='sigmoid', name=self.name2)

  def loading(self, ckpt):
    self.bert.loading(ckpt)
    _ = self.propagating(tf.ones((2, 2)), tf.zeros((2, 2)), tf.zeros((2, 2)), False)
    name1 = [i1.name[:-2] for i1 in self.weights[-4:]]
    valu1 = [tf.train.load_variable(ckpt, i1) for i1 in name1]
    keras.backend.batch_set_value(zip(self.weights[-4:], valu1))

  def propagating(self, text, segment, mask, training=False):
    x1 = self.bert.propagating(text, segment, mask, False, training)
    return self.dense2(self.dense1(x1))


tokenizer_1 = mm.Tokenizer()
tokenizer_1.loading('models/electra_base_ch/vocab.txt')
cor_1 = Corrector('models/electra_base_ch/electra_config.json', 'electra', 768)
cor_1.loading('models/electra_base_ch/electra_base')
mlm_1 = mm.MLM('models/bert_base_ch/bert_config.json', 'bert')
mlm_1.loading('models/bert_base_ch/bert_model.ckpt')

In [3]:
class TextCorrector(object):
  def __init__(self, mlm, corrector, tokenizer, maxlen):
    self.mlm = mlm
    self.cor = corrector
    self.tokenizer = tokenizer
    self.vocab = list(self.tokenizer.vocab.keys())
    self.maxlen = maxlen

  def checking(self, sentence, byone):
    text1, segm1, mask1 = self.tokenizer.encoding(sentence, None, self.maxlen)
    pred1 = self.cor.propagating(np.array([text1]), np.array([segm1]), np.array([mask1]))
    pred1 = tf.squeeze(pred1)
    text1 = text1[1:text1.index(102)]
    leng1 = range(len(text1))
    list1 = [pred1[i1+1].numpy() for i1 in leng1]
    list2 = [self.vocab[i1] for i1 in text1]

    if not byone:
      list3 = ['[MASK] ' if list1[i1] > 0.5 else list2[i1] for i1 in leng1]
      text2, segm2, mask2 = self.tokenizer.encoding(''.join(list3), None, self.maxlen)
      pred2 = self.mlm.propagating(np.array([text2]), np.array([segm2]), np.array([mask2]))
      pred2 = np.argmax(pred2, axis=-1)[0]
      list3 = [self.vocab[pred2[i1+1]] if list1[i1] > 0.5 else list2[i1] for i1 in leng1]
      return list1, list2, list3

    list3 = list(list2)

    for i1 in leng1:
      if list1[i1] > 0.5:
        list3[i1] = '[MASK] '
        text2, segm2, mask2 = self.tokenizer.encoding(''.join(list3), None, self.maxlen)
        pred2 = self.mlm.propagating(np.array([text2]), np.array([segm2]), np.array([mask2]))
        pred2 = np.argmax(pred2, axis=-1)[0]
        list3[i1] = self.vocab[pred2[i1+1]]

    return list1, list2, list3

  def correcting(self, sentence, byone=False):
    list1, list2, list3 = self.checking(sentence, byone)
    list2 = ['['+list2[i1]+']' if list1[i1] > 0.5 else list2[i1] for i1 in range(len(list1))]
    list3 = ['['+list3[i1]+']' if list1[i1] > 0.5 else list3[i1] for i1 in range(len(list1))]
    print(''.join(list2))
    print(''.join(list3))

In [4]:
corrector_1 = TextCorrector(mlm_1, cor_1, tokenizer_1, 64)
corrector_1.correcting('今天天气真差，阳光明魅，风和日立，天朗气青，非常适合外出履行。', True)

今天天气真[差]，阳光明[魅]，风和日[立]，天朗气[青]，非常适合外出[履]行。
今天天气真[好]，阳光明[媚]，风和日[丽]，天朗气[清]，非常适合外出[旅]行。
