In [1]:
import torch
from torch import nn
from torch.utils import data
from torch.cuda.amp import autocast, GradScaler

import numpy as np
from tqdm.notebook import tqdm

import sys
import datetime

sys.path.append('../code')
from dataset import get_data, make_vocab, WantWordsDataset as WWDataset

from transformers import (
    AdamW, get_linear_schedule_with_warmup
)

from models import SentenceBERTForRD

In [2]:
d = get_data('../wantwords-english-baseline/data', word2vec=False)

Loading data...
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [3]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

In [4]:
target2idx, idx2target = make_vocab(d, None)

In [5]:
# target2idx maps target words to indices
# target_matrix maps target indices to bpe sequences, padded/truncated to mask_size
target2idx['book'], idx2target[target2idx['book']]

(16187, 'book')

In [6]:
# can freeze for (part of) first epoch or so and then unfreeze to train the whole model
model = SentenceBERTForRD('distilbert-base-nli-stsb-mean-tokens', 
                          len(target2idx), freeze_sbert=True, criterion=nn.CrossEntropyLoss())

In [7]:
T = model.sbert.tokenizer
train_dataset = WWDataset(train_data + train_data_def, T, target2idx)
dev_dataset = WWDataset(dev_data, T, target2idx)
test_dataset_seen = WWDataset(test_data_seen, T, target2idx)
test_dataset_unseen = WWDataset(test_data_unseen, T, target2idx)
test_dataset_desc = WWDataset(test_data_desc, T, target2idx)

In [8]:
batch_size = 128
num_workers = 4

loader_params = {
    'pin_memory': False,
    'batch_size': batch_size,
    'num_workers': num_workers,
    'collate_fn': train_dataset.collate_fn
}

train_loader = data.DataLoader(train_dataset, **{'shuffle': True, **loader_params})
dev_loader = data.DataLoader(dev_dataset, **{'shuffle': True, **loader_params})
test_loader_seen = data.DataLoader(test_dataset_seen, **{'shuffle': False, **loader_params})
test_loader_unseen = data.DataLoader(test_dataset_unseen, **{'shuffle': False, **loader_params})
test_loader_desc = data.DataLoader(test_dataset_desc, **{'shuffle': False, **loader_params})

In [9]:
epochs = 10

lr = 2e-5
optim = AdamW(model.parameters(), lr=lr)

warmup_duration = 0.01 # portion of the first epoch spent on lr warmup
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=len(train_loader) * warmup_duration, 
                                            num_training_steps=len(train_loader) * epochs)

epoch = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# scaler = GradScaler()

In [10]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheduler = type(scheduler).__name__
config.warmup_duration = warmup_duration

wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.29 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[<wandb.wandb_torch.TorchGraph at 0x7fb0a7ec24d0>]

In [11]:
model = model.to(device)

