In [1]:
!wget r2d2.fit.vutbr.cz/checkpoints/nq-open/reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip


--2021-04-18 10:18:36--  http://r2d2.fit.vutbr.cz/checkpoints/nq-open/reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip
Resolving r2d2.fit.vutbr.cz (r2d2.fit.vutbr.cz)... 147.229.13.167, 2001:67c:1220:80c::93e5:da7
Connecting to r2d2.fit.vutbr.cz (r2d2.fit.vutbr.cz)|147.229.13.167|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 462347964 (441M) [application/zip]
Saving to: ‘reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip’


2021-04-18 10:19:21 (9.88 MB/s) - ‘reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip’ saved [462347964/462347964]



In [2]:
!unzip reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip

Archive:  reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt.zip
  inflating: reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt  


In [None]:
!pip install -q transformers==4.3.3 torchtext==0.4.0

In [1]:
import torchtext
import torch


class RerankerDataset(torchtext.data.Dataset):

    def __init__(self, data, query_builder, passages_per_query=1, numerized=False, **kwargs):
        self.query_builder = query_builder
        self.passages_per_query = passages_per_query
        self.numerized = numerized

        fields = self.prepare_fields(1.)
        examples = self.get_example_list(data, fields)

        super().__init__(examples, fields, **kwargs)

    def get_example_list(self, data, fields):
        question = data["question"]
        passages = data["passages"]

        if not self.numerized:
            question = self.query_builder.tokenize_and_convert_to_ids(question)
            passages = [(self.query_builder.tokenize_and_convert_to_ids(item[0]), self.query_builder.tokenize_and_convert_to_ids(item[1])) for item in passages]

        max_length = self.query_builder.max_seq_length if self.query_builder.max_seq_length else self.query_builder.tokenizer.model_max_length
        max_length-= self.query_builder.num_special_tokens_to_add
        max_length-= len(question)

        query_length = 0
        query_passages = []
        examples = []
        for (t, p) in passages:

            if query_length + len(t) + len(p) + 2 > max_length or len(query_passages) >= self.passages_per_query:
                features = self.query_builder(question, query_passages, self.numerized)
                examples.append(
                    torchtext.data.Example.fromlist(
                        [
                            features["input_ids"],
                            features["attention_mask"],
                        ], 
                        fields
                    )
                )
                query_passages = []
                query_length = 0

            query_passages.append((t, p))
            query_length += len(t) + len(p) + 2

        features = self.query_builder(question, query_passages, self.numerized)
        examples.append(
            torchtext.data.Example.fromlist(
                [
                    features["input_ids"],
                    features["attention_mask"],
                ], 
                fields
            )
        )

        return examples

    @staticmethod
    def prepare_fields(pad_t):
        return [
            ("input_ids", torchtext.data.Field(use_vocab=False, batch_first=True, sequential=True, pad_token=pad_t)),
            ("attention_mask", torchtext.data.Field(use_vocab=False, batch_first=True, sequential=True, pad_token=0.)),
            #("score_mask", torchtext.data.Field(use_vocab=False, batch_first=True, sequential=True, pad_token=float("-Inf")))
        ]

    @classmethod
    def download(cls, root, check=None):
        raise NotImplementedError


    def filter_examples(self, field_names):
        raise NotImplementedError

In [2]:
import random
import time
import traceback
import math
import os
import sys
import logging
import transformers
import torch
import torchtext
import tqdm

from collections import Counter
from torch.utils.data import DataLoader, RandomSampler


LOGGER = logging.getLogger(__name__)
SEED = 1601640139674    # seed for deterministic shuffle of passages on longformer input


