# MSMARCO Model Performing on ELI5 Dataset

# End-to-End Agent Interaction

## Contributions
1. A system that takes in an input query and is capable of either generating a dialogue response or answering the question if possible given its existing knowledge, as well as identifying the intent of the input.
2. Our finetuned models of various sizes and trained on various datasets that are capable of answering questions reasonably well from any context its given, and ones for intent identification under both classification and masked language modelling cases.
3. An open-source codebase to be able to easily finetune any of the existing models (ours or huggingface's) on your own custom dataset. Useful for if you have domain-specific data.

## End-to-End Pipeline
1. Classify if the user query is something that requires a generic dialogue response or a question (Did they say "hello" or did they say "how do I get my TCard?"
2. If generic dialogue, use the generic dialogue model to generate a response
3. If a question, identify the intent of the question
4. Use a SentenceTransformer to identify which context is most similar to the question (therefore most likely to contain the answer)
5. Input the selected context and the user query into the question-answering model

## In This Notebook
- **Binary Classification**: Performed with intent identification model finetuned for binary classifaction
- Dialogue: Performed by huggingface's pretrained DialoGPT
- *Intent Identification*: (not present in this notebook)
- Context Retrieval: Performed by a pretrained Sentence Transformer that identifies which context is the most semantically similar to the input query.
- **Answer Generation**: Our finetuned GPT2 for QA

## Future Steps
- We can further improve the pipeline by finetuning the Dialogue and Context Retrieval steps.
    - DialoGPT was trained on reddit data, and is prone to informal dialogue and Star Wars memes.
    - SentenceTransformer can get finetuned on domain specific data to improve its accuracy for getting the best context. As things currently are, it is the biggest bottleneck affecting performance of the end-to-end agent if it provides a bad context.

In [1]:
from argparse import Namespace
import json
import time
import warnings
import sys
import pickle
warnings.filterwarnings("ignore")
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import (BertForMaskedLM, AutoTokenizer, AutoConfig, BertModel, BertConfig, 
BertTokenizer, BertForSequenceClassification, GPT2Tokenizer
)
from transformers import AutoModelForCausalLM
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as sentenceutils
from sentence_transformers import SentenceTransformer, CrossEncoder, evaluation, losses, InputExample, datasets

sys.path.insert(0,'../answer_generation')
from utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LABEL_COLUMNS = ['general', 'request']

In [2]:
# Load tokenizers.
print('Loading tokenizers')
berttokenizer = BertTokenizer.from_pretrained('bert-base-cased')
#gpt2tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

Loading tokenizers


In [3]:
class BinaryClassifier:
    """
    Classify if it's general dialogue or a request.
    """
    def __init__(self):
        chk_path = PATH+"/checkpoints/binary_classification/binary_classification.pth"
        checkpoint = torch.load(chk_path, map_location=torch.device('cpu'))

        self.model = BertForSequenceClassification.from_pretrained(
            "bert-base-cased", # Use the 12-layer BERT model, with an uncased vocab.
            num_labels = 2
        )
        self.model.load_state_dict(checkpoint)
        self.model.eval()
        self.model.to(device)
        self.softmax_layer = torch.nn.Softmax()
        self.LABEL_COLUMNS = ['general', 'request']
    
    def tokenize_sentences(self, sentences):
        # Tokenize all of the sentences and map the tokens to thier word IDs.
        input_ids = []
        attention_masks = []

        # For every sentence...
        for sent in sentences:
            # encode_plus will:
            #   (1) Tokenize the sentence.
            #   (2) Prepend the [CLS] token to the start.
            #   (3) Append the [SEP] token to the end.
            #   (4) Map tokens to their IDs.
            #   (5) Pad or truncate the sentence to max_length
            #   (6) Create attention masks for [PAD] tokens.
            encoded_dict = berttokenizer.encode_plus(
                                sent,                      # Sentence to encode.
                                add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                                max_length = 320,           # Pad & truncate all sentences.
                                padding='max_length',
                                return_attention_mask = True,   # Construct attn. masks.
                                return_tensors = 'pt',     # Return pytorch tensors.
                        )

            # Add the encoded sentence to the list.
            input_ids.append(encoded_dict['input_ids'])

            # And its attention mask (simply differentiates padding from non-padding).
            attention_masks.append(encoded_dict['attention_mask'])

        # Convert the lists into tensors.
        input_ids = torch.cat(input_ids, dim=0)
        attention_masks = torch.cat(attention_masks, dim=0)
        return input_ids,attention_masks
    
    def classify(self, query, printout=True):
        #tokenize inputted sentence to be compatible with BERT inputs
        token_ids, attention_masks = self.tokenize_sentences([query])

        #get a tensor containing probabilities of inputted sentence being irrelevant or relevant
        model_outputs = (self.model(token_ids.to(device), token_type_ids=None, attention_mask=attention_masks.to(device)))
        result = self.softmax_layer(model_outputs[0])

        #identify which output node has higher probability and what that probability is
        prediction = torch.argmax(result).item()
        confidence = torch.max(result).item()
        if printout:
            print("The class is: " + self.LABEL_COLUMNS[prediction] + " with {:.2f}% confident".format(confidence*100))

        return LABEL_COLUMNS[prediction], confidence

In [4]:
class DialogueGeneration:
    """
    Generic Dialogue Generation
    
    NOTE: DialoGPT is prone to making Star Wars references.
    History for chat is not implemented
    """
    def __init__(self, modelname="microsoft/DialoGPT-medium"):
        self.tokenizer = AutoTokenizer.from_pretrained(modelname)
        self.model = AutoModelForCausalLM.from_pretrained(modelname)
        
    def generate(self, query, printout=True):
        # format input
        step = 0
        new_user_input_ids = self.tokenizer.encode(query + self.tokenizer.eos_token, return_tensors='pt')
        bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
        chat_history_ids = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id)
        
        # generate response
        response = self.tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
        
        if printout:
            print("DialoGPT Response:", response)
        return response

## Semantic Search & Re-Ranker

The semantic search function performs the initial passage retrieval using a bi-encoder. The passage re-ranking is done using a cross encoder. Both are pre-trained encoders and implemented in the same function below. 

In [5]:
with open('/ssd003/projects/aieng/conversational_ai/demo/data/Eli5/Eli5_reranked/eli5_train_reranked.json', 'r') as f:
    eli5 = json.load(f)
eli5 = pd.read_json(eli5, orient='records')

In [6]:
passages = []
for i in range(0,len(eli5)):
    passages.append(eli5['passages'][i][0]['text'])

In [7]:
# load encoders 
bi_encoder = SentenceTransformer('msmarco-bert-base-dot-v5')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2',default_activation_function=nn.Sigmoid())

In [8]:
# load corpus embeddings 
with open('/ssd003/projects/aieng/conversational_ai/demo/data/Eli5/biencoder_embeddings/msmarco-bert-base-dot-v5.pickle', 'rb') as pkl:
    corpus_embeddings = pickle.load(pkl)

In [9]:
def search_and_rank(query, context_size = 3, printout=True):
    top_k=50
    # ------ PASSAGE RETRIEVAL ------
    start_time = time.time()
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = sentenceutils.semantic_search(question_embedding, corpus_embeddings, top_k=top_k, score_function=sentenceutils.dot_score)
    hits = hits[0]  # Get the hits for the first query
    end_time = time.time()
    
    if printout: print("Input question:", query)
    if printout: print("\n-------------------------\n")
    if printout: print("Top 10 passages (after {:.3f} seconds):".format(end_time - start_time))
    
    for hit in hits:
        if printout: print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']]))
        hit['passage'] = passages[hit['corpus_id']]
    
    # ------ RE-RANKER -----
    # score passages
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)
    
    # sort results
    for i in range(len(cross_scores)):
        hits[i]['cross-score'] = cross_scores[i]

    if printout: print("\n-------------------------\n")
    if printout: print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    
    for hit in hits[0:context_size]:
        if printout: print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
        hit['context'] = passages[hit['corpus_id']]
    return hits[0:context_size]

