In [None]:
from transformers import AutoModelWithLMHead, AutoTokenizer
from torch import tensor,argmax
from transformers import BertTokenizer
from transformers import BertForQuestionAnswering

In [None]:
tokenizer_q = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
model_q = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")
model_a = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
tokenizer_a = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [None]:
def get_questions(context, max_length=64):
    qns=[]
    sentences=context.split('.')
    for sentence in sentences[:-1]:
        input_text = "answer: %s  context: %s </s>" % ('', sentence)
        features = tokenizer_q([input_text], return_tensors='pt')

        output = model_q.generate(input_ids=features['input_ids'], 
                   attention_mask=features['attention_mask'],
                   max_length=max_length)
        qns.append(tokenizer_q.decode(output[0]).replace('<pad> question: ','').replace('</s>',''))
    return qns

In [None]:
def answer_question(question, context):
    input_ids = tokenizer_a.encode(question, context)
    sep_index = input_ids.index(tokenizer_a.sep_token_id)
    num_seg_a = sep_index + 1
    num_seg_b = len(input_ids) - num_seg_a
    segment_ids = [0]*num_seg_a + [1]*num_seg_b
    assert len(segment_ids) == len(input_ids)
    outputs = model_a(tensor([input_ids]),
                    token_type_ids=tensor([segment_ids]),
                    return_dict=True) 
    start_scores = outputs.start_logits
    end_scores = outputs.end_logits
    answer_start = argmax(start_scores)
    answer_end = argmax(end_scores)
    tokens = tokenizer_a.convert_ids_to_tokens(input_ids)
    answer = tokens[answer_start]
    for i in range(answer_start + 1, answer_end + 1):
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
        else:
            answer += ' ' + tokens[i]
    return answer

# Testing

## Generate Questions

In [None]:
context_q = input("Enter the Context to Generate Questions : \n\n\t")
print("\nGenerated Questions ",end='\n\n')
question_q=list(set(get_questions(context_q)))
for i,qn in enumerate(question_q,start=1):
    print(str(i)+') '+qn)

## Answering the Question

In [None]:
context_a = input("Enter the Context to Answer the Question : \n\n")
question_a = input("\nEnter the Question to Answer from the Context : \n\n")
print("\nGenerated Answer",end='\n\n')
answer_a=answer_question(question_a,context_a)
print(' '+answer_a)

## Generate Questions and Answers

In [None]:
context_q_a = input("Enter the Context to Generate Questions and Answers : \n\n\t")
print("\nGenerated Question and Answers",end='\n\n')
question_q_a=list(set(get_questions(context_q_a)))
answer_q_a=[]
for qn in question_q_a:
    answer_q_a.append(answer_question(qn,context_q_a))
for i in range(len(question_q_a)):
    print('Qn.'+str(i+1)+'  '+question_q_a[i],sep='\n')
    print('Ans : '+answer_q_a[i],sep='\n',end='\n\n')