In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import ipywidgets as widgets
from IPython.display import display

<figure>
  <img src="the-simpsons.png" alt="the-simpsons.png"/>
  <figcaption>(c) 20th Century Fox Television</figcaption>
</figure>

# Chat with Simpsons

Start your conversation with Simpsons chatbot!

In [None]:
model = AutoModelForCausalLM.from_pretrained("arampacha/DialoGPT-medium-simpsons")
tokenizer = AutoTokenizer.from_pretrained("arampacha/DialoGPT-medium-simpsons")

In [25]:
reset_btn = widgets.Button(description="Reset")
output_text = widgets.Textarea(placeholder="Start dialog", disabled=True, layout={'height': '100%', 'width':'99%'})
input_text = widgets.Text(description="Input:", layout={'width':'99%'})

In [28]:
def _reset(click):
    output_text.value = ""
    input_text.value = ""
    cache.clear()
    
reset_btn.on_click(_reset)

In [29]:
cache = {}
def write_to_output(sender):
    user_input = input_text.value
    input_text.value=""
    user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
    step = cache.get("step", 0)
    chat_history_ids = cache.get("chat_history_ids", torch.empty(1, 0))
    bot_input_ids = torch.cat([chat_history_ids, user_input_ids], dim=-1) if step > 0 else user_input_ids
    output_text.value += f">> User:     {user_input}\n"
    
    chat_history_ids = model.generate(bot_input_ids, 
                                      max_length=1000, 
                                      pad_token_id=tokenizer.eos_token_id,
                                      do_sample=True,
                                      top_p=0.9)
    bot_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
    output_text.value += f"SimpsonsBot: {bot_output}\n"
    
    cache["step"] = step + 1
    cache["chat_history_ids"] = chat_history_ids

    
input_text.on_submit(write_to_output)

In [31]:
display(widgets.VBox([output_text, input_text], layout={'height': '300px', 'width':'500px'}))

VBox(children=(Textarea(value='', disabled=True, layout=Layout(height='100%', width='99%'), placeholder='Start…

In [32]:
display(reset_btn)

Button(description='Reset', style=ButtonStyle())