## Answer Generation using GPT-2

In [10]:
class AnswerGeneration:
    """
    Answer Generation
    """
    def __init__(self):
        self.args = Namespace(
            # fill in where you have stored the checkpoint information and file
            model_checkpoint_dir=PATH+ "/ssd003/projects/aieng/conversational_ai/demo/checkpoints/marco_gpt2medium/",
            model_checkpoint_file=PATH+ "/ssd003/projects/aieng/conversational_ai/demo/checkpoints/marco_gpt2medium/checkpoint_epoch5_step587435.pth",
            no_sample=True,
            max_length=100,
            min_length=1,
            seed=39,
            temperature=0.7,
            top_k=100,
            top_p=0.,  # I recommend setting this to 0 so its more likely to not say "I don't know"
            device=("cuda" if torch.cuda.is_available() else "cpu"),
            force_answer=False  # discard any "I don't know"s and take the next best prediction
        )
        
        # Initializing GPT2 Tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.args.model_checkpoint_dir)

        # Initializing pretrained model
        config = GPT2Config.from_json_file(self.args.model_checkpoint_dir + 'config.json')
        state_dict = torch.load(self.args.model_checkpoint_file)
        if 'model' in state_dict:
            state_dict = state_dict['model']
        self.model = GPT2LMHeadModel.from_pretrained(self.args.model_checkpoint_file, config=config,
                                                     state_dict=state_dict)
        self.model.to(device)
        self.model.eval()

        # add our special tokens to the model
        add_special_tokens(self.model, self.tokenizer)
    
    def generate_answer(self, context, question, printout=True):
        if not isinstance(context, list):
            context = [context]
        query = [self.tokenizer.encode('<speaker1>' + question)]
        with torch.no_grad():
            out_ids = sample_sequence(context, query, self.tokenizer, self.model, self.args)
        response = self.tokenizer.decode(out_ids, skip_special_tokens=True)
        if printout:
            print("Answer:", response)
        return response