In [12]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [13]:
def test(loader, name):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, ((x, attention_mask), y) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=y)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(out, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
    return ({
        f'{name}_test_loss': test_loss,
        f'{name}_test_acc1': test_acc1,
        f'{name}_test_acc10': test_acc10,
        f'{name}_test_acc100': test_acc100,
        f'{name}_test_rank_median': median,
        f'{name}_test_rank_variance': var
    })
    

In [14]:
inc = 10
losses = []

for epoch in range(epoch, epochs):
    # Training
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, ((x, attention_mask), y) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x = x.to(device)
            attention_mask = attention_mask.to(device)
            y = y.to(device)
            
            loss, out = model(input_ids=x, attention_mask=attention_mask, 
                              ground_truth=y)

#             scaler.scale(loss).backward()
            loss.backward()
            
#             scaler.unscale_(optim)
            nn.utils.clip_grad_value_(model.parameters(), 5)
            
#             scaler.step(optim)
            optim.step()
#             scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss, attention_mask
            
    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save({'state_dict': model.state_dict()}, f)
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    with torch.no_grad():
        with tqdm(total=len(dev_loader)) as pbar:
            for i, ((x, attention_mask), y) in enumerate(dev_loader):
                if i % inc == 0 and i != 0:
                    display_loss = val_loss / i
                    pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = y.to(device)

#                 with autocast():
                loss, out = model(input_ids=x, attention_mask=attention_mask,
                                  ground_truth=y)

                val_loss += loss.detach()

                pbar.update(1)                
                
                result, indices = torch.topk(out, k=100, dim=-1, largest=True, sorted=True)
                
                acc1, acc10, acc100 = evaluate(indices, y)
                val_acc1 += acc1
                val_acc10 += acc10
                val_acc100 += acc100

                del x, y, out, loss
    
    wandb.log({
        'train_loss': train_loss / len(train_loader),
        'val_loss': val_loss / len(dev_loader),
        'val_acc1': val_acc1 / len(dev_loader),
        'val_acc10': val_acc10 / len(dev_loader),
        'val_acc100': val_acc100 / len(dev_loader),
        **test(test_loader_seen, 'seen'),
        **test(test_loader_unseen, 'unseen'),
        **test(test_loader_desc, 'desc')
    })
    

  0%|          | 0/5280 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/4 [00:00<?, ?it/s]

seen_test_loss: tensor(5.4102, device='cuda:0')
seen_test_acc1: 0.292
seen_test_acc10: 0.556
seen_test_acc100: 0.766
seen_test_rank_median: tensor(5.)
seen_test_rank_variance tensor(357.2637)


  0%|          | 0/4 [00:00<?, ?it/s]

unseen_test_loss: tensor(15.1948, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.004
unseen_test_acc100: 0.006
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(84.4202)


  0%|          | 0/2 [00:00<?, ?it/s]

desc_test_loss: tensor(4.2677, device='cuda:0')
desc_test_acc1: 0.29
desc_test_acc10: 0.585
desc_test_acc100: 0.85
desc_test_rank_median: tensor(5.)
desc_test_rank_variance tensor(212.4294)


  0%|          | 0/5280 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/4 [00:00<?, ?it/s]

seen_test_loss: tensor(4.5583, device='cuda:0')
seen_test_acc1: 0.422
seen_test_acc10: 0.67
seen_test_acc100: 0.824
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(345.5203)


  0%|          | 0/4 [00:00<?, ?it/s]

unseen_test_loss: tensor(16.1992, device='cuda:0')
unseen_test_acc1: 0.002
unseen_test_acc10: 0.004
unseen_test_acc100: 0.008
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(87.3857)


  0%|          | 0/2 [00:00<?, ?it/s]

desc_test_loss: tensor(4.1939, device='cuda:0')
desc_test_acc1: 0.255
desc_test_acc10: 0.575
desc_test_acc100: 0.835
desc_test_rank_median: tensor(6.)
desc_test_rank_variance tensor(224.9331)


  0%|          | 0/5280 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/593 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

seen_test_loss: tensor(4.1312, device='cuda:0')
seen_test_acc1: 0.46
seen_test_acc10: 0.736
seen_test_acc100: 0.844
seen_test_rank_median: tensor(1.)
seen_test_rank_variance tensor(345.2461)


  0%|          | 0/4 [00:00<?, ?it/s]

unseen_test_loss: tensor(17.0107, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.004
unseen_test_acc100: 0.008
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(88.4190)


  0%|          | 0/2 [00:00<?, ?it/s]

desc_test_loss: tensor(4.2017, device='cuda:0')
desc_test_acc1: 0.26
desc_test_acc10: 0.555
desc_test_acc100: 0.845
desc_test_rank_median: tensor(6.)
desc_test_rank_variance tensor(219.3234)


  0%|          | 0/5280 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/593 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

seen_test_loss: tensor(3.8241, device='cuda:0')
seen_test_acc1: 0.508
seen_test_acc10: 0.8
seen_test_acc100: 0.848
seen_test_rank_median: tensor(0.)
seen_test_rank_variance tensor(341.1228)


  0%|          | 0/4 [00:00<?, ?it/s]

unseen_test_loss: tensor(17.3610, device='cuda:0')
unseen_test_acc1: 0.004
unseen_test_acc10: 0.006
unseen_test_acc100: 0.008
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(88.6852)


  0%|          | 0/2 [00:00<?, ?it/s]

desc_test_loss: tensor(4.2648, device='cuda:0')
desc_test_acc1: 0.265
desc_test_acc10: 0.56
desc_test_acc100: 0.85
desc_test_rank_median: tensor(7.)
desc_test_rank_variance tensor(235.8689)


  0%|          | 0/5280 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/593 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

In [42]:
def getPredFromDesc(model, desc : str, top_n=10):
    desc = T(desc, return_tensors='pt', padding=True)
    x = desc['input_ids'].to(device)
    attention_mask = desc['attention_mask'].to(device)
    out = model(input_ids=x, attention_mask=attention_mask)
    result, indices = torch.topk(out, k=top_n, dim=-1, largest=True, sorted=True)
    
    indices = indices[0]
    return [idx2target[i] for i in indices], indices
    

In [18]:
test(test_loader_seen, 'seen') # epoch 1

  0%|          | 0/4 [00:00<?, ?it/s]

seen_test_loss: tensor(4.9566, device='cuda:0')
seen_test_acc1: 0.324
seen_test_acc10: 0.628
seen_test_acc100: 0.792
seen_test_rank_median: tensor(3.)
seen_test_rank_variance tensor(350.0158)


{'seen_test_loss': tensor(4.9566, device='cuda:0'),
 'seen_test_acc1': 0.324,
 'seen_test_acc10': 0.628,
 'seen_test_acc100': 0.792,
 'seen_test_rank_median': tensor(3.),
 'seen_test_rank_variance': tensor(350.0158)}

In [19]:
test(test_loader_unseen, 'unseen') # epoch 1

  0%|          | 0/4 [00:00<?, ?it/s]

unseen_test_loss: tensor(16.1984, device='cuda:0')
unseen_test_acc1: 0.002
unseen_test_acc10: 0.002
unseen_test_acc100: 0.004
unseen_test_rank_median: tensor(1000.)
unseen_test_rank_variance tensor(79.8172)


{'unseen_test_loss': tensor(16.1984, device='cuda:0'),
 'unseen_test_acc1': 0.002,
 'unseen_test_acc10': 0.002,
 'unseen_test_acc100': 0.004,
 'unseen_test_rank_median': tensor(1000.),
 'unseen_test_rank_variance': tensor(79.8172)}

In [20]:
test(test_loader_desc, 'desc') # epoch 1

  0%|          | 0/2 [00:00<?, ?it/s]

desc_test_loss: tensor(4.5469, device='cuda:0')
desc_test_acc1: 0.205
desc_test_acc10: 0.51
desc_test_acc100: 0.82
desc_test_rank_median: tensor(9.)
desc_test_rank_variance tensor(238.9242)


{'desc_test_loss': tensor(4.5469, device='cuda:0'),
 'desc_test_acc1': 0.205,
 'desc_test_acc10': 0.51,
 'desc_test_acc100': 0.82,
 'desc_test_rank_median': tensor(9.),
 'desc_test_rank_variance': tensor(238.9242)}

In [22]:
test(dev_loader, 'desc') # epoch 1

  0%|          | 0/593 [00:00<?, ?it/s]

desc_test_loss: tensor(16.4504, device='cuda:0')
desc_test_acc1: 0.001410251341056766
desc_test_acc10: 0.003980335560739657
desc_test_acc100: 0.005693725040528251
desc_test_rank_median: tensor(1000.)
desc_test_rank_variance tensor(76.4954)


{'desc_test_loss': tensor(16.4504, device='cuda:0'),
 'desc_test_acc1': 0.001410251341056766,
 'desc_test_acc10': 0.003980335560739657,
 'desc_test_acc100': 0.005693725040528251,
 'desc_test_rank_median': tensor(1000.),
 'desc_test_rank_variance': tensor(76.4954)}

In [45]:
getPredFromDesc(model, 'an inhabitant of a cold country', 100)

(['neve',
  'highlander',
  'home',
  'glacial',
  'countryman',
  'winters',
  'frontiersman',
  'lee',
  'mansfield',
  'sylvan',
  'montane',
  'snowbird',
  'alien',
  'overwinter',
  'frontier',
  'winter',
  'icehouse',
  'clime',
  'provincial',
  'in',
  'pathan',
  'mountaineer',
  'european',
  'glaciated',
  'siberian',
  'climate',
  'interning',
  'iceland',
  'forester',
  'icecap',
  'district',
  'elkhound',
  'bohemian',
  'denizen',
  'himalayan',
  'neighbor',
  'cold',
  'internal',
  'thaw',
  'familiar',
  'outstation',
  'afghan',
  'snowfall',
  'chilled',
  'interior',
  'refrigerated',
  'cooler',
  'villager',
  'hole',
  'grot',
  'igloo',
  'taiga',
  'gypsy',
  'refrigerant',
  'frosty',
  'hibernate',
  'iceman',
  'clown',
  'snowcap',
  'province',
  'outerwear',
  'den',
  'malamute',
  'swiss',
  'hun',
  'dweller',
  'interglacial',
  'immigrant',
  'communist',
  'visitor',
  'bavarian',
  'haft',
  'logan',
  'frozen',
  'inhabitant',
  'icing',
  

In [47]:
getPredFromDesc(model, 'employee at a circus', 100)

(['bughouse',
  'sideshow',
  'museology',
  'clown',
  'mountebank',
  'impresario',
  'butchers',
  'fairground',
  'loge',
  'pageant',
  'shew',
  'shop',
  'canteen',
  'shopping',
  'panopticon',
  'cad',
  'trapeze',
  'butcher',
  'professor',
  'dpa',
  'gypsy',
  'carnival',
  'exhibition',
  'stunt',
  'circus',
  'bazaar',
  'stalls',
  'hostel',
  'stunting',
  'snark',
  'museum',
  'usher',
  'nox',
  'garter',
  'sport',
  'parquet',
  'company',
  'pavillion',
  'pavilion',
  'upstage',
  'goat',
  'serai',
  'barnstorm',
  'hosiery',
  'roustabout',
  'situation',
  'bellboy',
  'troupe',
  'gripping',
  'carrousel',
  'fooling',
  'laboratory',
  'commissary',
  'barrack',
  'conservatory',
  'publican',
  'illusionist',
  'circle',
  'spectator',
  'mansfield',
  'expo',
  'staged',
  'showgrounds',
  'town',
  'amphitheater',
  'acy',
  'buffalo',
  'spectacular',
  'edmonton',
  'theater',
  'trapped',
  'theatergoer',
  'school',
  'vault',
  'lamasery',
  'gymna

In [48]:
getPredFromDesc(model, 'a road on which cars can go fast', 100)

(['fasting',
  'mobile',
  'roll',
  'motorcade',
  'drive',
  'speedster',
  'csr',
  'flash',
  'streamed',
  'fast',
  'hitchhiking',
  'crash',
  'roadster',
  'tracks',
  'cab',
  'speedway',
  'rally',
  'detroit',
  'van',
  'race',
  'hollywood',
  'street',
  'driveway',
  'combustion',
  'concourse',
  'slick',
  'carpool',
  'brisk',
  'rushing',
  'rolled',
  'cruiser',
  'compact',
  'trotting',
  'car',
  'runs',
  'chase',
  'riad',
  'tailgate',
  'swarming',
  'trucks',
  'rattle',
  'driving',
  'rallying',
  'carouse',
  'taxiing',
  'lurch',
  'shuttle',
  'bowled',
  'zip',
  'film',
  'rushes',
  'highway',
  'jeep',
  'speedo',
  'run',
  'fleet',
  'dribble',
  'highball',
  'snowplow',
  'bua',
  'mush',
  'corner',
  'clatter',
  'cannonball',
  'pickup',
  'bike',
  'close',
  'camber',
  'rattling',
  'ruck',
  'banish',
  'inside',
  'cam',
  'tipple',
  'go',
  'quickstep',
  'motorized',
  'boxcar',
  'flicking',
  'fender',
  'alert',
  'whip',
  'hasten

In [24]:
model_name = type(model).__name__
filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
with open(filename, 'wb+') as f:
    torch.save({'state_dict': model.state_dict()}, f)

In [35]:
T("you are not helpless", return_tensors='pt', padding=True)

{'input_ids': tensor([[  101,  2017,  2024,  2025, 13346,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}