class RerankerFramework(object):
    """ Passage reranker trainner """
    def __init__(self, device, config, train_dataloader=None, val_dataloader=None):
        self.LOGGER = logging.getLogger(self.__class__.__name__)

        self.device = device
        self.config = config
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

    def train(self,
              model,
              save_ckpt=None,
              num_epoch=5,
              learning_rate=1e-5,
              batch_size=1,
              iter_size=16,
              warmup_proportion=0.1,
              weight_decay_rate=0.01,
              no_decay=['bias', 'gamma', 'beta', 'LayerNorm.weight'],
              fp16=False,
              criterion=None
            ):
        # Add trainig configuration       
        self.config["training"] = {}
        self.config["training"]["num_epoch"] = num_epoch
        self.config["training"]["lr"] = learning_rate
        self.config["training"]["train_batch_size"] = batch_size
        self.config["training"]["iter_size"] = iter_size
        self.config["training"]["warmup_proportion"] = warmup_proportion
        self.config["training"]["weight_decay_rate"] = weight_decay_rate
        self.config["training"]["no_decay"] = no_decay
        self.config["training"]["fp16"] = fp16
        self.config["training"]["criterion"] = criterion

        self.LOGGER.info("Start training...")

        param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': weight_decay_rate},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
        optimizer = transformers.AdamW(
            optimizer_grouped_parameters, 
            lr=learning_rate,
            correct_bias=False
        )

        if not criterion:
            criterion = torch.nn.CrossEntropyLoss()

        num_training_steps = int(len(self.train_dataloader.dataset) / (iter_size) * num_epoch)
        num_warmup_steps = int(num_training_steps * warmup_proportion)
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer, 
            num_warmup_steps=num_warmup_steps, 
            num_training_steps=num_training_steps
        )

        start_time = time.time()

        self.iter = 0

        try:
            self.best_val_accuracy = -math.inf
            for epoch in range(1, num_epoch+1):
                LOGGER.info(f"Epoch {epoch} started.")

                self.train_epoch(model, optimizer, scheduler, criterion, epoch, iter_size, batch_size, fp16, save_ckpt)

            metrics = self.validate(model, self.val_dataloader, criterion)

            for key, value in metrics.items():
                LOGGER.info("Validation after '%i' iterations.", self.iter)
                LOGGER.info(f"{key}: {value:.4f}")

            if metrics["HIT@25"] > self.best_val_accuracy:
                LOGGER.info("Best checkpoint.")
                self.best_val_accuracy = metrics["HIT@25"]

            if save_ckpt:
                self.save_model(model, self.config, optimizer, scheduler,
                                save_ckpt+f"_HIT@25_{metrics['HIT@25']}.ckpt")

        except KeyboardInterrupt:
            LOGGER.info('Exit from training early.')
        except:
            LOGGER.exception("An exception was thrown during training: ")
        finally:
            LOGGER.info('Finished after {:0.2f} minutes.'.format((time.time() - start_time) / 60))

    def train_epoch(self, model, optimizer, scheduler, criterion, epoch, 
                    iter_size, batch_size, fp16, save_ckpt):
        model.train()

        train_loss = 0
        train_right = 0

        total_preds = []
        total_labels = []

        postfix = {"loss": 0., "accuracy": 0., "skip": 0}
        iter_ = tqdm.tqdm(enumerate(self.train_dataloader, 1), desc="[TRAIN]", total=len(self.train_dataloader.dataset) // self.train_dataloader.batch_size, postfix=postfix)

        optimizer.zero_grad()

        for it, batch in iter_:
            update = False
            try:
                data = {key: values.to(self.device) for key, values in batch.items()}
                logits = model(data)

                loss = criterion(logits, data["labels"]) / iter_size

                pred = torch.argmax(logits, dim=1)
                right = torch.mean((data["labels"].view(-1) == pred.view(-1)).float(), 0)

                train_loss += loss.item()
                train_right += right.item()

                postfix.update({"loss": "{:.6f}".format(train_loss / it), "accuracy": train_right / it})
                iter_.set_postfix(postfix)

                total_preds += list(pred.cpu().numpy())
                total_labels += list(data["labels"].cpu().numpy())

                loss.backward()

                if it % iter_size == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    update = True
                    self.iter += 1

                    if self.iter % 5000 == 0:
                        metrics = self.validate(model, self.val_dataloader, criterion)

                        for key, value in metrics.items():
                            LOGGER.info(f"Validation {key} after {self.iter} iteration: {value:.4f}")

                        if metrics["HIT@25"] > self.best_val_accuracy:
                            LOGGER.info("Best checkpoint.")
                            self.best_val_accuracy = metrics["HIT@25"]

                        if save_ckpt:
                            self.save_checkpoint(model, config, optmizer, scheduler,
                                                 save_ckpt+f"_HIT@25_{metrics['HIT@25']}.ckpt")

            except RuntimeError as e:
                if "CUDA out of memory." in str(e):
                    logging.debug(f"Allocated memory befor: {torch.cuda.memory_allocated(0)}")
                    torch.cuda.empty_cache()
                    logging.debug(f"Allocated memory after: {torch.cuda.memory_allocated(0)}")
                    logging.error(e)
                    tb = traceback.format_exc()
                    logging.error(tb)
                    postfix["skip"] += 1
                    iter_.set_postfix(postfix)
                else:
                    raise e
  
        if not update:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        print("")

        LOGGER.debug("Statistics in train time.")
        LOGGER.debug("Histogram of predicted passage: %s", str(Counter(total_preds)))
        LOGGER.debug("Histogram of labels: %s", str(Counter(total_labels)))

        LOGGER.info('Epoch is ended, samples: {0:5} | loss: {1:2.6f}, accuracy: {2:3.2f}%'.format(len(self.train_dataloader), train_loss / len(self.train_dataloader), 100 * train_right / len(self.train_dataloader)))
        return {
            "accuracy": train_right / len(self.train_dataloader)
        }

    def _update_parameters(self, optimizer, scheduler, dataloader, it, iter_size, train_loss, train_right):
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        sys.stdout.write('[TRAIN] step: {0:5}/{1:5} | loss: {2:2.6f}, accuracy: {3:3.2f}%'.format(it//iter_size, len(dataloader.dataset)//iter_size, train_loss / it, 100 * train_right / it) +'\r')
        sys.stdout.flush()

    @torch.no_grad()
    def validate(self, model, dataloader, criterion):
        model.eval()

        hits_k = [1, 2, 5, 10, 25, 50, dataloader.dataset.passages_in_batch]
        hits_sum = [0 for _ in hits_k]

        iter_ = tqdm.tqdm(enumerate(dataloader, 1), desc="[EVAL]", total=len(dataloader))

        for it, data in iter_:
            batch = {key: data[key].to(self.device) for key in ["input_ids", "attention_mask"]}

            batch_scores = model(batch)
            batch_scores = batch_scores.view(-1)
            batch_scores = batch_scores[batch_scores != float("-Inf")]

            top_k = batch_scores.shape[0]

            top_k_indices = torch.topk(batch_scores, top_k)[1].tolist()

            hit_rank = -1
            for hit_idx, seq_idx in enumerate(top_k_indices):
                if seq_idx in data["hits"]:
                    hit_rank = hit_idx
                    break

            for i, k in enumerate(hits_k):
                hits_sum[i]+= 1 if -1 < hit_rank < k else 0

        return {
            f"HIT@{key}": value/it for key, value in zip(hits_k, hits_sum)
        }

    @classmethod
    @torch.no_grad()
    def infer_longformer(cls, model, query_builder, question, support, return_top=20, 
                         max_passage_batch=20, top_k_from_retriever=5, numerized=False, 
                         batch_size=1, device=None):

        model.eval()
        if not numerized:
            question = query_builder.tokenize_and_convert_to_ids(question)
            support = [(query_builder.tokenize_and_convert_to_ids(title), query_builder.tokenize_and_convert_to_ids(context)) for title, context in support]

        indeces = list(range(top_k_from_retriever, len(support)))

        random.seed(SEED)
        random.shuffle(indeces)

        support = [support[idx] for idx in indeces]

        scores = []

        dataset = RerankerDataset({"question": question, "passages": support},
                query_builder, passages_per_query=max_passage_batch, 
                numerized = True)
        data_loader = torchtext.data.BucketIterator(
                dataset, batch_size=batch_size, shuffle=False, sort=False, 
                repeat=False)

        for batch in data_loader:
            batch = {key: getattr(batch, key).to(device) for key in ["input_ids", "attention_mask"]}

            batch_scores = model(batch)
            batch_scores = batch_scores.view(-1)
            batch_scores = batch_scores[batch_scores != float("-Inf")]

            scores.append(batch_scores)

        scores = torch.cat(scores).unsqueeze(0)


        top_k_indeces = torch.topk(scores, return_top-top_k_from_retriever)[1][0]
        indeces = [indeces[idx] for idx in top_k_indeces]
        scores = [scores[0][idx].item() for idx in top_k_indeces]

        if top_k_from_retriever > 0:
            indeces = list(range(top_k_from_retriever)) + indeces
            scores = top_k_from_retriever*[max(scores)] + scores

        return {
            "indeces": indeces,
            "scores": scores
        }

    @classmethod
    def save_model(cls, model, config, path):
        LOGGER.info(f"Save checkpoint '{path}'.")
        dict_to_save = {}
        dict_to_save["model"]  = model.state_dict()
        dict_to_save["config"] = config

        torch.save(dict_to_save, path)

    @classmethod
    def load_model(cls, path, device):
        if os.path.isfile(path):
            model = torch.load(path, map_location=device)
            LOGGER.info(f"Successfully loaded checkpoint '{path}'")
            return model["model"], model["config"]
        else:
            raise Exception(f"No checkpoint found at '{path}'")

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
reranker_model = "reranker_roberta-base_2021-02-25-17-27_athena19_HIT@25_0.8299645997487725.ckpt"
model_state_dict, model_config = RerankerFramework.load_model(reranker_model, device)

In [5]:
reranker_type = model_config["reranker_model_type"]
config = model_config["encoder_config"]
encoder = model_config["encoder"]

In [8]:
from transformers import RobertaConfig
config_dict = config.to_dict()
config_dict['use_cache'] = False
config = RobertaConfig.from_dict(config_dict)

In [9]:
config

RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.3.3",
  "type_vocab_size": 1,
  "use_cache": false,
  "vocab_size": 50265
}

In [10]:
from transformers import AutoTokenizer, AutoConfig, AutoModel
tokenizer = AutoTokenizer.from_pretrained(encoder)
encoder = AutoModel.from_config(config)

In [11]:
class BaselineRerankerQueryBuilder(object):

    def __init__(self, tokenizer, max_seq_length):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.start_context_token_id = self.tokenizer.convert_tokens_to_ids("madeupword0000")
        self.start_title_token_id = self.tokenizer.convert_tokens_to_ids("madeupword0001")

    def tokenize_and_convert_to_ids(self, text):
        tokens = self.tokenizer.tokenize(text)
        return self.tokenizer.convert_tokens_to_ids(tokens)

    @property
    def num_special_tokens_to_add(self):
        return self.tokenizer.num_special_tokens_to_add(pair=True)

    def __call__(self, question, passages, numerized=False):
        if not numerized:
            question = self.tokenize_and_convert_to_ids(question)
            passages = [(self.tokenize_and_convert_to_ids(item[0]), self.tokenize_and_convert_to_ids(item[1])) for item in passages]

        cls = self.tokenizer.convert_tokens_to_ids([self.tokenizer.bos_token])
        sep = self.tokenizer.convert_tokens_to_ids([self.tokenizer.sep_token])
        eos = self.tokenizer.convert_tokens_to_ids([self.tokenizer.eos_token])

        input_ids_list = []

        for passage in passages:
            input_ids = cls + question + sep + sep
            input_ids.extend([self.start_title_token_id] + passage[0])
            input_ids.extend([self.start_context_token_id] + passage[1] + eos)

            if len(input_ids) > self.max_seq_length:
                input_ids = input_ids[:self.max_seq_length-1] + eos

            input_ids_list.append(input_ids)
    
        seq_len = max(map(len, input_ids_list))

        input_ids_tensor = torch.ones(len(input_ids_list), seq_len).long()
        attention_mask_tensor = torch.zeros(len(input_ids_list), seq_len).long()

        for batch_index, input_ids in enumerate(input_ids_list):

            for sequence_index, value in enumerate(input_ids):
                input_ids_tensor[batch_index][sequence_index] = value
                attention_mask_tensor[batch_index][sequence_index] = 1.

        features = {
            "input_ids": input_ids_tensor,
            "attention_mask": attention_mask_tensor
        }

        return features

In [12]:
query_builder = BaselineRerankerQueryBuilder(tokenizer, model_config["max_length"])


In [13]:
class BaselineReranker(torch.nn.Module):
    """ Baseline passage reranker used in the paper. """

    def __init__(self, config, encoder):
        super().__init__()

        self.config = config
        self.encoder = encoder
        self.vt = torch.nn.Linear(config.hidden_size, 1, bias=False)

        self.init_weights(type(self.encoder))

    def init_weights(self, clz):
        """ Applies model's weight initialization to all non-pretrained parameters of this model"""
        for ch in self.children():
            if issubclass(ch.__class__, torch.nn.Module) and not issubclass(ch.__class__, PreTrainedModel):
                ch.apply(lambda module: clz._init_weights(self.encoder, module))

    def forward(self, batch):
        """
        The input looks like:
        [CLS] Q [SEP] <t> title <c> context [EOS]
        """

        inputs = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"]
        }

        outputs = self.encoder(**inputs)[1]

        scores = self.vt(outputs)
        scores = scores.view(1,-1)

        return scores

In [14]:
from transformers import PreTrainedModel
model = BaselineReranker(
                config,
                encoder)

In [15]:
model.load_state_dict(model_state_dict, strict=False)
model = model.to(device)
model.eval()

BaselineReranker(
  (encoder): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm(

In [16]:
question = "How is the PM of India?"
passages = ["I am PM of india", "Narendra Modi PM of  india", "I love cricket"]
k_top = 3
batch_size = 2

In [17]:
scores = []
for i in range(0, k_top, batch_size):
    passages_sublist = passages[i:i+batch_size]
    batch = query_builder(question, passages_sublist, False)
    batch = {key: batch[key].to(device) for key in ["input_ids", "attention_mask"]}
    batch_scores = model(batch)
    batch_scores = batch_scores.view(-1)
    scores.append(batch_scores)

In [19]:
scores = torch.cat(scores).unsqueeze(0)
top_k_indeces = torch.topk(scores, k_top)[1][0]

In [20]:
top_k_indeces

tensor([0, 2, 1])