In [1]:
from IPython.display import HTML

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Toggle code"></form>''')

In [2]:
import torch
from simplet5 import SimpleT5
from transformers import T5Config

import numpy as np
from numpy.linalg import norm
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

import random
from fasttext.FastText import _FastText as fasttext
import ipywidgets as widgets
from IPython.display import display, clear_output
import fitz
from bisect import insort

import math
import json

fasttext_path = "fastText_full_NoStop.bin"
base_model_path = "outputs/best"

seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

fastmodel = fasttext(fasttext_path)

TEXT_COLOUR = "Black"
doc_len = 0
def get_colour(text):
    return r'\(\color{'+TEXT_COLOUR+'} {' + text  + '}\)'

Global seed set to 42


In [9]:
pip freeze > requirements.txt

Note: you may need to restart the kernel to use updated packages.




In [3]:
def load_pdf(file):
    file_content = list(file)[0]
    content = file_content["content"].tobytes()
    return fitz.open(stream=content, filetype="pdf")

In [4]:
uploader = widgets.FileUpload(multiple=False, accept='.pdf')
def on_upload(change):
    #ch = change.value.values()
    # to obtain information regarding the pdf when the pdf is uploaded, buttonless design, will impact perf
    global doc_len
    doc_len = len(load_pdf(uploader.value))-1
    try:
        uploader.value[1].clear()
    except:
        pass
    clear_output()
    display(uploader)
    start_GUI()

uploader.observe(on_upload, 'value')
display(uploader)

FileUpload(value=(), accept='.pdf', description='Upload')

### READ ME
1) Run All with cell->Run All <br>
2) To begin, upload a academic paper in the form of .pdf <br>
3) If you want to change fasttext or generation model, edit path by toggling code on
<br>
<br>
__Abstract Generation__
- Model - Default is T5 Transformer
- Maxlength - maximum length of generated abstract
- Beam Number - larger the beam size, more width for the search network (resource intensive)
- TopK - how many top tokens to retain, high may offer better performance
- Repetition Penalty - determines how high the penalty is for repeat tokens
- Length Penalty - determines the penalty for increasing abstract length
- Temperature - determines how random generation is, default = 1
- Page Range - page numbers of input pdf to select, default = 0 - max page

(Only when Beam Number is 0, would temperature have an effect)
<br>
<br>

__Relevance Query__
- Query - the term to search for relevance
- Page Range - page numbers of input pdf to select, default = 0 - max page
- Window size - the context window to compare
- K - how many top similarity score is retained
- Alpha - the weight for top similarity score vs the average top similarity score default = 0.5
- 1 = Top similarity higher weight, 0 = average top similarity higher weight


In [5]:
def start_GUI():
    global radio
    radio = widgets.RadioButtons(
        options=['Base'],
        description=get_colour("Model:"),
        disabled=False
    )
    # Max Length
    global length
    length = widgets.IntText(
        value=512,
        description=get_colour("Max length"),
        disabled=False
    )
    # Beam number
    global beam
    beam = widgets.IntText(
        value=5,
        min=2,
        description=get_colour("Beam Number"),
        disabled=False
    )
    # Tok K
    global topk
    topk = widgets.IntText(
        value=50,
        description=get_colour("TopK"),
        disabled=False
    )
    # Repetition penalty
    global repetition
    repetition = widgets.FloatText(
        value=2.5,
        description=get_colour("Repetition Pen"),
        disabled=False
    )
    # Length penalty
    global length_pen
    length_pen = widgets.FloatText(
        value=1.0,
        description=get_colour("Length Pen"),
        disabled=False
    )
    # Temperature
    global temperature
    temperature = widgets.BoundedFloatText(
        value=1.0,
        min=1.0,
        step=0.1,
        description=get_colour("Temperature"),
        disabled=False
    )
    # PDF start and end
    global pdf_range
    pdf_range = widgets.IntRangeSlider(
        value=[0, doc_len],
        min=0,
        max=doc_len,
        step=1,
        description=get_colour("Page Range"),
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    button = widgets.Button(
        description='Generate',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Generate',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )
    #button.style.button_color="green"
    out = widgets.Output(layout={'border': '1px solid black'})

    def on_button_clicked(b):
        with out:
            out.clear_output()
            abstract = abstract_generate()
            print("abstract: ", abstract)
    button.on_click(on_button_clicked)
    
    # bug - can't fetch text value
    global text_box
    text_box = widgets.Text(
        value='',
        placeholder='Type something',
        description=get_colour("Query"),
        disabled=False,
        continous_update=True
    )
    
    global pdf_range_q
    pdf_range_q = widgets.IntRangeSlider(
        value=[0, doc_len],
        min=0,
        max=doc_len,
        step=1,
        description=get_colour("Page Range"),
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d',
    )
    # window
    global window
    window = widgets.IntText(
        value=10,
        description=get_colour("Window"),
        disabled=False
    )
    # k
    global k
    k = widgets.IntText(
        value=5,
        description=get_colour("K"),
        disabled=False
    )
    # alpha
    global alpha
    alpha = widgets.BoundedFloatText(
        value=0.5,
        min=0.0,
        max=1.0,
        step=0.05,
        description=get_colour("Alpha"),
        disabled=False
    )
    button_q = widgets.Button(
        description='Search',
        disabled=False,
        button_style='info', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Search',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )
    #button.style.button_color="green"
    out_q = widgets.Output(layout={'border': '1px solid black'})

    def on_button_clicked(b):
        with out_q:
            out_q.clear_output()
            b, t, q, m = relevance_query()
            print("Best score: ",b)
            print("Top scores: ", t)
            print("Best result: ",q)
            print("Mixed score: ", m)
            #print(j," Button clicked.")
    button_q.on_click(on_button_clicked)

    widget = [radio, length, beam, topk, repetition, length_pen, temperature, pdf_range, button]
    grid = widgets.GridBox(widget, layout=widgets.Layout(grid_template_columns="repeat(2, 300px)"))
    widget_q = [text_box, pdf_range_q, window, k, alpha, button_q]
    grid_q = widgets.GridBox(widget_q, layout=widgets.Layout(grid_template_columns="repeat(1, 100px)"))
    grids = [grid, grid_q]
    tabs = widgets.Tab()
    tabs.children=grids
    tabs_titles = ["Abstract Generation", "Relevance Query"]
    for i, title in enumerate(tabs_titles):
        tabs.set_title(i, title)
    display(tabs, out, out_q)

In [6]:

def contain_let(string):
    return any(char.isalpha() for char in string)
def contain_num(string):
    return any(char.isdigit() for char in string)
def contain_special(string, allowed):
    '''
    allowed is a list containing allowed symbols to pass detection
    '''
    return any(not(char.isalpha() or char.isdigit()) and (char not in allowed) for char in string)
def cleanLine(line, text=True):
    "Text parameter is to indicate whether the line is from text or abstract"

    symbols = ["'", "’"]
    stop_words = list(ENGLISH_STOP_WORDS)

    clean_line = line.lower()
    clean_line = clean_line.strip()

    # fix apostrophes in line by removing apostrophe with no following alphabet character
    clean_line = clean_line.replace("' ", " ")
    if clean_line and clean_line[-1] == "'":
        clean_line = clean_line[0:len(clean_line)-1]
    # fix apostrophes in line by removing space before single quote
    clean_line = clean_line.replace(" '", "'")
    #remove punctuation


    # clean line = clean line remove forms
    words = clean_line.split()

    #  remove forms
    words = [x.replace(x, "") if (contain_let(x) and contain_num(x))
             or contain_special(x, symbols)
             else x for x in words]
    # remove empty strings
    words = filter(None, words)

    # stop words from sklearn, remove stop words
    if text:
        words = [x for x in words if not x in stop_words]

    # combine the items into 1 string
    clean_line = ' '.join(words)
    return clean_line

def readPDF(file, start_page=0, end_page=-1):
    '''
    Will accept input of widget pdf file
    start_page indicates starting page to start reading default=first page
    end_page indicate last page to read default=last page
    '''
    doc = load_pdf(file)
    if end_page == -1:
        end_page = len(doc)-1

    # extracting text from page
    doc_text = ""
    for page in doc.pages(start_page, end_page):

        text = page.get_text("text")
        text = text.split('\n')
        text = " ".join(text)
        doc_text += text 
    doc_text = doc_text.strip()
    doc.close()
    return doc_text

def Process(raw_data, stop_word=True):
    return cleanLine(raw_data, stop_word)

def GetModel(path):
    if path=="Base":
        path = base_model_path
    model = SimpleT5()
    model.load_model(model_type="t5", model_dir=path, use_gpu=True)
    return model

def Infer(model, data, max_size, temperature, beams, top_k, repetition_penalty, length_penalty):
    model_config = model.model.config.to_dict()
    # switch to beam search
    if beams != 0:
        model_config["do_sample"] = False
    else:
        model_config["do_sample"] = True
    model_config["temperature"] = temperature #0-1, >0 for more random
    config = T5Config(**model_config)
    model.model.config = config


    output = model.predict(data, max_length=max_size, num_beams=beams, top_k=top_k, 
                            repetition_penalty=repetition_penalty,length_penalty=length_penalty)

    return output[0]

def abstract_generate():
    size = length.value
    temp = temperature.value
    beams = beam.value
    top_k = topk.value
    repetition_penalty = repetition.value
    length_penalty = length_pen.value
    start = pdf_range.value[0]
    end = pdf_range.value[1]
    if end == 0:
        end = -1
    text = readPDF(uploader.value, start_page=start, end_page=end)
    clean_text = Process(text)
    model = GetModel(radio.value)
    abstract = Infer(model, clean_text, size, temp, beams, top_k, repetition_penalty, length_penalty)
    return abstract

def get_cosine(query, key):
    return np.dot(query,key)/(norm(query)*norm(key))

# calculate mixed score with alpha
def get_mix_score(best, topk, alpha=0.5):
    return (alpha * best) + ((1-alpha) * np.average(topk))

# loop through pdf with context window
# update highest score and update topk score
def get_similarity(query, data, window=None, k=5):
    '''
    query: the terms to be searched
    data: the data to be compared with the query
    window: the size for each similarity comparison, default = query length
    k: relates to how many good hits you want, the higher the k, requires more good similarity scores
       default = 5
    '''
    query_vector = fastmodel.get_sentence_vector(query)
    #data = "machine and learning asdilasdasdasd1234567890abcdefghi"
    if window == None:
        window = len(query)

    if window > len(data):
        window = len(data)

    best = 0
    best_quote = ""
    topk = [0] * k

    for head in range(0, len(data), window):
        tail = (head+window)

        chunk = data[head:tail]
        chunk_vector = fastmodel.get_sentence_vector(chunk)

        temp_similarity = get_cosine(query_vector, chunk_vector)

        if temp_similarity > best:
            best = temp_similarity
            best_quote = chunk

        if temp_similarity > topk[0]:
            # keep the list size to k elements
            # insort method allows for more efficient insertion and sorting
            insort(topk, temp_similarity)
            topk = topk[-k:]
    topk.reverse()
    return best, topk, best_quote

def relevance_query():
    query = text_box.value
    start = pdf_range_q.value[0]
    end = pdf_range_q.value[1]
    if end == 0:
        end = -1
    win_size = window.value
    k_num = k.value
    alph = alpha.value
    raw_data = readPDF(uploader.value, start_page=start, end_page=end) # Obtain 1 long string
    data = Process(raw_data)

    best, top, quote = get_similarity(query=query, data=data, window=win_size, k=k_num)
    mixed_score = get_mix_score(best, top, alph)
    return best, top, quote, mixed_score