In [1]:
import requests
import math

API_URL = "http://api.netspeak.org/netspeak3/search?query=%s"

class NetSpeak:
    def __init__(self):
        self.headers = {'User-Agent': 'Mozilla/5.0 (compatible; MSIE 5.5; Windows NT)'}
        self.page = None
        self.dictionary = {}

    def __getPageContent(self, url):
        return requests.get(url, headers=self.headers).text
        # return self.opener.open(url).read()

    def __rolling(self, url, maxfreq=None):
        if maxfreq:
            webdata = self.__getPageContent(url + "&maxfreq=%s" % maxfreq)
        else:
            webdata = self.__getPageContent(url)
        if webdata:
            # webdata = webdata.decode('utf-8')
            results = [data.split('\t') for data in webdata.splitlines()]
            results = [(data[2], float(data[1])) for data in results]
            lastFreq = int(results[-1][1])
            if lastFreq != maxfreq:
                return results + self.__rolling(url, lastFreq)
            else:
                return []
        else:
            return []

    def search(self, query):
        if query in self.dictionary: return self.dictionary[query]
        
        queries = query.lower().split()
        new_query = []
        for token in queries:
            if token.count('|') > 0:
                new_query.append('[+{0}+]'.format('+'.join(token.split('|'))))
            elif token == '*':
                new_query.append('?')
            else:
                new_query.append(token)
        new_query = '+'.join(new_query)
        url = API_URL % (new_query.replace(' ', '+'))
        self.dictionary[query] = self.__rolling(url)
        return self.dictionary[query]
    
SE = NetSpeak() # singleton

In [2]:
confuse_word = open('lab4.confusables.txt','r').readlines()
Confuse = {}
for line in confuse_word:
    w ,c = line.split('\t')
    Confuse[w]=c.strip()

In [3]:
def get_trigrams(tokens):
    return [tokens[i:i+3] for i in range(len(tokens) - 2)]

In [4]:
def detect_where(tm):
    trigrams = get_trigrams(tm)
    tri_tmp = []
    for index,tri in enumerate(trigrams):
        #print(tri)
        res = SE.search(' '.join(tri))
        #print(res)
        if res:
            tri_tmp.append((index,res[0][1],tri))
        else:
            tri_tmp.append((index,0,tri))
    #print(tri_tmp)
    minn  = min(tri_tmp,key=lambda x:x[1])[2]
    #print(minn)
    for find_index in tri_tmp:
        #print(find_index[2])
        if find_index[2]==minn:
            detect_sentence = find_index
            
    return detect_sentence

def find_the_best(tm,start):
    
    best = (None, None, None, None, -math.inf)
    #find_the_best = []
    for i in range(start,start+3):
        candidate = []
        for corr in correction(tm[i]):
            candidate.append(corr[0])
        if tm[i] in Confuse.keys():
            candidate.append(tm[i])
        #print(candidate)
        for cancan in candidate:
            count = 1.0
            combine = tm[:i] + [cancan] + tm[i+1:]
            #print(combine)
            trigrams = get_trigrams(combine)
            
            for tri in trigrams:
                res = SE.search(' '.join(tri))
                count *= res[0][1] if res else 0
                #print(res,count)
                
            best = (combine,tm[i],cancan,candidate,count) if count > best[-1] else best
       
    return best

In [7]:
import onmt
import onmt.io
import onmt.translate
import onmt.ModelConstructor
from collections import namedtuple
from itertools import count
# Load the model.
Opt = namedtuple('Opt', ['model', 'data_type', 'reuse_copy_attn', "gpu"])
opt = Opt("ch-OpenNMT-py/ch-merge-model/demo_model_acc_91.31_ppl_1.70_e13.pt", "text",False, 0)
fields, model, model_opt =  onmt.ModelConstructor.load_test_model(opt,{"reuse_copy_attn":False})

Loading model parameters.


In [8]:
def ch_OpenNMT_generate_candidate(detect_sentence_arr):
    ch_candidate = {}
    
    f = open('text.txt','w')
    for i in detect_sentence_arr:
        for ii in i:
            f.write(ii+' ')
        f.write('\n')
    f.close()
    
    data = onmt.io.build_dataset(fields, "text", "text.txt", None, use_filter_pred=False)
    data_iter = onmt.io.OrderedIterator(
        dataset=data, device=0,
        batch_size=1, train=False, sort=False,
        sort_within_batch=True, shuffle=False)
    # Translator
    scorer = onmt.translate.GNMTGlobalScorer(None,
                                             None,
                                             None,
                                             None)
    # Translator
    translator = onmt.translate.Translator(model, fields,
                                               beam_size=10,
                                               n_best=5,
                                               global_scorer=scorer,
                                               cuda=True)
    builder = onmt.translate.TranslationBuilder(
            data, translator.fields,
            5, False, None)
    # Translator
    scorer = onmt.translate.GNMTGlobalScorer(None,
                                             None,
                                             None,
                                             None)
    
    translator = onmt.translate.Translator(model, fields,
                                               beam_size=20,
                                               n_best=10,
                                               global_scorer=scorer,
                                               cuda=True)
    builder = onmt.translate.TranslationBuilder(
            data, translator.fields,
            10, False, None)
    for batch in data_iter:
        batch_data = translator.translate_batch(batch, data)
        translations = builder.from_batch(batch_data)
        for trans in translations:
            n_best_preds = [" ".join(pred) for pred in trans.pred_sents[:10]]
        
        ch_candidate[' '.join(translations[0].src_raw).replace(' ','')] = n_best_preds
    
    
    return ch_candidate

In [11]:
def find_the_best(tm,start):
    best = (None, None, None, None, -math.inf)
    #find_the_best = []
    for i in range(start,start+3):
        candidate = []
        candidate = ch_candidate[word[i]]
        if tm[i] in Confuse.keys():
            candidate.append(tm[i])
        #print(candidate)
        for cancan in candidate:
            count = 1.0
            combine = tm[:i] + [cancan] + tm[i+1:]
            #print(combine)
            trigrams = get_trigrams(combine)
            
            for tri in trigrams:
                res = SE.search(' '.join(tri))
                count *= res[0][1] if res else 0
                #print(res,count)
                
            best = (combine,tm[i],cancan,candidate,count) if count > best[-1] else best
       
    return best

In [12]:
#分割正確跟錯誤的資料集
line = open('lab4.test.1.txt','r').readlines()
Correct_sentence = []
False_sentence = []
for sentence in line:
    tmp = sentence.split('\t')
    False_sentence.append(tmp[0].strip().lower())
    Correct_sentence.append(tmp[1].strip().lower())
test_Correct=Correct_sentence[:20]
test_False = False_sentence[:20]

In [None]:
hits = 0
for i,line in enumerate (test_False):
    word = line.split(' ')
    detect_sentence = detect_where(word)
    start = detect_sentence[0]
    ch_candidate = ch_OpenNMT_generate_candidate(detect_sentence[2])
    combine ,wrong ,right ,candidate ,_ =find_the_best(word,start)
    combine = ' '.join(combine).strip()
    if combine == test_Correct[i]:
        hits+=1
        
    print("Error:" +  str(wrong))
    print("Candidates:", candidate)
    print("Correction:", right)
    print(test_False[i], "->", combine )
    print("hits =", hits)
    print()

average src size 4.666666666666667 3


  align_vectors = self.sm(align.view(batch*targetL, sourceL))
  input = module(input)
