# Writing Assistant Demo Website

This google colab notebook hosts the demo website and makes use of google colab's GPU resources. To run the site, please follow the next steps:

1. Change the runtime type to GPU: Runtime -> Change runtime type -> GPU
2. Run all cells and wait until the website is shown within the output of the last cell.
3. Right-Click on the last cell to toggle the menu and select to "view output fullscreen"

In [None]:
# load model and config from public folder https://drive.google.com/drive/folders/1XOWxF54WWqIHRbWQ0YUDf7qDt3LbIrsE?usp=sharing
! echo "Prepare download of T5 model..."
! mkdir t5_model_2
%cd t5_model_2
! echo "download config of model"
! gdown "https://drive.google.com/uc?id=1RrgKJAXK3lujRKSfCgo7CA2bqvOQUhup"
! echo "download model"
! gdown "https://drive.google.com/uc?id=1RXLJgNpjBKMmc5-F781NZBmxRV9WJ2xM"
%cd ..

# set path to T5 model
best_model_path = "/content/t5_model_2/"

# install libraries
! echo "installing libraries..."
!pip install -q sentencepiece
!pip install -q torch==1.4.0
!pip install -q transformers==2.9.0
!pip install -q pytorch_lightning==0.7.5
!pip install -q jupyter-dash dash_bootstrap_components

# Install localtunnel
!npm install -g localtunnel

In [None]:
## prepare for model execution

import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer

def set_seed(seed):
  torch.manual_seed(seed)
#  if torch.cuda.is_available():
#   torch.cuda.manual_seed_all(seed)



set_seed(42)

model = T5ForConditionalGeneration.from_pretrained(best_model_path)
tokenizer = T5Tokenizer.from_pretrained('t5-large')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print ("device ",device)
model = model.to(device)

In [None]:
import nltk
nltk.download('punkt')

def infer(sentence="", n_sentences=1):
  text =  "paraphrase: " + sentence + " </s>"

  max_len = 256

  encoding = tokenizer.encode_plus(text,pad_to_max_length=True, return_tensors="pt")
  input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)

  # set top_k = 50 and set top_p = 0.95 and num_return_sequences = 5
  beam_outputs = model.generate(
      input_ids=input_ids, attention_mask=attention_masks,
      do_sample=True,
      max_length=256,
      top_k=120,
      top_p=0.95,
      early_stopping=True,
      num_return_sequences=n_sentences
  )

  final_outputs =[]
  for beam_output in beam_outputs:
      sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
      if sent.lower() != sentence.lower() and sent not in final_outputs:
          final_outputs.append(sent)
  return final_outputs

def process_input(input_text, n_sentences):
    if not input_text:
      return ""

    sentences = nltk.tokenize.sent_tokenize(input_text)
    
    output = []
    
    for sentence in sentences:
        output.append(" || ".join(infer(sentence,n_sentences)))

    return "\n\n".join(output)


In [None]:
# build website

#import dash
from jupyter_dash import JupyterDash
import dash_core_components as dcc
import dash_html_components as html
from dash_html_components.Button import Button
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import plotly.express as px

import pandas as pd

#external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
#app = JupyterDash(__name__, external_stylesheets=external_stylesheets)
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/gapminderDataFiveYear.csv')

##########


app.layout = dbc.Container(
    fluid=False,
    children=[
        dbc.Jumbotron(
        [
            html.H1("Writing Assistant", className="display-3"),
            html.P(
                "Apply text style transfer models to  "
                "improve the scientificity.",
                className="lead",
            ),
            html.Hr(className="my-2"),
            html.P(
                "Researched and developed by Master Research Project group 11 "
                "at Maastricht University."
            ),
            html.P(dbc.Button("Learn more", color="primary",
                href="https://www.maastrichtuniversity.nl/meta/415976/research-project-dsdm-1",
                target="new_tab"
                ),
            className="lead"),
        ]),
        dcc.Slider(
                  id="n_sentences",
                  min=1,
                  max=5,
                  step=None,
                  value=2,
                  marks={n+1:str(n+1) for n in range(5)}
              ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        dbc.Textarea(
                            id="source-text",
                            placeholder="Add your non-scientific text here...",
                            style={"margin-top": "15px", "height": "30vh"},
                        ),
                    ]
                ),
                dbc.Col(
                    [
                        dbc.Textarea(
                            id="target-text",
                            placeholder="Different output sentences divided by ||",
                            style={"margin-top": "15px", "height": "30vh"}
                        ),
                    ]
                ),
            ]
        ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        dbc.Button(
                            "Translate",
                            id="button-translate",
                            className="mr-2",
                            style={"margin-top": "15px"}
                      ),
                    ]
                )
            ]
        )
    ]
)



@app.callback(
    Output(component_id='target-text', component_property='value'),
    Input(component_id='button-translate', component_property='n_clicks'),
    State(component_id='source-text', component_property='value'),
    State(component_id='n_sentences', component_property='value'),
)
def update_output_div(n_clicks,input_value,n_sentences):
    sci_style = process_input(input_value,n_sentences)
    return sci_style



In [None]:
# Run app and display result inline in the notebook
#app.run_server(mode='external')
app.run_server(mode='external')

In [None]:
!lt --port 8050