In [1]:
import torch
from torch import nn
from torch.utils import data
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.notebook import tqdm

import sys
import datetime

sys.path.append('../code')
from dataset import get_data, MaskedDataset, make_vocab, read_json

from transformers import (
    AdamW, get_linear_schedule_with_warmup
)

from models import MaskedRDModel

In [2]:
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')

Using cache found in /home/ubuntu/.cache/torch/hub/huggingface_pytorch-transformers_master


In [3]:
d = get_data('../wantwords-english-baseline/data', word2vec=False)

Loading data...
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [4]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

In [6]:
mask_size = 5
target_matrix, target2idx, idx2target = make_vocab(d, tokenizer, mask_size=mask_size)

In [7]:
model = MaskedRDModel.from_pretrained('bert-base-uncased')
model.initialize(mask_size=mask_size, multilabel=True, ww_vocab_size=len(target2idx), pos_weight=5)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing MaskedRDModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing MaskedRDModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MaskedRDModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
# target2idx maps target words to indices
# target_matrix maps target indices to bpe sequences, padded/truncated to mask_size
target2idx['book'], target_matrix[target2idx['book']], idx2target[target2idx['book']]

(16187, tensor([2338,  103,  103,  103,  103]), 'book')

In [9]:
wn_data = read_json('../data/wn_data.json')
wn_categories = ['synonyms', 'hyponyms', 'hypernyms', 'related_forms']

In [10]:
train_dataset = MaskedDataset(train_data + train_data_def, tokenizer, target2idx, wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)

