##### Prerequisites

In [None]:
%%capture 

!pip install torch==1.12.1+cu113
!pip install transformers==4.21.0
!pip install ipywidgets==8.0.4

#### Imports 

In [2]:
from transformers import AutoTokenizer
from torch.nn import functional as F
from itertools import chain
import transformers 
import logging
import torch

In [3]:
torch.cuda.empty_cache()

##### Setup logging 

In [4]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [5]:
logger.info(f'[Using transformers version: {transformers.__version__}]')
logger.info(f'[Using torch version: {torch.__version__}]')

[Using transformers version: 4.21.0]
[Using torch version: 1.12.1+cu113]


#### Load GPT-Neo tokenizer

In [6]:
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
logger.info(tokenizer)

PreTrainedTokenizerFast(name_or_path='EleutherAI/gpt-neo-125M', vocab_size=50257, model_max_len=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True)})


In [7]:
special_tokens = {
    'bos_token': '<|startoftext|>',
    'additional_special_tokens': ['<|speaker-1|>', '<|speaker-2|>', '<|pad|>', '<|mask|>']
}

In [8]:
_ = tokenizer.add_special_tokens(special_tokens)
vocab = tokenizer.get_vocab()

In [9]:
logger.info(tokenizer)

PreTrainedTokenizerFast(name_or_path='EleutherAI/gpt-neo-125M', vocab_size=50257, model_max_len=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'additional_special_tokens': ['<|speaker-1|>', '<|speaker-2|>', '<|pad|>', '<|mask|>']})


#### Load fine-tuned GPT-Neo model 

In [10]:
model = transformers.AutoModelForCausalLM.from_pretrained('./../02-train/model')
model.resize_token_embeddings(len(vocab))
device = torch.device('cuda')
model.to(device)
logger.info(next(model.parameters()).device)

cuda:0


In [11]:
_ = model.eval()

#### Evaluate model 

In [12]:
bos_id = vocab['<|startoftext|>']
eos_id = vocab['<|endoftext|>']
speaker_1_id = vocab['<|speaker-1|>']
speaker_2_id = vocab['<|speaker-2|>']
mask = vocab['<|mask|>']

In [13]:
logger.info(f'bos_id = {bos_id}')
logger.info(f'eos_id = {eos_id}')
logger.info(f'speaker_1_id = {speaker_1_id}')
logger.info(f'speaker_2_id = {speaker_2_id}')
logger.info(f'mask = {mask}')

bos_id = 50257
eos_id = 50256
speaker_1_id = 50258
speaker_2_id = 50259
mask = 50261


In [15]:
MAX_LEN = 128
TOP_P = 1

In [16]:
def nucleus_sampling(input_ids, token_type_ids, input_len) -> str:
    output_ids = []
    
    # iterate over the maximum possible length of the generated sequence
    for i in range(input_len, MAX_LEN):
        # get the model's output for the next token
        output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, i-1]
        
        # apply the softmax function to convert the output logits into probabilities
        probs = F.softmax(output, dim=-1)

        # sort the probabilities in descending order
        sorted_probs, sorted_idxs = torch.sort(probs, descending=True)
        
        # compute the cumulative sum of probabilities
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
 
        # identify the indices to remove (i.e., those with cumsum greater than the given probability threshold)
        idx_remove = cumsum_probs > TOP_P
        idx_remove[:, 1:] = idx_remove[:, :-1].clone()
        idx_remove[:, 0] = False
        
        # zero out the probabilities of the indices to remove, then renormalize the probabilities
        sorted_probs[idx_remove] = 0.0
        sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)
        probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)
        
        # sample the next token from the probability distribution
        idx = torch.multinomial(probs, num_samples=1)
        idx_item = idx.squeeze(-1).squeeze(-1).item()
        
        # add the sampled token to the list of output ids
        output_ids.append(idx_item)
        
        # stop generation if the end-of-sequence token is generated
        if idx_item == eos_id:
            break

        # update the input ids and token type ids for the next iteration
        input_ids = torch.cat((input_ids, idx), dim=-1)
        next_type_id = torch.LongTensor([[speaker_2_id]]).to(device)
        token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
        
    return output_ids

### Interactive Chat 

In [36]:
RESET_PROMPT = 'reset'
MAX_TURNS = 4

In [37]:
def chat():
    logger.info('[Entering chat session ...]')
    logger.info(f'To quit the conversation and reset memory, please type "{RESET_PROMPT}"')
    
    query_history = []
            
    while True:
        utterance = input('You: ')
        
        # Exit session if user types the RESET prompt
        if utterance == RESET_PROMPT:
            logger.info(f'[Exiting chat session]')
            break
            
        # Add speaker 1 id to start of query and encode it using the tokenizer
        input_ids = tokenizer.encode(utterance)
        input_ids = [speaker_1_id] + input_ids
        query_history.append(input_ids)
        
        if len(query_history) >= MAX_TURNS:
            num_exceeded = len(query_history) - MAX_TURNS
            query_history = query_history[num_exceeded:]
            
        # Add beginning of sequence and end of sequence ids to input_ids, and convert it to a tensor
        input_ids = [bos_id] + list(chain.from_iterable(query_history)) + [speaker_2_id]

        # Determine the speaker of the first turn based on the first speaker id
        start_sp_id = query_history[0][0]
        
        # Determine the speaker of the next turn
        next_sp_id = speaker_1_id if start_sp_id == speaker_2_id else speaker_2_id

        # Create token type ids for each turn based on the speaker of the turn
        token_type_ids = [[start_sp_id] * len(turn) if h % 2 == 0 else [next_sp_id] * len(turn) for h, turn in enumerate(query_history)]

        # Add beginning of sequence and end of sequence ids to token_type_ids, and convert it to a tensor
        token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [speaker_2_id]

        # Determine the length of the input_ids tensor
        input_len = len(input_ids)
        
        # Convert input_ids and token_type_ids to PyTorch tensors, add an extra dimension, and move to the device (GPU)
        input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
        token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
        
        # output_ids = nucleus_sampling(input_ids, token_type_ids, input_len)   
        
        # generate a response from the model given some input
        output_ids = model.generate(input_ids=input_ids, 
                                    token_type_ids=token_type_ids, 
                                    pad_token_id=eos_id, 
                                    do_sample=True, 
                                    top_p=TOP_P, 
                                    max_length=MAX_LEN)
        
        # extract the generated sequence from the output and remove the input sequence
        output_ids = output_ids[0].tolist()[input_len:]
        
        # convert the generated sequence of token ids into text
        response = tokenizer.decode(output_ids, skip_special_tokens=True)
        print(f'Bot: {response}')
        
        # append the generated sequence to the query history as token ids
        query_history.append([speaker_2_id] + tokenizer.encode(response))    

In [196]:
chat()

[Entering chat session ...]
To quit the conversation and reset memory, please type "reset"


You:  hola


Bot: hello how are you doing today


You:  nothing much just chilling at home


Bot: huh?


You:  yes


Bot: that sounds like you


You:  lol


Bot: i think i might have a few myself


You:  few of what ?!


Bot: and a lot of them were the last of the four little ones


You:  i don't understand


Bot: i am a lot of yourself


You:  reset


[Exiting chat session]
