In [1]:
import torch
from torch import nn
# import torchtext
# Should also first install pytorch-transformers (aka transformers)
# See here https://pytorch.org/hub/huggingface_pytorch-transformers/
# and here https://huggingface.co/transformers/
# You might also have to manually pip install sentencepiece
from torch.cuda.amp import autocast, GradScaler

import numpy as np

from typing import List
from tqdm import tqdm
import sys
import datetime

sys.path.append('../code')
from dataset import get_data, WantWordsDataset as WWData

from transformers import (
    AdamW, get_linear_schedule_with_warmup, 
    EncoderDecoderModel, BertGenerationEncoder, BertGenerationDecoder, BertTokenizer
)

import gc

In [2]:
# Download vocabulary from S3 and cache.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')    

In [3]:
# Download model and configuration from S3 and cache.
enc_dec = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer

In [4]:
enc_dec

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [5]:
tokens = tokenizer(["a hot and dry place", "something you eat after dinner"], return_tensors='pt', padding=True)
print('input', tokens)
ground_truth = tokenizer(["desert", "dessert"], return_tensors='pt', padding=True)
print('ground truth', ground_truth)

input_ids = tokens['input_ids']
print(input_ids.shape)
attention_mask = tokens['attention_mask']
out = enc_dec(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids)

print('out', out)
logits = out['logits']
print(logits.shape)

best = logits.argmax(-1)
print(best.shape)
tokenizer.decode(best[0])

print(type(tokens))

