In [1]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

In [2]:
import re

def find_top(sentence, model, tokenizer, k): 
    print("Input: ", sentence)
    smatch = re.search("(\w+)\?", sentence)
    target = None
    
    if (smatch):
        target = re.sub("\?", "", smatch.group(0))
        
    sequence = re.sub("(\w+)?\?", tokenizer.mask_token, sentence)
    print(tokenizer.tokenize(sequence))
    
    input = tokenizer.encode(sequence, return_tensors="pt")
    mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

    token_logits = model(input)[0]
    mask_token_logits = token_logits[0, mask_token_index, :]

    topk = torch.topk(mask_token_logits, tokenizer.vocab_size, dim=1)
    top_tokens = list(zip(topk.indices[0].tolist(), topk.values[0].tolist()))

    for token, value in top_tokens[:k]:
        print(tokenizer.decode([token]), " ", value)
        
    if (target != None):
        vec = tokenizer.encode(target, return_tensors="pt")[0]
        if (len(vec) == 3):
            tk = vec[1].item()
            pos = None
            score = None
            
            for e, (t, v) in enumerate(top_tokens):
                if (t == tk):
                    pos = e
                    score = v
                    break
            print("Original word position: ", e, "; score: ", score)
        else: 
            if (len(vec) > 3):
                print("Original word is more then 1 token")
                print(tokenizer.tokenize(target))
            else:
                print("Original word wasn't found")
    print("===================")
    
def do_find(s):
    find_top(s, model,tokenizer, 10) 
    
base_samples = [
    "what's the local weather? forecast",
    "chance of rain? tomorrow in moscow",
    "chance of rain? tomorrow",
    "set the lights? on in the entire house",
    "set the lights on in the entire house?",
    "turn the lights off in the guest? bedroom",
    "turn the lights off in the guest bedroom?"
]

