In [None]:
#@title # Setting up the environment { vertical-output: true, display-mode: "form" }

###################
#####  SETUP  #####
###################

print("Mounting google drive.. ")
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# setting the project path
PROJECT_PATH = "/content/drive/MyDrive/TWM/" #@param {type:"string"}

print("Navigating to the project folder.. ")
import os
os.chdir(PROJECT_PATH)

import sys
DEPENDENCIES_PATH = "Dependencies" #@param {type:"string"}
# !pip install --target=$DEPENDENCIES_PATH transformers
sys.path.insert(0, os.path.join(PROJECT_PATH, DEPENDENCIES_PATH))

print("Found the following files:", os.listdir())

# importing dependencies
import spacy
import torch
from transformers import AutoTokenizer, AutoModel

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from IPython.display import display

import Utils.helperFunctions as helperFunctions
import Utils.dialogue_utils as dialogue_utils

###################
##### CONFIGS #####
###################

print("Navigating to the data directory..")
DATA_DIRECTORY = "DataEngineering/FinalDataset/small" #@param {type:"string"}
os.chdir(DATA_DIRECTORY)
print("Found the following files:", os.listdir())

SEED = 511 #@param {type:"integer"}
print("Setting the project seed.. ")
def seed_everything(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

MATPLOTLIB_STYLE = "seaborn" #@param {type:"string"}
plt.style.use(MATPLOTLIB_STYLE)

print("Done")

Mounting google drive.. 
Mounted at /content/drive
Navigating to the project folder.. 
Found the following files: ['DataEngineering', 'FineTuning', 'TrainingFromScratch', 'Utils', 'Dependencies', 'Testing Interface.ipynb']
Navigating to the data directory..
Found the following files: ['train.csv', 'test.csv', 'dev.csv']
Setting the project seed.. 
Done


In [None]:
tokenizer = AutoTokenizer.from_pretrained("TODBERT/TOD-BERT-JNT-V1")
tod_bert = AutoModel.from_pretrained("TODBERT/TOD-BERT-JNT-V1")

Downloading:   0%|          | 0.00/39.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/32.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at TODBERT/TOD-BERT-JNT-V1 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [1]:
#@title Preparations { vertical-output: true, display-mode: "form" }

import ipywidgets as widgets
from IPython.display import clear_output
import random

class TestingInterface():
    def __init__(self):
        self.context_widget = widgets.Textarea(rows=1)
        self.context_widget.observe(self.on_context_changed)

        self.history_widget = widgets.Textarea(
            placeholder='previous messages will appear here..',
            rows=10
        )
        self.history_widget.disabled = True
        self.reset_btn = widgets.Button(description='Reset')
        self.reset_btn.on_click(self.reset)

        self.response_widget = widgets.Combobox(
            options=[], 
            value="", 
            placeholder='Enter your message here', 
            ensure_option=False, 
            rows=3
        )
        self.response_widget.observe(self.on_char_input)
        self.response_widget.on_submit(self.on_response_submission)

        self.grid_placeholder = widgets.Label()

        self.grid = [
            self.grid_placeholder, self.reset_btn, 
            widgets.Label('Context:'), self.context_widget, 
            widgets.Label('History:'), self.history_widget, 
            widgets.Label('Response:'), self.response_widget, 
        ]

        self.layout_style = """repeat(2, 80px)"""

        self.ui = widgets.GridBox(
            self.grid, 
            layout=widgets.Layout(grid_template_columns=self.layout_style)
        )

        self.prev_speaker = "Person2: "

    @property
    def current_speaker(self):
        if self.prev_speaker == 'Person2: ':
            self.prev_speaker = 'Person1: '
    
        else:
            self.prev_speaker = 'Person2: '
    
        return self.prev_speaker
       
        
    @staticmethod
    def on_char_input(change):
        def generate_random_response(prefix):
            length = random.randrange(10, 30)
            vocab = 'abcdefghijklmnopqrstuvwxyz '
            return prefix + ''.join(random.choice(vocab) for _ in range(length))

        if ((change['type'] == 'change') and
            (change['name'] == 'value')):

            widget = change['owner']
            prefix = widget.get_interact_value().lower()
            widget.options = [generate_random_response(prefix) for _ in range(3)]
            # w.placeholder = w.options[0]
            widget.ensure_option = False

    def on_response_submission(self, widget):
        response = widget.value
        if response:
            widget.value = ""
            widget.options = []
            self.history_widget.value += '\n' + self.current_speaker + response
            self.history_widget.value = self.history_widget.value.strip()

    @staticmethod
    def on_context_changed(change):
        widget = change['owner']
        if ((change['type'] == 'change') and
            (change['name'] == 'value')):
            context_text = widget.get_interact_value()
            # print(context_text)

    def reset(self, reset_btn):
        self.context_widget.value = ''
        self.history_widget.value = ''
        
        self.response_widget.value = ''
        self.response_widget.options = []
        
        clear_output(wait=True)
        self.display()

    def display(self):
      display(self.ui)


In [2]:
#@title #Tod-Bert Model inference { vertical-output: true, display-mode: "form" }

interface = TestingInterface()
interface.display()

GridBox(children=(Label(value=''), Button(description='Reset', style=ButtonStyle()), Label(value='Context:'), …