In [39]:
from IPython.display import HTML

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Toggle code"></form>''')

### Imports

In [40]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from collections import Counter
import numpy as np
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

import random
import fasttext
from tqdm.notebook import tqdm
import ipywidgets as widgets
import fitz

import math
import json

DATASET = "./Dataset/"

Detect which device (CPU/GPU) to use.

In [41]:
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [42]:
def contain_let(string):
    return any(char.isalpha() for char in string)

In [43]:
def contain_num(string):
    return any(char.isdigit() for char in string)

In [44]:
def cleanLine(line, text=True):
    "Text parameter is to indicate whether the line is from text or abstract"
    alphabet = list(string.ascii_lowercase)
    number = list(map(str, range(10)))
    symbols = ["'", "’"]
    valid_char = alphabet + number + symbols
    stop_words = list(ENGLISH_STOP_WORDS)
    
    clean_line = line.lower()
    clean_line = clean_line.strip()
    
    # fix apostrophes in line by removing apostrophe with no following alphabet character
    clean_line = clean_line.replace("' ", " ")
    if clean_line and clean_line[-1] == "'":
        clean_line = clean_line[0:len(clean_line)-1]
    # fix apostrophes in line by removing space before single quote
    clean_line = clean_line.replace(" '", "'")
    #remove punctuation
    # replace all non alphabet character with space
    difference = list(set(clean_line).symmetric_difference(valid_char))

    for dif in difference:
        clean_line = clean_line.replace(dif, " ")
    
    # clean line = clean line remove forms
    words = clean_line.split()

    #  remove forms
    words = [x.replace(x, "") if contain_let(x) and contain_num(x) else x for x in words]
    # remove empty strings
    words = filter(None, words)

    # stop words from sklearn, remove stop words
    if text:
        words = [x for x in words if not x in stop_words]

    # combine the items into 1 string
    clean_line = ' '.join(words)
    return clean_line

In [45]:
def concatParagraph(paragraph, text):
    clean_paragraph = ""
    for line in paragraph:
        lines = cleanLine(line)
        clean_paragraph += cleanLine(lines, text) + " "
        #print(clean_paragraph)
        
    return(clean_paragraph.strip())

In [46]:
def concatPaper(paper, text):
    clean_paper = ""
    for paragraph in paper:
        clean_paper += concatParagraph(paragraph, text) + " "
    return(clean_paper.strip())

In [47]:
class Vocabulary(object):
    """ Simple vocabulary wrapper which maps every unique word to an integer ID. """
    def __init__(self):
        # intially, set both the IDs and words to dictionaries with special tokens
        self.word2idx = {'<start>': 0, '<end>': 1, '<pad>':2, '<unk>':3}
        self.idx2word = {0: '<start>', 1: '<end>', 2: '<pad>', 3: '<unk>'}
        self.idx = 4

    def add_word(self, word):
        # if the word does not already exist in the dictionary, add it
        if not word in self.word2idx:
            # this will convert each word to index and index to word as you saw in the tutorials
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            # increment the ID for the next word
            self.idx += 1

    def __call__(self, word):
        # if we try to access a word not in the dictionary, return the id for <unk>
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]
    
    ## added function for utility
    def get_word(self,index):
        # this returns the word when given an index
        return self.idx2word[index]

    def __len__(self):
        return len(self.word2idx)


# FastText

## Model for querying

In [49]:
# model = fasttext.train_unsupervised('data/fil9')
# model.save_model("result/fil9.bin")
model = fasttext.load_model("result/fil9.bin")
# EXPECTS TEXT FILE FORMAT WHERE 1 TEXT PER ROW
# IMPORTANT, NEED TO REDESIGN CLEANING PROCESS 

# Model Finalisation and Evaluation

## Inference Script

In [None]:
# Load model
best_model_path = "best_decoder20230719.pth"
# Need model class/ import from other class to reduce cell block clutter
best_model = Model(len(text_vocab),len(abstract_vocab),2,256)
best_model.load_state_dict(torch.load(best_model_path))

## Upload Button

In [None]:
TEXT_COLOUR = "<font color='Black'>"
doc_len = 0

In [None]:
def load_pdf(file):  
    file_content = list(file)[0]
    content = file_content["content"]
    return fitz.open(stream=content, filetype="pdf")