In [11]:
dev_dataset = MaskedDataset(dev_data, tokenizer, target2idx, 
                            wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_seen = MaskedDataset(test_data_seen, tokenizer, target2idx, 
                                  wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_unseen = MaskedDataset(test_data_unseen, tokenizer, target2idx, 
                                    wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)
test_dataset_desc = MaskedDataset(test_data_desc, tokenizer, target2idx, 
                                  wn_data=wn_data, wn_categories=wn_categories, mask_size=mask_size)

In [12]:
index = 1593

[idx2target[x] for x in dev_dataset[index][-1].coalesce().indices().squeeze(0)], idx2target[dev_dataset[index][1]]

(['classic',
  'authorized',
  'importance',
  'authoritative',
  'classical',
  'definitive',
  'important'],
 'authoritative')

In [13]:
batch_size = 40
num_workers = 0

loader_params = {
    'pin_memory': False,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'collate_fn': train_dataset.collate_fn
}

train_loader = data.DataLoader(train_dataset, **{'shuffle': True, **loader_params})
dev_loader = data.DataLoader(dev_dataset, **{'shuffle': True, **loader_params})
test_loader_seen = data.DataLoader(test_dataset_seen, **{'shuffle': False, **loader_params})
test_loader_unseen = data.DataLoader(test_dataset_unseen, **{'shuffle': False, **loader_params})
test_loader_desc = data.DataLoader(test_dataset_desc, **{'shuffle': False, **loader_params})

In [14]:
# Starting from epoch 2
# weight_gt = 10
# epochs = 9
# lr = 1e-5 * 0.905
# optim = AdamW(model.parameters(), lr=lr)
# scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=1, 
#                                             num_training_steps=len(train_loader) * epochs)
# epoch = 0
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
epochs = 10

lr = 1e-5
optim = AdamW(model.parameters(), lr=lr)

warmup_duration = 0.05 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

epoch = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scaler = GradScaler()

In [16]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheduler = type(scheduler).__name__
# config.warmup_duration = warmup_duration

# wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)


In [17]:
target_matrix = target_matrix.to(device)

In [18]:
model = model.to(device)

In [19]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [20]:
def test(loader, name, log=False):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = test_rank_median = test_rank_variance = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, (x,y, wn_ids) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x = x.to(device)
                attention_mask = (x != train_dataset.pad_id)
                y = y.to(device)
                wn_ids = wn_ids.to_dense().to(device).float()

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask, 
                                  target_matrix=target_matrix, ground_truth=y,
                                  wn_ids=wn_ids, weight_gt=weight_gt)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(out, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
    return {
            f'{name}_test_loss': test_loss,
            f'{name}_test_acc1': test_acc1,
            f'{name}_test_acc10': test_acc10,
            f'{name}_test_acc100': test_acc100,
            f'{name}_test_rank_median': median,
            f'{name}_test_rank_variance': var
           }
    

In [21]:
inc = 10
losses = []
print('Training beginning!')

for epoch in range(epoch, epochs):
    # Training
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, (x, y, wn_ids) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x = x.to(device)
            attention_mask = (x != train_dataset.pad_id)
            y = y.to(device)
            wn_ids = wn_ids.to_dense().to(device).float()
            
            loss, out = model(input_ids=x, attention_mask=attention_mask, 
                              target_matrix=target_matrix, ground_truth=y, 
                              wn_ids=wn_ids, weight_gt=weight_gt)
            
#             scaler.scale(loss).backward()
            loss.backward()
            
#             scaler.unscale_(optim)
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            
#             scaler.step(optim)
            optim.step()
#             scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss, attention_mask
            
    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save(model, f)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    try:
        with torch.no_grad():
            with tqdm(total=len(dev_loader)) as pbar:
                for i, (x, y, wn_ids) in enumerate(dev_loader):
                    if i % inc == 0 and i != 0:
                        display_loss = val_loss / i
                        pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                    x = x.to(device)
                    attention_mask = (x != train_dataset.pad_id)
                    y = y.to(device)
                    wn_ids = wn_ids.to_dense().to(device).float()

    #                 with autocast():
                    loss, out = model(input_ids=x, attention_mask=attention_mask, 
                                      target_matrix=target_matrix, ground_truth=y,
                                      wn_ids=wn_ids, weight_gt=weight_gt)

                    val_loss += loss.detach()

                    pbar.update(1)                

                    result, indices = torch.topk(out, k=100, dim=-1, largest=True, sorted=True)

                    acc1, acc10, acc100 = evaluate(indices, y)
                    val_acc1 += acc1
                    val_acc10 += acc10
                    val_acc100 += acc100

                    del x, y, out, loss
    except:
        print('Error encountered, aborting validation!')
    
    wandb.log({
        'train_loss': train_loss / len(train_loader),
        'val_loss': val_loss / len(dev_loader),
        'val_acc1': val_acc1 / len(dev_loader),
        'val_acc10': val_acc10 / len(dev_loader),
        'val_acc100': val_acc100 / len(dev_loader),
        **test(test_loader_seen, 'seen'),
        **test(test_loader_unseen, 'unseen'),
        **test(test_loader_desc, 'desc')
    })
    

Training beginning!


  0%|          | 0/16893 [00:00<?, ?it/s]

  0%|          | 0/1897 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(345.4165, device='cuda:0')
seen_test_acc1: 0.126
seen_test_acc10: 0.384
seen_test_acc100: 0.598
seen_test_rank_median: tensor(38.)
seen_test_rank_variance tensor(402.9168)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(429.6845, device='cuda:0')
unseen_test_acc1: 0.134
unseen_test_acc10: 0.308
unseen_test_acc100: 0.502
unseen_test_rank_median: tensor(88.)
unseen_test_rank_variance tensor(421.5570)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(603.5402, device='cuda:0')
desc_test_acc1: 0.25
desc_test_acc10: 0.74
desc_test_acc100: 0.965
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(43.4395)


  0%|          | 0/16893 [00:00<?, ?it/s]

  0%|          | 0/1897 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(302.7310, device='cuda:0')
seen_test_acc1: 0.15
seen_test_acc10: 0.46
seen_test_acc100: 0.698
seen_test_rank_median: tensor(12.)
seen_test_rank_variance tensor(349.5500)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(404.0241, device='cuda:0')
unseen_test_acc1: 0.13
unseen_test_acc10: 0.326
unseen_test_acc100: 0.564
unseen_test_rank_median: tensor(52.)
unseen_test_rank_variance tensor(402.2183)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(566.8254, device='cuda:0')
desc_test_acc1: 0.275
desc_test_acc10: 0.715
desc_test_acc100: 0.95
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(64.9332)


  0%|          | 0/16893 [00:00<?, ?it/s]

  0%|          | 0/1897 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(273.3678, device='cuda:0')
seen_test_acc1: 0.166
seen_test_acc10: 0.53
seen_test_acc100: 0.758
seen_test_rank_median: tensor(7.)
seen_test_rank_variance tensor(316.4328)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(388.8712, device='cuda:0')
unseen_test_acc1: 0.118
unseen_test_acc10: 0.352
unseen_test_acc100: 0.602
unseen_test_rank_median: tensor(37.)
unseen_test_rank_variance tensor(391.6554)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(538.9657, device='cuda:0')
desc_test_acc1: 0.265
desc_test_acc10: 0.765
desc_test_acc100: 0.94
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(86.7461)


  0%|          | 0/16893 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/1897 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(230.8112, device='cuda:0')
seen_test_acc1: 0.226
seen_test_acc10: 0.64
seen_test_acc100: 0.854
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(261.0100)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(376.1676, device='cuda:0')
unseen_test_acc1: 0.112
unseen_test_acc10: 0.364
unseen_test_acc100: 0.648
unseen_test_rank_median: tensor(25.)
unseen_test_rank_variance tensor(384.5897)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(515.3072, device='cuda:0')
desc_test_acc1: 0.265
desc_test_acc10: 0.715
desc_test_acc100: 0.91
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(118.9035)


  0%|          | 0/16893 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/1897 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/1897 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(215.3256, device='cuda:0')
seen_test_acc1: 0.252
seen_test_acc10: 0.716
seen_test_acc100: 0.89
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(248.2053)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(374.9537, device='cuda:0')
unseen_test_acc1: 0.136
unseen_test_acc10: 0.38
unseen_test_acc100: 0.664
unseen_test_rank_median: tensor(24.)
unseen_test_rank_variance tensor(379.3346)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(518.9099, device='cuda:0')
desc_test_acc1: 0.3
desc_test_acc10: 0.665
desc_test_acc100: 0.885
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(145.4935)


  0%|          | 0/16893 [00:00<?, ?it/s]

  0%|          | 0/1897 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

seen_test_loss: tensor(214.9702, device='cuda:0')
seen_test_acc1: 0.25
seen_test_acc10: 0.716
seen_test_acc100: 0.894
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(247.2444)


  0%|          | 0/13 [00:00<?, ?it/s]

unseen_test_loss: tensor(375.8186, device='cuda:0')
unseen_test_acc1: 0.134
unseen_test_acc10: 0.384
unseen_test_acc100: 0.656
unseen_test_rank_median: tensor(23.)
unseen_test_rank_variance tensor(379.3580)


  0%|          | 0/5 [00:00<?, ?it/s]

desc_test_loss: tensor(518.6593, device='cuda:0')
desc_test_acc1: 0.305
desc_test_acc10: 0.665
desc_test_acc100: 0.895
desc_test_rank_median: tensor(3.)
desc_test_rank_variance tensor(149.9286)


In [31]:
def getPredFromDesc(model, desc : str, mask_size=5, top_n=10):
    desc = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(desc))
    cls_id, mask_id, sep_id, pad_id = train_dataset.cls_id, train_dataset.mask_id, train_dataset.sep_id, train_dataset.pad_id
    desc_ids = [cls_id] + [mask_id] * mask_size + [sep_id] + desc
    x = torch.tensor(desc_ids).unsqueeze(0).to(device)
    attention_mask = (x != pad_id)
    out = model(input_ids=x, attention_mask=attention_mask, target_matrix=target_matrix)
    result, indices = torch.topk(out, k=top_n, dim=-1, largest=True, sorted=True)
    
    indices = indices[0]
    return [idx2target[i] for i in indices], indices, torch.sigmoid(result).squeeze(0)
    

In [36]:
words, idx, probs = getPredFromDesc(model, 'employee at a circus', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('circus', 0.993190348148346),
 ('clown', 0.884685218334198),
 ('trouper', 0.6588714122772217),
 ('performer', 0.6159197092056274),
 ('showman', 0.5994367003440857),
 ('clowning', 0.5536590814590454),
 ('bullhorn', 0.47201165556907654),
 ('troupe', 0.3969727158546448),
 ('comedian', 0.35517677664756775),
 ('busker', 0.32383280992507935),
 ('clownish', 0.2810075283050537),
 ('cager', 0.25143513083457947),
 ('host', 0.25134822726249695),
 ('punster', 0.23134440183639526),
 ('soul', 0.2298755645751953),
 ('equestrian', 0.2218306064605713),
 ('monkey', 0.20849783718585968),
 ('bullfighter', 0.20050320029258728),
 ('jobber', 0.1975749135017395),
 ('picador', 0.19386501610279083),
 ('supervillain', 0.18457858264446259),
 ('show', 0.18442580103874207),
 ('journeyman', 0.1825648993253708),
 ('nonaggression', 0.1777178943157196),
 ('prizefighter', 0.17524780333042145),
 ('exhibitor', 0.1655634045600891),
 ('exhibition', 0.1646496206521988),
 ('comic', 0.16271696984767914),
 ('squealer', 0.1534

In [37]:
words, idx, probs = getPredFromDesc(model, 'type of tree', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('tree', 0.8987012505531311),
 ('teakwood', 0.848529040813446),
 ('wood', 0.8431754112243652),
 ('chestnut', 0.8073810935020447),
 ('sycamore', 0.724624514579773),
 ('treed', 0.6613193154335022),
 ('satinwood', 0.6488152742385864),
 ('linden', 0.6279645562171936),
 ('ebony', 0.570404589176178),
 ('beechwood', 0.5560390949249268),
 ('mahogany', 0.5412676930427551),
 ('conifer', 0.5391260981559753),
 ('boxwood', 0.5335785150527954),
 ('maple', 0.529134213924408),
 ('rosewood', 0.5242615938186646),
 ('dogwood', 0.5204063653945923),
 ('brazilwood', 0.5168402791023254),
 ('cottonwood', 0.49642056226730347),
 ('bayberry', 0.491065114736557),
 ('tea', 0.47932055592536926),
 ('balata', 0.4710242748260498),
 ('logwood', 0.4409780502319336),
 ('nee', 0.43937623500823975),
 ('lime', 0.4367316961288452),
 ('cedarwood', 0.42682862281799316),
 ('fir', 0.42423343658447266),
 ('balsa', 0.42195358872413635),
 ('oak', 0.4087488055229187),
 ('casuarina', 0.4005316495895386),
 ('hazel', 0.381239563226699

In [38]:
words, idx, probs = getPredFromDesc(model, 'a road on which cars can go fast', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('road', 0.9543246030807495),
 ('drive', 0.9403229355812073),
 ('raceway', 0.9393033981323242),
 ('parkway', 0.8832072615623474),
 ('chase', 0.8816617727279663),
 ('trackway', 0.8809894323348999),
 ('velodrome', 0.8687973022460938),
 ('beltway', 0.8651279807090759),
 ('driveway', 0.8606955409049988),
 ('speedway', 0.8336283564567566),
 ('racetrack', 0.7930748462677002),
 ('driveline', 0.7808491587638855),
 ('hunt', 0.7629326581954956),
 ('route', 0.7609122395515442),
 ('track', 0.7595197558403015),
 ('belt', 0.7323867082595825),
 ('runway', 0.7022448182106018),
 ('control', 0.6142128109931946),
 ('crossway', 0.6007115840911865),
 ('move', 0.5898234844207764),
 ('travel', 0.5869883894920349),
 ('street', 0.5799174308776855),
 ('racecourse', 0.5694281458854675),
 ('canal', 0.5614708662033081),
 ('path', 0.5611274838447571),
 ('highway', 0.5501547455787659),
 ('pass', 0.5337620973587036),
 ('crosscut', 0.5330564975738525),
 ('fast', 0.5162973999977112),
 ('railroad', 0.495160847902298),


In [39]:
words, idx, probs = getPredFromDesc(model, 'a very intelligent person', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('genius', 0.9945294260978699),
 ('einstein', 0.99406498670578),
 ('brain', 0.9902623891830444),
 ('brainpower', 0.9856745600700378),
 ('intelligence', 0.9709501266479492),
 ('brainiac', 0.966468334197998),
 ('brains', 0.9625489115715027),
 ('intellect', 0.9441092610359192),
 ('brainstorming', 0.9407050013542175),
 ('brainy', 0.9381992220878601),
 ('intellectual', 0.9175571203231812),
 ('psychic', 0.9143380522727966),
 ('psyche', 0.8716870546340942),
 ('mindfulness', 0.8639110922813416),
 ('mind', 0.8611631989479065),
 ('brainstem', 0.8580717444419861),
 ('brilliance', 0.8215032815933228),
 ('minder', 0.8173473477363586),
 ('brilliant', 0.8104645609855652),
 ('mastermind', 0.8029227256774902),
 ('braincase', 0.7732986807823181),
 ('mental', 0.7660771608352661),
 ('intelligent', 0.7650362253189087),
 ('visionary', 0.7345557808876038),
 ('subconscious', 0.678047776222229),
 ('mentalist', 0.6644124984741211),
 ('intellectualism', 0.6626483201980591),
 ('expert', 0.6022065281867981),
 ('s

In [42]:
words, idx, probs = getPredFromDesc(model, 'the opposite of being happy', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('unhappiness', 0.9850491881370544),
 ('sadness', 0.8976619839668274),
 ('disconcert', 0.8223000168800354),
 ('unhappy', 0.7643705010414124),
 ('discomfiture', 0.717978298664093),
 ('ambivalence', 0.7091021537780762),
 ('sorrow', 0.5153690576553345),
 ('lethargic', 0.5139889121055603),
 ('endearment', 0.5011042356491089),
 ('dreariness', 0.4468076527118683),
 ('unhinge', 0.4360183775424957),
 ('uneasiness', 0.4352346360683441),
 ('craziness', 0.42641308903694153),
 ('emotion', 0.4042612314224243),
 ('aphorism', 0.38845741748809814),
 ('misery', 0.3688901662826538),
 ('discontent', 0.3537975549697876),
 ('misanthrope', 0.31808826327323914),
 ('discomfit', 0.3101854920387268),
 ('complacence', 0.31006255745887756),
 ('antipathetic', 0.3036043047904968),
 ('unbelief', 0.3019541800022125),
 ('desensitisation', 0.2978937029838562),
 ('upset', 0.2975330352783203),
 ('unvarying', 0.2957783341407776),
 ('sorrowing', 0.29450318217277527),
 ('sorrowful', 0.28617167472839355),
 ('feeling', 0.271

In [62]:
words, idx, probs = getPredFromDesc(model, 'something you use to measure your temperature', 5, 100)
[(word, prob.item()) for word, prob in zip(words, probs)]

[('thermometer', 0.998386025428772),
 ('calorimeter', 0.9806400537490845),
 ('thermometry', 0.8985686302185059),
 ('thermocouple', 0.8366416692733765),
 ('thermostat', 0.7714673280715942),
 ('gasometer', 0.7669250965118408),
 ('fahrenheit', 0.6841531991958618),
 ('seismometer', 0.6831289529800415),
 ('speedometer', 0.556801438331604),
 ('barometer', 0.5467906594276428),
 ('reflectometer', 0.5422741174697876),
 ('magnetometer', 0.5342686176300049),
 ('thermopile', 0.5173691511154175),
 ('temperature', 0.5084152817726135),
 ('tachometer', 0.48013073205947876),
 ('refrigerate', 0.4386438727378845),
 ('thermotherapy', 0.41434165835380554),
 ('thermoset', 0.2984980642795563),
 ('potboiler', 0.2958795130252838),
 ('hygrometer', 0.24375738203525543),
 ('thermography', 0.24371981620788574),
 ('setup', 0.2351464331150055),
 ('spirometer', 0.2349872887134552),
 ('radiometer', 0.2307460904121399),
 ('pyrometer', 0.23054927587509155),
 ('hypothermic', 0.22909153997898102),
 ('ergometer', 0.2165631