In [1]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
import re
from gensim.test.utils import datapath
from gensim.models.fasttext import *
import gensim.downloader as api
import os

In [17]:
# https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz
cap_path = datapath(os.getcwd() + "/data/cc.en.300.bin.gz")
fbkv = load_facebook_vectors(cap_path)

In [35]:
w2v = api.load("word2vec-google-news-300") 

In [3]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
model = AutoModelWithLMHead.from_pretrained("roberta-large")

In [68]:
def find_top(sentence, k, min_sim): 
    print("Input: ", sentence)
    smatch = re.search("(\w+)\?", sentence)
    target = None
    
    if (smatch):
        target = re.sub("\?", "", smatch.group(0))
        target = target.strip()
        
    sequence = re.sub("(\w+)?\?", tokenizer.mask_token, sentence)
    print(tokenizer.tokenize(sequence))
    
    input = tokenizer.encode(sequence, return_tensors="pt")
    mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

    token_logits = model(input)[0]
    mask_token_logits = token_logits[0, mask_token_index, :]

    topk = torch.topk(mask_token_logits, tokenizer.vocab_size, dim=1)
    top_tokens = list(zip(topk.indices[0].tolist(), topk.values[0].tolist()))

    total = 0
    for token, value in top_tokens:
        word = tokenizer.decode([token]).strip()
        sim = None
        if (target != None):
            sim = fbkv.similarity(target, word)
            try:
                sim2 = w2v.similarity(target, word)
            except:
                sim2 = -1
            if (sim < min_sim and sim2 < min_sim):
                continue
        print(word, " ", value, "; fasttext: ", sim, "; w2v: ", sim2)
        total += 1
        if (total > k):
            break

        
    if (target != None):
        vec = tokenizer.encode(target, return_tensors="pt")[0]
        if (len(vec) == 3):
            tk = vec[1].item()
            pos = None
            score = None
            
            for e, (t, v) in enumerate(top_tokens):
                if (t == tk):
                    pos = e
                    score = v
                    break
            print("Original word position: ", e, "; score: ", score)
        else: 
            if (len(vec) > 3):
                print("Original word is more then 1 token")
                print(tokenizer.tokenize(target))
            else:
                print("Original word wasn't found")
    print("===================")
    
def do_find(s, min_sim = 0):
    find_top(s, 20, min_sim) 
    
base_samples = [
    "what's the local weather? forecast",
    "chance of rain? tomorrow in Moscow",
    "chance of rain? tomorrow",
    "set the lights? on in the entire house",
    "set the lights on in the entire house?",
    "turn the lights off in the guest? bedroom",
    "turn the lights off in the guest bedroom?"
]

## Random wiki how to articles
articles = [
    "how to use a sharpening steel?",
    "how to identify? and treat liver shunts in cats",
    "how to identify and treat? liver shunts in cats",
    "how to identify and treat liver? shunts in cats",
    "use enamel paint? on a stove",
    "escape a sinking? ship",
    "escape a sinking ship?",
    "reduce the effect? of macular degeneration",
    "recognize sign? of an abusive person",
    "recognize signs of an abusive person?",
    "make wine? vinegar",
    "prepare boneless skinless chicken? thighs",
    "what can you eat with type 2 diabetes?"
]

In [65]:
for s in base_samples:
    do_find(s)

Input:  what's the local weather? forecast
['what', "'s", 'Ġthe', 'Ġlocal', '<mask>', 'Ġforecast']
weather   58.944068908691406 ; fasttext:  0.99999976 ; w2v:  0.9999998
temperature   55.825252532958984 ; fasttext:  0.45059466 ; w2v:  0.38615134
snow   55.3741455078125 ; fasttext:  0.49174437 ; w2v:  0.5175463
traffic   54.93745422363281 ; fasttext:  0.32146013 ; w2v:  0.27994886
rainfall   54.863487243652344 ; fasttext:  0.49745676 ; w2v:  0.47409654
rain   54.681884765625 ; fasttext:  0.5836918 ; w2v:  0.5686597
school   54.495479583740234 ; fasttext:  0.16067486 ; w2v:  0.08011741
weekend   54.44187545776367 ; fasttext:  0.3462723 ; w2v:  0.26656434
precipitation   54.293479919433594 ; fasttext:  0.4120771 ; w2v:  0.4670227
driving   54.13661193847656 ; fasttext:  0.2243191 ; w2v:  0.18876596
travel   53.818111419677734 ; fasttext:  0.2561077 ; w2v:  0.23630022
fire   53.77357864379883 ; fasttext:  0.25399584 ; w2v:  0.23497717
radar   53.732852935791016 ; fasttext:  0.31953615 ; w2

