In [1]:
# Install dependencies

# !pip install PySimpleGUI
# !pip install transformers
# !pip install PyPDF2

In [2]:
#Import required packages
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import torch
import numpy as np

# create bert model for question answering
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
# define tokenizer for bert
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

In [3]:
def bert_qa(question, context, max_len=500):

    #Tokenize input question and passage 
    #Add special tokens - [CLS] and [SEP]
    input_ids = tokenizer.encode (question, context,  max_length= max_len, truncation=True)  


    #Getting number of tokens in question and context passage that contains the answer
    sep_index = input_ids.index(102) 
    len_question = sep_index + 1   
    len_context = len(input_ids)- len_question  

    
    #Separate question and context 
    #Segment ids will be 0 for question and 1 for context
    segment_ids =  [0]*len_question + [1]*(len_context)  

    #Converting token ids to tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids) 


    #Getting start and end scores for answer
    #Converting input arrays to torch tensors before passing to the model
    start_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[0]
    end_token_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([segment_ids]) )[1]


    #Converting scores tensors to numpy arrays
    start_token_scores = start_token_scores.detach().numpy().flatten()
    end_token_scores = end_token_scores.detach().numpy().flatten()

    #Getting start and end index of answer based on highest scores
    answer_start_index = np.argmax(start_token_scores)
    answer_end_index = np.argmax(end_token_scores)


    #Getting scores for start and end token of the answer
    start_token_score = np.round(start_token_scores[answer_start_index], 2)
    end_token_score = np.round(end_token_scores[answer_end_index], 2)


    #Combining subwords starting with ## and get full words in output. 
    #It is because tokenizer breaks words which are not in its vocab.
    answer = tokens[answer_start_index] 
    for i in range(answer_start_index + 1, answer_end_index + 1):
        if tokens[i][0:2] == '##':  
            answer += tokens[i][2:] 
        else:
            answer += ' ' + tokens[i]  

    # If the answer not in the passage
    if ( answer_start_index == 0) or (start_token_score < 0 ) or  (answer == '[SEP]') or ( answer_end_index <  answer_start_index):
        answer = "Sorry, Couldn't find answer in given pdf. Please try again!"
    
    return (answer_start_index, answer_end_index, start_token_score, end_token_score,  answer)


In [4]:
# import required libraries
import PySimpleGUI as sg # For GUI
from PyPDF2 import PdfReader # TO read pdf

# Define the PySimpleGUI layout
layout = [
    [sg.Text('Select a PDF file:')],
    [sg.Input(key='file'), sg.FileBrowse(file_types=(("PDF files", "*.pdf"),))],
    [sg.Text('Enter question:')],
    [sg.InputText(key='question')],
    [sg.Button('Submit')],
    [sg.Text('Answer:'), ],
    [sg.Output(size=(60, 10))]
]

# Create the PySimpleGUI window
window = sg.Window('Question-Answering System using BERT', layout)

# Event loop to process events and get input values
while True:
    event, values = window.read()

    # Exit if the window is closed
    if event == sg.WIN_CLOSED:
        break

    # Read the PDF file and extract the text
    if event == 'Submit':
        file_path = values['file']
        pdf_file = open(file_path, 'rb')
        pdf_reader = PdfReader(pdf_file)
        
        question = values['question']
        text = ""

        for i in pdf_reader.pages:
            page = i.extract_text()
            text = text+page
    
        _, _ , _ , _, ans = bert_qa(question, text)
        print(ans)
        
        #print(text)
        
        # Close the PDF file
        pdf_file.close()

# Close the window
window.close()