In [None]:
uploader = widgets.FileUpload(multiple=False, accept='.pdf')
def on_upload(change):
    #ch = change.value.values()
    # to obtain information regarding the pdf when the pdf is uploaded, buttonless design, will impact perf
    global doc_len
    doc_len = len(load_pdf(uploader.value.values()))-1
    
    uploader.value.clear()
    uploader._counter=1

uploader.observe(on_upload, 'value')

## Abstract Generation tab

In [None]:
radio = widgets.RadioButtons(
    options=['Base', '512Dim', '1024Dim', 'Base+No Stop word'],
    description=TEXT_COLOUR+'Model:',
    disabled=False
)

In [None]:
# Max Length
length = widgets.IntText(
    value=100,
    description=TEXT_COLOUR+'Max length:',
    disabled=False
)

In [None]:
# Temperature
temperature = widgets.BoundedFloatText(
    value=0.5,
    min=0,
    max=1.0,
    step=0.1,
    description=TEXT_COLOUR+'Temperature:',
    disabled=False
)

In [None]:
# PDF start and end
pdf_range = widgets.IntRangeSlider(
    value=[0, doc_len],
    min=0,
    max=doc_len,
    step=1,

    description=TEXT_COLOUR + "Page Range:",
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

In [None]:
button = widgets.Button(
    description='Click me',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)
#button.style.button_color="green"
out = widgets.Output(layout={'border': '1px solid black'})

def on_button_clicked(b):
    with out:
        out.clear_output()
        #print(i," Button clicked.")

button.on_click(on_button_clicked)

In [None]:
widget = [radio, length, pdf_range, temperature, button]
grid = widgets.GridBox(widget, layout=widgets.Layout(grid_template_columns="repeat(1, 100px)"))

## Relevance Query tab

In [None]:
text_box = widgets.Text(
    value='',
    placeholder='Type something',
    description=TEXT_COLOUR+'Query:',
    disabled=False   
)

In [None]:
pdf_range_q = widgets.IntRangeSlider(
    value=[0, doc_len],
    min=0,
    max=doc_len,
    step=1,

    description=TEXT_COLOUR + "Page Range:",
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

In [None]:
button_q = widgets.Button(
    description='Click me',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)
#button.style.button_color="green"
out_q = widgets.Output(layout={'border': '1px solid black'})

def on_button_clicked(b):
    with out_q:
        out_q.clear_output()
        print("hi")
        #print(j," Button clicked.")

button_q.on_click(on_button_clicked)

In [None]:
widget_q = [text_box, pdf_range_q, button_q]
grid_q = widgets.GridBox(widget_q, layout=widgets.Layout(grid_template_columns="repeat(1, 100px)"))

# -----------------------

# Final combined tab

In [None]:
grids = [grid, grid_q]
tabs = widgets.Tab()
tabs.children=grids
tabs_titles = ["Abstract Generation", "Relevance Query"]
for i, title in enumerate(tabs_titles):
    tabs.set_title(i, title)
display(tabs, out_q)

# -----------------------

In [None]:
def readPDF(file, start_page=0, end_page=-1):
    '''
    Will accept input of widget pdf file
    start_page indicates starting page to start reading default=first page
    end_page indicate last page to read default=last page
    '''
    doc = load_pdf(file)
    if end_page == -1:
        end_page = len(doc)-1
    
    # extracting text from page
    doc_text = ""
    for page in doc.pages(start_page, end_page):
        
        text = page.get_text("text")
        text = text.split('\n')
        text = " ".join(text)
        doc_text += text 
    doc_text = doc_text.strip()
    doc.close()
    return doc_text

In [None]:
def Process(raw_data, stop_word=True):
    data = cleanLine(raw_data, stop_word)
    data_id = tokenise(text_vocab, data)
    return data_id

In [None]:
def Infer(model, data, max_size=100, temperature=0.5):
    decoder_outputs = model(data, max_len=max_size, temperature=temperature)
    decoder_outputs = torch.unbind(decoder_outputs, 1)
    sentence = ""
    for output in decoder_outputs:
        token = output.argmax(1)
        # if word is pad then replace with space
        # if word is end then stop
        if token == 2:
            sentence += " "
        elif token == 1:
            break
        else:
            sentence += abstract_vocab.get_word(token.item()) + " "
    sentence = sentence.strip()
    return sentence

In [None]:
raw_data = readPDF(uploader.value.values(), start_page=0, end_page=-1) # Obtain 1 long string
data = Process(raw_data) # Clean string and convert to word_ids
output = Infer(best_model, data, max_size=100, temperature=0.5)