Vocabulary size = 30K

In [None]:
class QuestionGeneration():
    def __init__(self, lm_gq, uqa_qg, regularization):
        self.lm_qg = lm_qg
        self.uqa_qg = uqa_qg
        self.regularization = regularization
        
    def generate_token(self, context: str, ans: str, list_question_tokens: list) -> str:
        '''
        Input:
            - context:
            - ans:
            - list_question_tokens: the history of generated question tokens.
            The form of list_question_tokens is a list of (q_token, class), where class is 0 for lm
            token and 1 for uqa token
        Output:
            - The next generated question token
        '''
        # contains the tokens of the generated question
        list_question_tokens = []
        # contains the classes of each token of the gen. question. Same len as list_question_tokens
        # 0 = LM, 1 = UQA
        list_token_classes = []
        qi = None
        
        while not self.finished_generation(qi):
            # Generate the probs of the ith query token using the lm and uqa models
            lm_qi_probs = self.lm_qg.forward(context, ans, list_question_tokens)
            uqa_qi_probs = self.uqa_qg.forward(context, ans, list_question_tokens)

            # Get the final token  prob distribution using regulatization
            qi_probs, qi_class = self.regularization(lm_qi_probs, uqa_qi_probs, 
                                                     list_question_tokens, list_token_classes)

            # Get the predicted token
            qi = torch.argmax(qi_probs, dim=1)
            list_question_tokens.append(vocab_idx2str[qi])
            list_token_classes.append(qi_class)
            
            
    def finished_generation(self, question_token):
        return question_token == '?' or question_token =='EOS'