In [1]:
#@title # Setting up the environment { vertical-output: true, display-mode: "form" }

###################
#####  SETUP  #####
###################

print("Mounting google drive.. ")
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

# setting the project path
PROJECT_PATH = "./" #@param {type:"string"}

print("Navigating to the project folder.. ")
import os
os.chdir(PROJECT_PATH)

print("Installing dependencies... ")
# !pip install -q sentencepiece
# !pip install -q torch>=1.7.0,!=1.8.0
# !pip install -q transformers==4.16.2
# !pip install -q pytorch-lightning==1.5.10
# !pip install -q swifter 
# !pip install -q evaluate 
# !pip install -q bert-score 

print("Found the following files:", os.listdir())


import matplotlib.pyplot as plt
from IPython.display import display

# import Utils.helperFunctions as helperFunctions
# import Utils.dialogue_utils as dialogue_utils

###################
##### CONFIGS #####
###################

import random
import torch



print("Runtime info:- ")
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Mounting google drive.. 
Navigating to the project folder.. 
Installing dependencies... 
Found the following files: ['.ipynb_checkpoints', 'Book', 'cuda_path.txt', 'google_drive_local_runtime_cc.bat', 'Graduation-Project', 'Old', 'speech-to-text', 'streamlit_interface.py', 'T5 Checkpoints', 't5_on_tpu.ipynb']
Runtime info:- 
Thu Jul  7 17:22:10 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 511.79       Driver Version: 511.79       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   59C    P8    10W /  N/A |   1659MiB /  6144MiB |     22%      Default |
|

# Code Setup

In [2]:
#@title Imports

import os
import re
import gc
import time
import glob
import random
import functools
import multiprocessing
from typing import List, Tuple, Dict, Callable

# parallize the apply function of pandas
# import swifter

# to show the full dialogues in the dataframes
import pandas as pd
pd.set_option('max_colwidth', 1000)

# import spacy

import torch
import numpy as np

from transformers import (
    T5ForConditionalGeneration,
    PreTrainedTokenizer,
    T5TokenizerFast as T5Tokenizer,
)



from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelWithLMHead, AutoTokenizer

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar

import matplotlib.pyplot as plt

plt.style.use("seaborn")
SEED = 512 #@param {type:"integer"}

def reset_environment(reset_seed=True, seed=SEED):
    torch.cuda.empty_cache()
    gc.collect()
    if reset_seed:
        pl.seed_everything(SEED, workers=True)