## Random wiki how to articles
articles = ["how to use a sharpening steel?",
"how to identify? and treat liver shunts in cats",
"how to identify and treat? liver shunts in cats",
"how to identify and treat liver? shunts in cats",
"how to identify and treat liver shunt? in cats",
"use enamel? paint on a stove",
"use enamel paint? on a stove",
"how to etch? wood",
"escape a sinking? ship",
"escape a sinking ship?",
"how to choose a cat? bed"]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")
model = AutoModelWithLMHead.from_pretrained("bert-large-uncased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1344997306.0, style=ProgressStyle(descr…




In [4]:
for s in base_samples:
    do_find(s)

Input:  what's the local weather? forecast
['what', "'", 's', 'the', 'local', '[MASK]', 'forecast']
weather   13.933197975158691
wind   9.824950218200684
storm   8.210029602050781
climate   7.648484230041504
tornado   7.5318193435668945
traffic   7.494179725646973
business   7.40327787399292
news   7.227436542510986
flood   7.180978298187256
heat   7.1706132888793945
Original word position:  0 ; score:  13.933197975158691
Input:  chance of rain? tomorrow in moscow
['chance', 'of', '[MASK]', 'tomorrow', 'in', 'moscow']
meeting   9.252975463867188
success   8.11396312713623
working   7.349595069885254
victory   7.223593235015869
landing   7.137297630310059
flying   6.959140777587891
winning   6.743556976318359
freedom   6.722757816314697
action   6.516355037689209
that   6.507873058319092
Original word position:  205 ; score:  3.3014800548553467
Input:  chance of rain? tomorrow
['chance', 'of', '[MASK]', 'tomorrow']
winning   8.949573516845703
success   7.771795272827148
returning   7.74

In [5]:
for a in articles:
    do_find(a)

Input:  how to use a sharpening steel?
['how', 'to', 'use', 'a', 'sharpe', '##ning', '[MASK]']
.   26.010730743408203
;   24.789464950561523
?   19.370759963989258
!   16.240190505981445
|   11.801448822021484
...   10.177217483520508
,   9.878094673156738
:   9.558690071105957
-   9.409708976745605
।   7.7624125480651855
Original word position:  8281 ; score:  -3.2654049396514893
Input:  how to identify? and treat liver shunts in cats
['how', 'to', '[MASK]', 'and', 'treat', 'liver', 'shu', '##nts', 'in', 'cats']
identify   12.187110900878906
detect   11.675525665283203
prevent   11.412120819091797
find   11.294748306274414
make   11.01986312866211
locate   10.963879585266113
remove   10.463669776916504
prepare   10.146129608154297
measure   9.821012496948242
repair   9.662485122680664
Original word position:  0 ; score:  12.187110900878906
Input:  how to identify and treat? liver shunts in cats
['how', 'to', 'identify', 'and', '[MASK]', 'liver', 'shu', '##nts', 'in', 'cats']
treat   1

In [6]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
model = AutoModelWithLMHead.from_pretrained("roberta-large")

In [7]:
for s in base_samples:
    do_find(s)

Input:  what's the local weather? forecast
['what', "'s", 'Ġthe', 'Ġlocal', '<mask>', 'Ġforecast']
 weather   58.944068908691406
 temperature   55.825252532958984
 snow   55.3741455078125
 traffic   54.93745422363281
 rainfall   54.863487243652344
 rain   54.681884765625
 school   54.495479583740234
 weekend   54.44187545776367
 precipitation   54.293479919433594
 driving   54.13661193847656
Original word position:  0 ; score:  58.944068908691406
Input:  chance of rain? tomorrow in moscow
['chance', 'Ġof', '<mask>', 'Ġtomorrow', 'Ġin', 'Ġmos', 'cow']
 bulls   56.93559265136719
 action   56.55562210083008
 fireworks   56.390655517578125
 beef   56.36457443237305
 cattle   55.886962890625
 thousands   55.69841003417969
 events   55.575355529785156
 cows   55.51642608642578
 Beef   55.48601150512695
 wrestling   55.44188690185547
Original word position:  63 ; score:  54.1391487121582
Input:  chance of rain? tomorrow
['chance', 'Ġof', '<mask>', 'Ġtomorrow']
 rain   59.147621154785156
 snow

In [8]:
for a in articles:
    do_find(a)

Input:  how to use a sharpening steel?
['how', 'Ġto', 'Ġuse', 'Ġa', 'Ġsharp', 'ening', '<mask>']
 tool   55.64176940917969
 knife   55.36676788330078
 pencil   55.1890869140625
 iron   54.96471405029297
 stone   54.86848449707031
 wheel   54.341529846191406
 pen   53.8603401184082
 device   53.61422348022461
 brush   53.129844665527344
 blade   52.88595199584961
Original word position:  16 ; score:  52.06224060058594
Input:  how to identify? and treat liver shunts in cats
['how', 'Ġto', '<mask>', 'Ġand', 'Ġtreat', 'Ġliver', 'Ġshun', 'ts', 'Ġin', 'Ġcats']
 diagnose   61.9696159362793
 spot   60.435420989990234
 prevent   60.22199249267578
 identify   59.7053337097168
 detect   59.45398712158203
 find   59.16080856323242
 recognize   58.65917205810547
 monitor   57.187862396240234
 catch   57.06170654296875
 locate   56.8363151550293
Original word position:  3 ; score:  59.7053337097168
Input:  how to identify and treat? liver shunts in cats
['how', 'Ġto', 'Ġidentify', 'Ġand', '<mask>', 

In [9]:
tokenizer = AutoTokenizer.from_pretrained("flaubert-large-cased")
model = AutoModelWithLMHead.from_pretrained("flaubert-large-cased")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1030.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1561415.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=895731.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1493194721.0, style=ProgressStyle(descr…




In [10]:
for s in base_samples:
    do_find(s)

Input:  what's the local weather? forecast
['wh', 'at', "'</w>", 's</w>', 'the</w>', 'local</w>', '<special1>', 'for', 'ec', 'ast</w>']
change   16.84689712524414
time   15.88947582244873
for   15.521727561950684
if   15.50067138671875
next   15.128596305847168
"   14.749134063720703
second   14.658276557922363
re   14.613921165466309
in   14.404990196228027
,   14.119577407836914
Original word is more then 1 token
['we', 'ather</w>']
Input:  chance of rain? tomorrow in moscow
['chance</w>', 'of</w>', '<special1>', 'tom', 'or', 'row</w>', 'in</w>', 'mos', 'co', 'w</w>']
for   17.209707260131836
to   16.50912857055664
in   16.4116268157959
bury   16.273530960083008
end   16.26807403564453
be   16.241287231445312
play   16.057323455810547
even   15.000792503356934
three   14.698225021362305
the   14.340210914611816
Original word position:  86 ; score:  10.30911922454834
Input:  chance of rain? tomorrow
['chance</w>', 'of</w>', '<special1>', 'tom', 'or', 'row</w>']
for   17.60697364807129

In [11]:
for a in articles:
    do_find(a)

Input:  how to use a sharpening steel?
['how</w>', 'to</w>', 'use</w>', 'a</w>', 'shar', 'pen', 'ing</w>', '<special1>']
for   15.021530151367188
body   12.83782958984375
,   12.724711418151855
fruit   11.820119857788086
to   11.552226066589355
of   11.362137794494629
ver   11.32625961303711
in   11.12248706817627
and   11.064708709716797
support   11.035231590270996
Original word is more then 1 token
['ste', 'el</w>']
Input:  how to identify? and treat liver shunts in cats
['how</w>', 'to</w>', '<special1>', 'and</w>', 'tre', 'at</w>', 'li', 'ver</w>', 'sh', 'un', 'ts</w>', 'in</w>', 'cats</w>']
sex   17.300159454345703
use   14.001256942749023
cure   13.8759765625
lave   13.508421897888184
girls   12.985353469848633
and   12.805550575256348
fine   12.574772834777832
div   12.547651290893555
stream   12.534297943115234
pr   12.533191680908203
Original word is more then 1 token
['identi', 'fy</w>']
Input:  how to identify and treat? liver shunts in cats
['how</w>', 'to</w>', 'identi', 