In [133]:
#!/usr/bin/env python
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Read and Data analyze

In [49]:
from collections import Counter, defaultdict
from utils.config import tag2pos_table, level_table
from utils.EGP import Egp
from utils.EVP import Evp
from utils.BNC import Bnc
import kenlm

model = kenlm.Model('/home/nlplab/jjc/gec/lm/coca.prune.bin')

In [3]:
def lm(last_sent, ngram):
    sentence = last_sent + ' ' + ngram
    score = model.score(sentence, bos=True, eos=False) # / len(sentence.split())
    return score


def level_score(ngram):
    levels = [Evp.get_level(token) for token in ngram.split()]
    score = sum([level_table[level] if level else 0 for level in levels]) / len(levels)
    return score


def normalize_tag(tag):
    return tag2pos_table[tag]+'.' if tag in tag2pos_table else tag


In [126]:
def suggest_patterns(related_patterns, text_sent, headword, pos, top_k=1):
    # not containing related ngram patterns
    related_patterns = filter(lambda ptn: ptn[1] not in [9, 12] and headword in Bnc.ngram_groups[ptn], related_patterns)

    # group patterns by number
    pattern_groups = defaultdict(list)
    for pattern, no in related_patterns:
        pattern_groups[no].append(pattern)
        
    # calc LM and get top k
    targets = []
    for no in pattern_groups:
        candidates = [ng for pattern in pattern_groups[no] for ng in Bnc.ngram_groups[(pattern, no)][headword]]
        candidates = filter(lambda ng: len(ng.split()) < 7, candidates)
        candidates = filter(lambda ng: len(Bnc.sentences[ng]) > 1, candidates)
        scores = [(ng, level_score(ng), lm(text_sent, ng)) for ng in candidates]
        scores = sorted(scores, key=lambda x: x[2], reverse=True)[:10]

        if len(scores) > 0:
            scores = sorted(scores, key=lambda x: x[1], reverse=True)[:5]
            ngrams = [score[0] for score in scores]
            lm_score = sum(score[2] for score in scores) / len(scores)
            targets.append({'no': no, 'level': Egp.get_level(no), 'pos': normalize_tag(pos), 
                            'lm': lm_score, 'pattern': Egp.get_norm_pattern(no), 'ngrams': ngrams,
                            'category': Egp.get_category(no), 'subcategory': Egp.get_subcategory(no),
                            'statement': Egp.get_statement(no) } )
    
    # sort by lm
    targets = sorted(targets, key=lambda t: t['lm'], reverse=True)
    
    return targets


def suggest_sentences(ngram):
    return Bnc.sentences[ngram][:3]


In [127]:
# from utils.grammar import iterate_all_patterns

def auto_suggest(parse_sent, gets):
    last_word, last_index = parse_sent[-1], len(parse_sent) - 1
    if last_word.is_punct and len(parse_sent) > 1:
        last_word, last_index = parse_sent[-2], len(parse_sent) - 2
        
    headword, pos = last_word.text, last_word.tag_
    text_sent = parse_sent.text.rsplit(' ', maxsplit=1)[0] # to text
    
    # get max get
    gets = [get for get in gets if last_index in get['indices']]
    get = None
    
    related_patterns = Bnc.pattern_groups[pos].union(Bnc.pattern_groups[headword])
    if len(gets) > 0:
        get = max(gets, key=lambda get: level_table[get['level']])
        get['pattern'] = Egp.get_norm_pattern(get['no'])
        related_patterns = filter(lambda ptn: level_table[Egp.get_level(ptn[1])] > level_table[get['level']] and get['no'] != ptn[1], related_patterns)
    
    patterns = suggest_patterns(related_patterns, text_sent, headword, pos)
    
    return get, { 'patterns': patterns, 'collocations': [] }


In [140]:
from utils.parser import nlp
from utils.grammar import iterate_all_patterns

