Vocabulary size = 30K

In [1]:
from BERTQG.token_generation import load_model, generate_token
from Regularization_module import Regularizer_Discriminator

class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, device = load_model(bert_model, uqa_qg)
        self.regularization = Regularizer_Discriminator(regularization)
        
    def generate_question(self, context: str, ans: str, ans_start: int, list_question_tokens: list) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for lm
            token and 1 for uqa token
        Output:
            - The next generated question token
        '''
        # contains the tokens of the generated question
        list_question_tokens = []
        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        # 0 = LM, 1 = UQA
        list_token_classes = []
        qi = None
        
        # generation finished when [SEP] is created
        while not self.finished_generation(qi):
            question_text = " ".jon(list_question_tokens)
            
            # Generate the toknes and probs of the ith query token using the lm and uqa models
            lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(lm_qg, tokenizer, device, context, question_text, ans, ans_start)
            uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(uqa_qg, tokenizer, device, context, question_text, ans, ans_start)
            
            # Get the final token  prob distribution using regulatization
            jsonobj = self.__convert2squad(context, ans, ans_start, question_text)
            qi_probs = self.regularization.predict_prob(jsonobj) #qi_probs: [batchsize,2] ndarray
            qi_class = np.argmin(qi_probs, axis=1) #qi_class: [batchsize,] ndarray 0:should generate UQA, 1:should generate LM
            # Get the predicted token
            if qi_class == 1: # LM
                qi = lm_qi_token
            else: # UQA
                qi = uqa_qi_token
                
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)

        return " ".join(list_question_tokens)
            
    def finished_generation(self, question_token):
        return question_token =='[SEP]'
    
    def __convert2squad(self, context: str, answer: str, ans_start: int,  question: str) -> dict:
        '''
        Create a SQuAD instance
        Inputs:
            - context: paragrah
            - answer
            - ans_start
            - question: might not be the full question (we are generating questions token by token)
        Returns:
            - squad instance
        '''
        squad = {'data': [], 'version': '1.0'}
        squad['data'].append({'paragraphs': [{'title': 'title', 
                                              'context': context,
                                              'qas': [{'answers': [{'answer_start': ans_start, 'text': answer}],
                                                       'question': question,
                                                       'id': 0}]}]})
        return squad

In [2]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_50k_QG/pytorch_model_60000.bin'
regul_model = "./discri_model.bin"# model path here
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg, regul_model)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=361.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [1]:
import json
with open('../datasets/squad_NER_ans_full_wo_questions_evenly_distributed.json', 'r') as f:
    squad = json.load(f)

# Testing

## Vocabulary

In [9]:
import urllib.request
import numpy as np
word_url = "http://svnweb.freebsd.org/csrg/share/dict/words?view=co&content-type=text/plain"
response = urllib.request.urlopen(word_url)
long_txt = response.read().decode()
VOCAB_IDX2_STR = np.array(long_txt.splitlines()).reshape(-1,1)

In [10]:
VOCAB_IDX2_STR.shape

(25487, 1)

In [23]:
import torch
import torch.nn as nn

In [24]:
softmax = nn.Softmax(dim=0)

In [29]:
qi_probs = softmax(torch.empty(VOCAB_IDX2_STR.shape[0]).normal_(mean=4,std=0.5))

In [30]:
qi = torch.argmax(qi_probs, dim=0)

In [33]:
VOCAB_IDX2_STR[qi].item()

'Valparaiso'