In [1]:
from pattern.en import *
from nltk.corpus import wordnet as wn
from nltk.stem.wordnet import WordNetLemmatizer
from nltk import word_tokenize, pos_tag
from collections import defaultdict
import random

In [2]:
def get_wordnet_pos(treebank_tag):

    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return ''

In [3]:
def tag_map(pos):
    _pos = pos[0]
    if (_pos == "J"): return wn.ADJ
    elif (_pos == "V"): return wn.VERB
    elif (_pos == "R"): return wn.ADV
    else: return wn.NOUN
    
# input word_tokenize of sentence
def preprocessing(tokens, debug = False):
    tokenList = []
    tagList = dict()
    lemmaList = dict()
    
    # lemmatize
    _lemma = WordNetLemmatizer()
    for token, tag in pos_tag(tokens):
        if (debug):
            print("token: {}, tag: {}".format(token, tag))
        tokenList.append(token)
        tagList[token] = tag

        lemma = _lemma.lemmatize(token, tag_map(tag))
        lemmaList[token] = lemma
        if (debug):
            print(token, "=>", lemma)
    return tokenList, tagList, lemmaList

def get_name_pos_from_syn(syn):
    name = syn.name().split('.')[0]
    pos = syn.name().split('.')[1]
    return name, pos

In [4]:
def make_synDict(tokenList, tagList, lemmaList, debug = False, thres = -1):
    synDict = dict()
    
    # make synset using lemma
    for token in tokenList:
        lemma = lemmaList[token]
        synDict[lemma] = wn.synsets(str(lemma))

    # get synset with only same postag
    for token in tokenList:
        lemma = lemmaList[token]
        for val in synDict[lemma]:
            name = val.name().split('.')[0]
            pos = val.name().split('.')[1]
            if (tag_map(tagList[token]) != pos):
                synDict[lemma].remove(val)
                if (debug):
                    print("[REMOVE] pos: {}, tokenpos: {} / token: {}, synName: {}".format(pos, tagList[token], lemma, name))
            if (tag_map(tagList[token]) == pos and debug):
                print("[NOT REMOVE] pos: {}, tokenpos: {} / token: {}, synName: {}".format(pos, tagList[token], lemma, name))
            if (len(synDict[lemma]) == 0):
                del(synDict[lemma])
                
    if (thres > 0):   
        print("Apply threshold")
        for token in tokenList:
            lemma = lemmaList[token]
            if (len(synDict[lemma]) >= thres):
                del(synDict[lemma])
        if (debug): 
            print(synDict)
            
    return synDict

def get_tense(syn, original_pos):
    
    # VB: Verb, base form 
    # VBD: Verb, past tense
    # VBG: Verb, gerund or present participle
    # VBN: Verb, past participle
    # VBP: Verb, non-3rd person singular present
    # VBZ: Verb, 3rd person singular present
    
    name, pos = get_name_pos_from_syn(syn)
    _tense = "present"            # INFINITIVE, PRESENT, PAST, FUTURE
    _person = 1                # 1, 2, 3, or None
    _number = "singular"       # SG, PL
    _mood = "indicative"       # INDICATIVE, IMPERATIVE, CONDITIONAL, SUBJUNCTIVE
    _aspect = "imperfective"   # IMPERFECTIVE, PERFECTIVE, PROGRESSIVE
    
    if (original_pos == "VBD"):
        _tense = "past"
    elif (original_pos == "VBG"):
        _aspect = "progressive"
    elif (original_pos == "VBN"):
        _tense = "past"
        _aspect = "progressive"
    elif (original_pos == "VBZ"):
        _person = 3
    
    return conjugate(name,
             tense = _tense,
             person = _person,
             number = _number,
             mood = _mood,
             aspect = _aspect, 
             negated = False)
    
    
def make_hypernymDict(synDict, tokenList, tagList, lemmaList, debug=False, thres = -1):
    for token in tokenList:
        lemma = lemmaList[token]
        for syn in synDict[lemma]:
            hyper = syn.hyponyms()
            if ((thres > 0) and (len(hyper) > thres)):
                synDict[lemma].remove(syn)
            elif (len(hyper) == 0):
                synDict[lemma].remove(syn)
            else:
                #print("lemma {} - hyper of {} is {}".format(lemma, syn, hyper))
                continue
                
    _key = list(synDict.keys())
    if (len(_key) > 1):
        rand1 = random.randint(1, len(_key))
    else:
        rand1 = 1
    # print("rand1: ", rand1)
    keyElem = _key[rand1-1]
    randList = synDict[keyElem]
    if (len(randList) > 1):
        rand2 = random.randint(1, len(randList))
    else:
        rand2 = 1
    # print("rand2: ", rand2)
    # print("randList: ", randList)
    randElem = randList[rand2-1]
    
    # get original word [_token] and original tag [original_tag]
    _index = (list(lemmaList.values())).index(keyElem)
    _token = tokenList[_index]
    original_tag = tagList[_token]
    
    name, pos = get_name_pos_from_syn(randElem)
    renamed = ""
    if (pos == "v"):
        # need to take care of tense
        _renamed = get_tense(randElem, original_tag)
    else:
        _renamed = name
        
    print("change {} to {}".format(_token, _renamed))
    
    return (_token, _renamed)

def main(sen):
    _debug = False
    tokens = word_tokenize(sen)
    tokenList, tagList, lemmaList = preprocessing(tokens, debug=_debug)
    synDict = make_synDict(tokenList, tagList, lemmaList, debug=False)
    hypernymDict = make_hypernymDict(synDict, tokenList, tagList, lemmaList, debug=False, thres = 10)
main("I am planning to have a trip to LA")

In [5]:
def main(sen): 
    debug = False 
    tokens = word_tokenize(sen)
    tokenList, tagList, lemmaList = preprocessing(tokens, debug=debug)
    synDict = make_synDict(tokenList, tagList, lemmaList, debug=False) 
    (ori_word, nxt_word) = make_hypernymDict(synDict, tokenList, tagList, lemmaList, debug=False, thres = 10) 
    retVal = sen.replace(ori_word, nxt_word)
    return retVal
    
for i in range(10):
    print(main("I am planning to have a trip to LA"))


RuntimeError: generator raised StopIteration