In [8]:
!wget https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip -P ./data/QQP/ && cd ./data/QQP; unzip -j QQP-clean.zip -d ./  
    

--2021-07-18 10:58:42--  https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.74.142, 172.67.9.4, 104.22.75.142, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.74.142|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 41696084 (40M) [application/zip]
Saving to: ‘./data/QQP/QQP-clean.zip’


2021-07-18 10:58:48 (7.21 MB/s) - ‘./data/QQP/QQP-clean.zip’ saved [41696084/41696084]

Archive:  QQP-clean.zip
  inflating: ./train.tsv             
  inflating: ./dev.tsv               
  inflating: ./test.tsv              


In [6]:
!curl -L0 http://nlp.stanford.edu/data/glove.6B.zip --output ./data/glove.6B.zip && cd ./data/; unzip -j glove.6B.zip -d ./  

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0   308    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0   345    0     0    0     0      0      0 --:--:--  0:00:01 --:--:--     0
100  822M  100  822M    0     0  4484k      0  0:03:07  0:03:07 --:--:-- 4258k3  0:02:03 2853k3:07  0:01:04  0:02:03 2814k 305M    0     0  4406k      0  0:03:11  0:01:11  0:02:00 3870k    0  4683k      0  0:02:59  0:01:51  0:01:08 5272k  4521k      0  0:03:06  0:02:45  0:00:21 5512k62M    0     0  4541k      0  0:03:05  0:02:52  0:00:13 5070k21k      0  0:03:06  0:02:55  0:00:11 4137k
Archive:  glove.6B.zip
  inflating: ./glove.6B.50d.txt      
  inflating: ./glove.6B.100d.txt     
  inflating: ./glove.6B.200d.txt     
  inflating: ./glove.6B.300d.txt     


In [1]:
import sys
sys.path.insert(0,'/Users/a17194839/Library/Python/3.7/lib/python/site-packages')

In [3]:
import re
import string
from collections import Counter, defaultdict
from typing import Dict, List, Tuple, Union, Callable

from itertools import permutations

import nltk
import numpy as np
import math
import pandas as pd
import torch
import torch.nn.functional as F

import pickle

In [4]:
def compute_gain(y_value: float, gain_scheme: str) -> float:
    if gain_scheme == "const":
        return y_value
    elif gain_scheme == "exp2":
        return 2 ** y_value - 1

def dcg_k(ys_true: torch.Tensor, ys_pred: torch.Tensor, k: int) -> float:
    _, indices = torch.sort(ys_pred, descending=True)
    sorted_true = ys_true[indices][:k].numpy()
    gain = compute_gain(sorted_true, gain_scheme="exp2")
    discount = [math.log2(float(x)) for x in range(2, len(sorted_true) + 2)]
    discounted_gain = float((gain / discount).sum())
    return discounted_gain

In [5]:
glue_qqp_dir = './data/QQP'
glove_path = './data/glove.6B.50d.txt'

In [6]:
class GaussianKernel(torch.nn.Module):
    def __init__(self, mu: float = 1., sigma: float = 1., requires_grad = False):
        super().__init__()
        mu_ = np.array(mu)
        sigma_ = np.array(sigma)
        self.requires_grad = requires_grad
        self.mu = torch.nn.Parameter(torch.Tensor(mu_), requires_grad=self.requires_grad)
        self.sigma = torch.nn.Parameter(torch.Tensor(sigma_), requires_grad=self.requires_grad)

    def forward(self, x):
        adj = x - self.mu
        return torch.exp(-0.5 * adj * adj / self.sigma / self.sigma)        
        
        
# class GaussianKernel(torch.nn.Module):
#     def __init__(self, mu: float = 1., sigma: float = 1.):
#         super().__init__()
#         self.mu = mu
#         self.sigma = sigma

#     def forward(self, x):
#         # допишите ваш код здесь 
#         x = (x - self.mu) ** 2 / (2 * self.sigma ** 2)
#         x = torch.exp( - x)
#         return x
        
        
class KNRM(torch.nn.Module):
    def __init__(self, embedding_matrix: np.ndarray, freeze_embeddings: bool = False, kernel_num: int = 21,
                 sigma: float = 0.1, exact_sigma: float = 0.001,
                 out_layers: List[int] = [10, 5]):
        super().__init__()
        self.embeddings = torch.nn.Embedding.from_pretrained(
            torch.FloatTensor(embedding_matrix),
            freeze=freeze_embeddings,
            padding_idx=0
        )

        self.kernel_num = kernel_num
        self.sigma = sigma
        self.exact_sigma = exact_sigma
        self.out_layers = out_layers

        self.kernels = self._get_kernels_layers()

        self.mlp = self._get_mlp()

        self.out_activation = torch.nn.Sigmoid()

    def _get_kernels_layers(self) -> torch.nn.ModuleList:

        mus = [1.0]
        if self.kernel_num > 1:
            bin_size = 2.0 / (self.kernel_num - 1)  
            mus.append(1 - bin_size / 2)
            for i in range(1, self.kernel_num - 1):
                mus.append(mus[i] - bin_size)
        mus = list(reversed(mus))
        sigmas = [self.sigma] * (self.kernel_num - 1) + [self.exact_sigma]  
        
        gausskern_lst = [(GaussianKernel(mu,sigma)) for mu, sigma in zip(mus, sigmas)]
        kernels = torch.nn.ModuleList(gausskern_lst)
        return kernels

    def _get_mlp(self) -> torch.nn.Sequential:        
        if self.out_layers:
            output = []
            hidden_sizes = [self.kernel_num] + self.out_layers + [1]
            for i, hidden in enumerate(hidden_sizes[1:],1):
                output.append(torch.nn.ReLU())
                output.append(torch.nn.Linear(hidden_sizes[i-1], hidden))
        else:
            output = [torch.nn.Linear(self.kernel_num, 1)]
        return torch.nn.Sequential(*output)

    def forward(self, input_1: Dict[str, torch.Tensor], input_2: Dict[str, torch.Tensor]) -> torch.FloatTensor:
        logits_1 = self.predict(input_1)
        logits_2 = self.predict(input_2)

        logits_diff = logits_1 - logits_2

        out = self.out_activation(logits_diff)
        return out

    def _get_matching_matrix(self, query: torch.Tensor, doc: torch.Tensor) -> torch.FloatTensor:
        query = self.embeddings(query)
        doc = self.embeddings(doc)
        query = query / (query.norm(p=2, dim=-1, keepdim=True) + 1e-16)
        doc = doc / (doc.norm(p=2, dim=-1, keepdim=True) + 1e-16)
        return torch.bmm(query, doc.transpose(-1, -2))
    