master   62.25900650024414 ; fasttext:  0.25583786 ; w2v:  0.16522017
guest   60.62987518310547 ; fasttext:  1.0 ; w2v:  1.0000001
second   59.89073181152344 ; fasttext:  0.20992765 ; w2v:  0.09038264
spare   59.19243240356445 ; fasttext:  0.17714785 ; w2v:  0.113498464
third   58.96836471557617 ; fasttext:  0.23888935 ; w2v:  0.09191485
main   58.719661712646484 ; fasttext:  0.24381165 ; w2v:  0.0968393
back   58.608848571777344 ; fasttext:  0.11738129 ; w2v:  0.014101536
upstairs   58.48353576660156 ; fasttext:  0.30337965 ; w2v:  0.20656306
other   58.32860565185547 ; fasttext:  0.19010572 ; w2v:  0.11675016
front   58.00963592529297 ; fasttext:  0.18248263 ; w2v:  0.086054556
downstairs   57.79798889160156 ; fasttext:  0.2840525 ; w2v:  0.20741828
middle   57.3287353515625 ; fasttext:  0.072875775 ; w2v:  0.0066892
first   57.30056381225586 ; fasttext:  0.2574252 ; w2v:  0.08624094
extra   56.92780685424805 ; fasttext:  0.2632902 ; w2v:  0.10322451
Master   56.860107421875 ; fastte

In [66]:
for a in articles:
    do_find(a)

Input:  how to use a sharpening steel?
['how', 'Ġto', 'Ġuse', 'Ġa', 'Ġsharp', 'ening', '<mask>']
tool   55.64176940917969 ; fasttext:  0.1867103 ; w2v:  0.04131889
knife   55.36676788330078 ; fasttext:  0.32689506 ; w2v:  0.10359444
pencil   55.1890869140625 ; fasttext:  0.26298997 ; w2v:  0.12676606
iron   54.96471405029297 ; fasttext:  0.5903258 ; w2v:  0.44558045
stone   54.86848449707031 ; fasttext:  0.42118368 ; w2v:  0.32327506
wheel   54.341529846191406 ; fasttext:  0.28017086 ; w2v:  0.15702476
pen   53.8603401184082 ; fasttext:  0.21977136 ; w2v:  0.06327921
device   53.61422348022461 ; fasttext:  0.100269385 ; w2v:  0.049095023
brush   53.129844665527344 ; fasttext:  0.20714068 ; w2v:  0.06368443
blade   52.88595199584961 ; fasttext:  0.39472598 ; w2v:  0.23887269
tip   52.83150100708008 ; fasttext:  0.1761272 ; w2v:  0.07401559
stick   52.60443878173828 ; fasttext:  0.21086262 ; w2v:  0.078297704
pin   52.417808532714844 ; fasttext:  0.21666886 ; w2v:  0.10574301
board   52.

cruise   58.42720413208008 ; fasttext:  0.32779014 ; w2v:  0.28956673
cargo   56.786293029785156 ; fasttext:  0.26759654 ; w2v:  0.11580957
pirate   56.477928161621094 ; fasttext:  0.19924934 ; w2v:  0.20721014
sinking   56.103729248046875 ; fasttext:  1.0 ; w2v:  1.0000001
container   55.70797348022461 ; fasttext:  0.169901 ; w2v:  0.07815184
small   55.65782165527344 ; fasttext:  0.09459712 ; w2v:  0.06347525
tall   55.656585693359375 ; fasttext:  0.116673246 ; w2v:  0.106152534
ghost   55.45624542236328 ; fasttext:  0.15300737 ; w2v:  0.1765039
passenger   55.334571838378906 ; fasttext:  0.18724741 ; w2v:  0.07284105
new   55.228729248046875 ; fasttext:  0.11614084 ; w2v:  0.06528825
space   54.90544128417969 ; fasttext:  0.06106523 ; w2v:  0.0638573
rocket   54.855770111083984 ; fasttext:  0.18235245 ; w2v:  0.1791914
mother   54.80230712890625 ; fasttext:  0.13208514 ; w2v:  0.04660174
large   54.76909637451172 ; fasttext:  0.15687037 ; w2v:  0.063319
sailing   54.56751251220703 ;

KeyError: "word '?' not in vocabulary"

In [69]:
for s in base_samples:
    do_find(s, 0.5)

Input:  what's the local weather? forecast
['what', "'s", 'Ġthe', 'Ġlocal', '<mask>', 'Ġforecast']
weather   58.944068908691406 ; fasttext:  0.99999976 ; w2v:  0.9999998
snow   55.3741455078125 ; fasttext:  0.49174437 ; w2v:  0.5175463
rain   54.681884765625 ; fasttext:  0.5836918 ; w2v:  0.5686597
storm   53.713104248046875 ; fasttext:  0.49685508 ; w2v:  0.5032116
winter   53.38654708862305 ; fasttext:  0.55396944 ; w2v:  0.54917294
forecast   53.28466796875 ; fasttext:  0.5524446 ; w2v:  0.3627208
temperatures   52.94639205932617 ; fasttext:  0.52381134 ; w2v:  0.5156578
climate   52.55876922607422 ; fasttext:  0.52157056 ; w2v:  0.48332137
Weather   52.347877502441406 ; fasttext:  0.63699436 ; w2v:  0.6667285
storms   51.34629821777344 ; fasttext:  0.5273381 ; w2v:  0.5282772
rainy   51.1673698425293 ; fasttext:  0.58184874 ; w2v:  0.5066526
cold   50.968055725097656 ; fasttext:  0.5236649 ; w2v:  0.47726867
rains   50.079490661621094 ; fasttext:  0.5327639 ; w2v:  0.52753514
sunny