input {'input_ids': tensor([[ 101, 1037, 2980, 1998, 4318, 2173,  102],
        [ 101, 2242, 2017, 4521, 2044, 4596,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}
ground truth {'input_ids': tensor([[  101,  5532,   102],
        [  101, 18064,   102]]), 'token_type_ids': tensor([[0, 0, 0],
        [0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1],
        [1, 1, 1]])}
torch.Size([2, 7])
out Seq2SeqLMOutput(loss=None, logits=tensor([[[ -9.1125,  -9.8064,  -9.7444,  ...,  -6.9838,  -6.4828,  -5.1133],
         [ -8.6599,  -9.0282,  -9.0602,  ...,  -8.1490,  -7.1236,  -2.8740],
         [-10.6954, -10.9401, -10.9348,  ...,  -9.0845,  -7.7530,  -4.8472],
         ...,
         [-11.4635, -11.7717, -11.7954,  ..., -10.1804,  -8.9493,  -5.1916],
         [-11.4487, -11.5727, -11.6773,  ...,  -9.4174,  -8.9801,  -3.8844],
         [-10.3390, -10.6331, -10.5138,  ...,  

In [6]:
criterion = nn.CrossEntropyLoss()

gt_lens = torch.sum(ground_truth['attention_mask'], dim=-1) - 2
gt_input_ids = ground_truth['input_ids']
Y = [gt_input_ids[i][1:1+gt_lens[i]] for i in range(len(gt_lens))]
X = [logits[i][1:1+gt_lens[i]] for i in range(len(gt_lens))]

print(X)
print(Y)

batch_loss = sum(criterion(x, y) for x, y in zip(X, Y))
batch_loss / len(gt_lens)

[tensor([[-8.6599, -9.0282, -9.0602,  ..., -8.1490, -7.1236, -2.8740]],
       grad_fn=<SliceBackward>), tensor([[-9.2955, -9.6788, -9.4659,  ..., -9.0749, -6.9601, -6.7503]],
       grad_fn=<SliceBackward>)]
[tensor([5532]), tensor([18064])]


tensor(15.3459, grad_fn=<DivBackward0>)

In [7]:
class BertEncDec(nn.Module):
    def __init__(self, enc_dec, criterion):
        super(BertEncDec, self).__init__()
        self.enc_dec = enc_dec # the BERT encoder/decoder
        self.criterion = criterion
    
    def _init_weights(self):
        nn.init.xavier_normal_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)
        
    def forward(self, x , y):
        '''Where x, y are BatchEncodings returned by a tokenizer object'''
        input_ids, attention_mask = x['input_ids'], x['attention_mask']
        batch_size = len(input_ids)
        
        out = self.enc_dec(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids)
        
        logits = out['logits']
        gt = y['input_ids']
        gt_lens = torch.sum(y['attention_mask'], dim=-1) - 2 # subtract 2 to account for start/end tokens
        Y = (gt[i][1:1+gt_lens[i]] for i in range(batch_size))
        X = (logits[i][1:1+gt_lens[i]] for i in range(batch_size))
        batch_loss = sum(self.criterion(x, y) for x, y in zip(X, Y))
        
        return logits, batch_loss / batch_size

In [8]:
d, word2vec = get_data('../wantwords-english-baseline/data')

Loading data...
word2vec: 75099 vectors
Training data: 675715 word-def pairs
Dev data: 75873 word-def pairs
Test data: 1200 word-def pairs


In [9]:
train_data, train_data_def, dev_data, test_data_seen, \
    test_data_unseen, test_data_desc = d

train_dataset = WWData(train_data + train_data_def, word2vec, 300, tokenizer)
dev_dataset = WWData(dev_data, word2vec, 300, tokenizer)
# Three distinct test sets
test_dataset_seen = WWData(test_data_seen, word2vec, 300, tokenizer)
test_dataset_unseen = WWData(test_data_unseen, word2vec, 300, tokenizer)
test_dataset_desc = WWData(test_data_desc, word2vec, 300, tokenizer)

In [10]:
batch_size = 16
num_workers = 4

make_loader = lambda dataset, shuffle: \
                torch.utils.data.DataLoader(
                        dataset, shuffle=shuffle, pin_memory=False,
                        batch_size=batch_size, num_workers=num_workers,
                        collate_fn=lambda x: dataset.collate_fn(x, word2vec=False))

train_loader = make_loader(train_dataset, True)
print(f'Train loader: {len(train_loader)}')
dev_loader = make_loader(dev_dataset, True)
print(f'Dev loader: {len(dev_loader)}')

test_loader_seen = make_loader(test_dataset_seen, False)
print(f'Test loader (seen): {len(test_loader_seen)}')
test_loader_unseen = make_loader(test_dataset_unseen, False)
print(f'Test loader (unseen): {len(test_loader_unseen)}')
test_loader_desc = make_loader(test_dataset_desc, False)
print(f'Test loader (descriptions): {len(test_loader_desc)}')

Train loader: 42233
Dev loader: 4743
Test loader (seen): 32
Test loader (unseen): 32
Test loader (descriptions): 13


In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
# model = torch.load('../trained_models/bert_baseline_wwdata.pt')
criterion = nn.CrossEntropyLoss()
model = BertEncDec(enc_dec, criterion)
model = model.to(device)
model

BertEncDec(
  (enc_dec): EncoderDecoderModel(
    (encoder): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
     

In [13]:
epochs = 5

lr = 5e-5
optim = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps=(len(train_loader) // 10), 
                                            num_training_steps=(epochs * len(train_loader)))

epoch = 0

In [14]:
scaler = GradScaler()

In [15]:
import wandb

wandb.init(project='reverse-dictionary', entity='reverse-dict')

config = wandb.config
config.learning_rate = lr
config.epochs = epochs
config.batch_size = batch_size
config.optimizer = type(optim).__name__
config.scheulder = type(scheduler).__name__

# wandb.watch(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mreverse-dict[0m (use `wandb login --relogin` to force relogin)


In [16]:
def evaluate(pred, gt, test=False):
    acc1 = acc10 = acc100 = 0
    n = len(pred)
    pred_rank = []
    for p, word in zip(pred, gt):
        if test:
            loc = (p == word).nonzero(as_tuple=True)
            if len(loc) != 0:
                pred_rank.append(max(loc[-1], 1000))
            else:
                pred_rank.append(1000)
        if word in p[:100]:
            acc100 += 1
            if word in p[:10]:
                acc10 += 1
                if word == p[0]:
                    acc1 += 1
    if test:
        pred_rank = torch.tensor(pred_rank, dtype=torch.float32)
        return (acc1/n, acc10/n, acc100/n, 
                torch.median(pred_rank), torch.sqrt(torch.var(pred_rank)))
    else:
        return acc1/n, acc10/n, acc100/n

In [17]:
# embeds = word2vec.embeddings.detach().clone().T.to(device)

In [18]:
inc = 10
losses = []

for epoch in range(epoch, epochs):
    model.train()
    train_loss = 0.0
    length = len(train_loader)
    # Train on subset of training data to save time
    with tqdm(total=len(train_loader)) as pbar:
        for i, (x, y) in enumerate(train_loader):
            if i % inc == 0 and i != 0:
                display_loss = train_loss / i
                pbar.set_description(f'Epoch {epoch+1}, Train Loss: {train_loss / i}')

                
            if i == length // 4 or i == length // 2 or i == 3 * length // 4:
                model_name = type(model).__name__
                if i == length // 4:
                    frac = '.25'
                elif i == length // 2:
                    frac = '.5'
                else:
                    frac = '.75'
                filename = f'../trained_models/{model_name} Epoch {epoch+1}{frac} at {datetime.datetime.now()}'.replace(' ', '_')
                with open(filename, 'wb+') as f:
                    torch.save(model, f)

            optim.zero_grad()
            x['input_ids'] = x['input_ids'].to(device)
            x['attention_mask'] = x['attention_mask'].to(device)
            y['input_ids'] = y['input_ids'].to(device)
                
            with autocast():
                out, loss = model(x, y)
                
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
            
            train_loss += loss.detach()
            
            scheduler.step()
            
            pbar.update(1)
            
            del x, y, out, loss
            if i % 20 == 0:
                torch.cuda.empty_cache()
    wandb.log({'train_loss': train_loss/(len(train_loader)//2)})

    model_name = type(model).__name__
    filename = f'../trained_models/{model_name} Epoch {epoch+1} at {datetime.datetime.now()}'.replace(' ', '_')
    with open(filename, 'wb+') as f:
        torch.save(model, f)
    
    model.eval()
    val_loss = 0.0
    val_acc1, val_acc10, val_acc100 = 0.0, 0.0, 0.0
    with torch.no_grad():
        with tqdm(total=len(dev_loader)) as pbar:
            for i, (x, y) in enumerate(dev_loader):
                if i % inc == 0 and i != 0:
                    display_loss = val_loss / i
                    pbar.set_description(f'Epoch {epoch+1}, Val Loss: {val_loss / i}')

                x['input_ids'] = x['input_ids'].to(device)
                x['attention_mask'] = x['attention_mask'].to(device)
                y['input_ids'] = y['input_ids'].to(device)
                
                with autocast():
                    out, loss = model(x, y)

                val_loss += loss.detach()

                pbar.update(1)
                
                
#                 result, indices = torch.sort(probs, descending=True)
                
#                 acc1, acc10, acc100 = evaluate(indices, y_inds)
#                 val_acc1 += acc1
#                 val_acc10 += acc10
#                 val_acc100 += acc100

                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
        
#     dev_length = len(dev_loader)
#     wandb.log({'val_loss': val_loss/dev_length})    
#     wandb.log({'val_acc1': val_acc1/dev_length})
#     wandb.log({'val_acc10': val_acc10/dev_length})
#     wandb.log({'val_acc100': val_acc100/dev_length})

Epoch 1, Train Loss: 5.447841644287109:  69%|██████▊   | 28996/42233 [1:44:25<47:11,  4.67it/s]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 1, Train Loss: 5.2306671142578125: 100%|██████████| 42233/42233 [2:32:00<00:00,  4.63it/s] 
Epoch 1, Val Loss: 9.999505996704102: 100%|██████████| 4743/4743 [03:31<00:00, 22.41it/s] 
Epoch 2, Train Loss: 3.683689594268799: 100%|██████████| 42233/42233 [2:32:04<00:00,  4.63it/s]   
Epoch 2, Val Loss: 10.356205940246582: 100%|██████████| 4743/4743 [03:32<00:00, 22.33it/s]
Epoch 3, Train Loss: 2.8184902667999268: 100%|██████████| 42233/42233 [2:32:10<00:00,  4.63it/s]  
Epoch 3, Val Loss: 10.520895004272461: 100%|██████████| 4743/4743 [03:32<00:00, 22.29it/s]
Epoch 4,

KeyboardInterrupt: 

In [23]:
# Informally test the model
model.eval()
x, y = train_dataset.collate_fn([("a type of gun", ''),
                                 ("native of cold country", ""),
                                 ("someone who owns land", "")], False)
# there seem to be a lot of gun-related entries in the dictionary...
x['input_ids'] = x['input_ids'].to(device)
x['attention_mask'] = x['attention_mask'].to(device)
y['input_ids'] = y['input_ids'].to(device)

out, _ = model(x, y)

# Get most likely words
best = out.argmax(dim=-1)

for k in range(len(x)):
    print(tokenizer.decode(best[k]))

gun gun gun gun gun gun
rustic rustic rustic rustic rustic rustic
land land land land land land


In [73]:
def test(loader, name):
    inc = 3
    model.eval()
    test_loss = 0.0
    test_acc1 = test_acc10 = test_acc100 = test_rank_median = test_rank_variance = 0.0
    with torch.no_grad():
        with tqdm(total=len(loader)) as pbar:
            for i, (x, (y,y_inds)) in enumerate(loader):
                if i % inc == 0 and i != 0:
                    display_loss = test_loss / i
                    pbar.set_description(f'Test Loss: {display_loss}')

                x, attention_mask = x.input_ids.to(device), x.attention_mask.to(device)
                y = y.to(device)

                with autocast():
                    out = model(x, attention_mask)

                    loss = 1 - criterion(out, y).sum() / len(x)

                    probs = out.mm(embeds)

                test_loss += loss.detach()

                pbar.update(1)

                result, indices = torch.sort(probs, descending=True)
                
                b = len(x)
                acc1, acc10, acc100, rank_median, rank_variance = evaluate(indices, y_inds, test=True)
                test_acc1 += (acc1 * b / batch_size)
                test_acc10 += (acc10 * b / batch_size)
                test_acc100 += (acc100 * b / batch_size)
                test_rank_median += (rank_median * b / batch_size)
                test_rank_variance += (rank_variance * b / batch_size)

                del x, y, out, loss
                if i % 20 == 0:
                    torch.cuda.empty_cache()
    
    test_length = len(loader)
    
    print(f'{name}_test_loss:', test_loss / test_length)
    print(f'{name}_test_acc1:', test_acc1 / test_length)
    print(f'{name}_test_acc10:', test_acc10 / test_length)
    print(f'{name}_test_acc100:', test_acc100 / test_length)
    print(f'{name}_test_rank_median:', test_rank_median / test_length)
    print(f'{name}_test_rank_variance:', test_rank_variance / test_length)
    
    wandb.log({
        f'{name}_test_loss': test_loss / test_length,
        f'{name}_test_acc1': test_acc1 / test_length,
        f'{name}_test_acc10': test_acc10 / test_length,
        f'{name}_test_acc100': test_acc100 / test_length,
        f'{name}_test_rank_median': test_rank_median / test_length,
        f'{name}_test_rank_variance': test_rank_variance / test_length
    })

In [74]:
test(test_loader_unseen, 'unseen')

Test Loss: 0.48293179273605347: 100%|██████████| 16/16 [00:00<00:00, 17.67it/s]

unseen_test_loss: tensor(0.4835, device='cuda:0')
unseen_test_acc1: 0.03125
unseen_test_acc10: 0.181640625
unseen_test_acc100: 0.39453125
unseen_test_rank_median: tensor(1001.)
unseen_test_rank_variance: tensor(7662.0938)





In [75]:
test(test_loader_seen, 'seen')

Test Loss: 0.4254843294620514: 100%|██████████| 16/16 [00:00<00:00, 17.70it/s] 

seen_test_loss: tensor(0.4242, device='cuda:0')
seen_test_acc1: 0.0859375
seen_test_acc10: 0.31640625
seen_test_acc100: 0.533203125
seen_test_rank_median: tensor(976.5625)
seen_test_rank_variance: tensor(5600.6313)





In [77]:
test(test_loader_desc, 'description')

Test Loss: 0.3972916305065155: 100%|██████████| 7/7 [00:00<00:00, 12.21it/s]

description_test_loss: tensor(0.4196, device='cuda:0')
description_test_acc1: 0.0625
description_test_acc10: 0.2544642857142857
description_test_acc100: 0.5178571428571429
description_test_rank_median: tensor(892.8571)
description_test_rank_variance: tensor(5165.1597)





In [21]:
gc.collect()

20