#         q = query.numpy().tolist()
#         d = doc.numpy().tolist()

#         a_emb, b_emb = torch.Tensor(self.embeddings.weight[q]), torch.Tensor(self.embeddings.weight[d])
#         norm_a = a_emb.norm(p=2, dim=1)
#         norm_b = b_emb.norm(p=2, dim=1)

#         sim = a_emb.matmul(b_emb.t())/norm_a/norm_b
#         matching_matrix = sim.nan_to_num()
#         return matching_matrix

    def _apply_kernels(self, matching_matrix: torch.FloatTensor) -> torch.FloatTensor:
        KM = []
        for kernel in self.kernels:
            # shape = [B]
            K = torch.log1p(kernel(matching_matrix).sum(dim=-1)).sum(dim=-1)
            KM.append(K)
                
        # shape = [B, K]
        kernels_out = torch.stack(KM, dim=1)
        return kernels_out

    def predict(self, inputs: Dict[str, torch.Tensor]) -> torch.FloatTensor:
        # shape = [Batch, Left, Embedding], [Batch, Right, Embedding]
        query, doc = inputs['query'], inputs['document']
        
        # shape = [Batch, Left, Right]
        matching_matrix = self._get_matching_matrix(query, doc)
        # shape = [Batch, Kernels]
        kernels_out = self._apply_kernels(matching_matrix)
        # shape = [Batch]
        out = self.mlp(kernels_out)
        return out


In [7]:
class RankingDataset(torch.utils.data.Dataset):
    def __init__(self, index_pairs_or_triplets: List[List[Union[str, float]]],
                 idx_to_text_mapping: Dict[str, str], vocab: Dict[str, int], oov_val: int,
                 preproc_func: Callable, max_len: int = 30):
        self.index_pairs_or_triplets = index_pairs_or_triplets
        self.idx_to_text_mapping = idx_to_text_mapping
        self.vocab = vocab
        self.oov_val = oov_val
        self.preproc_func = preproc_func
        self.max_len = max_len

    def __len__(self):
        return len(self.index_pairs_or_triplets)

    def _tokenized_text_to_index(self, tokenized_text: List[str]) -> List[int]:
        return [self.vocab.get(item, self.oov_val) for item in tokenized_text[:self.max_len]]

    def _convert_text_idx_to_token_idxs(self, idx: int) -> List[int]:
        tokenized_text = self.preproc_func(self.idx_to_text_mapping[idx])
        return self._tokenized_text_to_index(tokenized_text)

    def __getitem__(self, idx: int):
        pass

class TrainTripletsDataset(RankingDataset):
    def __getitem__(self, idx):
        query, document0, document1, label = self.index_pairs_or_triplets[idx]
        q_tokens = self._convert_text_idx_to_token_idxs(query)
        d_tokens0 = self._convert_text_idx_to_token_idxs(document0)
        d_tokens1 = self._convert_text_idx_to_token_idxs(document1)
        left_elem = {'query': q_tokens, 'document': d_tokens0}
        right_elem = {'query': q_tokens, 'document': d_tokens1}
        
        return left_elem, right_elem, label


class ValPairsDataset(RankingDataset):
    def __getitem__(self, idx):
        query, document, label = self.index_pairs_or_triplets[idx]
        q_tokens = self._convert_text_idx_to_token_idxs(query)
        d_tokens = self._convert_text_idx_to_token_idxs(document)
        qd_dct = {'query': q_tokens, 'document': d_tokens}
        
        return qd_dct, label
       

In [8]:
def collate_fn(batch_objs: List[Union[Dict[str, torch.Tensor], torch.FloatTensor]]):
    max_len_q1 = -1
    max_len_d1 = -1
    max_len_q2 = -1
    max_len_d2 = -1

    is_triplets = False
    for elem in batch_objs:
        if len(elem) == 3:
            left_elem, right_elem, label = elem
            is_triplets = True
        else:
            left_elem, label = elem

        max_len_q1 = max(len(left_elem['query']), max_len_q1)
        max_len_d1 = max(len(left_elem['document']), max_len_d1)
        if len(elem) == 3:
            max_len_q2 = max(len(right_elem['query']), max_len_q2)
            max_len_d2 = max(len(right_elem['document']), max_len_d2)

    q1s = []
    d1s = []
    q2s = []
    d2s = []
    labels = []

    for elem in batch_objs:
        if is_triplets:
            left_elem, right_elem, label = elem
        else:
            left_elem, label = elem

        pad_len1 = max_len_q1 - len(left_elem['query'])
        pad_len2 = max_len_d1 - len(left_elem['document'])
        if is_triplets:
            pad_len3 = max_len_q2 - len(right_elem['query'])
            pad_len4 = max_len_d2 - len(right_elem['document'])

        q1s.append(left_elem['query'] + [0] * pad_len1)
        d1s.append(left_elem['document'] + [0] * pad_len2)
        if is_triplets:
            q2s.append(right_elem['query'] + [0] * pad_len3)
            d2s.append(right_elem['document'] + [0] * pad_len4)
        labels.append([label])
    q1s = torch.LongTensor(q1s)
    d1s = torch.LongTensor(d1s)
    if is_triplets:
        q2s = torch.LongTensor(q2s)
        d2s = torch.LongTensor(d2s)
    labels = torch.FloatTensor(labels)

    ret_left = {'query': q1s, 'document': d1s}
    if is_triplets:
        ret_right = {'query': q2s, 'document': d2s}
        return ret_left, ret_right, labels
    else:
        return ret_left, labels

