Vocabulary size = 30K

In [29]:
import numpy as np
class Freq_Regularization():
    def predict_class(self, list_token_classes):
        '''
        UQA class = 0
        LM class = 1
        '''
        if len(list_token_classes) == 0:
            return np.random.choice(2, 1, p=[0.5, 0.5])[0]

        prob_uqa = sum(list_token_classes)/len(list_token_classes)
        prob_lm = 1 - prob_uqa
        return np.random.choice(2, 1, p=[prob_uqa, prob_lm])[0]


In [30]:
from BERTQG.token_generation import load_model, generate_token
#from Regularization_module import Regularizer_Discriminator

class QuestionGeneration():
    def __init__(self, bert_model, lm_gq, uqa_qg, regularization=None):
        self.lm_qg, _, _ = load_model(bert_model, lm_qg)
        self.uqa_qg, self.tokenizer, self.device = load_model(bert_model, uqa_qg)
        self.regularization = Freq_Regularization()
        
    def generate_question(self, context: str, ans: str, ans_start: str, list_question_tokens: list = []) -> str:
        '''
        Input:
            - context
            - ans
            - ans_start
            - list_question_tokens: the history of generated question tokens.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for uqa
            token and 1 for lm token
        Output:
            - The next generated question token
        '''
        # contains the tokens of the generated question
        list_question_tokens = []
        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        # 0 = UQA, 1 = LM
        list_token_classes = []
        qi = None
        
        # generation finished when [SEP] is created
        while not self.finished_generation(qi):
            question_text = " ".join(list_question_tokens)
            
            # Generate the toknes and probs of the ith query token using the lm and uqa models
            lm_qi_token, lm_qi_idx, lm_qi_probs = generate_token(self.lm_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            uqa_qi_token, uqa_qi_idx, uqa_qi_probs = generate_token(self.uqa_qg, self.tokenizer, self.device, context, question_text, ans, ans_start)
            
            # Get token class to use
            qi_class = self.regularization.predict_class(list_token_classes)
            # Get the predicted token
            if qi_class == 1: # LM
                qi = lm_qi_token
            else: # UQA
                qi = uqa_qi_token
                
            list_question_tokens.append(qi)
            list_token_classes.append(qi_class)
            
            # for testing
            print(list_question_tokens)
            print(["UQA" if cls == 0 else "LM" for cls in list_token_classes])
            print("\n")

        return " ".join(list_question_tokens[:-1]) # without [SEP]
            
    def finished_generation(self, question_token):
        return question_token =='[SEP]'
    
    def __convert2squad(self, context: str, answer: str, ans_start: int,  question: str) -> dict:
        '''
        Create a SQuAD instance
        Inputs:
            - context: paragrah
            - answer
            - ans_start
            - question: might not be the full question (we are generating questions token by token)
        Returns:
            - squad instance
        '''
        squad = {'data': [], 'version': '1.0'}
        squad['data'].append({'paragraphs': [{'title': 'title', 
                                              'context': context,
                                              'qas': [{'answers': [{'answer_start': ans_start, 'text': answer}],
                                                       'question': question,
                                                       'id': 0}]}]})
        return squad

In [31]:
bert_model = 'BERTQG/models/bert-base-uncased'
lm_qg = 'BERTQG/models/lm_10k_QG/pytorch_model.bin'
uqa_qg = 'BERTQG/models/uqa_10k_QG/pytorch_model.bin'
#regul_model = "./discri_model.bin"# model path here
QG = QuestionGeneration(bert_model, lm_qg, uqa_qg)

04/29/2020 00:57:33 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
04/29/2020 00:57:33 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
04/29/2020 00:57:33 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/lm_10k_QG/pytorch_model.bin


04/29/2020 00:57:35 - INFO - BERTQG.tokenization -   loading vocabulary file BERTQG/models/bert-base-uncased/vocab.txt
04/29/2020 00:57:35 - INFO - BERTQG.modeling -   loading archive file BERTQG/models/bert-base-uncased
04/29/2020 00:57:35 - INFO - BERTQG.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 3,
  "vocab_size": 30522
}



Load BERTQG/models/uqa_10k_QG/pytorch_model.bin


In [32]:
import json
with open('/workspace/ml-workspace/emnlp_qg/datasets/squad_train_evenly_distributed.json', 'r') as f:
    squad = json.load(f)

In [33]:
# context = squad['data'][0]['paragraphs'][0]['context']
# ans = squad['data'][0]['paragraphs'][0]['qas'][0]['answers'][0]['text']
# ans_start = squad['data'][0]['paragraphs'][0]['qas'][0]['answers'][0]['answer_start']

In [34]:
context = "The city has a proud history of theatre. Stephen Kemble of the famous Kemble family successfully managed the original Theatre Royal, Newcastle for fifteen years (1791–1806). He brought members of his famous acting family such as Sarah Siddons and John Kemble out of London to Newcastle. Stephen Kemble guided the theatre through many celebrated seasons. The original Theatre Royal in Newcastle was opened on 21 January 1788 and was located on Mosley Street. It was demolished to make way for Grey Street, where its replacement was built."
question_text = "when did the original theatre royal open ?"
ans = "1788"
ans_start = context.find(ans)
assert context[ans_start:ans_start+len(ans)] == ans

In [35]:
QG.generate_question(context, ans, ans_start)

['when']
['UQA']


['when', 'did']
['UQA', 'LM']


['when', 'did', 'the']
['UQA', 'LM', 'LM']


['when', 'did', 'the', 'original']
['UQA', 'LM', 'LM', 'LM']


['when', 'did', 'the', 'original', 'theatre']
['UQA', 'LM', 'LM', 'LM', 'UQA']


['when', 'did', 'the', 'original', 'theatre', 'royal']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA']


['when', 'did', 'the', 'original', 'theatre', 'royal', 'open']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA', 'LM']


['when', 'did', 'the', 'original', 'theatre', 'royal', 'open', 'on']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA', 'LM', 'UQA']


['when', 'did', 'the', 'original', 'theatre', 'royal', 'open', 'on', '21']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA', 'LM', 'UQA', 'LM']


['when', 'did', 'the', 'original', 'theatre', 'royal', 'open', 'on', '21', 'january']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA', 'LM', 'UQA', 'LM', 'LM']


['when', 'did', 'the', 'original', 'theatre', 'royal', 'open', 'on', '21', 'january', '?']
['UQA', 'LM', 'LM', 'LM', 'UQA', 'UQA', 'LM', 'U

'when did the original theatre royal open on 21 january ?'

# Testing

## Vocabulary

In [9]:
import urllib.request
import numpy as np
word_url = "http://svnweb.freebsd.org/csrg/share/dict/words?view=co&content-type=text/plain"
response = urllib.request.urlopen(word_url)
long_txt = response.read().decode()
VOCAB_IDX2_STR = np.array(long_txt.splitlines()).reshape(-1,1)

In [10]:
VOCAB_IDX2_STR.shape

(25487, 1)

In [23]:
import torch
import torch.nn as nn

In [24]:
softmax = nn.Softmax(dim=0)

In [29]:
qi_probs = softmax(torch.empty(VOCAB_IDX2_STR.shape[0]).normal_(mean=4,std=0.5))

In [30]:
qi = torch.argmax(qi_probs, dim=0)

In [33]:
VOCAB_IDX2_STR[qi].item()

'Valparaiso'