### Homework 5: Question search engine

Remeber week01 where you used GloVe embeddings to find related questions? That was.. cute, but far from state of the art. It's time to really solve this task using context-aware embeddings.

__Warning:__ this task assumes you have seen `seminar.ipynb`!

In [None]:
# %pip install --upgrade transformers datasets accelerate deepspeed

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import datasets
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score
device = 'cuda'

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Load data and model

In [3]:
qqp = datasets.load_dataset('SetFit/qqp')
print('\n')
print("Sample[0]:", qqp['train'][0])
print("Sample[3]:", qqp['train'][3])

Repo card metadata block was not found. Setting CardData to empty.




Sample[0]: {'text1': 'How is the life of a math student? Could you describe your own experiences?', 'text2': 'Which level of prepration is enough for the exam jlpt5?', 'label': 0, 'idx': 0, 'label_text': 'not duplicate'}
Sample[3]: {'text1': 'What can one do after MBBS?', 'text2': 'What do i do after my MBBS ?', 'label': 1, 'idx': 3, 'label_text': 'duplicate'}


In [70]:
model_name = "gchhablani/bert-base-cased-finetuned-qqp"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name)

### Tokenize the data

In [71]:
MAX_LENGTH = 128
def preprocess_function(examples):
    result = tokenizer(
        examples['text1'], examples['text2'],
        padding='max_length', max_length=MAX_LENGTH, truncation=True
    )
    result['label'] = examples['label']
    return result

qqp_preprocessed = qqp.map(preprocess_function, batched=True)

In [72]:
print(repr(qqp_preprocessed['train'][0]['input_ids'])[:100], "...")

[101, 1731, 1110, 1103, 1297, 1104, 170, 12523, 2377, 136, 7426, 1128, 5594, 1240, 1319, 5758, 136,  ...


### Task 1: evaluation (1 point)

We randomly chose a model trained on QQP - but is it any good?

One way to measure this is with validation accuracy - which is what you will implement next.

Here's the interface to help you do that:

In [73]:
val_set = qqp_preprocessed['validation']
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=1, shuffle=False, collate_fn=transformers.default_data_collator
)

In [76]:
for batch in val_loader:
     break  # here be your training code
print("Sample batch:", batch)

with torch.no_grad():
  predicted = model(
      input_ids=batch['input_ids'],
      attention_mask=batch['attention_mask'],
      token_type_ids=batch['token_type_ids']
  )

print('\nPrediction (probs):', torch.softmax(predicted.logits, dim=1).data.numpy())