In [17]:
class Solution:
    def __init__(self, glue_qqp_dir: str, glove_vectors_path: str,
                 min_token_occurancies: int = 1,
                 random_seed: int = 0,
                 embed_size = 50,
                 emb_rand_uni_bound: float = 0.2,
                 freeze_knrm_embeddings: bool = True,
                 knrm_kernel_num: int = 30,
                 knrm_out_mlp: List[int] = [15, 7],
                 dataloader_bs: int = 1024,
                 train_lr: float = 0.01,
                 change_train_loader_ep: int = 10
                 ):
        
        self.gain_scheme = 'exp2'
        self.glue_qqp_dir = glue_qqp_dir
        self.glove_vectors_path = glove_vectors_path
        self.glue_train_df = self.get_glue_df('train')
        self.glue_dev_df = self.get_glue_df('dev')
        
        self.create_test_triples = self.sample_data_for_train_iter(self.glue_train_df)
        self.dev_pairs_for_ndcg = self.create_val_pairs(self.glue_dev_df)
        
        self.min_token_occurancies = min_token_occurancies
        self.all_tokens = self.get_all_tokens(
            [self.glue_train_df, self.glue_dev_df], self.min_token_occurancies)

        self.random_seed = random_seed
        self.embed_size = embed_size
        self.emb_rand_uni_bound = emb_rand_uni_bound
        self.freeze_knrm_embeddings = freeze_knrm_embeddings
        self.knrm_kernel_num = knrm_kernel_num
        self.knrm_out_mlp = knrm_out_mlp
        self.dataloader_bs = dataloader_bs
        self.train_lr = train_lr
        self.change_train_loader_ep = change_train_loader_ep

        self.model, self.vocab, self.unk_words = self.build_knrm_model()
        
        self.idx_to_text_mapping_train = self.get_idx_to_text_mapping(
            self.glue_train_df)
        
        self.idx_to_text_mapping_dev = self.get_idx_to_text_mapping(
            self.glue_dev_df)
        
        self.val_dataset = ValPairsDataset(self.dev_pairs_for_ndcg, 
                                           self.idx_to_text_mapping_dev, 
                                           vocab=self.vocab, 
                                           oov_val=self.vocab['OOV'], 
                                           preproc_func=self.simple_preproc)
        
        self.val_dataloader = torch.utils.data.DataLoader(
                                                           self.val_dataset, 
                                                           batch_size=self.dataloader_bs, 
                                                           num_workers=0, 
                                                           collate_fn=collate_fn, 
                                                           shuffle=False)

#         self.test_dataset = TrainTripletsDataset(self.create_test_triples, 
#                                                  self.idx_to_text_mapping_train, 
#                                                  vocab=self.vocab, 
#                                                  oov_val=self.vocab['OOV'], 
#                                                  preproc_func=self.simple_preproc)
        
#         self.test_dataloader = torch.utils.data.DataLoader(
#                                                            self.test_dataset, 
#                                                            batch_size=self.dataloader_bs, 
#                                                            num_workers=0, 
#                                                            collate_fn=collate_fn, 
#                                                            shuffle=False)        
        
    def get_glue_df(self, partition_type: str) -> pd.DataFrame:
        assert partition_type in ['dev', 'train']
        glue_df = pd.read_csv(
            self.glue_qqp_dir + f'/{partition_type}.tsv', sep='\t', dtype=object)
        glue_df = glue_df.dropna(axis=0, how='any').reset_index(drop=True)
        glue_df_fin = pd.DataFrame({
            'id_left': glue_df['qid1'],
            'id_right': glue_df['qid2'],
            'text_left': glue_df['question1'],
            'text_right': glue_df['question2'],
            'label': glue_df['is_duplicate'].astype(int)
        })
        return glue_df_fin

    def hadle_punctuation(self, inp_str: str) -> str:
        regex = re.compile('[%s]' % re.escape(string.punctuation))
        out = regex.sub('', inp_str)
        return out

    def simple_preproc(self, inp_str: str) -> List[str]:
        rem_puct_str = self.hadle_punctuation(inp_str).lower()       
        return nltk.word_tokenize(rem_puct_str)
    
    def _filter_rare_words(self, vocab: Dict[str, int], min_occurancies: int) -> Dict[str, int]:
        filter_dct = dict([(k,v) for k,v in vocab.items() if v >= min_occurancies])
        return filter_dct
    
    def get_all_tokens(self, list_of_df: List[pd.DataFrame], min_occurancies: int) -> List[str]:
        all_texts = []
        fin_cnt = Counter()
        for df in list_of_df:
            all_texts += list(df.text_left)
            all_texts += list(df.text_right)
            counter = Counter(self.simple_preproc(" ".join(list(set(all_texts)))))
            fin_cnt.update(counter)
        token_cnt = self._filter_rare_words(fin_cnt, min_occurancies)
        return list(token_cnt.keys())
        