print("Setting the project seed.. ")
def seed_everything(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def reset_environment(reset_seed=True, seed=SEED):
    torch.cuda.empty_cache()
    gc.collect()
    if reset_seed:
        pl.seed_everything(SEED, workers=True)

seed_everything()

print("Done")


Setting the project seed.. 
Done


In [3]:
#@title PytorchDataset
class PyTorchDataset(Dataset):
    """  PyTorch Dataset class  """
    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: PreTrainedTokenizer,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
    ):
        """
        initiates a PyTorch Dataset Module for input data
        Args:
            data (pd.DataFrame): input pandas dataframe. Dataframe must have 2 column --> "source_text" and "target_text"
            tokenizer (PreTrainedTokenizer): a PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
        """
        self.tokenizer = tokenizer
        self.data = data
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len

    def __len__(self):
        """ returns length of data """
        return len(self.data)

    def __getitem__(self, index: int):
        """ returns dictionary of input tensors to feed into T5/MT5 model"""

        data_row = self.data.iloc[index]
        source_text = data_row["source_text"]

        source_text_encoding = self.tokenizer(
            source_text,
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        target_text_encoding = self.tokenizer(
            data_row["target_text"],
            max_length=self.target_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        labels = target_text_encoding["input_ids"]
        # to make sure we have correct labels for T5 text generation
        labels[labels == 0] = -100

        return {
            "source_text_input_ids": source_text_encoding["input_ids"].flatten(),
            "source_text_attention_mask": source_text_encoding["attention_mask"].flatten(),
            "labels": labels.flatten(),
            "labels_attention_mask": target_text_encoding["attention_mask"].flatten(),
        }

In [4]:
#@title LightningDataModule
class LightningDataModule(pl.LightningDataModule):
    """ PyTorch Lightning data class """

    def __init__(
        self,
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        eval_df: pd.DataFrame,
        tokenizer: PreTrainedTokenizer,
        batch_size: int = 4,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        num_workers: int = 2,
    ):
        """
        initiates a PyTorch Lightning Data Module
        Args:
            train_df (pd.DataFrame): training dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
            test_df (pd.DataFrame): validation dataframe. Dataframe must contain 2 columns --> "source_text" & "target_text"
            tokenizer (PreTrainedTokenizer): PreTrainedTokenizer (T5Tokenizer, MT5Tokenizer, or ByT5Tokenizer)
            batch_size (int, optional): batch size. Defaults to 4.
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
        """
        super().__init__()

        self.train_df = train_df
        self.test_df = test_df
        self.eval_df = eval_df
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_dataset = PyTorchDataset(
            self.train_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len,
        )
        self.test_dataset = PyTorchDataset(
            self.test_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len,
        )
        self.val_dataset = PyTorchDataset(
            self.eval_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len,
        )

    def train_dataloader(self):
        """ training dataloader """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        """ test dataloader """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        """ validation dataloader """
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )


In [5]:
# @title LightningModel

class LightningModel(pl.LightningModule):
    """ PyTorch Lightning Model class"""

    def __init__(
        self,
        tokenizer,
        model,
        checkpoint_name: str,
        output_dir: str,
        save_only_last_epoch: bool = False,
        learning_rate: float = 0.0001,
    ):
        """
        initiates a PyTorch Lightning Model
        Args:
            tokenizer : T5/MT5/ByT5 tokenizer
            model : T5/MT5/ByT5 model
            output_dir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
            save_only_last_epoch (bool, optional): If True, save just the last epoch else models are saved for every epoch
        """
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.output_dir = output_dir
        self.average_training_loss = None
        self.average_validation_loss = None
        self.save_only_last_epoch = save_only_last_epoch
        self.learning_rate = learning_rate
        self.checkpoint_name = checkpoint_name

        self.checkpoints_dir = f"{self.output_dir}/{self.checkpoint_name}/"

        self.experiment_version = list(
            sorted(map(lambda s: int(s[s.rfind('version')+8:]),
                   ['version_-1']+glob.glob(self.checkpoints_dir+"/version_*")),
                   reverse=True)
            )[0]+1

        self.checkpoints_dir += f"version_{self.experiment_version}/"

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        """ forward step """
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_size):
        """ training step """
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log(
            "train_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
        )
        return loss

    def validation_step(self, batch, batch_size):
        """ validation step """
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log(
            "val_loss", loss, prog_bar=True, logger=True, on_epoch=True, on_step=True
        )
        return loss

    def test_step(self, batch, batch_size):
        """ test step """
        input_ids = batch["source_text_input_ids"]
        attention_mask = batch["source_text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels,
        )

        self.log("test_loss", loss, prog_bar=True, logger=True,)
        return loss

    def configure_optimizers(self):
        """ configure optimizers """

        # return Adafactor(
        #     self.parameters(),
        #     lr=1e-3,
        #     eps=(1e-30, 1e-3),
        #     clip_threshold=1.0,
        #     decay_rate=-0.8,
        #     beta1=None,
        #     weight_decay=0.0,
        #     relative_step=False,
        #     scale_parameter=False,
        #     warmup_init=False
        # )

        # this is the old optimizer
        return AdamW(
            self.parameters(), 
            lr=self.learning_rate
        )

    def training_epoch_end(self, training_step_outputs):
        """ save tokenizer and model on epoch end """
        self.average_training_loss = np.round(
            torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
            4,
        )
        path = self.checkpoints_dir \
              +f"-epoch-{self.current_epoch}" \
              +f"-tloss-{str(self.average_training_loss)}" \
              +f"-vloss-{str(self.average_validation_loss)}"
        
        if self.save_only_last_epoch:
            if self.current_epoch == self.trainer.max_epochs - 1:
                self.tokenizer.save_pretrained(path)
                self.model.save_pretrained(path)
        else:
            self.tokenizer.save_pretrained(path)
            self.model.save_pretrained(path)

    def validation_epoch_end(self, validation_step_outputs):
        _loss = [x.cpu() for x in validation_step_outputs]
        self.average_validation_loss = np.round(
            torch.mean(torch.stack(_loss)).item(),
            4,
        )


