In [1]:
import torch
from torch import nn
from torch.utils import data
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.notebook import tqdm

import sys
import datetime

sys.path.append('../code')
from dataset import get_data, MaskedDataset, make_vocab, read_json

from transformers import (
    AdamW, get_linear_schedule_with_warmup
)

from models import MaskedRDModel

In [2]:
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')

Using cache found in /home/ubuntu/.cache/torch/hub/huggingface_pytorch-transformers_master


In [3]:
d = get_data('../wantwords-english-baseline/data', word2vec=False)

Loading data...
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [4]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

In [5]:
mask_size = 5
target_matrix, target2idx, idx2target = make_vocab(d, tokenizer, mask_size=mask_size)

In [6]:
model = MaskedRDModel.from_pretrained('bert-base-uncased')
model.initialize(mask_size=mask_size, multilabel=True, ww_vocab_size=len(target2idx), pos_weight=10)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing MaskedRDModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing MaskedRDModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MaskedRDModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# target2idx maps target words to indices
# target_matrix maps target indices to bpe sequences, padded/truncated to mask_size
target2idx['book'], target_matrix[target2idx['book']], idx2target[target2idx['book']]

(16187, tensor([2338,  103,  103,  103,  103]), 'book')

In [8]:
wn_data = read_json('../data/wn_data.json')
wn_categories = ['synonyms', 'hyponyms', 'hypernyms', 'related_forms']

In [9]:
train_dataset = MaskedDataset(train_data + train_data_def, tokenizer, target2idx, wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)