#         fin_cnt = Counter()    
#         for df in list_of_df:
#             tokens_union = np.concatenate(df['text_left'].apply(lambda x: self.simple_preproc(x)).values).ravel().tolist() + \
#                            np.concatenate(df['text_right'].apply(lambda x: self.simple_preproc(x)).values).ravel().tolist()
#             fin_cnt.update(Counter(tokens_union))               
#         filtered_dct = self._filter_rare_words(dict(fin_cnt), min_occurancies)           
#         return [k for k, _ in filtered_dct.items()]
                
    def _read_glove_embeddings(self, file_path: str) -> Dict[str, List[str]]:
        res = dict([tuple(line.split(" ", 1)) for line in open(file_path, 'r')]) 
        fin = dict([(k, v.split(" ")) for k,v in res.items() if k not in string.punctuation])
        return fin

    def create_glove_emb_from_file(self, file_path: str, inner_keys: List[str],
                                   random_seed: int, rand_uni_bound: float
                                   ) -> Tuple[np.ndarray, Dict[str, int], List[str]]:

        word_embeddings = self._read_glove_embeddings(file_path)

        unk_words = list(set(inner_keys) - set(word_embeddings.keys()))
        known_words = list(set(inner_keys) & set(word_embeddings.keys()))
        emb_size = len(word_embeddings[known_words[0]])
        emb_array = np.zeros((len(inner_keys) + 2, emb_size))
        
        word2ind = {"PAD": 0, "OOV" : 1}
        unk_embedding = np.random.uniform(-rand_uni_bound, rand_uni_bound, emb_size)
        emb_array[1, :] = unk_embedding
        for index, word in enumerate(inner_keys, 2):
            emb_array[index, :] = word_embeddings.get(word, unk_embedding)
            word2ind[word] = index                
        unk_words += ["PAD", "OOV"]
        return emb_array, word2ind, unk_words        
        

    def build_knrm_model(self) -> Tuple[torch.nn.Module, Dict[str, int], List[str]]:
        emb_matrix, vocab, unk_words = self.create_glove_emb_from_file(
            self.glove_vectors_path, self.all_tokens, self.random_seed, self.emb_rand_uni_bound)
        torch.manual_seed(self.random_seed)
        knrm = KNRM(emb_matrix, freeze_embeddings=self.freeze_knrm_embeddings,
                    out_layers=self.knrm_out_mlp, kernel_num=self.knrm_kernel_num)
        return knrm, vocab, unk_words


    def create_val_pairs(self, 
                         inp_df: pd.DataFrame, 
                         fill_top_to: int = 15,
                         min_group_size: int = 2, 
                         seed: int = 0) -> List[List[Union[str, float]]]:
        
        inp_df_select = inp_df[['id_left', 'id_right', 'label']]
        inf_df_group_sizes = inp_df_select.groupby('id_left').size()
        glue_dev_leftids_to_use = list(
            inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
        groups = inp_df_select[inp_df_select.id_left.isin(
            glue_dev_leftids_to_use)].groupby('id_left')

        all_ids = set(inp_df['id_left']).union(set(inp_df['id_right']))

        out_pairs = []

        np.random.seed(seed)

        for id_left, group in groups:
            ones_ids = group[group.label > 0].id_right.values
            zeroes_ids = group[group.label == 0].id_right.values
            sum_len = len(ones_ids) + len(zeroes_ids)
            num_pad_items = max(0, fill_top_to - sum_len)
            if num_pad_items > 0:
                cur_chosen = set(ones_ids).union(
                    set(zeroes_ids)).union({id_left})
                pad_sample = np.random.choice(
                    list(all_ids - cur_chosen), num_pad_items, replace=False).tolist()
            else:
                pad_sample = []
            for i in ones_ids:
                out_pairs.append([id_left, i, 2])
            for i in zeroes_ids:
                out_pairs.append([id_left, i, 1])
            for i in pad_sample:
                out_pairs.append([id_left, i, 0])
        return out_pairs
 
    def sample_data_for_train_iter( self, 
                                    inp_df: pd.DataFrame, 
#                                     fill_top_to: int = 15,
#                                     groupby_q: str = 'id_left',
#                                     min_group_size: int = 2,
#                                     perm_min_samples: int = 5,
#                                     inp_frac_neg: float = 0.1,
#                                     inp_frac_pos: float = 0.1,                            
#                                     perm_frac_neg: float = 0.9,
#                                     perm_frac_pos: float = 0.9,
                                    seed: int = 0) -> List[List[Union[str, float]]]:

        np.random.seed(seed)

        groups = inp_df[['id_left', 'id_right', 'label']].groupby('id_left')
        pairs_w_labels = []
        np.random.seed(seed)
        all_right_ids = inp_df.id_right.values
        for id_left, group in groups:
            labels = group.label.unique()
            if len(labels) > 1:
                for label in labels:
                    same_label_samples = group[group.label ==
                                               label].id_right.values
                    if label == 0 and len(same_label_samples) > 1:
                        sample = np.random.choice(same_label_samples, 2, replace=False)
                        pairs_w_labels.append([id_left, sample[0], sample[1], 0.5])
                    elif label == 1:
                        less_label_samples = group[group.label < label].id_right.values
                        pos_sample = np.random.choice(same_label_samples, 1, replace=False)
                        if len(less_label_samples) > 0:
                            neg_sample = np.random.choice(less_label_samples, 1, replace=False)
                        else:
                            neg_sample = np.random.choice(all_right_ids, 1, replace=False)
                        pairs_w_labels.append([id_left, pos_sample[0], neg_sample[0], 1])
                        
        return pairs_w_labels        
        
            
#         inp_fracs = {0: inp_frac_neg, 1:  inp_frac_pos}     
#         inp_sampled_df = pd.concat([dff.sample(frac=inp_fracs.get(i), 
#                                                random_state = seed) for i,dff in inp_df.groupby('label')])
# #         inp_sampled_df = inp_df.sample(frac=0.15)
    
#         inp_df_select = inp_sampled_df[['id_left', 'id_right', 'label']]
#         inf_df_group_sizes = inp_df_select.groupby(groupby_q).size()
#         glue_test_leftids_to_use = list(
#             inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
#         groups = inp_df_select[inp_df_select.id_left.isin(
#             glue_test_leftids_to_use)].groupby(groupby_q)

#         next_id = list(set(inp_df_select.columns.values) - set({groupby_q, 'label'}))[0] 

#         out_pairs = []

#         for idx, group in groups:
#             ids = group[next_id].values
#             perm = list(permutations(ids, 2))
#             triples_lst = []
#             if len(perm) > 1:      
#                 for pair in perm:
#                     label_0, label_1 = group[group[next_id].isin(pair)].label.values
# #                     diff    = ((label_0 >= label_1)&(label_1 != 0)).astype(int)
#                     if label_0 == 0:
#                         triples_lst.append(list([idx, pair[0], pair[1], 0]))
#                     else:
#                         triples_lst.append(list([idx, pair[0], pair[1], 1]))
                        
#                 triples_df = pd.DataFrame(triples_lst, columns=[groupby_q, next_id+'0', next_id+'1', 'label'])
                
#                 if len(perm) >= perm_min_samples:
#                     perm_fracs = {0: perm_frac_neg, 1: perm_frac_pos}
#                     sampled_df = pd.concat([dff.sample(frac=perm_fracs.get(i), 
#                                                        random_state = seed) for i,dff in triples_df.groupby('label')])
#                     out_pairs.extend(sampled_df.values.tolist()) 
#                 else:
#                     out_pairs.extend(triples_df.values.tolist())
            
#         return out_pairs    

    
    def get_idx_to_text_mapping(self, inp_df: pd.DataFrame) -> Dict[str, str]:
        left_dict = (
            inp_df
            [['id_left', 'text_left']]
            .drop_duplicates()
            .set_index('id_left')
            ['text_left']
            .to_dict()
        )
        right_dict = (
            inp_df
            [['id_right', 'text_right']]
            .drop_duplicates()
            .set_index('id_right')
            ['text_right']
            .to_dict()
        )
        left_dict.update(right_dict)
        return left_dict
     

    def ndcg_k(self, ys_true: np.array, ys_pred: np.array, ndcg_top_k: int = 10) -> float:
      
        discounted_dsg = dcg_k(torch.Tensor(ys_true), torch.Tensor(ys_pred), ndcg_top_k)
        ideal_dcg = dcg_k(torch.Tensor(ys_true), torch.Tensor(ys_true), ndcg_top_k)
        
        if ideal_dcg != 0:
            ndcg = discounted_dsg / ideal_dcg
        else:
            ndcg = 0
        return ndcg
        
        return current_dcg / ideal_dcg
        
    def compute_gain_diff(self, y_true, gain_scheme):
        if gain_scheme == "exp2":
            gain_diff = torch.pow(2.0, y_true) - torch.pow(2.0, y_true.t())
        elif gain_scheme == "diff":
            gain_diff = y_true - y_true.t()
        else:
            raise ValueError(f"{gain_scheme} method not supported")
        return gain_diff 


    def valid(self, model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader) -> float:
        labels_and_groups = val_dataloader.dataset.index_pairs_or_triplets
        labels_and_groups = pd.DataFrame(labels_and_groups, columns=['left_id', 'right_id', 'rel'])
        
        all_preds = []
        for batch in (val_dataloader):
            inp_1, y = batch
            preds = model.predict(inp_1)
            preds_np = preds.detach().numpy()
            all_preds.append(preds_np)
        all_preds = np.concatenate(all_preds, axis=0)
        labels_and_groups['preds'] = all_preds
        
        ndcgs = []
        for cur_id in labels_and_groups.left_id.unique():
            cur_df = labels_and_groups[labels_and_groups.left_id == cur_id]
            ndcg = self.ndcg_k(cur_df.rel.values.reshape(-1), cur_df.preds.values.reshape(-1))
            if np.isnan(ndcg):
                ndcgs.append(0)
            else:
                ndcgs.append(ndcg)
        return np.mean(ndcgs)

    
    def train(self, n_epochs: int):
        opt = torch.optim.SGD(self.model.parameters(), lr=self.train_lr)
        criterion = torch.nn.BCELoss()
        ndcg = 0
        for epoch in range(n_epochs):
            epoch_loss = 0
            if epoch % self.change_train_loader_ep == 0:
                current_subset = self.sample_data_for_train_iter(inp_df = self.glue_train_df, 
                                                                 seed = epoch)
                train_dataset = TrainTripletsDataset(current_subset,
                                                     self.idx_to_text_mapping_train, 
                                                     vocab=self.vocab, 
                                                     oov_val=self.vocab['OOV'],
                                                     preproc_func=self.simple_preproc)
                train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                                               batch_size=self.dataloader_bs,
                                                               num_workers=0, 
                                                               collate_fn=collate_fn, 
                                                               shuffle=True)
            for batch in train_dataloader:
                inp_1, inp_2, y = batch
                preds = self.model(inp_1, inp_2)
                batch_loss = criterion(preds, y)
                batch_loss.backward()
                opt.step()
                epoch_loss += batch_loss.item()
            if epoch > 5:
                ndcg = self.valid(self.model, self.val_dataloader)
                print("epoch: {} ndcg: {}".format(epoch, ndcg))
            if ndcg > 0.950:
                break
    

In [18]:
%%time
sol = Solution(glue_qqp_dir = glue_qqp_dir, 
               glove_vectors_path = glove_path)

CPU times: user 1min 44s, sys: 1.89 s, total: 1min 46s
Wall time: 1min 46s


In [19]:
%%time
sol.train(150)

epoch: 6 ndcg: 0.5341467504074768
epoch: 7 ndcg: 0.5477597276543643
epoch: 8 ndcg: 0.7101372570389655
epoch: 9 ndcg: 0.7892089136771846
epoch: 10 ndcg: 0.9100668548371349
epoch: 11 ndcg: 0.8955673357311789
epoch: 12 ndcg: 0.8919433338673058
epoch: 13 ndcg: 0.9278066723314595
epoch: 14 ndcg: 0.9040322747393125
epoch: 15 ndcg: 0.8838207987188396
epoch: 16 ndcg: 0.8476004573701311
epoch: 17 ndcg: 0.8415805832838761
epoch: 18 ndcg: 0.8942976559708666
epoch: 19 ndcg: 0.8829057468894735
epoch: 20 ndcg: 0.9413478188337606
epoch: 21 ndcg: 0.9386434862613925
epoch: 22 ndcg: 0.9261453526861442
epoch: 23 ndcg: 0.9135855712921562
epoch: 24 ndcg: 0.05069075787896105
epoch: 25 ndcg: 0.9029567457832243
epoch: 26 ndcg: 0.040637576983257796
epoch: 27 ndcg: 0.9422096206844152
epoch: 28 ndcg: 0.6588493255564072
epoch: 29 ndcg: 0.5217631565948775
epoch: 30 ndcg: 0.8144786334990184
epoch: 31 ndcg: 0.9052019563596706
epoch: 32 ndcg: 0.9815772987206142
CPU times: user 12min 33s, sys: 50.8 s, total: 13min 24s

In [44]:
mlp_layer = sol.model.mlp

In [45]:
mlp_layer.state_dict()

OrderedDict([('1.weight',
              tensor([[-1.3675e-03,  9.7939e-02, -1.5025e-01, -1.3421e-01, -6.9578e-02,
                        5.1119e-02,  1.1802e-03,  1.5375e-01, -2.2322e-03,  6.3431e-02,
                       -5.8350e-02, -9.0235e-02, -2.9208e-01, -2.8554e-01, -2.5503e-01,
                       -1.5083e-01, -3.1750e-02,  6.9409e-02, -1.1293e-01, -4.6203e-02,
                        1.0435e-01,  1.9258e-01,  7.3940e-03,  1.8652e-01,  2.4456e-02,
                        7.1679e-02,  2.0797e-01, -1.4159e-01, -1.0320e-01, -5.9219e-02],
                      [-7.1165e-02,  1.5776e-01, -1.1824e-01, -8.3621e-02, -1.2641e-01,
                       -1.6893e-01, -1.0427e-01,  1.5859e-01,  8.2196e-02,  8.7204e-02,
                       -6.6780e-04, -1.2414e-01, -1.4524e-02, -2.1770e-01, -1.7867e-01,
                       -1.4649e-01,  5.4514e-02,  4.6886e-02, -1.3491e-01, -5.8005e-02,
                        6.6114e-02,  1.3351e-01,  2.9996e-02, -1.1296e-02,  9.0604e-02,
     

In [46]:
torch.save(mlp_layer,open("./pkl/mlp_pretr.tor","wb"))

In [49]:
pretrained_emb_krnm = sol.model.embeddings.state_dict()

In [53]:
pretrained_emb_krnm

OrderedDict([('weight',
              tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.0022, -0.0369,  0.0055,  ..., -0.1204, -0.0544, -0.0585],
                      [ 0.4532,  0.0598, -0.1058,  ...,  0.5324, -0.2510,  0.6255],
                      ...,
                      [-0.0022, -0.0369,  0.0055,  ..., -0.1204, -0.0544, -0.0585],
                      [-0.5138, -0.9359, -0.1946,  ...,  0.2532,  0.1363, -0.0305],
                      [-0.3993, -0.5018, -0.8642,  ...,  0.4815, -0.2657, -0.2634]]))])

In [52]:
torch.save(pretrained_emb_krnm, open("./pkl/pretr_emb_krnm_50.tor","wb"))

In [51]:
with open("./pkl/pretr_emb_krnm_50", "wb") as f:
    pickle.dump(pretrained_emb_krnm,f)

In [40]:
# with open("./pkl/pretr_emb_krnm_50", "rb") as fin:
#     emb_50 = pickle.load(fin)

## Tests

In [9]:
emb_matrix, vocab, unk_words = sol.create_glove_emb_from_file(
            sol.glove_vectors_path, sol.all_tokens, sol.random_seed, sol.emb_rand_uni_bound)

In [11]:
with open("./pkl/vocab", "wb") as f:
    pickle.dump(vocab,f)

In [11]:
sol.val_dataset[10]

({'query': [19, 115, 51, 7915, 142, 171, 6597, 44, 136],
  'document': [137, 812, 587, 43, 17, 2215, 19864, 8065, 595, 19889]},
 0)

In [12]:
for item in sol.val_dataset:
    print(item)
    break

({'query': [19, 115, 51, 7915, 142, 171, 6597, 44, 136], 'document': [19, 115, 51, 11926, 301, 142, 171, 6597, 44, 136]}, 2)


In [13]:
for item in sol.val_dataloader:
    res, label = item
    break

In [14]:
res['document'].shape

torch.Size([1024, 30])

In [15]:
knrm = sol.model

In [16]:
mmat = knrm._get_matching_matrix(res['query'], res['document'])
mmat.shape

torch.Size([1024, 24, 30])

In [17]:
KM = []
for kernel in sol.model.kernels:
    # shape = [B]
    K = torch.log1p(kernel(mmat).sum(dim=-1)).sum(dim=-1)
    KM.append(K)

# shape = [B, K]
kernels_out = torch.stack(KM, dim=1)
kernels_out.shape

torch.Size([1024, 22])

In [10]:
sol.model._get_mlp()

Sequential(
  (0): ReLU()
  (1): Linear(in_features=22, out_features=10, bias=True)
  (2): ReLU()
  (3): Linear(in_features=10, out_features=5, bias=True)
  (4): ReLU()
  (5): Linear(in_features=5, out_features=1, bias=True)
)

In [19]:
inp_df = sol.glue_train_df
min_group_size = 2

In [20]:
inp_df.head()

Unnamed: 0,id_left,id_right,text_left,text_right,label
0,213221,213222,How is the life of a math student? Could you d...,Which level of prepration is enough for the ex...,0
1,536040,536041,How do I control my horny emotions?,How do you control your horniness?,1
2,364011,490273,What causes stool color to change to yellow?,What can cause stool to come out as little balls?,0
3,155721,7256,What can one do after MBBS?,What do i do after my MBBS ?,1
4,279958,279959,Where can I find a power outlet for my laptop ...,"Would a second airport in Sydney, Australia be...",0


In [421]:
inp_df.shape

(363846, 5)

In [454]:
groupby_q = 'id_right'

inp_df_select = inp_df[['id_left', 'id_right', 'label']]
inf_df_group_sizes = inp_df_select.groupby(groupby_q).size()
glue_test_leftids_to_use = list(
    inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
groups = inp_df_select[inp_df_select.id_left.isin(
    glue_test_leftids_to_use)].groupby(groupby_q)

In [481]:
inp_df.label.value_counts()

0    229468
1    134378
Name: label, dtype: int64

In [479]:
groupby_q = 'id_left'
perm_min_samples = 10

inp_df_select = inp_df[['id_left', 'id_right', 'label']]
inf_df_group_sizes = inp_df_select.groupby(groupby_q).size()
glue_test_leftids_to_use = list(
    inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
groups = inp_df_select[inp_df_select.id_left.isin(
    glue_test_leftids_to_use)].groupby(groupby_q)

next_id = list(set(inp_df_select.columns.values) - set({groupby_q, 'label'}))[0] 

out_pairs = []

np.random.seed(0)

for idx, group in tqdm(groups, total = len(groups)):
    id_gr = set({str(idx)})
    ids = group[next_id].values
    perm = list(permutations(ids, 2))
    if len(perm) > 0:
        triples_lst = \
        [list(id_gr.union({id0, id1}))+[(group[group[next_id]==id0].label.values[0] >= 
                                         group[group[next_id]==id1].label.values[0]
                                        ).astype(int)
                                       ]

         for id0, id1 in perm
        ]

        triples_df = pd.DataFrame(triples_lst, columns=[groupby_q, next_id+'0', next_id+'1', 'label'])

        if len(perm) >= perm_min_samples:
            fracs = {0: 0.1, 1: 0.5}
        else:
            fracs = {0: 1., 1: 1.}
        sampled_df = pd.concat([dff.sample(frac=fracs.get(i), random_state = 0) for i,dff in triples_df.groupby('label')])

        out_pairs.extend(sampled_df.values.tolist())   

100%|██████████| 8275/8275 [00:52<00:00, 157.49it/s]


In [358]:
groups.get_group('10061')

Unnamed: 0,id_left,id_right,label
74599,10061,296775,1
161359,10061,14160,1
213407,10061,10062,1
227129,10061,38068,1
240820,10061,38069,1
275839,10061,135394,1
283645,10061,280645,0
287616,10061,76075,1
312024,10061,105493,1
317904,10061,72773,1


In [356]:
groups.get_group('100021')['label'].values.tolist()

[1, 1, 1, 1, 1, 1]

In [98]:
groups.groups

{'100001': [72169, 281245], '10001': [109028, 222911], '100014': [166599, 255056], '100016': [47531, 227668], '100017': [312234, 333088], '100021': [4489, 51524, 117562, 169432, 192293, 211110], '100022': [71771, 134299, 325283, 359495], '100024': [77008, 131059, 276451], '100028': [102332, 158540, 205807, 264483, 287050, 324542], '100041': [56880, 105325, 166916], '100048': [157076, 191857, 193508], '100068': [46372, 95341, 247219, 341451], '10008': [162671, 344368], '100084': [104374, 162216, 290276], '10009': [17486, 124792, 328372, 337960], '100091': [26535, 108070], '10010': [32927, 346526], '100104': [227867, 326091], '10011': [17702, 80430, 90966, 91742, 284918], '10013': [86060, 126705, 325087], '100134': [18002, 347315], '100136': [44476, 111770, 318911, 359155], '100149': [127622, 203513], '100153': [243359, 300529, 357390], '100154': [47938, 110460], '100160': [79695, 85570, 120549], '100162': [136875, 256736], '10017': [19970, 47004, 51892], '100172': [129669, 330482, 33048

In [370]:
gr = groups.get_group('10061')

In [372]:
gr.head()

Unnamed: 0,id_left,id_right,label
74599,10061,296775,1
161359,10061,14160,1
213407,10061,10062,1
227129,10061,38068,1
240820,10061,38069,1


In [314]:
sol.dev_pairs_for_ndcg[:5]

[['100141', '75743', 2],
 ['100141', '100142', 2],
 ['100141', '264410', 0],
 ['100141', '275077', 0],
 ['100141', '367828', 0]]

In [151]:
query_tok = torch.Tensor([[22, 45, 11, 15, 0, 0],
                          [22, 45, 11, 15, 0, 0]]).type(torch.int)
doc_tok = torch.Tensor([[43, 21, 9, 15, 0, 0],
                        [43, 21, 9, 15, 0, 0]]).type(torch.int)

In [132]:
BAT, A, B = query_tok.shape[0], query_tok.shape[1], doc_tok.shape[1]
assert doc_tok.shape[0] == BAT

In [137]:
def remove_padding(sim, query_tok, doc_tok, BAT, A, B):
    nul = torch.zeros_like(sim)
    sim = torch.where(query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == self.padding, nul, sim)
    sim = torch.where(doc_tok.reshape(BAT, 1, B).expand(BAT, A, B) == self.padding, nul, sim)
    return sim

def exact_match_matrix(query_tok, doc_tok, BAT, A, B):
    sim = (query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == doc_tok.reshape(BAT, 1, B).expand(BAT, A, B)).float()
    sim = remove_padding(sim, query_tok, doc_tok, BAT, A, B)
    return sim

def cosine_similarity_matrix(query_tok, doc_tok, BAT, A, B):
    a_emb, b_emb = torch.Tensor(emb_matrix[query_tok]), torch.Tensor(emb_matrix[doc_tok])
    a_denom = a_emb.norm(p=2, dim=2).reshape(BAT, A, 1).expand(BAT, A, B) + 1e-9  # avoid 0div
    b_denom = b_emb.norm(p=2, dim=2).reshape(BAT, 1, B).expand(BAT, A, B) + 1e-9  # avoid 0div
    perm = b_emb.permute(0, 2, 1)
    sim = a_emb.bmm(perm) / (a_denom * b_denom)
    sim = remove_padding(sim, query_tok, doc_tok, BAT, A, B)
    return sim

In [272]:
q = query_tok[0].numpy().tolist()
d = doc_tok[0].numpy().tolist()

a_emb, b_emb = torch.Tensor(emb_matrix[q]), torch.Tensor(emb_matrix[d])
norm_a = a_emb.norm(p=2, dim=1)
norm_b = b_emb.norm(p=2, dim=1)

sim = a_emb.matmul(b_emb.t())/norm_a/norm_b
sim = sim.nan_to_num()

sim

tensor([[0.8262, 0.5856, 0.3818, 0.6570, 0.0000, 0.0000],
        [0.4679, 0.2739, 0.4120, 0.4533, 0.0000, 0.0000],
        [0.9873, 0.6609, 0.3957, 0.6451, 0.0000, 0.0000],
        [0.4632, 0.6222, 0.3524, 1.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [302]:
krnm = KNRM(emb_matrix, freeze_embeddings = False)

In [303]:
matching_matrix = krnm._get_matching_matrix(query_tok[0], doc_tok[0])
# shape = [Batch, Kernels]
kernels_out = krnm._apply_kernels(matching_matrix)
# shape = [Batch]
out = krnm.mlp(kernels_out)

[tensor(5.0523e-19, grad_fn=<SumBackward1>), tensor(4.0939e-15, grad_fn=<SumBackward1>), tensor(1.2204e-11, grad_fn=<SumBackward1>), tensor(1.3383e-08, grad_fn=<SumBackward1>), tensor(5.3992e-06, grad_fn=<SumBackward1>), tensor(0.0008, grad_fn=<SumBackward1>), tensor(0.0435, grad_fn=<SumBackward1>), tensor(0.8049, grad_fn=<SumBackward1>), tensor(4.1637, grad_fn=<SumBackward1>), tensor(7.7497, grad_fn=<SumBackward1>), tensor(7.7836, grad_fn=<SumBackward1>), tensor(4.5836, grad_fn=<SumBackward1>), tensor(2.7560, grad_fn=<SumBackward1>), tensor(3.6246, grad_fn=<SumBackward1>), tensor(4.0500, grad_fn=<SumBackward1>), tensor(3.8926, grad_fn=<SumBackward1>), tensor(3.3240, grad_fn=<SumBackward1>), tensor(2.2559, grad_fn=<SumBackward1>), tensor(1.6252, grad_fn=<SumBackward1>), tensor(1.6986, grad_fn=<SumBackward1>), tensor(8.2863, grad_fn=<SumBackward1>)]


In [304]:
krnm.kernels

ModuleList(
  (0): GaussianKernel()
  (1): GaussianKernel()
  (2): GaussianKernel()
  (3): GaussianKernel()
  (4): GaussianKernel()
  (5): GaussianKernel()
  (6): GaussianKernel()
  (7): GaussianKernel()
  (8): GaussianKernel()
  (9): GaussianKernel()
  (10): GaussianKernel()
  (11): GaussianKernel()
  (12): GaussianKernel()
  (13): GaussianKernel()
  (14): GaussianKernel()
  (15): GaussianKernel()
  (16): GaussianKernel()
  (17): GaussianKernel()
  (18): GaussianKernel()
  (19): GaussianKernel()
  (20): GaussianKernel()
)

In [299]:
kernels_out

tensor([5.0523e-19, 4.0939e-15, 1.2204e-11, 1.3383e-08, 5.3992e-06, 8.0124e-04,
        4.3541e-02, 8.0487e-01, 4.1637e+00, 7.7497e+00, 7.7836e+00, 4.5836e+00,
        2.7560e+00, 3.6246e+00, 4.0500e+00, 3.8926e+00, 3.3240e+00, 2.2559e+00,
        1.6252e+00, 1.6986e+00, 8.2863e+00], grad_fn=<StackBackward>)

In [142]:
res = dict([tuple(re.sub("\n","",line).split(" ", 1)) for line in open(glove_path, 'r')])
fin = dict([(k, v.split(" ")) for k,v in res.items() if k not in string.punctuation])

In [10]:
res = torch.empty(3,10).uniform_(-0.2,0.2).type(torch.FloatTensor)

In [135]:
vec = np.random.uniform(-0.2,0.2, size=(10,))
["{:8.6f}".format(elem) for elem in vec]

['-0.078511',
 '-0.059738',
 '-0.126076',
 '0.127243',
 '-0.151578',
 '-0.120386',
 '-0.054364',
 '-0.058489',
 '0.168536',
 '-0.010920']