In [1]:
import torch
from torch import nn
# import torchtext
# Should also first install pytorch-transformers (aka transformers)
# See here https://pytorch.org/hub/huggingface_pytorch-transformers/
# and here https://huggingface.co/transformers/
# You might also have to manually pip install sentencepiece
from torch.cuda.amp import autocast, GradScaler

import numpy as np

from typing import List
from tqdm import tqdm
import sys
sys.path.append('../code')

from dataset import get_data, WantWordsDataset as WWData
import datetime

from transformers import AdamW, get_linear_schedule_with_warmup

In [5]:
# Download vocabulary from S3 and cache.
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')    

Downloading: "https://github.com/huggingface/pytorch-transformers/archive/master.zip" to /home/ubuntu/.cache/torch/hub/master.zip


In [3]:
# Download model and configuration from S3 and cache.
bert = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')    

Using cache found in /home/ubuntu/.cache/torch/hub/huggingface_pytorch-transformers_master


In [18]:
class BertBaseline(nn.Module):
    def __init__(self, embedding_dim, encoder, dropout=0.1):
        super(BertBaseline, self).__init__()
        self.encoder = encoder # the BERT model
        self.decoder = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(768, embedding_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_normal_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)
        
    def forward(self, x : torch.Tensor, attention_mask : torch.Tensor):
        out = self.encoder(x, attention_mask)
        encoded = out.get('last_hidden_state')[:,0,:]
        return self.decoder(encoded)

In [3]:
d, word2vec = get_data('../wantwords-english-baseline/data')

Loading data...
word2vec: 75099 vectors
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [6]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

train_dataset = WWData(train_data + train_data_def, word2vec, 300, tokenizer)
dev_dataset = WWData(dev_data, word2vec, 300, tokenizer)
# Three distinct test sets
test_dataset_seen = WWData(test_data_seen, word2vec, 300, tokenizer)
test_dataset_unseen = WWData(test_data_unseen, word2vec, 300, tokenizer)
test_dataset_desc = WWData(test_data_desc, word2vec, 300, tokenizer)

In [7]:
batch_size = 32
num_workers = 4

make_loader = lambda dataset, shuffle: torch.utils.data.DataLoader(
                dataset, shuffle=True, pin_memory=False,
                batch_size=batch_size, num_workers=num_workers,
                collate_fn=dataset.collate_fn)

train_loader = make_loader(train_dataset, True)
print(f'Train loader: {len(train_loader)}')
dev_loader = make_loader(dev_dataset, True)
print(f'Dev loader: {len(dev_loader)}')

test_loader_seen = make_loader(test_dataset_seen, False)
print(f'Test loader (seen): {len(test_loader_seen)}')
test_loader_unseen = make_loader(test_dataset_unseen, False)
print(f'Test loader (unseen): {len(test_loader_unseen)}')
test_loader_desc = make_loader(test_dataset_desc, False)
print(f'Test loader (descriptions): {len(test_loader_desc)}')

Train loader: 21117
Dev loader: 2372
Test loader (seen): 16
Test loader (unseen): 16
Test loader (descriptions): 7


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
dropout = 0.1
model = torch.load('../trained_models/BertBaseline_Epoch_10_at_2021-04-11_00:55:19.961444')
# model = BertBaseline(300, bert, dropout=dropout)
model = model.to(device)
model