In [10]:
dev_dataset = MaskedDataset(dev_data, tokenizer, target2idx, 
                            wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_seen = MaskedDataset(test_data_seen, tokenizer, target2idx, 
                                  wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_unseen = MaskedDataset(test_data_unseen, tokenizer, target2idx, 
                                    wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_desc = MaskedDataset(test_data_desc, tokenizer, target2idx, 
                                  wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)

In [11]:
index = 1593

[idx2target[x] for x in dev_dataset[index][-1].coalesce().indices().squeeze(0)], idx2target[dev_dataset[index][1]]

(['classic',
  'authorized',
  'importance',
  'authoritative',
  'classical',
  'definitive',
  'important'],
 'authoritative')

In [12]:
batch_size = 32
num_workers = 0

loader_params = {
    'pin_memory': False,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'collate_fn': train_dataset.collate_fn
}

train_loader = data.DataLoader(train_dataset, **{'shuffle': True, **loader_params})
dev_loader = data.DataLoader(dev_dataset, **{'shuffle': True, **loader_params})
test_loader_seen = data.DataLoader(test_dataset_seen, **{'shuffle': False, **loader_params})
test_loader_unseen = data.DataLoader(test_dataset_unseen, **{'shuffle': False, **loader_params})
test_loader_desc = data.DataLoader(test_dataset_desc, **{'shuffle': False, **loader_params})

In [13]:
# Starting from epoch 2
# weight_gt = 10
# epochs = 9
# lr = 1e-5 * 0.905
# optim = AdamW(model.parameters(), lr=lr)
# scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=1, 
#                                             num_training_steps=len(train_loader) * epochs)
# epoch = 0
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
epochs = 10

lr = 3e-5
optim = AdamW(model.parameters(), lr=lr)

warmup_duration = 0.05 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

epoch = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaler = GradScaler()

In [15]:
weight_gt = 25

In [16]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheduler = type(scheduler).__name__
# config.warmup_duration = warmup_duration

# wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.29 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [17]:
target_matrix = target_matrix.to(device)

In [18]:
model = model.to(device)

In [19]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [20]:
def test(loader, name, log=False):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = test_rank_median = test_rank_variance = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, (x,y, wn_ids) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x = x.to(device)
                attention_mask = (x != train_dataset.pad_id)
                y = y.to(device)
                wn_ids = wn_ids.to_dense().to(device).float()

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask, 
                                  target_matrix=target_matrix, ground_truth=y,
                                  wn_ids=wn_ids, weight_gt=weight_gt)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(out, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
    return {
            f'{name}_test_loss': test_loss,
            f'{name}_test_acc1': test_acc1,
            f'{name}_test_acc10': test_acc10,
            f'{name}_test_acc100': test_acc100,
            f'{name}_test_rank_median': median,
            f'{name}_test_rank_variance': var
           }
    

In [None]:
inc = 10
losses = []
print('Training beginning!')

for p in optim.param_groups:
    p['lr'] = 1e-5
    
warmup_duration = 0.05 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

for epoch in range(epoch, epochs + 10):
    # Training
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, (x, y, wn_ids) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x = x.to(device)
            attention_mask = (x != train_dataset.pad_id)
            y = y.to(device)
            wn_ids = wn_ids.to_dense().to(device).float()
            
            loss, out = model(input_ids=x, attention_mask=attention_mask, 
                              target_matrix=target_matrix, ground_truth=y, 
                              wn_ids=wn_ids, weight_gt=weight_gt)
            
#             scaler.scale(loss).backward()
            loss.backward()
            
#             scaler.unscale_(optim)
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            
#             scaler.step(optim)
            optim.step()
#             scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss, attention_mask
            
    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save(model, f)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    try:
        with torch.no_grad():
            with tqdm(total=len(dev_loader)) as pbar:
                for i, (x, y, wn_ids) in enumerate(dev_loader):
                    if i % inc == 0 and i != 0:
                        display_loss = val_loss / i
                        pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                    x = x.to(device)
                    attention_mask = (x != train_dataset.pad_id)
                    y = y.to(device)
                    wn_ids = wn_ids.to_dense().to(device).float()

    #                 with autocast():
                    loss, out = model(input_ids=x, attention_mask=attention_mask, 
                                      target_matrix=target_matrix, ground_truth=y,
                                      wn_ids=wn_ids, weight_gt=weight_gt)

                    val_loss += loss.detach()

                    pbar.update(1)                

                    result, indices = torch.topk(out, k=100, dim=-1, largest=True, sorted=True)

                    acc1, acc10, acc100 = evaluate(indices, y)
                    val_acc1 += acc1
                    val_acc10 += acc10
                    val_acc100 += acc100

                    del x, y, out, loss
    except:
        print('Error encountered, aborting validation!')
    
    wandb.log({
        'train_loss': train_loss / len(train_loader),
        'val_loss': val_loss / len(dev_loader),
        'val_acc1': val_acc1 / len(dev_loader),
        'val_acc10': val_acc10 / len(dev_loader),
        'val_acc100': val_acc100 / len(dev_loader),
        **test(test_loader_seen, 'seen'),
        **test(test_loader_unseen, 'unseen'),
        **test(test_loader_desc, 'desc')
    })
    

Training beginning!


  0%|          | 0/21117 [00:00<?, ?it/s]

In [28]:
def getPredFromDesc(model, desc : str, mask_size=5, top_n=10):
    desc = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(desc))
    cls_id, mask_id, sep_id, pad_id = train_dataset.cls_id, train_dataset.mask_id, train_dataset.sep_id, train_dataset.pad_id
    desc_ids = [cls_id] + [mask_id] * mask_size + [sep_id] + desc
    x = torch.tensor(desc_ids).unsqueeze(0).to(device)
    attention_mask = (x != pad_id)
    out = model(input_ids=x, attention_mask=attention_mask, target_matrix=target_matrix)
    result, indices = torch.topk(out, k=top_n, dim=-1, largest=True, sorted=True)
    
    indices = indices[0]
    return [idx2target[i] for i in indices], indices, torch.sigmoid(result).squeeze(0)
    

In [95]:
model = torch.load('../trained_models/MaskedRDModel_Epoch_1_at_2021-05-06_16:31:58.058170')

In [152]:
model.eval()
None

In [None]:
queries = [
    ("employee at a circus", "clown"),
    ("a type of tree", None), # type
    ("the opposite of being happy", None), # type
    ("a road on which cars can go fast", "highway"),
    ("a very intelligent person", "genius"),
    ("a very smart person", "genius"),
    ("something you use to measure your temperature", "thermometer"),
    ("a dark time of day", "night"),
    ("medieval social hierarchy where peasants and vassals served lords", "feudalism"),
    ("very cute", "adorable"),
    ("")
]

In [96]:
words, idx, probs = getPredFromDesc(model, 'employee at a circus', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('circus', 0.994208812713623),
 ('clown', 0.9391021728515625),
 ('clowning', 0.7502662539482117),
 ('stampede', 0.6639121174812317),
 ('busker', 0.6302651762962341),
 ('barker', 0.6152220368385315),
 ('somebody', 0.6016197800636292),
 ('showman', 0.591995358467102),
 ('jockey', 0.5682175159454346),
 ('ringer', 0.5543749928474426),
 ('ride', 0.512154221534729),
 ('scouter', 0.5062361359596252),
 ('performer', 0.491067111492157),
 ('handler', 0.4897576868534088),
 ('bulldozer', 0.45466405153274536),
 ('bullfighter', 0.4532380700111389),
 ('valet', 0.44002220034599304),
 ('bullhorn', 0.42381030321121216),
 ('theatergoer', 0.4092405438423157),
 ('exhibitor', 0.38998040556907654),
 ('gypsy', 0.3890130817890167),
 ('clownish', 0.3867891728878021),
 ('rider', 0.38242051005363464),
 ('crowder', 0.3685600459575653),
 ('flapper', 0.35859018564224243),
 ('comedian', 0.35737788677215576),
 ('equestrian', 0.3531621992588043),
 ('show', 0.34194135665893555),
 ('pincher', 0.3402089476585388),
 ('man

In [103]:
words, idx, probs = getPredFromDesc(model, 'a type of tree', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('linden', 0.8909744024276733),
 ('tree', 0.8873805403709412),
 ('stocked', 0.8616869449615479),
 ('lime', 0.8452991247177124),
 ('teakwood', 0.8133090138435364),
 ('chestnut', 0.7823934555053711),
 ('maple', 0.7769330143928528),
 ('plum', 0.7707264423370361),
 ('wood', 0.7617564797401428),
 ('shrub', 0.7613511681556702),
 ('hop', 0.760560154914856),
 ('cottonwood', 0.7386062145233154),
 ('cypress', 0.7183142900466919),
 ('oak', 0.7154815793037415),
 ('spruce', 0.6588332056999207),
 ('pinewood', 0.6537127494812012),
 ('olive', 0.6344695687294006),
 ('brazilwood', 0.6344084739685059),
 ('ebony', 0.6217503547668457),
 ('boxwood', 0.6095532178878784),
 ('mahogany', 0.5943478941917419),
 ('treed', 0.5928893685340881),
 ('coffee', 0.5916092395782471),
 ('barking', 0.5861811637878418),
 ('tea', 0.5833937525749207),
 ('evergreen', 0.5778738260269165),
 ('hardwood', 0.5618163347244263),
 ('eucalyptus', 0.5394592881202698),
 ('cork', 0.5355424284934998),
 ('oakleaf', 0.53113853931427),
 ('plum

In [203]:
words, idx, probs = getPredFromDesc(model, 'a road on which cars can go quickly without stopping', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('superhighway', 0.9432564377784729),
 ('freeway', 0.831933319568634),
 ('expressway', 0.8186946511268616),
 ('road', 0.7954754829406738),
 ('highway', 0.7796876430511475),
 ('speed', 0.7761233448982239),
 ('haste', 0.7701517343521118),
 ('move', 0.7681809067726135),
 ('race', 0.7569608092308044),
 ('speedway', 0.7484490275382996),
 ('travel', 0.7474307417869568),
 ('drive', 0.7294769287109375),
 ('parkway', 0.7202438712120056),
 ('beltway', 0.7067565321922302),
 ('roadster', 0.7067355513572693),
 ('hurry', 0.691335916519165),
 ('motor', 0.6897876858711243),
 ('maneuver', 0.6633190512657166),
 ('zip', 0.6555811762809753),
 ('shoot', 0.6417109966278076),
 ('street', 0.6287841796875),
 ('driveway', 0.6164090037345886),
 ('runway', 0.6096982955932617),
 ('taxiway', 0.6056353449821472),
 ('route', 0.6033530235290527),
 ('straightway', 0.6019049286842346),
 ('lanes', 0.600609540939331),
 ('roads', 0.5971937775611877),
 ('flyway', 0.5926243662834167),
 ('chase', 0.5788533687591553),
 ('auto

In [193]:
words, idx, probs = getPredFromDesc(model, 'a very intelligent person', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('genius', 0.9899966716766357),
 ('brains', 0.9687216281890869),
 ('brain', 0.9671897888183594),
 ('intellect', 0.9648787975311279),
 ('einstein', 0.9598327279090881),
 ('brainstem', 0.9438502788543701),
 ('intelligence', 0.9433559775352478),
 ('brilliant', 0.9343078136444092),
 ('intelligent', 0.9325506091117859),
 ('brainstorming', 0.9114322066307068),
 ('brainy', 0.8861953020095825),
 ('brainpower', 0.8857043981552124),
 ('somebody', 0.858609676361084),
 ('intellectual', 0.852242648601532),
 ('brainiac', 0.84941565990448),
 ('braincase', 0.8284827470779419),
 ('smart', 0.8238560557365417),
 ('person', 0.8067717552185059),
 ('brilliance', 0.8000427484512329),
 ('wit', 0.7959452867507935),
 ('expert', 0.7734217643737793),
 ('brainchild', 0.7488033175468445),
 ('cypher', 0.7206608653068542),
 ('smartness', 0.7182350158691406),
 ('psychic', 0.7109935283660889),
 ('spark', 0.7096468806266785),
 ('brainwash', 0.6831423044204712),
 ('soul', 0.6703240275382996),
 ('expertness', 0.650392055

In [106]:
words, idx, probs = getPredFromDesc(model, 'a very smart person', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('smart', 0.992626965045929),
 ('smarty', 0.9777527451515198),
 ('smartness', 0.976097583770752),
 ('smarting', 0.9745392799377441),
 ('smarts', 0.9617751836776733),
 ('shoot', 0.9205918908119202),
 ('hothead', 0.9074836373329163),
 ('hotdog', 0.9038079977035522),
 ('flash', 0.871448814868927),
 ('hotshot', 0.8312650322914124),
 ('hotchpotch', 0.7926220893859863),
 ('shot', 0.7819179892539978),
 ('flasher', 0.7720305323600769),
 ('sharpshooter', 0.7085450291633606),
 ('smartly', 0.7082189321517944),
 ('dart', 0.7061830163002014),
 ('darts', 0.7058963775634766),
 ('flashy', 0.688761830329895),
 ('blast', 0.6872258186340332),
 ('sharp', 0.6670240759849548),
 ('hot', 0.6640612483024597),
 ('hotpot', 0.5960732102394104),
 ('ache', 0.5757908225059509),
 ('hotfoot', 0.5692092776298523),
 ('genius', 0.5527622103691101),
 ('darter', 0.5455978512763977),
 ('hunger', 0.540242075920105),
 ('torpedo', 0.5259228348731995),
 ('blaster', 0.5232740044593811),
 ('bullet', 0.5100688934326172),
 ('burn'

In [197]:
words, idx, probs = getPredFromDesc(model, 'the opposite of being happy', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('unhappiness', 0.9564859867095947),
 ('happiness', 0.9440613389015198),
 ('sadness', 0.9340919256210327),
 ('misbehavior', 0.926218569278717),
 ('misanthropy', 0.9002848267555237),
 ('unhappy', 0.8268434405326843),
 ('misery', 0.771747350692749),
 ('happy', 0.7376412749290466),
 ('misanthrope', 0.7179532051086426),
 ('sorrow', 0.6878830194473267),
 ('sorrowing', 0.6854086518287659),
 ('aggravation', 0.6803992390632629),
 ('loneliness', 0.6780133247375488),
 ('discontent', 0.5803792476654053),
 ('complacency', 0.5688847899436951),
 ('sadomasochism', 0.5650791525840759),
 ('lethargic', 0.5404712557792664),
 ('misfortunate', 0.5287728905677795),
 ('sorrowful', 0.5277866721153259),
 ('discoloration', 0.5162428617477417),
 ('miscarry', 0.5112969279289246),
 ('complacence', 0.5023670196533203),
 ('joyousness', 0.4918251931667328),
 ('uneasiness', 0.48316457867622375),
 ('inactivity', 0.47422292828559875),
 ('selfishness', 0.4702393412590027),
 ('moodiness', 0.4695213735103607),
 ('disconce

In [196]:
words, idx, probs = getPredFromDesc(model, 'something you use to measure your temperature', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('thermometer', 0.9870902895927429),
 ('calorimeter', 0.9651489853858948),
 ('temperature', 0.9647976756095886),
 ('heat', 0.8819329738616943),
 ('speedometer', 0.8643578886985779),
 ('pyrometer', 0.8432449698448181),
 ('fahrenheit', 0.8431944847106934),
 ('magnetometer', 0.7803845405578613),
 ('heater', 0.7791886329650879),
 ('refrigerate', 0.7783988118171692),
 ('seismometer', 0.7512844204902649),
 ('gasometer', 0.6834492683410645),
 ('isotherm', 0.6559324860572815),
 ('hygrometer', 0.6509665846824646),
 ('densitometer', 0.644676685333252),
 ('barometer', 0.6405194401741028),
 ('refractor', 0.6307488679885864),
 ('heating', 0.6181195974349976),
 ('aether', 0.5048888921737671),
 ('gage', 0.5039142966270447),
 ('interferometer', 0.4987775385379791),
 ('rheometer', 0.49652099609375),
 ('cold', 0.4597109258174896),
 ('quicklime', 0.44842100143432617),
 ('spirometer', 0.4475388824939728),
 ('tachometer', 0.44367653131484985),
 ('refrigerant', 0.4419322609901428),
 ('isoelectric', 0.43211

In [102]:
words, idx, probs = getPredFromDesc(model, 'a dark time of day', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('dusk', 0.9902498126029968),
 ('night', 0.9843398332595825),
 ('twilight', 0.983363151550293),
 ('dau', 0.970775306224823),
 ('darkness', 0.968502938747406),
 ('nightfall', 0.9646450281143188),
 ('day', 0.9296711683273315),
 ('nightly', 0.9182392954826355),
 ('nighttime', 0.9147279858589172),
 ('hour', 0.8999171257019043),
 ('moon', 0.8985822796821594),
 ('nightgown', 0.8807250261306763),
 ('da', 0.8677366375923157),
 ('dusky', 0.8567270636558533),
 ('eve', 0.8546727299690247),
 ('tide', 0.8473078608512878),
 ('daylight', 0.8041820526123047),
 ('period', 0.7880057692527771),
 ('evening', 0.7875092029571533),
 ('forenoon', 0.7712737917900085),
 ('none', 0.7663474678993225),
 ('dak', 0.7646946907043457),
 ('daytime', 0.7641853094100952),
 ('tomorrow', 0.7562503814697266),
 ('midnight', 0.7519458532333374),
 ('morrow', 0.7462226152420044),
 ('dace', 0.7396567463874817),
 ('dah', 0.7387454509735107),
 ('daunt', 0.7087877988815308),
 ('morning', 0.6945293545722961),
 ('afternoon', 0.69389

In [281]:
words, idx, probs = getPredFromDesc(model, 'medieval social hierarchy where peasants and vassals served lords', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('oligarchy', 0.9340060353279114),
 ('aristocracy', 0.9111836552619934),
 ('nobility', 0.8552895188331604),
 ('government', 0.8252381682395935),
 ('order', 0.8251005411148071),
 ('feudalism', 0.8089253902435303),
 ('celibacy', 0.7787441611289978),
 ('homage', 0.7786471843719482),
 ('dukedom', 0.761428952217102),
 ('familly', 0.7329322099685669),
 ('hierarchy', 0.7318239808082581),
 ('autocracy', 0.7199180722236633),
 ('peerage', 0.7192471027374268),
 ('lords', 0.7177726626396179),
 ('noblesse', 0.70628422498703),
 ('commune', 0.6935475468635559),
 ('obedience', 0.6930590271949768),
 ('lordship', 0.6775382161140442),
 ('court', 0.65690678358078),
 ('bureaucracy', 0.6550346612930298),
 ('principality', 0.6500912308692932),
 ('domain', 0.6483148336410522),
 ('fiefdom', 0.636286735534668),
 ('feudal', 0.6148019433021545),
 ('inquisition', 0.6092568039894104),
 ('chapter', 0.6073670983314514),
 ('heterosis', 0.6024513840675354),
 ('regency', 0.5921065807342529),
 ('lord', 0.585498154163360

In [254]:
words, idx, probs = getPredFromDesc(model, 'deep learning', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('study', 0.9601989388465881),
 ('learning', 0.910089910030365),
 ('deep', 0.9032861590385437),
 ('indoctrination', 0.8831983208656311),
 ('education', 0.8816927075386047),
 ('learn', 0.8657744526863098),
 ('catechumen', 0.848861575126648),
 ('deepness', 0.8403034210205078),
 ('analysis', 0.8245221972465515),
 ('science', 0.8152496814727783),
 ('catechism', 0.8136394619941711),
 ('instructive', 0.8025436401367188),
 ('drill', 0.8013206124305725),
 ('lyceum', 0.8002776503562927),
 ('graphology', 0.7923668026924133),
 ('depth', 0.7837069630622864),
 ('concentration', 0.7835208773612976),
 ('deepening', 0.7703989744186401),
 ('instruct', 0.7556796669960022),
 ('training', 0.7527759075164795),
 ('instruction', 0.7248058915138245),
 ('teach', 0.7106432914733887),
 ('exercise', 0.7100815773010254),
 ('catechol', 0.6653569936752319),
 ('profound', 0.6645328402519226),
 ('analyst', 0.6405267715454102),
 ('studying', 0.6402386426925659),
 ('cryptology', 0.635059118270874),
 ('seminary', 0.6283

In [204]:
words, idx, probs = getPredFromDesc(model, 'when somebody gives something to you and afterwards you have it', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('give', 0.9956091046333313),
 ('gift', 0.9803848266601562),
 ('giving', 0.9600266218185425),
 ('take', 0.9423472285270691),
 ('giver', 0.9335689544677734),
 ('treat', 0.9269536733627319),
 ('deal', 0.9230650067329407),
 ('donate', 0.9195953607559204),
 ('have', 0.9164474010467529),
 ('use', 0.8981136679649353),
 ('gifting', 0.8838661313056946),
 ('receive', 0.8785077929496765),
 ('administer', 0.876017153263092),
 ('delivery', 0.8690775632858276),
 ('drink', 0.8679701685905457),
 ('change', 0.8615391850471497),
 ('present', 0.8581041097640991),
 ('handing', 0.8442950248718262),
 ('make', 0.8227622509002686),
 ('distribute', 0.8207599520683289),
 ('supply', 0.8196119666099548),
 ('deliver', 0.8173937797546387),
 ('initiate', 0.7992436289787292),
 ('accept', 0.7902272939682007),
 ('offer', 0.7816919684410095),
 ('share', 0.7715913653373718),
 ('get', 0.7598951458930969),
 ('payment', 0.7582589387893677),
 ('treats', 0.7529175281524658),
 ('communicate', 0.7516933679580688),
 ('provide'

In [207]:
words, idx, probs = getPredFromDesc(model, 'when someone you trust does something that breaks your trust', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('trust', 0.9927996397018433),
 ('trusty', 0.9561113119125366),
 ('betrayal', 0.9532467722892761),
 ('cheat', 0.8850262761116028),
 ('betray', 0.8843281865119934),
 ('trustworthy', 0.871504545211792),
 ('faith', 0.8544620871543884),
 ('gamble', 0.8512130975723267),
 ('deception', 0.8467381596565247),
 ('compromised', 0.8375467658042908),
 ('distrust', 0.8366779685020447),
 ('cheated', 0.8210803866386414),
 ('venture', 0.8131632208824158),
 ('trusting', 0.7965644001960754),
 ('fuck', 0.7916842699050903),
 ('trustor', 0.7729158997535706),
 ('trustful', 0.7722510695457458),
 ('job', 0.7710670232772827),
 ('confidence', 0.7629819512367249),
 ('manipulate', 0.7540766596794128),
 ('misconduct', 0.7507216334342957),
 ('fraud', 0.7440340518951416),
 ('false', 0.7437525987625122),
 ('pretend', 0.7307678461074829),
 ('judas', 0.7197650074958801),
 ('confide', 0.7180337309837341),
 ('believe', 0.7108274698257446),
 ('bets', 0.7078320384025574),
 ('pawn', 0.7063281536102295),
 ('rely', 0.70381033

In [210]:
words, idx, probs = getPredFromDesc(model, 'to help someone else learn', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('teach', 0.9946306943893433),
 ('educate', 0.9796843528747559),
 ('learn', 0.9754385948181152),
 ('tutor', 0.9596310257911682),
 ('train', 0.956894040107727),
 ('prepare', 0.9506540298461914),
 ('education', 0.9429842829704285),
 ('training', 0.927601158618927),
 ('trainer', 0.9044510722160339),
 ('instruction', 0.8917723298072815),
 ('trainee', 0.8876506090164185),
 ('develop', 0.8800643682479858),
 ('help', 0.8763608932495117),
 ('drill', 0.8746880292892456),
 ('instruct', 0.8685868978500366),
 ('tutoring', 0.8648293018341064),
 ('aid', 0.8603531122207642),
 ('inform', 0.8338075280189514),
 ('preparation', 0.8322434425354004),
 ('coach', 0.8235621452331543),
 ('groom', 0.8183148503303528),
 ('trained', 0.8138930797576904),
 ('instructive', 0.8135185837745667),
 ('improve', 0.8075883984565735),
 ('school', 0.8013655543327332),
 ('mentor', 0.798274576663971),
 ('support', 0.7881101369857788),
 ('tutorship', 0.7843430638313293),
 ('doctor', 0.7829017043113708),
 ('initiate', 0.7770995

In [279]:
words, idx, probs = getPredFromDesc(model, 'to make something normal', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('normalize', 0.9985143542289734),
 ('normalization', 0.9941756725311279),
 ('normal', 0.9729980230331421),
 ('regularize', 0.9526902437210083),
 ('modify', 0.9470736384391785),
 ('naturalise', 0.9366539120674133),
 ('change', 0.9286817908287048),
 ('amend', 0.9210201501846313),
 ('naturalize', 0.9168502688407898),
 ('secularize', 0.8821982741355896),
 ('alter', 0.8583032488822937),
 ('civilize', 0.8575515747070312),
 ('normalcy', 0.8489350080490112),
 ('regularization', 0.8364447951316833),
 ('modernize', 0.8240464329719543),
 ('acclimatize', 0.8122901916503906),
 ('normale', 0.7942690849304199),
 ('moralise', 0.7790758013725281),
 ('sublimate', 0.768255352973938),
 ('assimilate', 0.7592564225196838),
 ('standardize', 0.750598132610321),
 ('naturalization', 0.736865222454071),
 ('normality', 0.7347490191459656),
 ('specialize', 0.7294057607650757),
 ('improve', 0.682942807674408),
 ('rationalize', 0.6667715907096863),
 ('secularization', 0.6553975939750671),
 ('acclimate', 0.63347911

In [133]:
'''
Training beginning!

Epoch 1, Train Loss: 1291.271240234375: 100%
21117/21117 [1:51:20<00:00, 2.56it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 1, Val Loss: 1454.054443359375: 100%
2372/2372 [04:24<00:00, 8.54it/s]
Test Loss: 687.0896606445312: 100%
16/16 [00:01<00:00, 9.53it/s]

seen_test_loss: tensor(672.2285, device='cuda:0')
seen_test_acc1: 0.162
seen_test_acc10: 0.512
seen_test_acc100: 0.764
seen_test_rank_median: tensor(8.)
seen_test_rank_variance tensor(310.8351)

Test Loss: 987.5115356445312: 100%
16/16 [00:01<00:00, 11.16it/s]

unseen_test_loss: tensor(980.4160, device='cuda:0')
unseen_test_acc1: 0.112
unseen_test_acc10: 0.33
unseen_test_acc100: 0.59
unseen_test_rank_median: tensor(47.)
unseen_test_rank_variance tensor(401.8403)

Test Loss: 1049.1865234375: 100%
7/7 [00:00<00:00, 14.02it/s]

desc_test_loss: tensor(980.0026, device='cuda:0')
desc_test_acc1: 0.275
desc_test_acc10: 0.725
desc_test_acc100: 0.925
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(70.2536)

Epoch 2, Train Loss: 879.2266235351562: 100%
21117/21117 [1:51:54<00:00, 3.18it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 2, Val Loss: 1426.1895751953125: 100%
2372/2372 [04:24<00:00, 10.73it/s]
Test Loss: 488.4723205566406: 100%
16/16 [00:01<00:00, 9.56it/s]

seen_test_loss: tensor(472.8967, device='cuda:0')
seen_test_acc1: 0.214
seen_test_acc10: 0.698
seen_test_acc100: 0.894
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(215.4854)

Test Loss: 947.445556640625: 100%
16/16 [00:01<00:00, 11.18it/s]

unseen_test_loss: tensor(940.1313, device='cuda:0')
unseen_test_acc1: 0.102
unseen_test_acc10: 0.362
unseen_test_acc100: 0.636
unseen_test_rank_median: tensor(30.)
unseen_test_rank_variance tensor(380.8903)

Test Loss: 944.8580322265625: 100%
7/7 [00:00<00:00, 13.99it/s]

desc_test_loss: tensor(878.2668, device='cuda:0')
desc_test_acc1: 0.23
desc_test_acc10: 0.715
desc_test_acc100: 0.91
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(125.0875)

Epoch 3, Train Loss: 686.5811157226562: 100%
21117/21117 [1:52:03<00:00, 2.94it/s]
Epoch 3, Val Loss: 1403.363525390625: 100%
2372/2372 [04:23<00:00, 9.02it/s]
Test Loss: 431.358642578125: 100%
16/16 [00:01<00:00, 9.43it/s]

seen_test_loss: tensor(417.4239, device='cuda:0')
seen_test_acc1: 0.296
seen_test_acc10: 0.772
seen_test_acc100: 0.92
seen_test_rank_median: tensor(2.)
seen_test_rank_variance tensor(197.4330)

Test Loss: 913.62841796875: 100%
16/16 [00:01<00:00, 11.11it/s]

unseen_test_loss: tensor(902.2078, device='cuda:0')
unseen_test_acc1: 0.112
unseen_test_acc10: 0.382
unseen_test_acc100: 0.654
unseen_test_rank_median: tensor(22.)
unseen_test_rank_variance tensor(364.8770)

Test Loss: 940.71826171875: 100%
7/7 [00:00<00:00, 14.08it/s]

desc_test_loss: tensor(870.6017, device='cuda:0')
desc_test_acc1: 0.26
desc_test_acc10: 0.675
desc_test_acc100: 0.905
desc_test_rank_median: tensor(4.)
desc_test_rank_variance tensor(185.5576)

Epoch 4, Train Loss: 560.3533935546875: 100%
21117/21117 [1:51:45<00:00, 3.08it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 6, Val Loss: 1457.1527099609375: 100%
2372/2372 [04:23<00:00, 8.24it/s]
Test Loss: 349.925048828125: 100%
16/16 [00:01<00:00, 9.44it/s]

seen_test_loss: tensor(335.5642, device='cuda:0')
seen_test_acc1: 0.362
seen_test_acc10: 0.86
seen_test_acc100: 0.95
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(190.8087)

Test Loss: 951.2265014648438: 100%
16/16 [00:01<00:00, 11.15it/s]

unseen_test_loss: tensor(932.3807, device='cuda:0')
unseen_test_acc1: 0.112
unseen_test_acc10: 0.434
unseen_test_acc100: 0.696
unseen_test_rank_median: tensor(16.)
unseen_test_rank_variance tensor(358.9769)

Test Loss: 1029.6146240234375: 100%
7/7 [00:00<00:00, 14.20it/s]

desc_test_loss: tensor(948.0621, device='cuda:0')
desc_test_acc1: 0.215
desc_test_acc10: 0.63
desc_test_acc100: 0.87
desc_test_rank_median: tensor(5.)
desc_test_rank_variance tensor(235.2347)

Epoch 7, Train Loss: 355.93682861328125: 4%
902/21117 [04:48<2:05:44, 2.68it/s]

Training beginning!

Epoch 7, Train Loss: 354.0775146484375: 100%
21117/21117 [1:51:25<00:00, 3.13it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

wandb: 500 encountered ({"errors":[{"message":"Error 1040: Too many connections","path":["project"]}],"data":{"project":null}}), retrying request
wandb: Network error resolved after 0:00:56.213574, resuming normal operation.

Epoch 7, Val Loss: 1486.6312255859375: 100%
2372/2372 [04:23<00:00, 8.76it/s]
Test Loss: 338.9664611816406: 100%
16/16 [00:01<00:00, 9.56it/s]

seen_test_loss: tensor(324.8030, device='cuda:0')
seen_test_acc1: 0.378
seen_test_acc10: 0.866
seen_test_acc100: 0.948
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(174.6115)

Test Loss: 961.43212890625: 100%
16/16 [00:01<00:00, 11.15it/s]

unseen_test_loss: tensor(949.6530, device='cuda:0')
unseen_test_acc1: 0.11
unseen_test_acc10: 0.45
unseen_test_acc100: 0.71
unseen_test_rank_median: tensor(14.)
unseen_test_rank_variance tensor(355.9951)

Test Loss: 1055.6123046875: 100%
7/7 [00:00<00:00, 14.06it/s]

desc_test_loss: tensor(986.2967, device='cuda:0')
desc_test_acc1: 0.225
desc_test_acc10: 0.625
desc_test_acc100: 0.885
desc_test_rank_median: tensor(4.)
desc_test_rank_variance tensor(241.3334)

Epoch 8, Train Loss: 317.3240051269531: 100%
21117/21117 [1:51:37<00:00, 3.36it/s]
Epoch 8, Val Loss: 1510.5469970703125: 100%
2372/2372 [04:21<00:00, 8.90it/s]
Test Loss: 326.8212585449219: 100%
16/16 [00:01<00:00, 9.56it/s]

seen_test_loss: tensor(312.5106, device='cuda:0')
seen_test_acc1: 0.388
seen_test_acc10: 0.876
seen_test_acc100: 0.95
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(169.8056)

Test Loss: 977.635498046875: 100%
16/16 [00:01<00:00, 11.21it/s]

unseen_test_loss: tensor(968.0921, device='cuda:0')
unseen_test_acc1: 0.12
unseen_test_acc10: 0.448
unseen_test_acc100: 0.714
unseen_test_rank_median: tensor(14.)
unseen_test_rank_variance tensor(355.5157)

Test Loss: 1053.174072265625: 100%
7/7 [00:00<00:00, 14.09it/s]

desc_test_loss: tensor(979.2820, device='cuda:0')
desc_test_acc1: 0.21
desc_test_acc10: 0.63
desc_test_acc100: 0.88
desc_test_rank_median: tensor(4.)
desc_test_rank_variance tensor(249.2989)

Epoch 9, Train Loss: 290.9953308105469: 100%
21117/21117 [1:51:28<00:00, 2.83it/s]
Epoch 9, Val Loss: 1540.5982666015625: 100%
2372/2372 [04:21<00:00, 9.60it/s]
Test Loss: 324.8987121582031: 100%
16/16 [00:01<00:00, 9.66it/s]

seen_test_loss: tensor(310.3402, device='cuda:0')
seen_test_acc1: 0.4
seen_test_acc10: 0.888
seen_test_acc100: 0.95
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(176.8679)

Test Loss: 988.954345703125: 100%
16/16 [00:01<00:00, 11.23it/s]

unseen_test_loss: tensor(977.1419, device='cuda:0')
unseen_test_acc1: 0.114
unseen_test_acc10: 0.454
unseen_test_acc100: 0.722
unseen_test_rank_median: tensor(13.)
unseen_test_rank_variance tensor(359.3176)

Test Loss: 1111.536376953125: 100%
7/7 [00:00<00:00, 14.21it/s]

desc_test_loss: tensor(1031.1094, device='cuda:0')
desc_test_acc1: 0.17
desc_test_acc10: 0.62
desc_test_acc100: 0.88
desc_test_rank_median: tensor(4.)
desc_test_rank_variance tensor(260.9289)

Epoch 10, Train Loss: 273.5261535644531: 84%
17674/21117 [1:33:12<17:46, 3.23it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Test Loss: 324.5528869628906: 100%
16/16 [00:01<00:00, 9.67it/s]

seen_test_loss: tensor(309.9997, device='cuda:0')
seen_test_acc1: 0.384
seen_test_acc10: 0.892
seen_test_acc100: 0.95
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(175.3214)

Test Loss: 991.7783813476562: 100%
16/16 [00:01<00:00, 11.24it/s]

unseen_test_loss: tensor(979.3575, device='cuda:0')
unseen_test_acc1: 0.124
unseen_test_acc10: 0.444
unseen_test_acc100: 0.726
unseen_test_rank_median: tensor(13.)
unseen_test_rank_variance tensor(358.8594)

Test Loss: 1110.7630615234375: 100%
7/7 [00:00<00:00, 14.23it/s]

desc_test_loss: tensor(1029.1157, device='cuda:0')
desc_test_acc1: 0.205
desc_test_acc10: 0.64
desc_test_acc100: 0.875
desc_test_rank_median: tensor(5.)
desc_test_rank_variance tensor(262.6233)

'''
None

In [58]:
from transformers import pipeline

In [59]:
m = pipeline('fill-mask')

In [63]:
m("<mask> is a programmer", top_k=100)

[{'sequence': 'He is a programmer',
  'score': 0.017467252910137177,
  'token': 894,
  'token_str': 'He'},
 {'sequence': 'Daniel is a programmer',
  'score': 0.008494171313941479,
  'token': 18322,
  'token_str': 'Daniel'},
 {'sequence': ' who is a programmer',
  'score': 0.006517456378787756,
  'token': 54,
  'token_str': ' who'},
 {'sequence': 'Who is a programmer',
  'score': 0.006088791415095329,
  'token': 12375,
  'token_str': 'Who'},
 {'sequence': 'David is a programmer',
  'score': 0.005976406391710043,
  'token': 8773,
  'token_str': 'David'},
 {'sequence': 'James is a programmer',
  'score': 0.005649610422551632,
  'token': 18031,
  'token_str': 'James'},
 {'sequence': 'Craig is a programmer',
  'score': 0.005570814944803715,
  'token': 39230,
  'token_str': 'Craig'},
 {'sequence': 'Smith is a programmer',
  'score': 0.005305597558617592,
  'token': 14124,
  'token_str': 'Smith'},
 {'sequence': 'Cook is a programmer',
  'score': 0.0049401517026126385,
  'token': 32963,
  'tok

In [239]:
test_data_unseen[442]

{'word': 'oust',
 'lexnames': ['verb.social'],
 'root_affix': [],
 'sememes': ['dismiss', 'expel'],
 'definitions': 'remove and replace'}

In [266]:
for i in range(len(train_data)):
    if train_data[i]['word'] == 'bert':
        print(train_data[i])

{'word': 'bert', 'lexnames': [], 'root_affix': [], 'sememes': [], 'definitions': 'a diminutive form of male given names containing the element bert also used as a formal given name'}


{'synonyms': [],
 'antonyms': [],
 'related_forms': [],
 'hyponyms': [],
 'hypernyms': []}