parse = nlp('Those students are friendly')
get, sugs = auto_suggest(parse, iterate_all_patterns(parse))
from pprint import pprint
# pprint(sugs)
from utils.vocabulary import recommend_vocabs
print(recommend_vocabs('friendly'))
# for sug in sugs['patterns']:
#     if sug['level'].startswith('A'): continue
#     print(sug['level'], sug['ngram'], str(round(sug['lm'],2)) + ' \\\\', sep=' & ')

[('freindly', None), ('Friendly', None), ('unfriendly', 'B1'), ('friendlier', None), ('friendliest', None), ('hospitable', 'C1'), ('anybody_Petney', None), ('friendliness', 'B2'), ('congenial', None), ('detecting_gastrointestinal_disorders', None), ('Petney_observed', None), ('sociable', 'B1'), ('Jowzjan_province_Darzab_district', None), ('easygoing', 'B1'), ('oriented', None), ('hourslong_confrontation', None), ('welcoming', None), ('pleasant', 'A2'), ('staff_hrogers', None), ('approachable', None), ('unintimidating', None), ('friendy', None), ('WTOP_strives', None), ('Maya_Derkovic', None), ('courteous', 'C2'), ('cordial', None), ('accommodating', None), ('cheerful', 'B1'), ('Luttrell_Auction', None), ('convivial', None), ('orientated', None), ('www.playthq.com', None), ('amiable', None), ('lively', 'B1'), ('loving', None), ('chatty', 'C1'), ('unstuffy', None), ('personable', None), ('warm', 'A1'), ('respectful', 'C1'), ('neighborly', None), ('hireable', None), ('considerate', 'C1'),

In [151]:
import pickle
with open('recommend.prune.pickle', 'rb') as handle:
    info = pickle.load(handle)
    counts, ngrams, sentences = info['counts'], info['ngrams'], info['sentences']
            
# total = 0
# for item, count in counts.items():
#     total += count
# print(total)

8098545


In [None]:
# def auto_suggest(headword, pos, last_sent):
#     related_patterns = pattern_groups[headword] + pattern_groups[pos] # get headword matched POS and first word
#     related_patterns = [related_pat for related_pat in related_patterns if headword in ngram_groups[related_pat]]

#     # merge same number rule
#     total = Counter()
#     for pattern, no in related_patterns: 
#         for ngram in ngram_groups[(pattern, no)][headword]:
#             total[no] += ngrams[(pattern, no)][ngram]
#     top_k_keys = dict(total.most_common(3))

#     # get top k patterns
#     target_patterns = filter(lambda key: key[1] in top_k_keys, related_patterns)

#     # normalize patterns
#     target_norm_patterns = defaultdict(list)
#     for pattern, no in target_patterns:
#         target_norm_patterns[normalize_pattern(headword, pattern)].append((pattern, no))
        
#     # retrieve ngrams example
#     total = 0
#     target_ngrams = []
#     for norm_pattern in target_norm_patterns: # (pattern, no)
#         scores = [lm(last_sent, ng) for pattern, no in target_norm_patterns[norm_pattern] 
#                   for ng in ngram_groups[(pattern, no)][headword]] # get ngram in given patterns
#         scores = sorted(scores, key=lambda x: x[1], reverse=True)[:3]

#         t_ngrams = [ng[0] for ng in scores]
#         avg = 1 / abs(sum([s[1] for s in scores]) / len(scores))
#         total += avg        

#         target_ngrams.append({
#             'pattern': norm_pattern, 'pos': normalize_tag(pos),
#             'no': no, 'level': Egp.get_level(no), 
#             # 'count': counts[(pattern, no)], 
#             'lm': avg,
#             'category': Egp.get_category(no), 'subcategory': Egp.get_subcategory(no),
#             'ngrams': t_ngrams, 
#             'sentence': sentences[t_ngrams[0]][0] })

#     # get means and sorted
#     for ng in target_ngrams: 
#         ng['lm'] = ng['lm'] / total
        
#     return target_ngrams
