In [11]:
import torch
from torch import nn
from torch.utils import data
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.notebook import tqdm

import sys
import datetime

sys.path.append('../code')
from dataset import get_data, make_vocab, WantWordsDataset as WWDataset

from transformers import (
    AdamW, get_linear_schedule_with_warmup
)

from models import SentenceBERTForRD

In [14]:
d = get_data('../wantwords-english-baseline/data', word2vec=False)

Loading data...
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [15]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

In [16]:
target2idx, idx2target = make_vocab(d, None)

In [5]:
# target2idx maps target words to indices
# target_matrix maps target indices to bpe sequences, padded/truncated to mask_size
target2idx['book'], idx2target[target2idx['book']]

(16187, 'book')

In [15]:
# can freeze for (part of) first epoch or so and then unfreeze to train the whole model
model = SentenceBERTForRD('distilbert-base-nli-stsb-mean-tokens', 
                          len(target2idx), freeze_sbert=False, criterion=nn.CrossEntropyLoss())

In [28]:
T = model.sbert.tokenizer

In [16]:
train_dataset = WWDataset(train_data + train_data_def, T, target2idx)
dev_dataset = WWDataset(dev_data, T, target2idx)
test_dataset_seen = WWDataset(test_data_seen, T, target2idx)
test_dataset_unseen = WWDataset(test_data_unseen, T, target2idx)
test_dataset_desc = WWDataset(test_data_desc, T, target2idx)

In [17]:
batch_size = 55
num_workers = 4

loader_params = {
    'pin_memory': False,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'collate_fn': train_dataset.collate_fn
}

train_loader = data.DataLoader(train_dataset, **{'shuffle': True, **loader_params})
dev_loader = data.DataLoader(dev_dataset, **{'shuffle': True, **loader_params})
test_loader_seen = data.DataLoader(test_dataset_seen, **{'shuffle': False, **loader_params})
test_loader_unseen = data.DataLoader(test_dataset_unseen, **{'shuffle': False, **loader_params})
test_loader_desc = data.DataLoader(test_dataset_desc, **{'shuffle': False, **loader_params})

In [None]:
epochs = 10

lr = 2e-5
optim = AdamW(model.parameters(), lr=lr)

warmup_duration = 0.01 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

epoch = 0

# scaler = GradScaler()

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheduler = type(scheduler).__name__
config.warmup_duration = warmup_duration

wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.29 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[<wandb.wandb_torch.TorchGraph at 0x7fb0a7ec24d0>]

In [19]:
model = model.to(device)

In [20]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [21]:
def test(loader, name):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, ((x, attention_mask), y) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=y)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(out, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
    return ({
        f'{name}_test_loss': test_loss,
        f'{name}_test_acc1': test_acc1,
        f'{name}_test_acc10': test_acc10,
        f'{name}_test_acc100': test_acc100,
        f'{name}_test_rank_median': median,
        f'{name}_test_rank_variance': var
    })
    

In [22]:
inc = 10
losses = []

for epoch in range(epoch, epochs):
    # Training
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, ((x, attention_mask), y) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x = x.to(device)
            attention_mask = attention_mask.to(device)
            y = y.to(device)
            
            loss, out = model(input_ids=x, attention_mask=attention_mask, 
                              ground_truth=y)

#             scaler.scale(loss).backward()
            loss.backward()
            
#             scaler.unscale_(optim)
            nn.utils.clip_grad_value_(model.parameters(), 5)
            
#             scaler.step(optim)
            optim.step()
#             scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss, attention_mask
            
    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save({'state_dict': model.state_dict()}, f)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    with torch.no_grad():
        with tqdm(total=len(dev_loader)) as pbar:
            for i, ((x, attention_mask), y) in enumerate(dev_loader):
                if i % inc == 0 and i != 0:
                    display_loss = val_loss / i
                    pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=y)

                val_loss += loss.detach()

                pbar.update(1)                
                
                result, indices = torch.topk(out, k=100, dim=-1, largest=True, sorted=True)
                
                acc1, acc10, acc100 = evaluate(indices, y)
                val_acc1 += acc1
                val_acc10 += acc10
                val_acc100 += acc100

                del x, y, out, loss
    
    wandb.log({
        'train_loss': train_loss / len(train_loader),
        'val_loss': val_loss / len(dev_loader),
        'val_acc1': val_acc1 / len(dev_loader),
        'val_acc10': val_acc10 / len(dev_loader),
        'val_acc100': val_acc100 / len(dev_loader),
        **test(test_loader_seen, 'seen'),
        **test(test_loader_unseen, 'unseen'),
        **test(test_loader_desc, 'desc')
    })
    

  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(9.3957, device='cuda:0')
