In [None]:
!pip install transformers
!python -m nltk.downloader punkt
!pip install sentencepiece

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 15.9MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 55.2MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 48.3MB/s 
Installing collected packages: tokenizers, sacremoses, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.1
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk

In [None]:
import itertools

from nltk import sent_tokenize

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

PIPELINE_SETTINGS = {
    #"model": "valhalla/t5-base-qg-hl",
    "model": "mrm8488/t5-base-finetuned-question-generation-ap",
    "ans_model": ["valhalla/t5-base-qa-qg-hl"]
}

class QGPipeline:

    def __init__(self, pipeline_settings: dict = PIPELINE_SETTINGS, use_cuda: bool = True) :

        self.model = AutoModelForSeq2SeqLM.from_pretrained(pipeline_settings['model'])
        self.tokenizer = AutoTokenizer.from_pretrained(pipeline_settings['model'], use_fast=False)

        self.ans_model = []
        self.ans_tokenizer = []
        for i in range(len(pipeline_settings['ans_model'])) :
            self.ans_model.append(AutoModelForSeq2SeqLM.from_pretrained(pipeline_settings['ans_model'][i]))
            self.ans_tokenizer.append(AutoTokenizer.from_pretrained(pipeline_settings['ans_model'][i], use_fast=False))

        self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
        self.model.to(self.device)

        for i in range(len(self.ans_model)) :
            if self.ans_model[i] is not self.model:
                self.ans_model[i].to(self.device)

    def __call__(self, text : str):
        input_text = " ".join(text.split())
        answers = self._extract_answers(input_text)
        # print(sents)
        # print(answers)
        if len(answers) == 0:
          return []

        questions = self._generate_questions(answers, input_text)
        question_answers_list = []
        for question, answer in zip(questions, answers) :
            question_answers_list.append({'question': question, 'answer': answer})
        return question_answers_list

    def _extract_answers(self, context):
        inputs = self._prepare_inputs_for_ans_extraction(context)
        inputs = self._tokenize(inputs, padding=True, truncation=True)

        answers = []
        for i in range(len(self.ans_model)) :
            outs = self.ans_model[i].generate(
                input_ids=inputs['input_ids'].to(self.device), 
                attention_mask=inputs['attention_mask'].to(self.device), 
                max_length=32,
            )
            
            dec = [self.ans_tokenizer[i].decode(ids, skip_special_tokens=False) for ids in outs]
            print("Dec : ", dec)
            decoded_output = [item.split('<sep>') for item in dec]
            print("Decoded output 1 : ", decoded_output)
            decoded_output = [i[0] for i in decoded_output]
            print("Decoded output 2 : ", decoded_output)
            answers.extend(decoded_output)
        
        for i in range(len(answers)) :
            answers[i] = answers[i].replace("<pad> ", "")

        answers = list(set(answers))
        return answers
 
    def _prepare_inputs_for_ans_extraction(self, text):
        sents = sent_tokenize(text)

        inputs = []
        for i in range(len(sents)):
            source_text = "extract answers:"
            for j, sent in enumerate(sents):
                if i == j:
                    sent = "<hl> %s <hl>" % sent
                source_text = "%s %s" % (source_text, sent)
                source_text = source_text.strip()
            
            source_text = source_text + " </s>"
            inputs.append(source_text)

        return inputs
  
    def _tokenize(self, inputs, padding=True, truncation=True, add_special_tokens=True, max_length=512):
        inputs = self.ans_tokenizer[0].batch_encode_plus(
            inputs, 
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding="max_length" if padding else False,
            pad_to_max_length=padding,
            return_tensors="pt"
        )
        return inputs
    
    def _generate_questions(self, answers, context):
        questions = []
        for answer in answers :
            input_text = "answer: %s  context: %s </s>" % (answer, context)
            inputs = self._tokenize([input_text], padding=True, truncation=True)
        
            outs = self.model.generate(
                input_ids=inputs['input_ids'].to(self.device), 
                attention_mask=inputs['attention_mask'].to(self.device), 
                max_length=64,
                num_beams=4,
            )
            questions.extend([self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs])

        for i in range(len(questions)) :
            questions[i] = questions[i].replace("question: ", "")

        return questions

In [None]:
text2 = " (1) The Internal Committee or, as the case may be, the Local Committee, may, \
before initiating an inquiry under section 11 and at the request of the aggrieved woman take steps to settle \
the matter between her and the respondent through conciliation: \
Provided that no monetary settlement shall be made as a basis of conciliation. \
(2) Where settlement has been arrived at under sub-section (1), the Internal Committee or the Local \
Committee, as the case may be, shall record the settlement so arrived and forward the same to the \
employer or the District Officer to take action as specified in the recommendation. \
(3) The Internal Committee or the Local Committee, as the case may be, shall provide the copies of \
the settlement as recorded under sub-section (2) to the aggrieved woman and the respondent. \
(4) Where a settlement is arrived at under sub-section (1), no further inquiry shall be conducted by the \
Internal Committee or the Local Committee, as the case may be. "

qg = QGPipeline()
qg(text2)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1230.0, style=ProgressStyle(description…

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.





HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1187795641.0, style=ProgressStyle(descr…




The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1786.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=629.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891612585.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=31.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=65.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=90.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=656.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242013376.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=31.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=65.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=90.0, style=ProgressStyle(description_w…




  f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."


Dec :  ['<pad> monetary settlement <sep> </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>', '<pad> record the settlement so arrived and forward the same to the employer or the District Officer <sep> </s>', '<pad> The Internal Committee or the Local Committee <sep> </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>', '<pad> no further inquiry <sep> </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>']
Decoded output 1 :  [['<pad> monetary settlement ', ' </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'], ['<pad> record the settlement so arrived and forward the same to the employer or the District Officer ', ' </s>'], ['<pad> The Internal Committee or the Local Committee ', ' </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'], ['<pad> no further inquiry ', ' </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>']]
Decoded output 2 :  ['<pad> m

[{'answer': 'the employer or the District Officer ',
  'question': 'Who shall the settlement be sent to?'},
 {'answer': 'monetary settlement ',
  'question': 'What shall not be made as a basis of conciliation?'},
 {'answer': 'record the settlement so arrived and forward the same to the employer or the District Officer ',
  'question': 'Where a settlement has been reached under sub-section (1), the Internal Committee or the Local Committee shall do what?'},
 {'answer': 'no further inquiry ',
  'question': 'What happens if a settlement is reached under sub-section (1)?'},
 {'answer': 'The Internal Committee or the Local Committee ',
  'question': 'Who may take steps to settle the matter between the aggrieved woman and the respondent?'}]