In [1]:
import math
import json

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import transformers
from tqdm import tqdm

from src import data

In [2]:
model_type = 'distilbert-base-uncased'

config = data.BaseConfig(model_type=model_type, batch_size=256, max_candidates=20)

In [3]:
dataset_preparer = data.DatasetPreparer(config=config)

Building index: 100%|██████████| 103/103 [01:07<00:00,  1.52it/s]


In [4]:
dataloaders, dataloaders_with_candidates = dataset_preparer.load_data(as_data_loader=True)

train_loader = dataloaders[0]
valid_loader = dataloaders[1]

train_loader_with_candidates = dataloaders_with_candidates[0]
valid_loader_with_candidates = dataloaders_with_candidates[1]

In [5]:
class GlobalMaskedPooling(nn.Module):

    def __init__(self, pooling_type='mean', dim=1, length_scaling=True, square=True):
        super().__init__()

        self.pooling_type = pooling_type
        self.dim = dim
        self.length_scaling = length_scaling
        self.square = square

        if self.pooling_type == 'max':
            self.mask_value = -10000.
        else:
            self.mask_value = 0.

        if self.pooling_type not in ['mean', 'max']:
            raise ValueError('Available types: mean, max')

    def forward(self, x, pad_mask):
        lengths = pad_mask.sum(self.dim).float()

        x = x.masked_fill((~pad_mask).unsqueeze(-1), self.mask_value)

        if self.pooling_type == 'mean':
            scaling = x.size(self.dim) / lengths
        else:
            scaling = torch.ones(x.size(self.dim))

        if self.length_scaling:
            lengths_factor = lengths
            if self.square:
                lengths_factor = lengths_factor ** 0.5
            scaling /= lengths_factor

        scaling = scaling.masked_fill(lengths == 0, 1.).unsqueeze(-1)

        if self.pooling_type == 'mean':
            x = x.mean(self.dim)
        else:
            x = x.max(self.dim)

        x *= scaling

        return x

    def extra_repr(self) -> str:
        return f'pooling_type="{self.pooling_type}"'

In [6]:
class Encoder(nn.Module):
    
    def __init__(self, bert):
        super().__init__()
        
        self.bert = bert
        self.pooling = GlobalMaskedPooling(length_scaling=False, square=False)
        
    def forward(self, token_ids, pad_mask):
        
        embed = self.bert(token_ids, pad_mask)[0]
        
        embed = self.pooling(embed, pad_mask.bool())
        
        embed = F.normalize(embed)
        
        return embed

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
bert = transformers.AutoModel.from_pretrained(config.model_type)

encoder = Encoder(bert=bert).to(device)
encoder.eval()

recall = data.Recall()

In [9]:
def score_candidates(question_embeddings, candidates_embeddings):
    
    candidates_batch_size, model_dim = candidates_embeddings.size()
    candidates_per_sample = candidates_batch_size // question_embeddings.size(0)
    
    candidates_embeddings = candidates_embeddings.view(question_embeddings.size(0),
                                                       candidates_per_sample,
                                                       model_dim)
    
    similarity_matrix = torch.bmm(question_embeddings.unsqueeze(1), 
                                  candidates_embeddings.transpose(1, 2)).squeeze(dim=1)
    
    similarity_matrix = similarity_matrix.detach().cpu()
    
    return similarity_matrix

In [None]:
recall.reset()

for token_ids, positions, token_types in tqdm(valid_loader_with_candidates):
    
    question_token_ids = token_ids[0].to(device)
    response_token_ids = token_ids[1].to(device)
    
    question_positions = positions[0].to(device)
    response_positions = positions[1].to(device)
    
    question_token_types = token_types[0].to(device)
    response_token_types = token_types[1].to(device) 
    
    question_pad_mask = (question_token_ids != config.pad_index).float().to(device)
    response_pad_mask = (response_token_ids != config.pad_index).float().to(device)
    
    with torch.no_grad():
        question_embeddings = encoder(question_token_ids, question_pad_mask)
        candidates_embeddings = encoder(response_token_ids, response_pad_mask)
        
        similarity_matrix = score_candidates(question_embeddings, candidates_embeddings)
    
    recall.add(similarity_matrix)

  5%|▍         | 45/976 [02:05<43:50,  2.83s/it] 

In [None]:
recall.metrics

In [None]:
print(recall.messages)