seen_test_acc1: 0.012
seen_test_acc10: 0.074
seen_test_acc100: 0.25
seen_test_rank_median: tensor(852.)
seen_test_rank_variance tensor(426.3570)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(13.0325, device='cuda:0')
unseen_test_acc1: 0.0
unseen_test_acc10: 0.002
unseen_test_acc100: 0.004
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(67.4732)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(6.2205, device='cuda:0')
desc_test_acc1: 0.29
desc_test_acc10: 0.555
desc_test_acc100: 0.745
desc_test_rank_median: tensor(7.)
desc_test_rank_variance tensor(349.5893)


  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(8.1269, device='cuda:0')
seen_test_acc1: 0.06
seen_test_acc10: 0.216
seen_test_acc100: 0.468
seen_test_rank_median: tensor(140.)
seen_test_rank_variance tensor(419.7760)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(13.9106, device='cuda:0')
unseen_test_acc1: 0.0
unseen_test_acc10: 0.004
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(75.1580)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.8081, device='cuda:0')
desc_test_acc1: 0.375
desc_test_acc10: 0.7
desc_test_acc100: 0.82
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(299.2062)


  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(7.0747, device='cuda:0')
seen_test_acc1: 0.128
seen_test_acc10: 0.366
seen_test_acc100: 0.616
seen_test_rank_median: tensor(36.)
seen_test_rank_variance tensor(383.9724)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.1173, device='cuda:0')
unseen_test_acc1: 0.002
unseen_test_acc10: 0.004
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(76.6629)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.2866, device='cuda:0')
desc_test_acc1: 0.375
desc_test_acc10: 0.7
desc_test_acc100: 0.845
desc_test_rank_median: tensor(1.)
desc_test_rank_variance tensor(245.0346)


  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(6.1792, device='cuda:0')
seen_test_acc1: 0.206
seen_test_acc10: 0.498
seen_test_acc100: 0.718
seen_test_rank_median: tensor(10.)
seen_test_rank_variance tensor(358.8513)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.3497, device='cuda:0')
unseen_test_acc1: 0.002
unseen_test_acc10: 0.006
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(77.2013)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.1387, device='cuda:0')
desc_test_acc1: 0.36
desc_test_acc10: 0.695
desc_test_acc100: 0.85
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(235.2277)


  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(5.5666, device='cuda:0')
seen_test_acc1: 0.292
seen_test_acc10: 0.584
seen_test_acc100: 0.768
seen_test_rank_median: tensor(4.)
seen_test_rank_variance tensor(355.4846)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.4899, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.006
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(81.6484)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.0460, device='cuda:0')
desc_test_acc1: 0.38
desc_test_acc10: 0.675
desc_test_acc100: 0.865
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(239.9710)


  0%|          | 0/12286 [00:00<?, ?it/s]

  0%|          | 0/1380 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(5.1292, device='cuda:0')
seen_test_acc1: 0.33
seen_test_acc10: 0.636
seen_test_acc100: 0.798
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(349.5005)


  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.6690, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.006
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(82.8029)


  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.0401, device='cuda:0')
desc_test_acc1: 0.36
desc_test_acc10: 0.675
desc_test_acc100: 0.86
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(234.6072)


  0%|          | 0/12286 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
def getPredFromDesc(model, desc : str, top_n=10):
    desc = T(desc, return_tensors='pt', padding=True)
    x = desc['input_ids'].to(device)
    attention_mask = desc['attention_mask'].to(device)
    out = model(input_ids=x, attention_mask=attention_mask)
    result, indices = torch.topk(out, k=top_n, dim=-1, largest=True, sorted=True)
    
    indices = indices[0]
    return [idx2target[i] for i in indices], indices
    

In [28]:
test(test_loader_seen, 'seen')

  0%|          | 0/10 [00:00<?, ?it/s]

seen_test_loss: tensor(5.1246, device='cuda:0')
seen_test_acc1: 0.326
seen_test_acc10: 0.64
seen_test_acc100: 0.804
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(347.5329)