BertBaseline(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [27]:
criterion = nn.CosineSimilarity()

epochs = 5

lr = 5e-5
weight_decay = 1e-6
optim = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=(len(train_loader) // 2), num_training_steps=(5 * len(train_loader) // 2))

epoch = 0

In [12]:
scaler = GradScaler()

In [13]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheulder = type(scheduler).__name__
config.dropout = dropout

# wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)


In [35]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(min(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1, acc10, acc100, pred_rank)
    else:
        return acc1/n, acc10/n, acc100/n

In [10]:
embeds = word2vec.embeddings.detach().clone().T.to(device)

In [16]:
inc = 10
losses = []

for epoch in range(epoch, epochs):
    model.train()
    train_loss = 0.0
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)//2) as pbar:
        for i, (x, (y,_)) in zip(range(len(train_loader)//2), train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

            optim.zero_grad()

            x, attention_mask = x.input_ids.to(device), x.attention_mask.to(device)
            y = y.to(device)
            
            with autocast():
                out = model(x, attention_mask)

                loss = 1 - criterion(out, y).sum() / len(x)
                
            scaler.scale(loss).backward()

            scaler.step(optim)
            scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss
            if i % 20 == 0:
                torch.cuda.empty_cache()
    wandb.log({'train_loss': train_loss/(len(train_loader)//2)})

    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save(model, f)
    
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    with torch.no_grad():
        with tqdm(total=len(dev_loader)) as pbar:
            for i, (x, (y,y_inds)) in enumerate(dev_loader):
                if i % inc == 0 and i != 0:
                    display_loss = val_loss / i
                    pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                x, attention_mask = x.input_ids.to(device), x.attention_mask.to(device)
                y = y.to(device)
                
                with autocast():
                    out = model(x, attention_mask)

                    loss = 1 - criterion(out, y).sum() / len(x)
                    
                    probs = out.mm(embeds)

                val_loss += loss.detach()

                pbar.update(1)
                
                
                result, indices = torch.sort(probs, descending=True)
                
                acc1, acc10, acc100 = evaluate(indices, y_inds)
                val_acc1 += acc1
                val_acc10 += acc10
                val_acc100 += acc100

                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
        
    dev_length = len(dev_loader)
    wandb.log({'val_loss': val_loss/dev_length})    
    wandb.log({'val_acc1': val_acc1/dev_length})
    wandb.log({'val_acc10': val_acc10/dev_length})
    wandb.log({'val_acc100': val_acc100/dev_length})

Epoch 1, Train Loss: 0.6412608623504639: 100%|██████████| 10558/10558 [22:50<00:00,  7.71it/s]
Epoch 1, Val Loss: 0.5484607219696045: 100%|██████████| 2639/2639 [02:03<00:00, 21.39it/s]
Epoch 2, Train Loss: 0.5477613806724548: 100%|██████████| 10558/10558 [23:06<00:00,  7.61it/s]
Epoch 2, Val Loss: 0.4984854459762573: 100%|██████████| 2639/2639 [02:03<00:00, 21.37it/s] 
Epoch 3, Train Loss: 0.5111475586891174: 100%|██████████| 10558/10558 [23:06<00:00,  7.61it/s]
Epoch 3, Val Loss: 0.4665115475654602: 100%|██████████| 2639/2639 [02:04<00:00, 21.21it/s] 
Epoch 4, Train Loss: 0.48678988218307495: 100%|██████████| 10558/10558 [23:02<00:00,  7.64it/s]
Epoch 4, Val Loss: 0.4433624744415283: 100%|██████████| 2639/2639 [02:04<00:00, 21.26it/s] 
Epoch 5, Train Loss: 0.4698292315006256: 100%|██████████| 10558/10558 [23:02<00:00,  7.64it/s] 
Epoch 5, Val Loss: 0.433894544839859: 100%|██████████| 2639/2639 [02:03<00:00, 21.40it/s]  


In [24]:
# Informally test the model
model.eval()
x, y = train_dataset.collate_fn([("a type of gun", ''),
                                 ("native of cold country", ""),
                                 ("person who knows a lot", "")])
# there seem to be a lot of gun-related entries in the dictionary...
out = model(x.input_ids.to(device), x.attention_mask.to(device))

# Get most likely words
logits = out @ embeds
result, indices = torch.sort(logits, descending=True)

for k in range(3):
    for i in range(10):
        j = indices[k][i]
        print(train_dataset.embeddings.itos[int(j)])
    print()

suplex
musket
flintlock
saucepan
bucktail
harpoon
olp
bobber
awl
chokeslam

steppes
steppe
taiga
savannahs
yeti
tundra
pikas
wanderer
frontiersman
pika

mensch
thinker
snob
polemicist
philosopher
tinkerer
plagiarist
polymath
ignoramus
autodidact



In [38]:
def test(loader, name):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = test_rank_median = test_rank_variance = 0.0
    total_seen = 0
    all_pred = []
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, (x, (y,y_inds)) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x, attention_mask = x.input_ids.to(device), x.attention_mask.to(device)
                y = y.to(device)

                with autocast():
                    out = model(x, attention_mask)

                    loss = 1 - criterion(out, y).sum() / len(x)

                    probs = out.mm(embeds)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(probs, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, pred_rank = evaluate(indices, y_inds, test=True)
                test_acc1 += acc1
                test_acc10 += acc10
                test_acc100 += acc100
                total_seen += b
                all_pred.extend(pred_rank)
                
                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_loss /= len(loader)
    test_acc1 /= total_seen
    test_acc10 /= total_seen
    test_acc100 /= total_seen
    all_pred = torch.tensor(all_pred)
    median = torch.median(all_pred)
    var = torch.var(all_pred)**0.5
    
    print(f'{name}_test_loss:', test_loss)
    print(f'{name}_test_acc1:', test_acc1)
    print(f'{name}_test_acc10:', test_acc10)
    print(f'{name}_test_acc100:', test_acc100)
    print(f'{name}_test_rank_median:', median)
    print(f'{name}_test_rank_variance', var)
    
#     wandb.log({
#         f'{name}_test_loss': test_loss / test_length,
#         f'{name}_test_acc1': test_acc1 / test_length,
#         f'{name}_test_acc10': test_acc10 / test_length,
#         f'{name}_test_acc100': test_acc100 / test_length,
#         f'{name}_test_rank_median': test_rank_median / test_length,
#         f'{name}_test_rank_variance': test_rank_variance / test_length
#     })

In [39]:
test(test_loader_unseen, 'unseen')

Test Loss: 0.4879764914512634: 100%|██████████| 16/16 [00:01<00:00, 15.87it/s] 

unseen_test_loss: tensor(0.4880, device='cuda:0')
unseen_test_acc1: 0.028
unseen_test_acc10: 0.188
unseen_test_acc100: 0.39
unseen_test_rank_median: tensor(211.)
unseen_test_rank_variance tensor(439.5819)





In [40]:
test(test_loader_seen, 'seen')

Test Loss: 0.39735373854637146: 100%|██████████| 16/16 [00:01<00:00, 14.76it/s]

seen_test_loss: tensor(0.3959, device='cuda:0')
seen_test_acc1: 0.14
seen_test_acc10: 0.39
seen_test_acc100: 0.612
seen_test_rank_median: tensor(34.)
seen_test_rank_variance tensor(387.9576)





In [41]:
test(test_loader_desc, 'description')

Test Loss: 0.4008331894874573: 100%|██████████| 7/7 [00:00<00:00, 11.10it/s]

description_test_loss: tensor(0.4055, device='cuda:0')
description_test_acc1: 0.085
description_test_acc10: 0.33
description_test_acc100: 0.62
description_test_rank_median: tensor(38.)
description_test_rank_variance tensor(369.0333)





In [16]:
model

BertBaseline(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr