To start this Jupyter Dash app, please run all the cells below. Then, click on the **temporary** URL at the end of the last cell to open the app.

In [1]:
!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components transformers

[K     |████████████████████████████████| 45 kB 1.2 MB/s 
[K     |████████████████████████████████| 209 kB 7.8 MB/s 
[K     |████████████████████████████████| 3.8 MB 39.1 MB/s 
[K     |████████████████████████████████| 9.6 MB 20.8 MB/s 
[K     |████████████████████████████████| 895 kB 47.5 MB/s 
[K     |████████████████████████████████| 6.5 MB 32.5 MB/s 
[K     |████████████████████████████████| 67 kB 3.6 MB/s 
[K     |████████████████████████████████| 596 kB 37.6 MB/s 
[K     |████████████████████████████████| 357 kB 37.0 MB/s 
[?25h  Building wheel for retrying (setup.py) ... [?25l[?25hdone


In [4]:
import time

import dash
import dash_html_components as html
import dash_core_components as dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from jupyter_dash import JupyterDash
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  after removing the cwd from sys.path.
The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  """


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

print("Start loading model...")
name = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelWithLMHead.from_pretrained(name)

# Switch to cuda, eval mode, and FP16 for faster inference
if device == "cuda":
    model = model.half()
model.to(device)
model.eval()

print("Done.")

Device: cuda
Start loading model...


Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/642 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Done.


In [6]:
def textbox(text, box="other"):
    style = {
        "max-width": "55%",
        "width": "max-content",
        "padding": "10px 15px",
        "border-radius": "25px",
    }

    if box == "self":
        style["margin-left"] = "auto"
        style["margin-right"] = 0

        color = "primary"
        inverse = True

    elif box == "other":
        style["margin-left"] = 0
        style["margin-right"] = "auto"

        color = "light"
        inverse = False

    else:
        raise ValueError("Incorrect option for `box`.")

    return dbc.Card(text, style=style, body=True, color=color, inverse=inverse)

In [2]:
!pip install "dash-bootstrap-components<1"

Collecting dash-bootstrap-components<1
  Downloading dash_bootstrap_components-0.13.1-py3-none-any.whl (197 kB)
[?25l[K     |█▋                              | 10 kB 21.8 MB/s eta 0:00:01[K     |███▎                            | 20 kB 9.6 MB/s eta 0:00:01[K     |█████                           | 30 kB 8.2 MB/s eta 0:00:01[K     |██████▋                         | 40 kB 7.5 MB/s eta 0:00:01[K     |████████▎                       | 51 kB 3.6 MB/s eta 0:00:01[K     |██████████                      | 61 kB 4.3 MB/s eta 0:00:01[K     |███████████▋                    | 71 kB 4.5 MB/s eta 0:00:01[K     |█████████████▎                  | 81 kB 4.6 MB/s eta 0:00:01[K     |███████████████                 | 92 kB 5.1 MB/s eta 0:00:01[K     |████████████████▋               | 102 kB 4.3 MB/s eta 0:00:01[K     |██████████████████▎             | 112 kB 4.3 MB/s eta 0:00:01[K     |████████████████████            | 122 kB 4.3 MB/s eta 0:00:01[K     |█████████████████████▋      

In [7]:
conversation = html.Div(
    style={
        "width": "80%",
        "max-width": "800px",
        "height": "70vh",
        "margin": "auto",
        "overflow-y": "auto",
    },
    id="display-conversation",
)

controls = dbc.InputGroup(
    style={"width": "80%", "max-width": "800px", "margin": "auto"},
    children=[
        dbc.Input(id="user-input", placeholder="Write to the chatbot...", type="text"),
        dbc.InputGroupAddon(dbc.Button("Submit", id="submit"), addon_type="append",),
    ],
)


# Define app
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server


# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
        html.H1("Dash Chatbot (with DialoGPT)"),
        html.Hr(),
        dcc.Store(id="store-conversation", data=""),
        conversation,
        controls,
    ],
)

In [8]:
@app.callback(
    Output("display-conversation", "children"), [Input("store-conversation", "data")]
)
def update_display(chat_history):
    return [
        textbox(x, box="self") if i % 2 == 0 else textbox(x, box="other")
        for i, x in enumerate(chat_history.split(tokenizer.eos_token)[:-1])
    ]


@app.callback(
    [Output("store-conversation", "data"), Output("user-input", "value")],
    [Input("submit", "n_clicks"), Input("user-input", "n_submit")],
    [State("user-input", "value"), State("store-conversation", "data")],
)
def run_chatbot(n_clicks, n_submit, user_input, chat_history):
    if n_clicks == 0:
        return "", ""

    if user_input is None or user_input == "":
        return chat_history, ""

    # # temporary
    # return chat_history + user_input + "<|endoftext|>" + user_input + "<|endoftext|>", ""

    # encode the new user input, add the eos_token and return a tensor in Pytorch
    bot_input_ids = tokenizer.encode(
        chat_history + user_input + tokenizer.eos_token, return_tensors="pt"
    ).to(device)

    # generated a response while limiting the total chat history to 1000 tokens,
    chat_history_ids = model.generate(
        bot_input_ids, max_length=1024, pad_token_id=tokenizer.eos_token_id
    )
    chat_history = tokenizer.decode(chat_history_ids[0])

    return chat_history, ""

In [3]:
!pip install dash==2.0.0

Collecting dash==2.0.0
  Downloading dash-2.0.0-py3-none-any.whl (7.3 MB)
[K     |████████████████████████████████| 7.3 MB 3.8 MB/s 
Installing collected packages: dash
  Attempting uninstall: dash
    Found existing installation: dash 2.3.0
    Uninstalling dash-2.3.0:
      Successfully uninstalled dash-2.3.0
Successfully installed dash-2.0.0


In [9]:
if __name__ == '__main__':
    app.run_server(mode="external")

Dash app running on:


<IPython.core.display.Javascript object>