In [1]:
import math
import json

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import transformers
from tqdm import tqdm

from src import data

In [2]:
data_type = 'text'

config = data.BaseConfig(data_type=data_type, batch_size=256, max_candidates=20)

In [3]:
dataset_preparer = data.DatasetPreparer(config=config)

In [4]:
dataloaders, dataloaders_with_candidates = dataset_preparer.load_data(as_data_loader=True)

train_loader = dataloaders[0]
valid_loader = dataloaders[1]

train_loader_with_candidates = dataloaders_with_candidates[0]
valid_loader_with_candidates = dataloaders_with_candidates[1]

In [5]:
def score_candidates(question_embeddings, candidates_embeddings):
    
    candidates_batch_size, model_dim = candidates_embeddings.size()
    candidates_per_sample = candidates_batch_size // question_embeddings.size(0)
    
    candidates_embeddings = candidates_embeddings.view(question_embeddings.size(0),
                                                       candidates_per_sample,
                                                       model_dim)
    
    similarity_matrix = torch.bmm(question_embeddings.unsqueeze(1), 
                                  candidates_embeddings.transpose(1, 2)).squeeze(dim=1)
    
    similarity_matrix = similarity_matrix.detach().cpu()
    
    return similarity_matrix

In [6]:
for question, candidates in tqdm(valid_loader_with_candidates):
    break

  0%|          | 0/976 [00:00<?, ?it/s]


In [7]:
question

['how long have you been a doctor ?',
 'i am actually a cancer survivor which is why i decided to become a life coach',
 'that sounds much more fun than taking care of old people like me',
 "i'll need the carbs to have energy for basketball !",
 "that's very mice . this is the first time i spoke in 3 months",
 'hello , how are you doing today ?',
 'i am still an amateur so i hardly win haha',
 'yep , i was visiting my sis in spain last month , it is nice']

In [8]:
# max_candidates = 20
# true_answer = 1
# len(candidates) = len(question) * (true_answer + max_candidates)
len(candidates)

168

In [9]:
recall = data.Recall()

In [10]:
recall.reset()

for question, candidates in tqdm(valid_loader_with_candidates):
    
    question_embeddings = <YOUR CODE HERE>
    candidates_embeddings = <YOUR CODE HERE>
        
    with torch.no_grad():
        similarity_matrix = score_candidates(torch.tensor(question_embeddings),
                                             torch.tensor(candidates_embeddings))
    
    recall.add(similarity_matrix)

In [11]:
recall.metrics

In [12]:
print(recall.messages)