guest   60.62987518310547 ; fasttext:  1.0 ; w2v:  1.0000001
guests   54.34364700317383 ; fasttext:  0.6829934 ; w2v:  0.6226556
Guest   53.81261444091797 ; fasttext:  0.68728846 ; w2v:  0.6610845
Guest   47.012081146240234 ; fasttext:  0.68728846 ; w2v:  0.6610845
Original word position:  1 ; score:  60.62987518310547
Input:  turn the lights off in the guest bedroom?
['turn', 'Ġthe', 'Ġlights', 'Ġoff', 'Ġin', 'Ġthe', 'Ġguest', '<mask>']
room   62.96192169189453 ; fasttext:  0.60131574 ; w2v:  0.51099974
bedroom   61.55071258544922 ; fasttext:  1.0000002 ; w2v:  1.0000005
bathroom   60.759788513183594 ; fasttext:  0.6976226 ; w2v:  0.6174972
room   60.061004638671875 ; fasttext:  0.60131574 ; w2v:  0.51099974
bedrooms   59.70532989501953 ; fasttext:  0.761325 ; w2v:  0.78388095
bath   58.98833465576172 ; fasttext:  0.59171534 ; w2v:  0.54950047
house   58.139278411865234 ; fasttext:  0.56778806 ; w2v:  0.6496936
house   58.10982131958008 ; fasttext:  0.56778806 ; w2v:  0.6496936
bed   

In [70]:
for a in articles:
    do_find(a, 0.5)

Input:  how to use a sharpening steel?
['how', 'Ġto', 'Ġuse', 'Ġa', 'Ġsharp', 'ening', '<mask>']
iron   54.96471405029297 ; fasttext:  0.5903258 ; w2v:  0.44558045
steel   52.06224060058594 ; fasttext:  0.9999999 ; w2v:  0.99999976
metal   50.21138381958008 ; fasttext:  0.6613991 ; w2v:  0.63669586
plastic   48.058616638183594 ; fasttext:  0.5261128 ; w2v:  0.40624353
aluminum   47.98210525512695 ; fasttext:  0.70130825 ; w2v:  0.72432727
alloy   47.36555862426758 ; fasttext:  0.5877134 ; w2v:  0.5962639
copper   47.35080337524414 ; fasttext:  0.551749 ; w2v:  0.48506638
rubber   47.315185546875 ; fasttext:  0.50396174 ; w2v:  0.44299135
wood   46.582176208496094 ; fasttext:  0.5171267 ; w2v:  0.4854269
cement   46.39852523803711 ; fasttext:  0.46326405 ; w2v:  0.51855695
chrome   46.024742126464844 ; fasttext:  0.50509864 ; w2v:  0.40925384
titanium   45.84963607788086 ; fasttext:  0.6194425 ; w2v:  0.58357024
ceramic   45.57798385620117 ; fasttext:  0.5117519 ; w2v:  0.41449147
brass

effects   56.85041809082031 ; fasttext:  0.75341296 ; w2v:  0.6430684
impact   55.53343200683594 ; fasttext:  0.63831997 ; w2v:  0.65279025
effect   54.533267974853516 ; fasttext:  1.0000001 ; w2v:  1.0000004
impacts   53.56023406982422 ; fasttext:  0.5320576 ; w2v:  0.5113333
affect   50.765506744384766 ; fasttext:  0.5557193 ; w2v:  0.4310482
Effect   47.64755630493164 ; fasttext:  0.53610176 ; w2v:  0.3906251
effects   46.42105484008789 ; fasttext:  0.75341296 ; w2v:  0.6430684
ripple   46.08658981323242 ; fasttext:  0.5440018 ; w2v:  0.31508175
adverse   44.356441497802734 ; fasttext:  0.5487535 ; w2v:  0.35791332
effect   43.71855163574219 ; fasttext:  1.0000001 ; w2v:  1.0000004
impact   43.21967315673828 ; fasttext:  0.63831997 ; w2v:  0.65279025
detrimental   42.203277587890625 ; fasttext:  0.5952004 ; w2v:  0.39297584
ffect   41.44746017456055 ; fasttext:  0.5003516 ; w2v:  0.21939628
Effect   40.840518951416016 ; fasttext:  0.53610176 ; w2v:  0.3906251
Original word position: