In [None]:
!wget https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip
!unzip -qq QQP-clean.zip

In [None]:
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip -qq glove.6B.zip

In [None]:
import string
from collections import Counter
from typing import Dict, List, Tuple, Union, Callable

import re
import nltk
import numpy as np
import math
import pandas as pd
import torch
import torch.nn.functional as F

nltk.download('punkt')
nltk.download('punkt_tab')

test = pd.read_csv('/content/QQP/test.tsv', sep='\t')
train = pd.read_csv('/content/QQP/train.tsv', sep='\t')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [None]:
glue_qqp_dir = '/data/QQP/'
glove_path = '/data/glove.6B.50d.txt'


class GaussianKernel(torch.nn.Module):
    def __init__(self, mu: float = 1., sigma: float = 1.):
        super().__init__()
        self.mu = mu
        self.sigma = sigma

    def forward(self, x):
        transformation = -torch.pow((x - self.mu), 2) / (2 * self.sigma ** 2)
        return torch.exp(transformation)


class KNRM(torch.nn.Module):
    def __init__(self, embedding_matrix: np.ndarray, freeze_embeddings: bool, kernel_num: int = 21,
                 sigma: float = 0.1, exact_sigma: float = 0.001,
                 out_layers: List[int] = [10, 5]):
        super().__init__()
        self.embeddings = torch.nn.Embedding.from_pretrained(
            torch.FloatTensor(embedding_matrix),
            freeze=freeze_embeddings,
            padding_idx=0
        )
        self.cosine_sim = torch.nn.CosineSimilarity(dim=3, eps=1e-6)
        self.kernel_num = kernel_num
        self.sigma = sigma
        self.exact_sigma = exact_sigma
        self.out_layers = out_layers

        self.kernels = self._get_kernels_layers()

        self.mlp = self._get_mlp()

        self.out_activation = torch.nn.Sigmoid()

    def _get_kernels_layers(self) -> torch.nn.ModuleList:

        def _get_kernels_mu(k):
            step = 1 / (k - 1)
            left = -1 + step
            right = 1 - step
            return np.hstack([np.arange(left, right, (right - left)/ (k - 2)), right, 1])

        kernels = torch.nn.ModuleList()
        mu = _get_kernels_mu(self.kernel_num)
        for m in mu:
            if m != 1:
                kernel = GaussianKernel(m, self.sigma)
            else:
                kernel = GaussianKernel(1, self.exact_sigma)
            kernels.append(kernel)
        return kernels

    def _get_mlp(self) -> torch.nn.Sequential:
        fnn_lst = []

        current_nn = self.kernel_num
        for l in self.out_layers:
            fnn_lst.append(torch.nn.ReLU())
            fnn_lst.append(torch.nn.Linear(current_nn, l))
            current_nn = l

        fnn_lst.append(torch.nn.ReLU())
        fnn_lst.append(torch.nn.Linear(current_nn, 1))

        return torch.nn.Sequential(*fnn_lst)

    def forward(self, input_1: Dict[str, torch.Tensor], input_2: Dict[str, torch.Tensor]) -> torch.FloatTensor:
        logits_1 = self.predict(input_1)
        logits_2 = self.predict(input_2)

        logits_diff = logits_1 - logits_2

        out = self.out_activation(logits_diff)
        return out

    def _get_matching_matrix(self, query: torch.Tensor, doc: torch.Tensor) -> torch.FloatTensor:
        alpha = 1e-6
        query = self.embeddings(query)
        doc = self.embeddings(doc)
        nominator = (query.unsqueeze(dim=2) * doc.unsqueeze(dim=1)).sum(axis=3)
        denuminator = torch.sqrt(torch.sum(query * query, axis=2) + alpha).unsqueeze(dim=2) * torch.sqrt(torch.sum(doc * doc, axis=2) + alpha).unsqueeze(dim=1)
        matrix = nominator / (denuminator)
        return matrix

    def _apply_kernels(self, matching_matrix: torch.FloatTensor) -> torch.FloatTensor:
        KM = []
        for kernel in self.kernels:
            # shape = [B]
            K = torch.log1p(kernel(matching_matrix).sum(dim=-1)).sum(dim=-1)
            KM.append(K)

        # shape = [B, K]
        kernels_out = torch.stack(KM, dim=1)
        return kernels_out

    def predict(self, inputs: Dict[str, torch.Tensor]) -> torch.FloatTensor:
        # shape = [Batch, Left], [Batch, Right]
        query, doc = inputs['query'], inputs['document']

        # shape = [Batch, Left, Right]
        matching_matrix = self._get_matching_matrix(query, doc)
        # shape = [Batch, Kernels]
        kernels_out = self._apply_kernels(matching_matrix)
        # shape = [Batch]
        out = self.mlp(kernels_out)
        return out


class RankingDataset(torch.utils.data.Dataset):
    def __init__(self, index_pairs_or_triplets: List[List[Union[str, float]]],
                 idx_to_text_mapping: Dict[str, str], vocab: Dict[str, int], oov_val: int,
                 preproc_func: Callable, max_len: int = 30):
        self.index_pairs_or_triplets = index_pairs_or_triplets
        self.idx_to_text_mapping = idx_to_text_mapping
        self.vocab = vocab
        self.oov_val = oov_val
        self.preproc_func = preproc_func
        self.max_len = max_len

    def __len__(self):
        return len(self.index_pairs_or_triplets)

    def _tokenized_text_to_index(self, tokenized_text: List[str]) -> List[int]:
        return [self.vocab.get(t, self.oov_val) for t in tokenized_text]

    def _convert_text_idx_to_token_idxs(self, idx: int) -> List[int]:
        text = self.idx_to_text_mapping[idx]
        preproc_text = self.preproc_func(text)
        return self._tokenized_text_to_index(preproc_text)

    def __getitem__(self, idx: int):
        pass


class TrainTripletsDataset(RankingDataset):
    def __getitem__(self, idx):
        # Один запрос сравнивается с двумя документами
        query_doc_label = self.index_pairs_or_triplets[idx]
        left_elem, right_elem = {}, {}
        left_elem['query'] = self._convert_text_idx_to_token_idxs(query_doc_label[0])
        left_elem['document'] = self._convert_text_idx_to_token_idxs(query_doc_label[1])
        right_elem['query'] = self._convert_text_idx_to_token_idxs(query_doc_label[0])
        right_elem['document'] = self._convert_text_idx_to_token_idxs(query_doc_label[2])
        label = torch.tensor(query_doc_label[3]).float()
        return left_elem, right_elem, label



class ValPairsDataset(RankingDataset):
    def __getitem__(self, idx):
        query_doc_label = self.index_pairs_or_triplets[idx]
        left_elem = {}
        left_elem['query'] = self._convert_text_idx_to_token_idxs(query_doc_label[0])
        left_elem['document'] = self._convert_text_idx_to_token_idxs(query_doc_label[1])
        label = torch.tensor(query_doc_label[2])
        return left_elem, label


def collate_fn(batch_objs: List[Union[Dict[str, torch.Tensor], torch.FloatTensor]]):
    max_len_q1 = -1
    max_len_d1 = -1
    max_len_q2 = -1
    max_len_d2 = -1

    is_triplets = False
    for elem in batch_objs:
        if len(elem) == 3:
            left_elem, right_elem, label = elem
            is_triplets = True
        else:
            left_elem, label = elem

        max_len_q1 = max(len(left_elem['query']), max_len_q1)
        max_len_d1 = max(len(left_elem['document']), max_len_d1)
        if len(elem) == 3:
            max_len_q2 = max(len(right_elem['query']), max_len_q2)
            max_len_d2 = max(len(right_elem['document']), max_len_d2)

    q1s = []
    d1s = []
    q2s = []
    d2s = []
    labels = []

    for elem in batch_objs:
        if is_triplets:
            left_elem, right_elem, label = elem
        else:
            left_elem, label = elem

        pad_len1 = max_len_q1 - len(left_elem['query'])
        pad_len2 = max_len_d1 - len(left_elem['document'])
        if is_triplets:
            pad_len3 = max_len_q2 - len(right_elem['query'])
            pad_len4 = max_len_d2 - len(right_elem['document'])

        q1s.append(left_elem['query'] + [0] * pad_len1)
        d1s.append(left_elem['document'] + [0] * pad_len2)
        if is_triplets:
            q2s.append(right_elem['query'] + [0] * pad_len3)
            d2s.append(right_elem['document'] + [0] * pad_len4)
        labels.append([label])
    q1s = torch.LongTensor(q1s)
    d1s = torch.LongTensor(d1s)
    if is_triplets:
        q2s = torch.LongTensor(q2s)
        d2s = torch.LongTensor(d2s)
    labels = torch.FloatTensor(labels)

    ret_left = {'query': q1s, 'document': d1s}
    if is_triplets:
        ret_right = {'query': q2s, 'document': d2s}
        return ret_left, ret_right, labels
    else:
        return ret_left, labels


class Solution:
    def __init__(self, glue_qqp_dir: str, glove_vectors_path: str,
                 min_token_occurancies: int = 1,
                 random_seed: int = 0,
                 emb_rand_uni_bound: float = 0.2,
                 freeze_knrm_embeddings: bool = True,
                 knrm_kernel_num: int = 21,
                 knrm_out_mlp: List[int] = [],
                 dataloader_bs: int = 1024,
                 train_lr: float = 0.001,
                 change_train_loader_ep: int = 10
                 ):
        self.glue_qqp_dir = glue_qqp_dir
        self.glove_vectors_path = glove_vectors_path
        now = round(time.time())
        self.glue_train_df = self.get_glue_df('train')
        print("self.get_glue_df('train')", round(time.time()) - now); now = round(time.time())
        self.glue_dev_df = self.get_glue_df('dev')
        print("self.get_glue_df('dev')", round(time.time()) - now); now = round(time.time())
        self.dev_pairs_for_ndcg = self.create_val_pairs(self.glue_dev_df)
        print("self.create_val_pairs", round(time.time()) - now); now = round(time.time())
        self.min_token_occurancies = min_token_occurancies
        self.all_tokens = self.get_all_tokens(
            [self.glue_train_df, self.glue_dev_df], self.min_token_occurancies)
        print("self.get_all_tokens", round(time.time()) - now); now = round(time.time())

        self.random_seed = random_seed
        self.emb_rand_uni_bound = emb_rand_uni_bound
        self.freeze_knrm_embeddings = freeze_knrm_embeddings
        self.knrm_kernel_num = knrm_kernel_num
        self.knrm_out_mlp = knrm_out_mlp
        self.dataloader_bs = dataloader_bs
        self.train_lr = train_lr
        self.change_train_loader_ep = change_train_loader_ep

        self.model, self.vocab, self.unk_words = self.build_knrm_model()
        print("self.build_knrm_model", round(time.time()) - now); now = round(time.time())
        self.idx_to_text_mapping_train = self.get_idx_to_text_mapping(
            self.glue_train_df)
        print("self.get_idx_to_text_mapping", round(time.time()) - now); now = round(time.time())
        self.idx_to_text_mapping_dev = self.get_idx_to_text_mapping(
            self.glue_dev_df)
        print("self.idx_to_text_mapping_dev", round(time.time()) - now); now = round(time.time())

        self.val_dataset = ValPairsDataset(self.dev_pairs_for_ndcg,
              self.idx_to_text_mapping_dev,
              vocab=self.vocab, oov_val=self.vocab['OOV'],
              preproc_func=self.simple_preproc)
        self.val_dataloader = torch.utils.data.DataLoader(
            self.val_dataset, batch_size=self.dataloader_bs, num_workers=0,
            collate_fn=collate_fn, shuffle=False)

    def get_glue_df(self, partition_type: str) -> pd.DataFrame:
        assert partition_type in ['dev', 'train']
        glue_df = pd.read_csv(
            self.glue_qqp_dir + f'/{partition_type}.tsv', sep='\t', on_bad_lines='skip', dtype=object)
        glue_df = glue_df.dropna(axis=0, how='any').reset_index(drop=True)
        glue_df_fin = pd.DataFrame({
            'id_left': glue_df['qid1'],
            'id_right': glue_df['qid2'],
            'text_left': glue_df['question1'],
            'text_right': glue_df['question2'],
            'label': glue_df['is_duplicate'].astype(int)
        })
        return glue_df_fin

    def handle_punctuation(self, inp_str: str) -> str:
        inp_str = re.sub(r"""[!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~]""", ' ', inp_str)
        return inp_str

    def simple_preproc(self, inp_str: str) -> List[str]:
        inp_str = self.handle_punctuation(inp_str)
        inp_str = inp_str.lower()
        return nltk.word_tokenize(inp_str)

    def _filter_rare_words(self, vocab: Dict[str, int], min_occurancies: int) -> Dict[str, int]:
        return {k: v for k, v in vocab.items() if v >= min_occurancies}

    def get_all_tokens(self, list_of_df: List[pd.DataFrame], min_occurancies: int) -> List[str]:
        all_text = pd.concat(list_of_df, ignore_index=True)
        all_text = pd.concat([all_text['text_left'], all_text['text_right']]).drop_duplicates()
        token_counts = Counter(token for sentence in all_text for token in self.simple_preproc(sentence))
        keys = self._filter_rare_words(token_counts, min_occurancies).keys()
        return list(keys)

    def _read_glove_embeddings(self, file_path: str) -> Dict[str, List[str]]:
        with open(file_path) as f:
            emb = f.readlines()
            embeddings = {}
            for row in emb:
                row = row.split()
                embeddings.update({row[0]: row[1:]})
        return embeddings

    def create_glove_emb_from_file(self, file_path: str, inner_keys: List[str],
                                   random_seed: int, rand_uni_bound: float
                                   ) -> Tuple[np.ndarray, Dict[str, int], List[str]]:
        embs = self._read_glove_embeddings(file_path)
        print('glove_emb', len(embs))
        dim = len(list(embs.values())[0])
        pad_token = np.zeros(dim)
        oov_token = np.random.uniform(-0.2, 0.2, dim)

        words = list(embs.keys())
        print('inner_keys', len(words))
        unk_words = list(set(inner_keys).difference(set(words)))
        inner_keys = ['PAD'] + ['OOV'] + inner_keys
        matrix =  np.random.uniform(-rand_uni_bound, rand_uni_bound, (len(inner_keys), dim))
        matrix[0] = pad_token
        matrix[1] = oov_token
        for n, w in enumerate(inner_keys[2:]):
            if w in embs:
                matrix[n + 2] = embs[w]
        matrix = np.array(matrix).astype(float)
        word2token = {k: n for n, k in enumerate(inner_keys)}
        return matrix, word2token, unk_words

    def build_knrm_model(self) -> Tuple[torch.nn.Module, Dict[str, int], List[str]]:
        emb_matrix, vocab, unk_words = self.create_glove_emb_from_file(
            self.glove_vectors_path, self.all_tokens, self.random_seed, self.emb_rand_uni_bound)
        torch.manual_seed(self.random_seed)
        knrm = KNRM(emb_matrix, freeze_embeddings=self.freeze_knrm_embeddings,
                    out_layers=self.knrm_out_mlp, kernel_num=self.knrm_kernel_num)
        return knrm, vocab, unk_words

    def sample_data_for_train_iter(self, inp_df: pd.DataFrame, fill_top_to: int = 5,
                                   min_group_size: int = 2, seed: int = 0) -> List[List[Union[str, float]]]:
        inp_df_select = inp_df[['id_left', 'id_right', 'label']]
        tripl_df = inp_df_select.merge(inp_df_select.sample(frac=0.8), on='id_left')
        tripl_df = tripl_df[tripl_df['id_right_x'] != tripl_df['id_right_y']]
        tripl_df['label_1'] = (tripl_df['label_x'] - tripl_df['label_y']) > 0
        # Смотрим сколько раз есть 1 в таргете
        inf_df_group_sizes = tripl_df.groupby('id_left')['label_1'].sum()
        del tripl_df['label_1']
        # Берем индексы вопросов, которые больше порога
        glue_train_leftids_to_use = list(
            inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
        # Оставляем только вопросы, прошедшие порог и группируем
        groups = tripl_df[tripl_df.id_left.isin(
            glue_train_leftids_to_use)].groupby('id_left')

        out_pairs = []

        np.random.seed(2)


        all_ids = set(tripl_df['id_right_x']).union(set(tripl_df['id_right_y'])).union(set(tripl_df['id_left']))

        # Итерируемся по получившемуся датасету
        # Имеем в виду, что первый документ не может быть нерелевантнее второго
        # (label_x == 0 & label_y == 1 такого быть не должно)
        for id_left, group in groups:
            # id_left - запрос, id_right_x - первый кандидат, id_right_y - второй кандидат
            # ID первого кандидата, являющегося более РЕЛЕВАНТНЫМ, чем второй
            ones_filter = (group.label_x > 0) & (group.label_y == 0)
            ones_ids_right = group[ones_filter].id_right_x.values  # label == 1
            ones_ids_left = group[ones_filter].id_right_y.values  # label == 0

            # ID первого кандидата, являющегося НЕРЕЛЕВАНТНЫМ, как и второй
            zeros_filter = (group.label_x == 0) & (group.label_y == 0)
            zeroes_ids_right = group[zeros_filter].id_right_x.values  # label == 0
            zeroes_ids_left = group[zeros_filter].id_right_y.values  # label == 0

            sum_len = len(ones_ids_right) + len(zeroes_ids_right)
            # Считаем сколько не достает до максимального числа примеров
            num_pad_items = max(0, fill_top_to - sum_len)
            if num_pad_items > 0:
                # Рандомно выбираем НЕРЕЛЕВАНТНЫХ из общего множества ID, которые не встречаются в этом примере
                cur_chosen = set(ones_ids_right).union(set(ones_ids_left)).union({id_left})
                pad_sample = np.random.choice(
                    list(all_ids - cur_chosen), num_pad_items * 2, replace=False).tolist()
                pad_sample_right = [i for n, i in enumerate(pad_sample) if n % 2 == 0]
                pad_sample_left = [i for n, i in enumerate(pad_sample) if n % 2 == 1]
            else:
                pad_sample_right = []
                pad_sample_left = []
            # Формируем итоговые список
            # 1 - дубликат; 0.5 - похож но не дубликат, 0.5 - вообще мимо (pad_sample)
            for i, j in zip(ones_ids_right, ones_ids_left):
                out_pairs.append([id_left, i, j, 1])
            for i, j in zip(zeroes_ids_right, zeroes_ids_left):
                out_pairs.append([id_left, i, j, 0.5])
            for i, j in zip(pad_sample_right, pad_sample_left):
                out_pairs.append([id_left, i, j, 0.5])
        return out_pairs

    def create_val_pairs(self, inp_df: pd.DataFrame, fill_top_to: int = 15,
                         min_group_size: int = 2, seed: int = 0) -> List[List[Union[str, float]]]:
        # Берем только нужные столбцы
        inp_df_select = inp_df[['id_left', 'id_right', 'label']]
        # Смотрим сколько раз встречается левый вопрос
        inf_df_group_sizes = inp_df_select.groupby('id_left').size()
        # Берем индексы вопросов, которые больше порога
        glue_dev_leftids_to_use = list(
            inf_df_group_sizes[inf_df_group_sizes >= min_group_size].index)
        # Оставляем только вопросы, прошедшие порог и группируем
        groups = inp_df_select[inp_df_select.id_left.isin(
            glue_dev_leftids_to_use)].groupby('id_left')

        all_ids = set(inp_df['id_left']).union(set(inp_df['id_right']))

        out_pairs = []

        np.random.seed(seed)

        # Итерируемся по получившемуся датасету
        for id_left, group in groups:
            # ID ПРАВОГО запроса, который является дубликатом ЛЕВОГО
            ones_ids = group[group.label > 0].id_right.values
            # ID ПРАВОГО запроса, который НЕ является дубликатом ЛЕВОГО
            zeroes_ids = group[group.label == 0].id_right.values
            sum_len = len(ones_ids) + len(zeroes_ids)
            # Считаем сколько не достает до максимального числа примеров
            num_pad_items = max(0, fill_top_to - sum_len)
            if num_pad_items > 0:
                # Рандомно выбираем из общего множества ID, которые не встречаются в этом примере
                cur_chosen = set(ones_ids).union(
                    set(zeroes_ids)).union({id_left})
                pad_sample = np.random.choice(
                    list(all_ids - cur_chosen), num_pad_items, replace=False).tolist()
            else:
                pad_sample = []
            # Формируем итоговые список
            # 2 - дубликат; 1 - похож но не дубликат, 2 - вообще мимо
            for i in ones_ids:
                out_pairs.append([id_left, i, 2])
            for i in zeroes_ids:
                out_pairs.append([id_left, i, 1])
            for i in pad_sample:
                out_pairs.append([id_left, i, 0])
        return out_pairs

    def get_idx_to_text_mapping(self, inp_df: pd.DataFrame) -> Dict[str, str]:
        left_dict = (
            inp_df
            [['id_left', 'text_left']]
            .drop_duplicates()
            .set_index('id_left')
            ['text_left']
            .to_dict()
        )
        right_dict = (
            inp_df
            [['id_right', 'text_right']]
            .drop_duplicates()
            .set_index('id_right')
            ['text_right']
            .to_dict()
        )
        left_dict.update(right_dict)
        return left_dict

    def ndcg_k(self, ys_true: np.array, ys_pred: np.array, ndcg_top_k: int = 10) -> float:
        ideal_dcg = self.dcg(ys_true, ys_true, ndcg_top_k)
        pred_dcg = self.dcg(ys_true, ys_pred, ndcg_top_k)
        return (pred_dcg / ideal_dcg).item() if ideal_dcg != 0 else 0

    def dcg(self, ys_true: np.array, ys_pred: np.array, ndcg_top_k: int):
        argsort = np.argsort(ys_pred, axis=0)[::-1]
        ys_true_sorted = ys_true[argsort]
        ret = 0
        for i, l in enumerate(ys_true_sorted[:ndcg_top_k], 1):
            ret += (2 ** l - 1) / math.log2(1 + i)
        return ret


    def valid(self, model: torch.nn.Module, val_dataloader: torch.utils.data.DataLoader) -> float:
        with torch.no_grad():
            labels_and_groups = val_dataloader.dataset.index_pairs_or_triplets
            labels_and_groups = pd.DataFrame(labels_and_groups, columns=['left_id', 'right_id', 'rel'])

            all_preds = []
            for batch in (val_dataloader):
                inp_1, y = batch
                preds = model.predict(inp_1)
                preds_np = preds.detach().numpy()
                all_preds.append(preds_np)
            all_preds = np.concatenate(all_preds, axis=0)
            labels_and_groups['preds'] = all_preds

            ndcgs = []
            for cur_id in labels_and_groups.left_id.unique():
                cur_df = labels_and_groups[labels_and_groups.left_id == cur_id]
                ndcg = self.ndcg_k(cur_df.rel.values.reshape(-1), cur_df.preds.values.reshape(-1))
                if np.isnan(ndcg):
                    ndcgs.append(0)
                else:
                    ndcgs.append(ndcg)
            return np.mean(ndcgs)

    def train(self, n_epochs: int):
        opt = torch.optim.SGD(self.model.parameters(), lr=self.train_lr)
        criterion = torch.nn.BCELoss()
        print(f'Baseline NDCG {self.valid(self.model, self.val_dataloader)}')
        for n_epoch in range(n_epochs):
            # self.model.train()
            if n_epoch % self.change_train_loader_ep == 0:
                self.train_pairs_for_ndcg = self.sample_data_for_train_iter(self.glue_train_df)
                self.train_dataset = TrainTripletsDataset(self.train_pairs_for_ndcg,
                                                    self.idx_to_text_mapping_train,
                                                    vocab=self.vocab,
                                                    oov_val=self.vocab['OOV'],
                                                    preproc_func=self.simple_preproc
                                                    )
                self.train_dataloader = torch.utils.data.DataLoader(
                    self.train_dataset,
                    batch_size=self.dataloader_bs,
                    num_workers=0,
                    collate_fn=collate_fn, shuffle=True
                    )
            for batch in self.train_dataloader:
                # assert batch[2].unique().size()[0] == 2
                preds = self.model(batch[0], batch[1])
                loss = criterion(preds, batch[2])
                opt.zero_grad()
                loss.backward()
                opt.step()
                print(f'Loss {loss.item()}')
            print(f'Epoch {n_epoch + 1}: mean NDCG {self.valid(self.model, self.val_dataloader)}')

In [None]:
%%time
s = Solution(glue_qqp_dir = '/content/QQP', glove_vectors_path = '/content/glove.6B.50d.txt', knrm_out_mlp=[], freeze_knrm_embeddings=True, train_lr=0.01)

self.get_glue_df('train') 3
self.get_glue_df('dev') 1
self.create_val_pairs 107
self.get_all_tokens 0
glove_emb 400000
inner_keys 400000
self.build_knrm_model 83
self.get_idx_to_text_mapping 1
self.idx_to_text_mapping_dev 0
CPU times: user 3min 7s, sys: 2.54 s, total: 3min 10s
Wall time: 3min 14s


In [None]:
%%time
s.train(9)

Baseline NDCG 0.411420366205633
Loss 1.041214942932129
Loss 1.1276235580444336
Loss 0.7976537942886353
Loss 0.784345805644989
Loss 0.7953973412513733
Loss 0.8058797717094421
Loss 0.7783726453781128
Loss 0.756111741065979
Loss 0.7875634431838989
Loss 0.77696692943573
Loss 0.7584382891654968
Loss 0.7275822758674622
Loss 0.7607614398002625
Loss 0.7257606387138367
Loss 0.7532147169113159
Loss 0.7221930623054504
Loss 0.7237700819969177
Loss 0.7482706904411316
Loss 0.7351915836334229
Loss 0.7341530323028564
Loss 0.7251180410385132
Loss 0.724047064781189
Loss 0.7236614227294922
Epoch 1: mean NDCG 0.5260306691098524
Loss 0.7088751196861267
Loss 0.7389498949050903
Loss 0.7080919146537781
Loss 0.7210667729377747
Loss 0.731951892375946
Loss 0.7185767889022827
Loss 0.7334063053131104
Loss 0.7173170447349548
Loss 0.7064782381057739
Loss 0.7313122153282166
Loss 0.7249554395675659
Loss 0.7044662237167358
Loss 0.7269113659858704
Loss 0.7142452001571655
Loss 0.8230769038200378
Loss 0.7308399081230164
L

In [None]:
torch.save(s.model.embeddings.state_dict(), 'embeddings.bin')

In [None]:
import json

with open('vocab.json', "w") as file:
    json.dump(s.vocab, file)

In [None]:
 torch.save(s.model.mlp.state_dict(), 'knrm_mlp.bin')