In [39]:
from IPython.display import HTML

HTML('''<script>
code_show=true; 
function code_toggle() {
 if (code_show){
 $('div.input').hide();
 } else {
 $('div.input').show();
 }
 code_show = !code_show
} 
$( document ).ready(code_toggle);
</script>
<form action="javascript:code_toggle()"><input type="submit" value="Toggle code"></form>''')

### Imports

In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from simplet5 import SimpleT5
from transformers import T5Config


from collections import Counter
import numpy as np
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

import random
import fasttext
from tqdm.notebook import tqdm
import ipywidgets as widgets
import fitz

import math
import json

DATASET = "./Dataset/"

Global seed set to 42


Detect which device (CPU/GPU) to use.

In [2]:
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) 
torch.cuda.manual_seed_all(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
def contain_let(string):
    return any(char.isalpha() for char in string)

In [4]:
def contain_num(string):
    return any(char.isdigit() for char in string)

In [5]:
def contain_special(string, allowed):
    '''
    allowed is a list containing allowed symbols to pass detection
    '''
    return any(not(char.isalpha() or char.isdigit()) and (char not in allowed) for char in string)

In [6]:
def cleanLine(line, text=True):
    "Text parameter is to indicate whether the line is from text or abstract"

    symbols = ["'", "’"]
    stop_words = list(ENGLISH_STOP_WORDS)
    
    clean_line = line.lower()
    clean_line = clean_line.strip()
    
    # fix apostrophes in line by removing apostrophe with no following alphabet character
    clean_line = clean_line.replace("' ", " ")
    if clean_line and clean_line[-1] == "'":
        clean_line = clean_line[0:len(clean_line)-1]
    # fix apostrophes in line by removing space before single quote
    clean_line = clean_line.replace(" '", "'")
    #remove punctuation

    
    # clean line = clean line remove forms
    words = clean_line.split()

    #  remove forms
    words = [x.replace(x, "") if (contain_let(x) and contain_num(x))
             or contain_special(x, symbols)
             else x for x in words]
    # remove empty strings
    words = filter(None, words)

    # stop words from sklearn, remove stop words
    if text:
        words = [x for x in words if not x in stop_words]

    # combine the items into 1 string
    clean_line = ' '.join(words)
    return clean_line

# FastText

## Model for querying

In [36]:
# model = fasttext.train_unsupervised('data/fil9')
# model.save_model("result/fil9.bin")
fastmodel = fasttext.load_model("fastText_full_NoStop.bin")
# EXPECTS TEXT FILE FORMAT WHERE 1 TEXT PER ROW
# IMPORTANT, NEED TO REDESIGN CLEANING PROCESS 



# Model Finalisation and Evaluation

## Inference Script

## Upload Button

In [8]:
TEXT_COLOUR = "Black"
doc_len = 0
def get_colour(text):
    return r'\(\color{'+TEXT_COLOUR+'} {' + text  + '}\)'

In [9]:
def load_pdf(file):
    file_content = list(file)[0]
    content = file_content["content"].tobytes()
    return fitz.open(stream=content, filetype="pdf")

In [10]:
uploader = widgets.FileUpload(multiple=False, accept='.pdf')
def on_upload(change):
    #ch = change.value.values()
    # to obtain information regarding the pdf when the pdf is uploaded, buttonless design, will impact perf
    global doc_len
    print(uploader.value[0].content)
    doc_len = len(load_pdf(uploader.value))-1
    
    uploader.value.clear()
    uploader._counter=1

#uploader.observe(on_upload, 'value')
display(uploader)

FileUpload(value=(), accept='.pdf', description='Upload')

## Abstract Generation tab

In [11]:
radio = widgets.RadioButtons(
    options=['Base'],
    description=get_colour("Model:"),
    disabled=False
)

In [12]:
# Max Length
length = widgets.IntText(
    value=512,
    description=get_colour("Max length"),
    disabled=False
)

In [13]:
# Beam number
beam = widgets.IntText(
    value=5,
    min=2,
    description=get_colour("Beam Number"),
    disabled=False
)

In [14]:
# Tok K
topk = widgets.IntText(
    value=50,
    description=get_colour("TopK"),
    disabled=False
)

In [15]:
# Repetition penalty
repetition = widgets.FloatText(
    value=2.5,
    description=get_colour("Repetition Pen"),
    disabled=False
)

In [16]:
# Length penalty
length_pen = widgets.FloatText(
    value=1.0,
    description=get_colour("Length Pen"),
    disabled=False
)

In [17]:
# Temperature
temperature = widgets.BoundedFloatText(
    value=1,
    min=0.1,
    max=1.0,
    step=0.1,
    description=get_colour("Temperature"),
    disabled=False
)

In [18]:
# PDF start and end
pdf_range = widgets.IntRangeSlider(
    value=[0, doc_len],
    min=0,
    max=doc_len,
    step=1,

    description=get_colour("Page Range"),
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

In [19]:
button = widgets.Button(
    description='Click me',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)
#button.style.button_color="green"
out = widgets.Output(layout={'border': '1px solid black'})

def on_button_clicked(b):
    with out:
        out.clear_output()
        #print(i," Button clicked.")

button.on_click(on_button_clicked)

In [20]:
widget = [radio, length, beam, topk, repetition, length_pen, temperature, pdf_range, button]
grid = widgets.GridBox(widget, layout=widgets.Layout(grid_template_columns="repeat(2, 300px)"))

## Relevance Query tab

In [21]:
text_box = widgets.Text(
    value='',
    placeholder='Type something',
    description=get_colour("Query"),
    disabled=False   
)

In [22]:
pdf_range_q = widgets.IntRangeSlider(
    value=[0, doc_len],
    min=0,
    max=doc_len,
    step=1,

    description=get_colour("Page Range"),
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

In [23]:
button_q = widgets.Button(
    description='Click me',
    disabled=False,
    button_style='info', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)
#button.style.button_color="green"
out_q = widgets.Output(layout={'border': '1px solid black'})

def on_button_clicked(b):
    with out_q:
        out_q.clear_output()
        print("hi")
        #print(j," Button clicked.")

button_q.on_click(on_button_clicked)

In [24]:
widget_q = [text_box, pdf_range_q, button_q]
grid_q = widgets.GridBox(widget_q, layout=widgets.Layout(grid_template_columns="repeat(1, 100px)"))

# -----------------------

# Final combined tab

In [25]:
grids = [grid, grid_q]
tabs = widgets.Tab()
tabs.children=grids
tabs_titles = ["Abstract Generation", "Relevance Query"]
for i, title in enumerate(tabs_titles):
    tabs.set_title(i, title)
display(tabs, out_q)

Tab(children=(GridBox(children=(RadioButtons(description='\\(\\color{Black} {Model:}\\)', options=('Base',), v…

Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_right='1px solid b…

# -----------------------

# Inference

In [26]:
def readPDF(file, start_page=0, end_page=-1):
    '''
    Will accept input of widget pdf file
    start_page indicates starting page to start reading default=first page
    end_page indicate last page to read default=last page
    '''
    doc = load_pdf(file)
    if end_page == -1:
        end_page = len(doc)-1
    
    # extracting text from page
    doc_text = ""
    for page in doc.pages(start_page, end_page):
        
        text = page.get_text("text")
        text = text.split('\n')
        text = " ".join(text)
        doc_text += text 
    doc_text = doc_text.strip()
    doc.close()
    return doc_text

In [27]:
def Process(raw_data, stop_word=True):
    return cleanLine(raw_data, stop_word)

In [28]:
def GetModel(path):
    if path=="Base":
        path = "outputs/best"
    model = SimpleT5()
    model.load_model(model_type="t5", model_dir=path, use_gpu=True)
    return model

In [263]:
text = readPDF(uploader.value)
clean_text = Process(text)
model = GetModel(radio.value)
size = length.value
temp = 0.0#temperature.value
beams = beam.value
top_k = topk.value
repetition_penalty = repetition.value
length_penalty = length_pen.value
abstract = Infer(model, clean_text, size, temp, beams, top_k, repetition_penalty, length_penalty)
abstract

'paper introduces new metric distance tween text approach leverages recent sults mikolov et introduced novel model learns semantically tionships word mover’s distance measures dissimilarity text document weighted point cloud embeddings document distances compute exact provably neighbor search lower bound distances yields state art results wmd outperforms baselines'

In [262]:
def Infer(model, data, max_size, temperature, beams, top_k, repetition_penalty, length_penalty):
    model_config = model.model.config.to_dict()
    # switch to beam search
    if temperature == 0.0:
        model_config["do_sample"] = False
    else:
        model_config["do_sample"] = True
    model_config["temperature"] = temperature #0-1, >0 for more random
    config = T5Config(**model_config)
    model.model.config = config

    
    output = model.predict(data, max_length=max_size, num_beams=beams, top_k=top_k, 
                            repetition_penalty=repetition_penalty,length_penalty=length_penalty)
    
    return output[0]

In [None]:
raw_data = readPDF(uploader.value, start_page=0, end_page=-1) # Obtain 1 long string
data = Process(raw_data) # Clean string and convert to word_ids
output = Infer(best_model, data, max_size=100, temperature=0.5)

# Relevance Query

In [31]:
# get user input
user_input = text_box.value
user_input

'4'

In [32]:
# get input pdf as string
raw_data = readPDF(uploader.value, start_page=0, end_page=-1) # Obtain 1 long string
data = Process(raw_data)

In [425]:
from bisect import insort
# loop through pdf with context window
# update highest score and update topk score
def get_similarity(query, data, window=None, k=5):
    '''
    query: the terms to be searched
    data: the data to be compared with the query
    window: the size for each similarity comparison, default = query length
    k: relates to how many good hits you want, the higher the k, requires more good similarity scores
       default = 5
    '''
    query_vector = fastmodel.get_sentence_vector(query)
    data = "machine and learning asdilasdasdasd1234567890abcdefghi"
    if window == None:
        window = len(query)

    if window > len(data):
        window = len(data)
    print(window)
    best = 0
    topk = [0] * k
    print(data)
    for head in range(0, len(data), window):
        print(head)
        tail = (head+window)
        print(tail)

        chunk = data[head:tail]
        print([chunk])
        chunk_vector = fastmodel.get_sentence_vector(chunk)
        temp_similarity = get_cosine(query_vector, chunk_vector)
        print(temp_similarity)
        if temp_similarity > best:
            best = temp_similarity
    
        if temp_similarity > topk[0]:
            # keep the list size to k elements
            # insort method allows for more efficient insertion and sorting
            insort(topk, temp_similarity)
            topk = topk[-k:]
    topk.reverse()
    return best, topk

In [426]:
from numpy.linalg import norm
def get_cosine(query, key):
    return np.dot(query,key)/(norm(query)*norm(key))

In [427]:
# calculate mixed score with alpha
def get_mix_score(best, topk, alpha=0.5):
    return (alpha * best) + ((1-alpha) * np.average(topk))
    

In [432]:
best, top = get_similarity(query="machine learning", data=None, window=25)
print(best)
print(top)

25
machine and learning asdilasdasdasd1234567890abcdefghi
0
25
['machine and learning asdi']
0.84383076
25
50
['lasdasdasd1234567890abcde']
0.3307465
50
75
['fghi']
0.35131034
0.84383076
[0.84383076, 0.35131034, 0.3307465, 0, 0]


In [359]:
get_mix_score(best, top)

0.5459473729133606

In [360]:
# Ouput score
output.print(mix_score)

NameError: name 'output' is not defined