In [11]:
print("Loading Models")
binaryclass = BinaryClassifier()
dialogue = DialogueGeneration()
print("Loading MSMarco Model")
genanswers = AnswerGeneration()

Loading Models


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Loading MSMarco Model


Some weights of the model checkpoint at /ssd003/projects/aieng/conversational_ai/demo/checkpoints/marco_gpt2medium/checkpoint_epoch5_step587435.pth were not used when initializing GPT2LMHeadModel: ['multiple_choice_head.summary.bias', 'multiple_choice_head.summary.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Available Categories

1-Biology: 7829 

2-Chemistry: 1662

3-Technology: 3751

4-Economics: 1644

5-Physics: 2496

6-Mathematics: 431

7-Psychology: 79

## Dataset Stats
Train Size:  12524

Validation Size:  2684

Test Size:  2684

## Sample Question

Q: What affects continental drift?

Q: Following the passing of the Thirteenth Amendment, were there any cases of slave-owners attempting to continue the practice illegally?

## MSMarco Model

In [14]:
while True:
    input_query = input('>>>')
    print("Input:", input_query)
    
    query_type, type_confidence = binaryclass.classify(input_query)
    if query_type == 'general':
        reply = dialogue.generate(input_query)
    else:
        context_size = 3
        contexts = search_and_rank(input_query, context_size, printout=False)
        context = "".join([contexts[i]['passage'] for i in range(context_size)])
        print("Passage:", context)
        print("First Passage Score:", contexts[0]['score'])
        print('--------------------------')
        reply = genanswers.generate_answer(context, input_query)
    print('\n------------------------------')

>>>Hi
Input: Hi
The class is: general with 100.00% confident
DialoGPT Response: Hi! :D

------------------------------
>>>how are you?
Input: how are you?
The class is: general with 100.00% confident
DialoGPT Response: I'm good, how are you?

------------------------------
>>>I have a question
Input: I have a question
The class is: general with 100.00% confident
DialoGPT Response: What is it?

------------------------------
>>>What affects continental drift?
Input: What affects continental drift?
The class is: request with 100.00% confident
Passage: The theory of plate tectonics demonstrates that the continents of the Earth are moving across the surface at the rate of a few centimeters per year. This is expected to continue, causing the plates to relocate and collide. Continental drift is facilitated by two factors: the energy generation within the planet and the presence of a hydrosphere. With the loss of either of these, continental drift will come to a halt. The production of heat t

KeyboardInterrupt: Interrupted by user

## Thanks for your attention.
## Any Questions?