{'seen_test_loss': tensor(5.1246, device='cuda:0'),
 'seen_test_acc1': 0.326,
 'seen_test_acc10': 0.64,
 'seen_test_acc100': 0.804,
 'seen_test_rank_median': tensor(3.),
 'seen_test_rank_variance': tensor(347.5329)}

In [29]:
test(test_loader_unseen, 'unseen')

  0%|          | 0/10 [00:00<?, ?it/s]

unseen_test_loss: tensor(14.7876, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.006
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(82.8673)


{'unseen_test_loss': tensor(14.7876, device='cuda:0'),
 'unseen_test_acc1': 0.004,
 'unseen_test_acc10': 0.006,
 'unseen_test_acc100': 0.006,
 'unseen_test_rank_median': tensor(1000.),
 'unseen_test_rank_variance': tensor(82.8673)}

In [30]:
test(test_loader_desc, 'desc')

  0%|          | 0/4 [00:00<?, ?it/s]

desc_test_loss: tensor(4.0884, device='cuda:0')
desc_test_acc1: 0.365
desc_test_acc10: 0.675
desc_test_acc100: 0.86
desc_test_rank_median: tensor(2.)
desc_test_rank_variance tensor(234.0795)


{'desc_test_loss': tensor(4.0884, device='cuda:0'),
 'desc_test_acc1': 0.365,
 'desc_test_acc10': 0.675,
 'desc_test_acc100': 0.86,
 'desc_test_rank_median': tensor(2.),
 'desc_test_rank_variance': tensor(234.0795)}

In [24]:
getPredFromDesc(model, 'an inhabitant of a cold country', 100)

(['glacial',
  'arctic',
  'icy',
  'iceman',
  'frozen',
  'winters',
  'husky',
  'nordic',
  'boreal',
  'barbarian',
  'frigid',
  'hun',
  'mountaineer',
  'winter',
  'freezing',
  'cooler',
  'snowbird',
  'cold',
  'snowplow',
  'thaw',
  'malamute',
  'coldly',
  'renegade',
  'frosty',
  'afghan',
  'bushman',
  'barbarous',
  'icehouse',
  'tundra',
  'taiga',
  'tobogganing',
  'nippy',
  'chinook',
  'coolie',
  'rustic',
  'mink',
  'philistine',
  'refrigerant',
  'refrigerate',
  'cynic',
  'cool',
  'frosted',
  'glaciation',
  'frost',
  'snowy',
  'glaciated',
  'coolly',
  'gypsy',
  'frigidity',
  'starve',
  'landsman',
  'thawing',
  'troglodyte',
  'refrigerating',
  'chiller',
  'nazi',
  'nonnative',
  'skater',
  'foe',
  'provincial',
  'frostbite',
  'neve',
  'caribou',
  'chills',
  'peasant',
  'cooling',
  'clown',
  'celsius',
  'popsicle',
  'agnostic',
  'glade',
  'snowboarding',
  'slavic',
  'overwinter',
  'asia',
  'weasel',
  'polar',
  'barbar

In [25]:
getPredFromDesc(model, 'employee at a circus', 100)

(['clown',
  'butcher',
  'circus',
  'knackered',
  'jobber',
  'butchers',
  'mummer',
  'stag',
  'monger',
  'ham',
  'gaffer',
  'ringer',
  'mie',
  'cowherd',
  'huntsman',
  'knacker',
  'demonstrator',
  'juggler',
  'masquerading',
  'ranger',
  'comedian',
  'actor',
  'picket',
  'outlaw',
  'strongman',
  'buffo',
  'stunting',
  'carnival',
  'jock',
  'publican',
  'bandsman',
  'don',
  'vamp',
  'roustabout',
  'valet',
  'trooper',
  'fagot',
  'bullock',
  'beadle',
  'matador',
  'garnishee',
  'buffoon',
  'man',
  'ramrod',
  'reeve',
  'csd',
  'usher',
  'shifter',
  'adventurer',
  'larrikin',
  'tinker',
  'swagman',
  'jeweler',
  'roper',
  'ape',
  'gladiator',
  'tenter',
  'dof',
  'lackey',
  'rustler',
  'player',
  'teaser',
  'knight',
  'theatergoer',
  'pedant',
  'prostitute',
  'knave',
  'job',
  'cad',
  'flaunt',
  'artist',
  'pantomime',
  'hams',
  'mastering',
  'cadet',
  'handyman',
  'apprentice',
  'amateur',
  'stockman',
  'skinner',


In [26]:
getPredFromDesc(model, 'a road on which cars can go fast', 100)

(['riad',
  'drive',
  'speedway',
  'railroad',
  'railway',
  'lanes',
  'runway',
  'bus',
  'highway',
  'race',
  'rush',
  'nus',
  'driveway',
  'fast',
  'rushing',
  'gating',
  'bua',
  'accelerator',
  'freeway',
  'tailgate',
  'gate',
  'road',
  'car',
  'scuttling',
  'street',
  'tracked',
  'clearance',
  'route',
  'pad',
  'clip',
  'turnpike',
  'hacking',
  'trailer',
  'psd',
  'highball',
  'track',
  'rushes',
  'blitzed',
  'tracks',
  'gangway',
  'fasting',
  'causeway',
  'lock',
  'swift',
  'mileage',
  'hack',
  'fugitive',
  'slipping',
  'hitchhiking',
  'backstop',
  'routing',
  'turnstile',
  'highroad',
  'flash',
  'parking',
  'shuttling',
  'blockade',
  'corridor',
  'avenue',
  'cab',
  'trap',
  'door',
  'carpool',
  'flight',
  'shuttle',
  'thoroughfare',
  'bootlegging',
  'career',
  'taxi',
  'jump',
  'hijack',
  'hurry',
  'drove',
  'gas',
  'motor',
  'convertible',
  'steals',
  'drift',
  'expedited',
  'fleet',
  'trapped',
  'gun

In [27]:
getPredFromDesc(model, 'something you use to measure your temperature', 100)

(['thermometer',
  'thermometry',
  'isotherm',
  'thermostat',
  'calorimeter',
  'pyrometer',
  'fahrenheit',
  'refrigerant',
  'thermopile',
  'cooler',
  'centigrade',
  'freezer',
  'gasometer',
  'celsius',
  'chronograph',
  'heating',
  'superheating',
  'dynamometer',
  'temperance',
  'meter',
  'thermoregulation',
  'manometer',
  'calorific',
  'mileage',
  'barometer',
  'percentile',
  'heat',
  'thermocouple',
  'metering',
  'endothermic',
  'het',
  'anemometer',
  'ergometer',
  'thermography',
  'antarctica',
  'refrigeration',
  'comparator',
  'ph',
  'hygrometer',
  'measurement',
  'hotness',
  'coldness',
  'stp',
  'calibration',
  'joule',
  'cryogenics',
  'timepiece',
  'heater',
  'radiometer',
  'climate',
  'measurer',
  'potentiometer',
  'cryogenic',
  'tonometry',
  'calorie',
  'coefficient',
  'thermally',
  'titer',
  'gauge',
  'conduction',
  'seismometer',
  'dosimeter',
  'accelerometer',
  'equidistant',
  'thermotherapy',
  'calibrated',
  'b

In [32]:
getPredFromDesc(model, 'a large house that a rich person lives in', 100)

(['mansion',
  'house',
  'hall',
  'villa',
  'palace',
  'stateroom',
  'apartment',
  'hotel',
  'homestead',
  'domiciled',
  'tenement',
  'lodge',
  'roof',
  'lofting',
  'casa',
  'housing',
  'saloon',
  'cottage',
  'penthouse',
  'residence',
  'roofs',
  'cabinet',
  'householder',
  'home',
  'inn',
  'castle',
  'maisonette',
  'condominium',
  'manor',
  'seraglio',
  'summerhouse',
  'nesting',
  'placed',
  'ibn',
  'townhouse',
  'parkour',
  'cribbed',
  'demesne',
  'palatial',
  'landlord',
  'dwelling',
  'divan',
  'loge',
  'parlor',
  'court',
  'stage',
  'nested',
  'lodging',
  'cabin',
  'chateau',
  'building',
  'nest',
  'chamber',
  'cloakroom',
  'edifice',
  'housemate',
  'rooms',
  'stacked',
  'residency',
  'buttery',
  'guesthouse',
  'studio',
  'rotunda',
  'pension',
  'stack',
  'floor',
  'houseguest',
  'treasury',
  'kitty',
  'hovel',
  'lobbying',
  'pavilion',
  'household',
  'condo',
  'tenancy',
  'roundhouse',
  'stacks',
  'feature

In [33]:
getPredFromDesc(model, 'the opposite of being happy', 100)

(['unhappy',
  'happiness',
  'sadness',
  'misery',
  'vanity',
  'misfortune',
  'bashfulness',
  'complacence',
  'felicitate',
  'evil',
  'nothingness',
  'depression',
  'unhappiness',
  'absurd',
  'emptiness',
  'disappointment',
  'frailty',
  'pessimism',
  'felicity',
  'diffidence',
  'inconvenience',
  'melancholic',
  'egoism',
  'indisposition',
  'irrationality',
  'penury',
  'folly',
  'mischief',
  'languor',
  'blithe',
  'gravity',
  'hopeless',
  'unreality',
  'atheism',
  'joviality',
  'unreasoning',
  'luck',
  'adversity',
  'egotism',
  'discontent',
  'disconsolate',
  'nihilism',
  'insanity',
  'pitiable',
  'unkindly',
  'spleen',
  'transience',
  'arrogance',
  'magnanimity',
  'affliction',
  'gaiety',
  'demerit',
  'infatuation',
  'modesty',
  'ugliness',
  'uneasiness',
  'impotence',
  'disinterest',
  'disaffection',
  'malevolence',
  'sterility',
  'flatness',
  'envy',
  'deficiency',
  'goodness',
  'deadness',
  'regrets',
  'leniency',
  '

In [35]:
getPredFromDesc(model, 'a mammal that lives in water', 100) # decent results

(['amphibian',
  'amphibious',
  'swim',
  'porpoise',
  'plankton',
  'aqueous',
  'swimmer',
  'subaquatic',
  'cetacean',
  'gill',
  'puffer',
  'dipper',
  'swimming',
  'benthos',
  'sucker',
  'monkfish',
  'pelagic',
  'underwater',
  'hydrate',
  'smelt',
  'prawn',
  'submersible',
  'brine',
  'salamander',
  'ocean',
  'limnology',
  'beluga',
  'leviathan',
  'groundfish',
  'laver',
  'leeching',
  'catfish',
  'goby',
  'porgy',
  'neptune',
  'pinniped',
  'diver',
  'hydraulic',
  'lifeguard',
  'hydrosphere',
  'beaver',
  'skater',
  'drowned',
  'bathe',
  'sunfish',
  'dowse',
  'mermaid',
  'mussel',
  'tadpole',
  'dogfish',
  'awash',
  'sheepshead',
  'hardhead',
  'grouper',
  'planktonic',
  'anglerfish',
  'jellyfish',
  'bloodsucker',
  'afloat',
  'vessel',
  'paddling',
  'moray',
  'sea',
  'darter',
  'loach',
  'cataract',
  'shovelnose',
  'mullet',
  'swordfish',
  'laker',
  'butterfish',
  'nautilus',
  'steamer',
  'lake',
  'submarine',
  'supern

In [36]:
getPredFromDesc(model, 'a mammal that lives in ocean', 100) # bad results

(['oceanic',
  'ocean',
  'pelagic',
  'neptune',
  'sea',
  'benthos',
  'oceanography',
  'marine',
  'leviathan',
  'transoceanic',
  'seagoing',
  'nautilus',
  'coaster',
  'seabird',
  'shearwater',
  'beluga',
  'beachcomber',
  'gam',
  'underwater',
  'oceangoing',
  'prawn',
  'narwhal',
  'monkfish',
  'kingfish',
  'bowhead',
  'subaquatic',
  'porpoise',
  'transatlantic',
  'shovelnose',
  'seashell',
  'plankton',
  'halibut',
  'kelp',
  'seaweed',
  'tsunami',
  'cetacean',
  'seahorse',
  'hake',
  'seaboard',
  'mussel',
  'archipelago',
  'blackfish',
  'gopher',
  'brine',
  'reefer',
  'jason',
  'iceberg',
  'main',
  'anchovy',
  'buoy',
  'submarine',
  'seafaring',
  'hardhead',
  'brachiopod',
  'manatee',
  'atlantic',
  'mermaid',
  'baleen',
  'groundfish',
  'planktonic',
  'orion',
  'amphibian',
  'seaward',
  'offshore',
  'oversea',
  'cancer',
  'beagle',
  'chiton',
  'tide',
  'dogfish',
  'serval',
  'hydrosphere',
  'moray',
  'lobster',
  'crust

In [63]:
getPredFromDesc(model, 'a person who is very knowledgeable about many subjects', 100)

(['scholar',
  'pedant',
  'student',
  'expert',
  'professor',
  'intellectual',
  'literate',
  'adept',
  'polyglot',
  'bookworm',
  'lector',
  'educator',
  'thinker',
  'monitor',
  'teacher',
  'master',
  'virtuoso',
  'highbrow',
  'reader',
  'proficient',
  'mastering',
  'doctoring',
  'knowledgeable',
  'skeptic',
  'amateur',
  'connoisseur',
  'pedagogue',
  'technocrat',
  'doctored',
  'savant',
  'scientist',
  'bookman',
  'philosopher',
  'erudition',
  'conversant',
  'doctor',
  'gnostic',
  'cpa',
  'appreciator',
  'humanist',
  'collector',
  'pundit',
  'learning',
  'lettered',
  'mandarin',
  'science',
  'informant',
  'generalist',
  'artist',
  'scholastic',
  'teller',
  'learner',
  'informer',
  'tea',
  'technologist',
  'inquisitor',
  'educated',
  'lore',
  'fellow',
  'orator',
  'university',
  'specialist',
  'technician',
  'sage',
  'brains',
  'stargazer',
  'scholarship',
  'intellect',
  'study',
  'exegete',
  'empiric',
  'observer',
  

In [24]:
model_name = type(model).__name__
filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
with open(filename, 'wb+') as f:
    torch.save({'state_dict': model.state_dict()}, f)

In [35]:
T("you are not helpless", return_tensors='pt', padding=True)

{'input_ids': tensor([[  101,  2017,  2024,  2025, 13346,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [54]:
train_words = {train_dataset[i][1] for i in range(len(train_dataset))}
dev_words = {dev_dataset[i][1] for i in range(len(dev_dataset))}

In [57]:
len(dev_words), len(train_words)

(4998, 44996)

In [3]:
import json
def read_json(path):
    with open(path) as f:
        return json.load(f)

In [4]:
wn_data = read_json('../data/wn_data.json')

In [10]:
wn_data['unhappy']

{'synonyms': ['distressed', 'dysphoric'],
 'antonyms': ['euphoric', 'happy'],
 'related_forms': ['unhappiness', 'dysphoria'],
 'hyponyms': [],
 'hypernyms': []}

In [31]:
model = SentenceBERTForRD('distilbert-base-nli-stsb-mean-tokens', 
                          len(target2idx), freeze_sbert=False, criterion=nn.CrossEntropyLoss())

In [30]:
state_dict = torch.load('../trained_models/SentenceBERTForRD_Epoch_3_at_2021-05-05_19:58:01.613062')['state_dict']

In [32]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [33]:
model = model.to(device)

In [22]:
queries = [
    'a type of tree',
    'the opposite of being happy',
    'employee at a circus',
    'a road on which cars can go quickly without stopping',
    'a very intelligent person',
    'a very smart person',
    'something you use to measure your temperature',
    'a dark time of day',
    'medieval social hierarchy where peasants and vassals served lords',
    'to help someone else learn',
    'when someone you trust does something that breaks your trust',
    'deep learning'
]

In [29]:
from pprint import pprint

for q in queries:
    print(f'Results for {q}')
    pprint(getPredFromDesc(model, q, 100))
    print()

Results for a type of tree
(['tree',
  'treed',
  'eucalypt',
  'arboreal',
  'barking',
  'birch',
  'fir',
  'hemlock',
  'leatherwood',
  'butternut',
  'stocks',
  'aspen',
  'softwood',
  'brushwood',
  'durian',
  'sapling',
  'forester',
  'sylvan',
  'broadleaf',
  'oak',
  'redwood',
  'tope',
  'nutmeg',
  'carambola',
  'deodar',
  'balsa',
  'logwood',
  'longan',
  'bole',
  'swede',
  'rambutan',
  'trunked',
  'burl',
  'blackwood',
  'loquat',
  'trunks',
  'coppice',
  'oaken',
  'toon',
  'sugarbush',
  'stocked',
  'haw',
  'nox',
  'forest',
  'palm',
  'dudgeon',
  'chinquapin',
  'pecker',
  'roe',
  'timberland',
  'brand',
  'mast',
  'gumming',
  'knot',
  'flag',
  'stick',
  'sticks',
  'sycamore',
  'dogwood',
  'cedar',
  'pollard',
  'twiggy',
  'bau',
  'catalpa',
  'liana',
  'hazel',
  'kauri',
  'twig',
  'limb',
  'sapodilla',
  'bough',
  'locust',
  'fig',
  'teak',
  'cypress',
  'mahogany',
  'spurring',
  'atm',
  'pawpaw',
  'superior',
  'bayin

(['bribing',
  'traitor',
  'betray',
  'credits',
  'surety',
  'confide',
  'hedge',
  'perfidy',
  'trust',
  'trusty',
  'cheated',
  'entrust',
  'judas',
  'credited',
  'treason',
  'cheat',
  'embezzling',
  'fiduciary',
  'confidant',
  'venturing',
  'frailty',
  'compromising',
  'bet',
  'bribe',
  'depositing',
  'misplaced',
  'trusting',
  'subscribe',
  'mistrust',
  'compromise',
  'confusion',
  'jinx',
  'fraud',
  'treachery',
  'unfaithful',
  'pawning',
  'credit',
  'confessor',
  'mislead',
  'selling',
  'friend',
  'mate',
  'confidence',
  'infidelity',
  'confidential',
  'pawn',
  'lose',
  'guard',
  'suicide',
  'depositional',
  'distrust',
  'fear',
  'reliance',
  'accredit',
  'biter',
  'hazardous',
  'venture',
  'flaw',
  'responsible',
  'jilt',
  'fob',
  'betrayer',
  'credibility',
  'pledge',
  'trespassing',
  'bets',
  'margin',
  'fooling',
  'betting',
  'stakeholder',
  'treacherous',
  'informant',
  'hedger',
  'heartbreaker',
  'comfor

In [34]:
from pprint import pprint

for q in queries:
    print(f'Results for {q}')
    pprint(getPredFromDesc(model, q, 100))
    print()

Results for a type of tree
(['tree',
  'treed',
  'nox',
  'canes',
  'pine',
  'fir',
  'maple',
  'arboreal',
  'ironwood',
  'mahogany',
  'hardwood',
  'spruce',
  'apple',
  'oak',
  'hickory',
  'forest',
  'cedar',
  'larch',
  'grove',
  'log',
  'eucalyptus',
  'sandalwood',
  'plu',
  'bamboo',
  'redwood',
  'cypress',
  'birch',
  'stocks',
  'sycamore',
  'bush',
  'locust',
  'stocked',
  'knotting',
  'limes',
  'barking',
  'toon',
  'evergreen',
  'stick',
  'acacia',
  'chestnut',
  'basil',
  'knot',
  'nutmeg',
  'mallee',
  'cottonwood',
  'liming',
  'lime',
  'clove',
  'hemlock',
  'mulberry',
  'leatherwood',
  'hazel',
  'box',
  'mangrove',
  'timbered',
  'lumber',
  'dock',
  'baying',
  'leaf',
  'sylvan',
  'crip',
  'linden',
  'pollard',
  'butternut',
  'dogwood',
  'brushwood',
  'papaw',
  'standard',
  'fig',
  'bau',
  'ebony',
  'sticks',
  'forester',
  'reed',
  'shrub',
  'straw',
  'foliate',
  'grass',
  'softwood',
  'plum',
  'trunks',
  'p

  'aif',
  'aided',
  'learn',
  'improve',
  'afford',
  'doctor',
  'nourished',
  'doctored',
  'teach',
  'avail',
  'study',
  'nourish',
  'assisted',
  'availing',
  'recommend',
  'doctoring',
  'supporting',
  'cover',
  'benefit',
  'covers',
  'disciple',
  'read',
  'upgrade',
  'save',
  'saved',
  'give',
  'retrieve',
  'achieve',
  'nurture',
  'support',
  'profits',
  'training',
  'tutor',
  'follow',
  'second',
  'skillfulness',
  'cured',
  'seconds',
  'foster',
  'lesson',
  'vetter',
  'salving',
  'mend',
  'master',
  'minister',
  'edified',
  'recovered',
  'overhauling',
  'practicing',
  'skill',
  'do',
  'redeem',
  'exercise',
  'heal',
  'art',
  'supported',
  'exploited',
  'recover',
  'understand',
  'instruct',
  'discover',
  'use',
  'get',
  'upgraded',
  'selling',
  'experimenting',
  'provide',
  'salve',
  'cure',
  'mending',
  'con',
  'hear',
  'proving',
  'mastering',
  'assistance',
  'school',
  'promote',
  'spell',
  'seeking',
  