Sample batch: {'labels': tensor([0]), 'idx': tensor([0]), 'input_ids': tensor([[  101,  2009,  1132,  2170,   118,  4038,  1177,  2712,   136,   102,
          2009,  1132,  1117, 10224,  4724,  1177,  2712,   136,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,   

__Your task__ is to measure the validation accuracy of your model.
Doing so naively may take several hours. Please make sure you use the following optimizations:

- run the model on GPU with no_grad
- using batch size larger than 1
- use optimize data loader with num_workers > 1
- (optional) use [mixed precision](https://pytorch.org/docs/stable/notes/amp_examples.html)


In [77]:
model = model.to(device)

In [78]:
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=16, shuffle=False, collate_fn=transformers.default_data_collator, num_workers=6
)
preds = []
gts = []
for i, batch in enumerate(tqdm(val_loader)):
    batch = {k: v.to(device) for k, v in batch.items()}

    predicted = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        token_type_ids=batch['token_type_ids']
    )
    preds.extend(torch.argmax(predicted.logits, dim=-1).data.cpu().numpy())
    gts.extend(batch['labels'].data.cpu().numpy())
accuracy = accuracy_score(gts, preds)
print(f'Accuracy is {accuracy}')

100%|██████████| 2527/2527 [02:14<00:00, 18.80it/s]

Accuracy is 0.9083848627256987





In [79]:
assert 0.9 < accuracy < 0.91

### Task 2: train the model (4 points)

For this task, you have two options:

__Option A:__ fine-tune your own model. You are free to choose any model __except for the original BERT.__ We recommend [DeBERTa-v3](https://huggingface.co/microsoft/deberta-v3-base). Better yet, choose the best model based on public benchmarks (e.g. [GLUE](https://gluebenchmark.com/)).

You can write the training code manually or use transformers.Trainer (see [this example](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification)). Please make sure that your model's accuracy is at least __comparable__ with the above example for BERT.


__Option B:__ compare at least 3 pre-finetuned models (in addition to the above BERT model). For each model, report (1) its accuracy, (2) its speed, measured in samples per second in your hardware setup and (3) its size in megabytes. Please take care to compare models in equal setting, e.g. same CPU / GPU. Compile your results into a table and write a short (~half-page on top of a table) report, summarizing your findings.

In [4]:
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer

In [5]:
model_name = "microsoft/deberta-v3-base"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, output_hidden_states=True)

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['pooler.dense.bias', 'classifier.weight', 'classifier.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
MAX_LENGTH = 128
def preprocess_function(examples):
    result = tokenizer(
        examples['text1'], examples['text2'],
        padding='max_length', max_length=MAX_LENGTH, truncation=True
    )
    result['label'] = examples['label']
    return result

qqp_preprocessed = qqp.map(preprocess_function, batched=True)

In [7]:
val_set = qqp_preprocessed['validation']
train_set = qqp_preprocessed['train']

In [8]:
model = model.to(device)

In [9]:
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=8, shuffle=False, collate_fn=transformers.default_data_collator, num_workers=6
)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=8, shuffle=True, collate_fn=transformers.default_data_collator, num_workers=6
)

In [10]:
def validation_one_epoch(val_loader, model):
    preds = []
    gts = []
    model.eval()
    for i, batch in enumerate(tqdm(val_loader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            predicted = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                token_type_ids=batch['token_type_ids']
            )
            preds.extend(torch.argmax(predicted.logits, dim=-1).data.cpu().numpy())
            gts.extend(batch['labels'].data.cpu().numpy())
    accuracy = accuracy_score(gts, preds)
    return accuracy

def train_one_epoch(train_loader, model, criterion, opt):
    epoch_losses = []
    model.train(True)
    for i, batch in enumerate(tqdm(train_loader)):
        opt.zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}

        predicted = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            token_type_ids=batch['token_type_ids']
        )
        loss = criterion(predicted.logits, batch['labels'])
        loss.backward()
        opt.step()
        if i % 1000 == 0:
            print(f'{loss.item()}')
        epoch_losses.append(loss.item())
    return epoch_losses


In [11]:
starting_accuracy = validation_one_epoch(val_loader, model)
print(f'Starting accuracy is {starting_accuracy}')


100%|██████████| 5054/5054 [03:15<00:00, 25.84it/s]

Starting accuracy is 0.60563937670047





In [14]:
opt = torch.optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
num_epoch = 2

In [13]:
train_losses = []
val_accuracies = []
for epoch in range(num_epoch):
    print(f'Epoch {epoch} started')
    epoch_losses = train_one_epoch(train_loader, model, criterion, opt)
    train_losses.extend(epoch_losses)

    epoch_accuracy = validation_one_epoch(val_loader, model)
    val_accuracies.append(epoch_accuracy)
    print(f'Epoch train loss: {np.mean(epoch_losses)}, Epoch accuracy: {epoch_accuracy}')

Epoch 0 started


  0%|          | 2/45481 [00:00<2:52:14,  4.40it/s]

0.6778844594955444


  2%|▏         | 1002/45481 [02:34<1:54:17,  6.49it/s]

0.250326931476593


  4%|▍         | 2002/45481 [05:08<1:51:38,  6.49it/s]

0.5246270298957825


  7%|▋         | 3002/45481 [07:42<1:49:08,  6.49it/s]

0.3130256235599518


  9%|▉         | 4002/45481 [10:17<1:46:40,  6.48it/s]

0.26272690296173096


 11%|█         | 5002/45481 [12:51<1:44:08,  6.48it/s]

0.33838358521461487


 13%|█▎        | 6002/45481 [15:25<1:41:41,  6.47it/s]

0.17657075822353363


 15%|█▌        | 7002/45481 [17:59<1:39:06,  6.47it/s]

0.42277398705482483


 18%|█▊        | 8002/45481 [20:34<1:36:19,  6.49it/s]

0.5386254787445068


 20%|█▉        | 9002/45481 [23:08<1:33:48,  6.48it/s]

0.10299867391586304


 22%|██▏       | 10002/45481 [25:42<1:31:17,  6.48it/s]

0.08231203258037567


 24%|██▍       | 11002/45481 [28:17<1:28:40,  6.48it/s]

0.17699724435806274


 26%|██▋       | 12002/45481 [30:51<1:26:12,  6.47it/s]

0.48658809065818787


 29%|██▊       | 13002/45481 [33:25<1:23:28,  6.48it/s]

0.40795469284057617


 31%|███       | 14002/45481 [35:59<1:20:55,  6.48it/s]

0.13514763116836548


 33%|███▎      | 15002/45481 [38:34<1:18:23,  6.48it/s]

0.26768332719802856


 35%|███▌      | 16002/45481 [41:08<1:15:46,  6.48it/s]

0.18784402310848236


 37%|███▋      | 17002/45481 [43:42<1:13:12,  6.48it/s]

0.4410229027271271


 40%|███▉      | 18002/45481 [46:17<1:10:44,  6.47it/s]

0.4075922966003418


 42%|████▏     | 19002/45481 [48:51<1:08:08,  6.48it/s]

0.25600239634513855


 44%|████▍     | 20002/45481 [51:25<1:05:30,  6.48it/s]

0.2456217110157013


 46%|████▌     | 21002/45481 [54:00<1:02:53,  6.49it/s]

0.258957177400589


 48%|████▊     | 22002/45481 [56:34<1:00:23,  6.48it/s]

0.2600681483745575


 51%|█████     | 23002/45481 [59:08<57:46,  6.49it/s]  

0.5859765410423279


 53%|█████▎    | 24002/45481 [1:01:43<55:11,  6.49it/s]

0.31740301847457886


 55%|█████▍    | 25002/45481 [1:04:17<52:39,  6.48it/s]

0.3922359347343445


 57%|█████▋    | 26002/45481 [1:06:51<50:06,  6.48it/s]

0.05739535391330719


 59%|█████▉    | 27002/45481 [1:09:25<47:30,  6.48it/s]

0.2150469720363617


 62%|██████▏   | 28002/45481 [1:12:00<44:58,  6.48it/s]

0.26889294385910034


 64%|██████▍   | 29002/45481 [1:14:34<42:24,  6.48it/s]

0.0619502067565918


 66%|██████▌   | 30002/45481 [1:17:08<39:51,  6.47it/s]

0.13554883003234863


 68%|██████▊   | 31002/45481 [1:19:43<37:12,  6.49it/s]

0.11509271711111069


 70%|███████   | 32002/45481 [1:22:17<34:40,  6.48it/s]

0.15009307861328125


 73%|███████▎  | 33002/45481 [1:24:51<32:03,  6.49it/s]

0.17915162444114685


 75%|███████▍  | 34002/45481 [1:27:25<29:33,  6.47it/s]

0.0785939171910286


 77%|███████▋  | 35002/45481 [1:30:00<26:59,  6.47it/s]

0.03335019201040268


 79%|███████▉  | 36002/45481 [1:32:35<24:23,  6.47it/s]

0.4088817238807678


 81%|████████▏ | 37002/45481 [1:35:09<21:51,  6.47it/s]

0.07834145426750183


 84%|████████▎ | 38002/45481 [1:37:44<19:16,  6.47it/s]

0.3349553942680359


 86%|████████▌ | 39002/45481 [1:40:18<16:41,  6.47it/s]

0.05331047251820564


 88%|████████▊ | 40002/45481 [1:42:53<14:06,  6.47it/s]

0.06731503456830978


 90%|█████████ | 41002/45481 [1:45:27<11:32,  6.47it/s]

0.34748369455337524


 92%|█████████▏| 42002/45481 [1:48:02<08:57,  6.47it/s]

0.08727667480707169


 95%|█████████▍| 43002/45481 [1:50:37<06:23,  6.47it/s]

0.43100181221961975


 97%|█████████▋| 44002/45481 [1:53:11<03:48,  6.47it/s]

0.23592473566532135


 99%|█████████▉| 45002/45481 [1:55:46<01:14,  6.47it/s]

0.10003998875617981


100%|██████████| 45481/45481 [1:57:00<00:00,  6.48it/s]
100%|██████████| 5054/5054 [03:18<00:00, 25.48it/s]


Epoch train loss: 0.25100776760864, Epoch accuracy: 0.9081869898590156
Epoch 1 started


  0%|          | 2/45481 [00:00<2:51:20,  4.42it/s]

0.04739739000797272


  2%|▏         | 1002/45481 [02:35<1:54:40,  6.46it/s]

0.2804887294769287


  4%|▍         | 2002/45481 [05:09<1:51:59,  6.47it/s]

0.3800426721572876


  7%|▋         | 3002/45481 [07:44<1:49:20,  6.47it/s]

0.15224817395210266


  9%|▉         | 4002/45481 [10:18<1:46:55,  6.47it/s]

0.2926713228225708


 11%|█         | 5002/45481 [12:53<1:44:17,  6.47it/s]

0.5758025646209717


 13%|█▎        | 6002/45481 [15:27<1:41:38,  6.47it/s]

0.29146647453308105


 15%|█▌        | 7002/45481 [18:02<1:39:08,  6.47it/s]

0.200648695230484


 18%|█▊        | 8002/45481 [20:36<1:36:32,  6.47it/s]

0.08809169381856918


 20%|█▉        | 9002/45481 [23:11<1:33:59,  6.47it/s]

0.26753944158554077


 22%|██▏       | 10002/45481 [25:46<1:31:24,  6.47it/s]

0.1755686104297638


 24%|██▍       | 11002/45481 [28:20<1:28:45,  6.47it/s]

0.04821682348847389


 26%|██▋       | 12002/45481 [30:55<1:26:17,  6.47it/s]

0.07421061396598816


 29%|██▊       | 13002/45481 [33:29<1:23:40,  6.47it/s]

0.034907083958387375


 31%|███       | 14002/45481 [36:04<1:21:02,  6.47it/s]

0.23901613056659698


 33%|███▎      | 15002/45481 [38:38<1:18:30,  6.47it/s]

0.180892676115036


 35%|███▌      | 16002/45481 [41:13<1:15:59,  6.47it/s]

0.09233801066875458


 37%|███▋      | 17002/45481 [43:47<1:13:20,  6.47it/s]

0.010177480056881905


 40%|███▉      | 18002/45481 [46:22<1:10:43,  6.48it/s]

0.2488073855638504


 42%|████▏     | 19002/45481 [48:56<1:08:13,  6.47it/s]

0.014398403465747833


 44%|████▍     | 20002/45481 [51:31<1:05:35,  6.47it/s]

0.1291915327310562


 46%|████▌     | 21002/45481 [54:05<1:03:01,  6.47it/s]

0.008896240033209324


 48%|████▊     | 22002/45481 [56:40<1:00:23,  6.48it/s]

0.1317199468612671


 51%|█████     | 23002/45481 [59:14<57:52,  6.47it/s]  

0.011544324457645416


 53%|█████▎    | 24002/45481 [1:01:48<55:17,  6.47it/s]

0.2528868317604065


 55%|█████▍    | 25002/45481 [1:04:23<52:43,  6.47it/s]

0.041766587644815445


 57%|█████▋    | 26002/45481 [1:06:57<50:08,  6.47it/s]

0.05726592242717743


 59%|█████▉    | 27002/45481 [1:09:32<47:36,  6.47it/s]

0.3711647391319275


 62%|██████▏   | 28002/45481 [1:12:06<44:58,  6.48it/s]

0.053027600049972534


 64%|██████▍   | 29002/45481 [1:14:41<42:27,  6.47it/s]

0.047402963042259216


 66%|██████▌   | 30002/45481 [1:17:15<39:51,  6.47it/s]

0.010797766037285328


 68%|██████▊   | 31002/45481 [1:19:50<37:19,  6.47it/s]

0.5991066694259644


 70%|███████   | 32002/45481 [1:22:24<34:41,  6.47it/s]

0.35314521193504333


 73%|███████▎  | 33002/45481 [1:24:59<32:08,  6.47it/s]

0.13320046663284302


 75%|███████▍  | 34002/45481 [1:27:33<29:32,  6.48it/s]

0.12324168533086777


 77%|███████▋  | 35002/45481 [1:30:08<26:58,  6.48it/s]

0.2863939702510834


 79%|███████▉  | 36002/45481 [1:32:42<24:25,  6.47it/s]

0.18892861902713776


 81%|████████▏ | 37002/45481 [1:35:17<21:50,  6.47it/s]

0.032635629177093506


 84%|████████▎ | 38002/45481 [1:37:51<19:15,  6.47it/s]

0.07609133422374725


 86%|████████▌ | 39002/45481 [1:40:26<16:41,  6.47it/s]

0.42234665155410767


 88%|████████▊ | 40002/45481 [1:43:00<14:06,  6.47it/s]

0.020260415971279144


 90%|█████████ | 41002/45481 [1:45:35<11:31,  6.47it/s]

0.029123537242412567


 92%|█████████▏| 42002/45481 [1:48:09<08:57,  6.47it/s]

0.592880368232727


 95%|█████████▍| 43002/45481 [1:50:44<06:23,  6.47it/s]

0.022202998399734497


 97%|█████████▋| 44002/45481 [1:53:18<03:48,  6.47it/s]

0.24287566542625427


 99%|█████████▉| 45002/45481 [1:55:53<01:14,  6.47it/s]

0.2909305691719055


100%|██████████| 45481/45481 [1:57:07<00:00,  6.47it/s]
100%|██████████| 5054/5054 [03:18<00:00, 25.47it/s]

Epoch train loss: 0.17162563806994696, Epoch accuracy: 0.9194410091516201





In [14]:
# torch.save(model.state_dict(), 'deberta_2_epochs.pt')

In [15]:
model.load_state_dict(torch.load('deberta_2_epochs.pt'))

<All keys matched successfully>

In [16]:
final_accuracy = validation_one_epoch(val_loader, model)
print(f'Final accuracy is {final_accuracy}')

100%|██████████| 5054/5054 [03:16<00:00, 25.68it/s]

Final accuracy is 0.9194410091516201





### Task 3: try the full pipeline (1 point)

Finally, it is time to use your model to find duplicate questions.
Please implement a function that takes a question and finds top-5 potential duplicates in the training set. For now, it is fine if your function is slow, as long as it yields correct results.

Showcase how your function works with at least 5 examples.

In [50]:
MAX_LENGTH = 128
def preprocess_function_single(examples):
    del examples['text2']
    result_1 = tokenizer(
        examples['text1'],
        padding='max_length', max_length=MAX_LENGTH, truncation=True
    )
    return result_1

qqp_preprocessed_single = qqp.map(preprocess_function_single, batched=True)

Map:   0%|          | 0/363846 [00:00<?, ? examples/s]

Map:   0%|          | 0/40430 [00:00<?, ? examples/s]

Map:   0%|          | 0/390965 [00:00<?, ? examples/s]

In [51]:
train_set_single = qqp_preprocessed_single['train']

In [56]:
non_shuffled_train_loader = torch.utils.data.DataLoader(
    train_set_single, batch_size=8, shuffle=False, collate_fn=transformers.default_data_collator, num_workers=6
)

In [57]:
def get_phrase_embedding(question, model, tokenizer):
    batch = tokenizer(question, padding='max_length', max_length=MAX_LENGTH, truncation=True)
    batch = {k: torch.as_tensor([v]).to(device) for k, v in batch.items() if k!='idx'}
    predicted = model(**batch)
    embs = predicted['hidden_states'][-1][:, 1:, :]
    all_embs = embs.cpu().detach().numpy().sum(axis=1)
    vector = all_embs[0]
    return vector

In [58]:
def find_nearest_nearpy(question, model, tokenizer):
    """
    given text line (query), return k most similar lines from data, sorted from most to least similar
    similarity should be measured as cosine between query and line embedding vectors
    hint: it's okay to use global variables: data and data_vectors. see also: np.argpartition, np.argsort
    """
    # YOUR CODE
    query_emb = get_phrase_embedding(question, model, tokenizer)
    top_k_indexes = np.array([int(x[1].split('_')[1]) for x in engine.neighbours(query_emb)])
    top_k_neighbours = train_set_single[top_k_indexes]
    return {k: v for k, v in top_k_neighbours.items() if k in['text1', 'text2']}

In [59]:
import numpy
# import tqdm

from nearpy import Engine
from nearpy.hashes import RandomBinaryProjections

dimension = 768

rbp = RandomBinaryProjections('rbp', 10)

engine = Engine(dimension, lshashes=[rbp])

with torch.no_grad():
    model.eval()
    for index, batch in enumerate(tqdm(non_shuffled_train_loader)):
        batch = {k: v.to(device) for k, v in batch.items() if k!='idx'}
        predicted = model(**batch)
        embs = predicted['hidden_states'][-1][:, 1:, :]
        n = embs.shape[0]

        all_embs = embs.cpu().numpy().sum(axis=1)
        for i in range(n):
            engine.store_vector(all_embs[i], 'data_%d' % (index*len(batch) + i))

100%|██████████| 45481/45481 [30:15<00:00, 25.05it/s]


In [64]:
q_r = find_nearest_nearpy('What is it like to be in prison?', model, tokenizer)
q_r

{'text1': ['What is the best way to take a picture with a phone?',
  'Do u believe in aliens?',
  'What is the best gay Asian dating sites or apps?',
  'Which is the best campus ambassador program running?',
  'What will be the effect of banning 500 and 1000 notes on stock markets in India?',
  'Can we download movies from torrents?',
  'What heart rate should I target for in a 5 K?',
  'Can a poor man fall in love with a girl from a rich family?',
  'How is Kafka different from typical JMS message brokers like IBM MQ, Active MQ, etc.?',
  'What are the best and worst things about public transit in Enping, Guangdong, China? How could it be improved?']}

In [66]:
q_r = find_nearest_nearpy('What is the mass of the Earth?', model, tokenizer)
q_r

{'text1': ['I am an NRI and recently purchased a flat in India and other assets … tax form required?',
  'How long does 5 mg of Klonopin stay in your system?',
  'How do I perform meditation techniques such as the Sudarshan Kriya, Vipassana, Isha Kriya, and Transcendental meditation?',
  'How do I start from the scraps in an online business?',
  'What are your three wishes?',
  'How do I get into University of Pennsylvania Jerome Fisher M&T Program?',
  "Why don't people learn from their mistakes?",
  "What's new about iOS 10?",
  'What are the best qualities of an engineer?',
  'Is the world getting worse?']}

### It looks like I made a mistake somewhere, because neighbours are not good at all.
What I did: took all tokens from last hidden state, summed them together and then used the ANN engine to search for them.

__Bonus:__ for bonus points, try to find a way to run the function faster than just passing over all questions in a loop. For isntance, you can form a short-list of potential candidates using a cheaper method, and then run your tranformer on that short list. If you opted for this solution, please keep both the original implementation and the optimized one - and explain briefly what is the difference there.