In [1]:
!pip install transformers
!python -m nltk.downloader punkt
!pip install sentencepiece

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 21.1MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/08/cd/342e584ee544d044fb573ae697404ce22ede086c9e87ce5960772084cad0/sacremoses-0.0.44.tar.gz (862kB)
[K     |████████████████████████████████| 870kB 52.7MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 49.2MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.44-cp37-none-any.whl size=886084 sha256=1b5d

In [8]:
import itertools

from nltk import sent_tokenize

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

PIPELINE_SETTINGS = {
    #"model": "valhalla/t5-base-qg-hl",
    "model": "mrm8488/t5-base-finetuned-question-generation-ap",
    "ans_model": ["valhalla/t5-base-qa-qg-hl", "valhalla/t5-small-qa-qg-hl"]
}

class QGPipeline:

    def __init__(self, pipeline_settings: dict = PIPELINE_SETTINGS, use_cuda: bool = True) :

        self.model = AutoModelForSeq2SeqLM.from_pretrained(pipeline_settings['model'])
        self.tokenizer = AutoTokenizer.from_pretrained(pipeline_settings['model'], use_fast=False)

        self.ans_model = []
        self.ans_tokenizer = []
        for i in range(len(pipeline_settings['ans_model'])) :
            self.ans_model.append(AutoModelForSeq2SeqLM.from_pretrained(pipeline_settings['ans_model'][i]))
            self.ans_tokenizer.append(AutoTokenizer.from_pretrained(pipeline_settings['ans_model'][i], use_fast=False))

        self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
        self.model.to(self.device)

        for i in range(len(self.ans_model)) :
            if self.ans_model[i] is not self.model:
                self.ans_model[i].to(self.device)

    def __call__(self, text : str):
        input_text = " ".join(text.split())
        answers = self._extract_answers(input_text)
        # print(sents)
        # print(answers)
        if len(answers) == 0:
          return []

        questions = self._generate_questions(answers, input_text)
        question_answers_list = []
        for question, answer in zip(questions, answers) :
            question_answers_list.append({'question': question, 'answer': answer})
        return question_answers_list

    def _extract_answers(self, context):
        inputs = self._prepare_inputs_for_ans_extraction(context)
        inputs = self._tokenize(inputs, padding=True, truncation=True)

        answers = []
        for i in range(len(self.ans_model)) :
            outs = self.ans_model[i].generate(
                input_ids=inputs['input_ids'].to(self.device), 
                attention_mask=inputs['attention_mask'].to(self.device), 
                max_length=32,
            )
            
            dec = [self.ans_tokenizer[i].decode(ids, skip_special_tokens=False) for ids in outs]
            decoded_output = [item.split('<sep>') for item in dec]
            decoded_output = [i[0] for i in decoded_output]
            answers.extend(decoded_output)
        
        for i in range(len(answers)) :
            answers[i] = answers[i].replace("<pad> ", "")

        answers = list(set(answers))
        return answers
 
    def _prepare_inputs_for_ans_extraction(self, text):
        sents = sent_tokenize(text)

        inputs = []
        for i in range(len(sents)):
            source_text = "extract answers:"
            for j, sent in enumerate(sents):
                if i == j:
                    sent = "<hl> %s <hl>" % sent
                source_text = "%s %s" % (source_text, sent)
                source_text = source_text.strip()
            
            source_text = source_text + " </s>"
            inputs.append(source_text)

        return inputs
  
    def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512):
        inputs = self.ans_tokenizer[0].batch_encode_plus(
            inputs, 
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding="max_length" if padding else False,
            pad_to_max_length=padding,
            return_tensors="pt"
        )
        return inputs
    
    def _generate_questions(self, answers, context):
        questions = []
        for answer in answers :
            input_text = "answer: %s  context: %s </s>" % (answer, context)
            inputs = self._tokenize([input_text], padding=True, truncation=True)
        
            outs = self.model.generate(
                input_ids=inputs['input_ids'].to(self.device), 
                attention_mask=inputs['attention_mask'].to(self.device), 
                max_length=64,
                num_beams=4,
            )
            questions.extend([self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs])

        for i in range(len(questions)) :
            questions[i] = questions[i].replace("question: ", "")

        return questions

In [9]:
text2 = " (1) The Internal Committee or, as the case may be, the Local Committee, may, \
before initiating an inquiry under section 11 and at the request of the aggrieved woman take steps to settle \
the matter between her and the respondent through conciliation: \
Provided that no monetary settlement shall be made as a basis of conciliation. \
(2) Where settlement has been arrived at under sub-section (1), the Internal Committee or the Local \
Committee, as the case may be, shall record the settlement so arrived and forward the same to the \
employer or the District Officer to take action as specified in the recommendation. \
(3) The Internal Committee or the Local Committee, as the case may be, shall provide the copies of \
the settlement as recorded under sub-section (2) to the aggrieved woman and the respondent. \
(4) Where a settlement is arrived at under sub-section (1), no further inquiry shall be conducted by the \
Internal Committee or the Local Committee, as the case may be. "

qg = QGPipeline()
qg(text2)

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
  f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."


[{'answer': 'monetary settlement ',
  'question': 'What shall not be made as a basis of conciliation?'},
 {'answer': 'no further inquiry ',
  'question': 'What happens if a settlement is reached under sub-section (1)?'},
 {'answer': 'record the settlement so arrived and forward the same to the employer or the District Officer ',
  'question': 'Where a settlement has been reached under sub-section (1), the Internal Committee or the Local Committee shall do what?'},
 {'answer': 'the employer or the District Officer ',
  'question': 'Who shall the settlement be sent to?'},
 {'answer': 'The Internal Committee or the Local Committee ',
  'question': 'Who may take steps to settle the matter between the aggrieved woman and the respondent?'}]