In [6]:
#@title SimpleT5 class
class SimpleT5:
    """ Custom SimpleT5 class """

    def __init__(self,
        checkpoint_name: str,
        output_dir: str,
        learning_rate: float,
    ) -> None:
        """ initiates SimpleT5 class """
        self.learning_rate = learning_rate
        self.output_dir = output_dir
        self.checkpoint_name = checkpoint_name
        self.trainer = None
        self.callbacks = [TQDMProgressBar(refresh_rate=5)]

    def from_pretrained(self, model_name="t5-small") -> None:
        """
        loads T5/MT5 Model model for training/finetuning
        Args:
            model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
        """
        self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
        self.model = T5ForConditionalGeneration.from_pretrained(
            f"{model_name}", return_dict=True
        )

    def train(
        self,
        train_df: pd.DataFrame,
        test_df: pd.DataFrame,
        eval_df: pd.DataFrame,
        source_max_token_len: int = 512,
        target_max_token_len: int = 512,
        batch_size: int = 8,
        max_epochs: int = 5,
        use_gpu: bool = True,
        early_stopping_patience_epochs: int = 0,  # 0 to disable early stopping feature
        precision=32,
        logger="default",
        dataloader_num_workers: int = 2,
        save_only_last_epoch: bool = False,
    ):
        """
        trains T5/MT5 model on custom dataset
        Args:
            train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "source_text" and "target_text"
            test_df ([type], optional): test datarame. Dataframe must have 2 column --> "source_text" and "target_text"
            eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "source_text" and "target_text"
            source_max_token_len (int, optional): max token length of source text. Defaults to 512.
            target_max_token_len (int, optional): max token length of target text. Defaults to 512.
            batch_size (int, optional): batch size. Defaults to 8.
            max_epochs (int, optional): max number of epochs. Defaults to 5.
            use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
            output_dir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
            early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training, if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping. Defaults to 0 (disabled)
            precision (int, optional): sets precision training - Double precision (64), full precision (32) or half precision (16). Defaults to 32.
            logger (pytorch_lightning.loggers) : any logger supported by PyTorch Lightning. Defaults to "default". If "default", pytorch lightning default logger is used.
            dataloader_num_workers (int, optional): number of workers in train/test/val dataloader
            save_only_last_epoch (bool, optional): If True, saves only the last epoch else models are saved at every epoch
        """
        self.data_module = LightningDataModule(
            train_df,
            test_df,
            eval_df,
            self.tokenizer,
            batch_size=batch_size,
            source_max_token_len=source_max_token_len,
            target_max_token_len=target_max_token_len,
            num_workers=dataloader_num_workers,
        )

        self.T5Model = LightningModel(
            tokenizer=self.tokenizer,
            model=self.model,
            checkpoint_name=self.checkpoint_name,
            output_dir=self.output_dir,
            save_only_last_epoch=save_only_last_epoch,
        )

        # add callbacks for early stopping
        if early_stopping_patience_epochs > 0:
            early_stop_callback = EarlyStopping(
                monitor="val_loss",
                min_delta=0.01,
                patience=early_stopping_patience_epochs,
                verbose=True,
                mode="min",
            )
            self.callbacks.append(early_stop_callback)

        # add gpu support
        gpus = 1 if use_gpu else 0

        # add logger
        loggers = True if logger == "default" else logger

        # prepare trainer
        self.trainer = pl.Trainer(
            logger=loggers,
            callbacks=self.callbacks,
            max_epochs=max_epochs,
            gpus=gpus,
            precision=precision,
            log_every_n_steps=1,
        )

        # fit trainer
        self.trainer.fit(self.T5Model, self.data_module)

    def load_model(
        self,
        model_dir: str = "outputs",
        use_gpu: bool = False,
        ):
        """
        loads a checkpoint for inferencing/prediction
        Args:
            model_dir (str, optional): path to model directory. Defaults to "outputs".
            use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
        """
        self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
        self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")

        if use_gpu:
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                raise "exception ---> no gpu found. set use_gpu=False, to use CPU"
        else:
            self.device = torch.device("cpu")

        self.model = self.model.to(self.device)

    def predict(
        self,
        source_text: str,
        max_length: int = 512,
        num_return_sequences: int = 1,
        num_beams: int = 2,
        top_k: int = 50,
        top_p: float = 0.95,
        do_sample: bool = True,
        repetition_penalty: float = 2.5,
        length_penalty: float = 1.0,
        early_stopping: bool = True,
        temperature: float=1.0,
        skip_special_tokens: bool = True,
        clean_up_tokenization_spaces: bool = True,
    ):
        """
        generates prediction for T5/MT5 model
        Args:
            source_text (str): any text for generating predictions
            max_length (int, optional): max token length of prediction. Defaults to 512.
            num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
            num_beams (int, optional): number of beams. Defaults to 2.
            top_k (int, optional): Defaults to 50.
            top_p (float, optional): Defaults to 0.95.
            do_sample (bool, optional): Defaults to True.
            repetition_penalty (float, optional): Defaults to 2.5.
            length_penalty (float, optional): Defaults to 1.0.
            early_stopping (bool, optional): Defaults to True.
            skip_special_tokens (bool, optional): Defaults to True.
            clean_up_tokenization_spaces (bool, optional): Defaults to True.
        Returns:
            list[str]: returns predictions
        """
        input_ids = self.tokenizer.encode(
            source_text, return_tensors="pt", add_special_tokens=True
        )
        input_ids = input_ids.to(self.device)
        output = self.model.generate(
            input_ids=input_ids,
            num_beams=num_beams,
            max_length=max_length,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            early_stopping=early_stopping,
            top_p=top_p,
            top_k=top_k,
            num_return_sequences=num_return_sequences,
            return_dict_in_generate=True,
            output_scores=True,
        )
            # forced_eos_token_id=self.tokenizer.eos_token
        generated_ids = output['sequences']
        preds = [
            self.tokenizer.decode(
                g,
                skip_special_tokens=skip_special_tokens,
                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            )
            for g in generated_ids
        ]
        return preds, output['sequences_scores']

    def predict_multiple(
        self,
        source_text: List[str],
        max_length: int = 256,
        num_return_sequences: int = 1,
        num_beams: int = 2,
        top_k: int = 50,
        top_p: float = 0.95,
        do_sample: bool = True,
        repetition_penalty: float = 2.5,
        length_penalty: float = 1.0,
        early_stopping: bool = True,
        skip_special_tokens: bool = True,
        clean_up_tokenization_spaces: bool = True,
    ):
        input_ids = self.tokenizer(
            source_text, 
            return_tensors="pt",
            add_special_tokens=True,
            padding='max_length',
            max_length=max_length,
            truncation=True
        )['input_ids']

        input_ids = input_ids.to(self.device)

        generated_samples = self.model.generate(
            input_ids=input_ids,
            num_beams=num_beams,
            max_length=max_length,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            early_stopping=early_stopping,
            top_p=top_p,
            top_k=top_k,
            num_return_sequences=num_return_sequences,
        )

        return [self.tokenizer.decode(
                g, 
                skip_special_tokens=skip_special_tokens, 
                clean_up_tokenization_spaces=clean_up_tokenization_spaces) 
            for g in generated_samples
        ]

    def batch_predict(self, batch_size: int, sequences: List[str], **kwargs):
        output = []
        n_steps = (len(sequences)//batch_size)
        for i in range(0, n_steps*batch_size, batch_size):
            output += self.predict_multiple(
                sequences[i:i+batch_size], 
                **kwargs)
        output += self.predict_multiple(
          sequences[n_steps*batch_size:], 
          **kwargs
        )
        return output

In [7]:
from colorama import Style, Back, Fore

In [8]:
#@title
    # @ staticmethod
    # def on_char_input(change):
    #     def generate_random_response(prefix):
    #         length = random.randrange(10, 30)
    #         vocab = 'abcdefghijklmnopqrstuvwxyz '
    #         return prefix + ''.join(random.choice(vocab) for _ in range(length))

    #     if ((change['type'] == 'change') and
    #         (change['name'] == 'value')):

    #         widget = change['owner']
    #         prefix = widget.get_interact_value().lower()
    #         widget.options = [generate_random_response(prefix) for _ in range(3)]
    #         widget.options = ["I want to order a pizza", "I want to order a large pepperoni pizza"]
    #         # w.placeholder = w.options[0]
    #         widget.ensure_option = False

from ipywidgets import CallbackDispatcher, register, Text

from ipywidgets import trait_types 
from traitlets import Unicode, Bool, Int, Container


@register
class Combobox(Text):
    """Single line textbox widget with a dropdown and autocompletion.
    """
    _model_name = Unicode('ComboboxModel').tag(sync=True)
    _view_name = Unicode('ComboboxView').tag(sync=True)

    options = trait_types.TypedTuple(
        trait=Unicode(),
        help="Dropdown options for the combobox"
    ).tag(sync=True)

    ensure_option = Bool(
        False,
        help='If set, ensure value is in options. Implies continuous_update=False.'
    ).tag(sync=True)


#@title Interface Setup
import ipywidgets as widgets
from IPython.display import clear_output
import random
from functools import partial

class TestingInterface():
    def __init__(self, model, **kwargs):
        self.model = model
        self.kwargs = kwargs

        self.context_widget = widgets.Textarea(rows=1)
        self.context_widget.observe(self.on_context_changed)

        self.history_widget = widgets.Textarea(
            placeholder='previous messages will appear here..',
            rows=10
        )
        self.history_widget.disabled = True
        self.reset_btn = widgets.Button(description='Reset')
        self.reset_btn.on_click(self.reset)

        self.response_widget = Combobox(
            options=[], 
            value="", 
            placeholder='Enter your message here', 
            ensure_option=False, 
            continuous_update=True,
            rows=3
        )
        self.response_widget.observe(
            self.on_char_input
        )
            # self.debouncer
        self.response_widget.on_submit(self.on_response_submission)

        self.grid_placeholder = widgets.Label()

        self.grid = [
            self.grid_placeholder, self.reset_btn, 
            widgets.Label('Context:'), self.context_widget, 
            widgets.Label('History:'), self.history_widget, 
            widgets.Label('Response:'), self.response_widget, 
        ]

        self.layout_style = """repeat(2, 100px)"""

        self.ui = widgets.GridBox(
            self.grid, 
            layout=widgets.Layout(grid_template_columns=self.layout_style)
        )

        self.prev_speaker = "Person2: "
        self.last_call = time.time()

    @property
    def current_speaker(self):
        if self.prev_speaker == 'Person2: ':
            self.prev_speaker = 'Person1: '
    
        else:
            self.prev_speaker = 'Person2: '
    
        return self.prev_speaker
    
    def debouncer(self, change):
        if ((change['type'] == 'change') and
            (change['name'] == 'value')):

            elapsed = time.time() - self.last_call
            self.last_call = time.time()

            if (elapsed) > 0.1:
                self.on_char_input(change)

    def on_char_input(self, change):
        widget = change['owner']
        prefix = widget.get_interact_value().lower()
        if prefix and prefix[-1]!= " ":
            return
        eos_token = self.model.tokenizer.eos_token
        options, scores = self.model.predict(
            "compelete: "
            + self.history_widget.value.replace("\n", eos_token)
            + eos_token
            + prefix,
            **self.kwargs,
        )
        # widget.options = [prefix + option for option in widget.options]
        # clear_output(wait=False)
        # self.display()
        # print()
        print(f'\rSuggestion ({scores[0]:0.3f}):' + Fore.LIGHTGREEN_EX,
              options[0], 
              end=Style.RESET_ALL, 
              flush=True
        )

        # w.placeholder = w.options[0]
        widget.ensure_option = False

    def on_response_submission(self, widget):
        response = widget.value
        if response:
            widget.value = ""
            widget.options = []
            self.history_widget.value += "\n" + self.current_speaker + response
            self.history_widget.value = self.history_widget.value.strip()

    @staticmethod
    def on_context_changed(change):
        widget = change['owner']
        if ((change['type'] == 'change') and
            (change['name'] == 'value')):
            context_text = widget.get_interact_value()
            # print(context_text)

    def reset(self, reset_btn):
        self.context_widget.value = ''
        self.history_widget.value = ''
        
        self.response_widget.value = ''
        self.response_widget.options = []
        
        clear_output(wait=True)
        self.display()

    def display(self):
      display(self.ui)


# Inference

## Manual Testing

In [11]:
#@title
from pprint import pprint
reset_environment()
LOCAL_DIR = "T5 Checkpoints/" 
LOCAL_DIR =  PROJECT_PATH + LOCAL_DIR

best_checkpoint_name = "t5-v1_1-base_BatchSize-16_N-Splits-4_DatasetSize-large_Topic-Food&Drink_version_2"
best_checkpoint_path = LOCAL_DIR+"t5-v1_1-base_check_points/t5-v1_1-base_BatchSize-16_N-Splits-4_DatasetSize-large_Topic-Food&Drink/version_2/-epoch-9-tloss-1.5577-vloss-1.8296"


best_checkpoint_name = "t5-v1_1-base_BatchSize-16_N-Splits-4_DatasetSize-large_Topic-Food&Drink_version_2"
best_checkpoint_path = LOCAL_DIR+"-epoch-9-tloss-1.5577-vloss-1.8296"


def get_model_from_disk(checkpoint_path, model_name, use_gpu=True):
    model = SimpleT5(model_name, output_dir="", learning_rate=1e-4)
    model.load_model(checkpoint_path, use_gpu=use_gpu)
    return model

top_model = get_model_from_disk(
    best_checkpoint_path, 
    best_checkpoint_name, 
    use_gpu=True)


kwargs = {
    "num_return_sequences": 3, 
    "num_beams": 3, 
    "top_p": 0.95,
    "top_k": 100,
    "temperature": 1.0 
}

def get_suggestions(model, prompt, return_scores=False, kwargs=kwargs):
    suggestions, scores = model.predict("comeplete: " + prompt, **kwargs)
    if return_scores:
        return suggestions, scores

    print("\n".join(suggestions))

Global seed set to 512


## Good Examples

In [None]:
get_suggestions(top_model, "I want to eat")

a pizza with my family tonight.
a pizza.
a pizza with my friends.


In [None]:
get_suggestions(top_model, "I want to drink")

a cup of coffee.
a cup of coffee, please.
a cup of coffee. It tastes good.


In [None]:
get_suggestions(top_model, "I want to order a")

burger for pickup at the local Burger King.
burger for pickup at the nearest Burger King.
burger for pickup at the local Burger King


In [None]:
get_suggestions(top_model, "May I order a")

iced coffee from starbucks
burger for me?
iced coffee from Starbucks


In [None]:
get_suggestions(top_model, "What kind of pizza toppings do you have?")

a variety of pizza toppings. I've got pepperoni, sausage, bacon, and pineapple.
a variety of pizza toppings. I've got pepperoni, bacon, sausage, and pineapple.
a variety of pizza toppings. I've got pepperoni, sausage, and bacon.


In [None]:
get_suggestions(top_model, "I want to order a large pizza with")

ham and pineapple
ham and pineapple on it
ham and pineapple.


In [None]:
get_suggestions(top_model, "I want to order a double cheese burger with")

ketchup and onions
ketchup and mustard
ketchup and onions.


## Bad Examples

In [None]:
bad_checkpoint_name = "t5v1.1-base-small-dataset"
bad_checkpoint_path = LOCAL_DIR+"simplet5-epoch-9-train-loss-2.1007-val-loss-2.4144/"
bad_model = get_model_from_disk(
    bad_checkpoint_path, 
    bad_checkpoint_name, 
    use_gpu=True
)

In [None]:
reset_environment()

Global seed set to 512


In [None]:
get_suggestions(bad_model, "I want to eat")

food in the centre of town.
food in San Francisco, California.
food in San Francisco.


In [None]:
get_suggestions(bad_model, "I want to drink")

a drink with my friends tonight.
drink a lot of coffee.
a drink.


In [None]:
get_suggestions(bad_model, "I want to order a")

restaurant in San Francisco, California.
restaurant that serves Mexican food in the centre of town
restaurant that serves Mexican food.


In [None]:
get_suggestions(bad_model, "May I order a")

cab to take me from JFK International Airport to the nearest Starbucks?
cab to take me from JFK International Airport to the Hilton Hotel?
restaurant in San Francisco, California?


In [None]:
get_suggestions(bad_model, "What kind of pizza toppings do you have?")

What kind of crust do you want?
What kind of crust?
I have pepperoni and sausage.


In [None]:
get_suggestions(bad_model, "I want to order a large pizza with")

a side of french fries and pepperoni.
a side of french fries, please.
a side of french fries.


In [None]:
get_suggestions(bad_model, "I want to order a double cheese burger with")

a side of fries and french fries.
a side of fries and french fries, please.
a side of french fries and a side of fries.


## Live Testing

In [13]:
#@title # T5 Model inference { vertical-output: true, display-mode: "form" }
interface = TestingInterface(
    top_model,
    num_return_sequences=1, 
    num_beams=3, 
    top_k=100, 
    top_p=0.94
)

interface.display()
print("\n\n")

GridBox(children=(Label(value=''), Button(description='Reset', style=ButtonStyle()), Label(value='Context:'), …

Suggestion (-0.477):[92m 'm looking for a place to dine in the centre of town that serves italian food[0m