In [113]:
import logging
import requests
from pprint import pprint
from time import time
from gector.gec_model import GecBERTModel

import spacy
from utils.helpers import *

nlp = spacy.load("en")

In [114]:
model = GecBERTModel(
    vocab_path = "./data/output_vocabulary",
    model_paths = ["./pretrain/roberta_1_gector.th"],
    model_name = "roberta",
    is_ensemble = False,
    min_probability = 0.1,
    min_error_probability = 0.1,
)

In [3]:
source_file = '/home/citao/github/gector/dataset/wil.ABCN.dev.gold.bea19.0.source'
target_file = '/home/citao/github/gector/dataset/wil.ABCN.dev.gold.bea19.0.target'

In [111]:
def compare(source, target):
    doc = nlp(source)
    source_tokens = []
    for token in doc:
        source_tokens.append(token.text)
    
    doc = nlp(target)
    target_tokens = []
    for token in doc:
        target_tokens.append(token.text)

    # batch call
    batch = [source_tokens]
    preds, probs, idxs_batch, inter_pred_batch, error_probs, cnt = model.handle_batch(batch, config={}, debug=True)
    for iter_probas in probs:
        for one_probas in iter_probas:
            for idx, i in enumerate(one_probas):
                one_probas[idx] = round(i, 5)

    for iter_error_probs in error_probs:
        for idx, i in enumerate(iter_error_probs):
            iter_error_probs[idx] = round(i, 5)
                
    
    pred_text = ' '.join(preds[0])
    target_text = ' '.join(target_tokens)
    
    for ori_token, curr_error_probs, curr_idxs_batch, curr_probs, curr_iter_pred in zip(batch, error_probs, idxs_batch, probs, inter_pred_batch):
        inter_corrected_tokens = [ori_token] + curr_iter_pred
        inter_corrected_tokens = [['__START__']+i for i in inter_corrected_tokens]
        for i, (iter_error_prob, iter_idxs, iter_probs, iter_pred) in enumerate(zip(curr_error_probs, curr_idxs_batch, curr_probs, curr_iter_pred)):
            iter_labels = [model.vocab.get_token_from_index(i, namespace='labels') for i in iter_idxs]
            print('<Iteration {}>'.format(i+1))
            print('\n#### Before ####\n{}\n'.format(' '.join(inter_corrected_tokens[i][1:])))
            print('[Sentence Error Probability] ', iter_error_prob)
            print('[Correction Predtion]')
            print('-'*70)
            for _token, _edit, _proba in zip(inter_corrected_tokens[i], iter_labels, iter_probs):
                print(_token.ljust(30), _edit.ljust(30), _proba)
            print('-'*70)
            print('\n#### After #####\n{}\n'.format(' '.join(inter_corrected_tokens[i+1][1:])))
        print('='*70)

In [112]:
compare(
    "The rich people will buy a car but the poor people always need to use a bus or taxi .",
    "Rich people will buy a car , but poor people always need to use a bus or taxi ."
)


<Iteration 1>

#### Before ####
The rich people will buy a car but the poor people always need to use a bus or taxi .

[Sentence Error Probability]  0.69337
[Correction Predtion]
----------------------------------------------------------------------
__START__                      $KEEP                          0.99857
The                            $KEEP                          0.77089
rich                           $TRANSFORM_CASE_CAPITAL        0.51208
people                         $KEEP                          0.99815
will                           $KEEP                          0.77467
buy                            $KEEP                          0.99937
a                              $KEEP                          0.99871
car                            $APPEND_,                      0.74655
but                            $KEEP                          0.99894
the                            $KEEP                          0.7024
poor                           $KEEP               

20