##### Prerequisites

In [None]:
%%capture 

!pip install torch==1.12.1+cu113
!pip install transformers==4.21.0

#### Imports 

In [3]:
from transformers import AutoTokenizer
from torch.nn import functional as F
from itertools import chain
import transformers 
import logging
import torch

In [4]:
torch.cuda.empty_cache()

##### Setup logging 

In [5]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [6]:
logger.info(f'[Using transformers version: {transformers.__version__}]')
logger.info(f'[Using torch version: {torch.__version__}]')

[Using transformers version: 4.21.0]
[Using torch version: 1.13.1+cu117]


#### Setup essentials 

#### Load GPT-Neo tokenizer

In [7]:
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
logger.info(tokenizer)

PreTrainedTokenizerFast(name_or_path='EleutherAI/gpt-neo-125M', vocab_size=50257, model_max_len=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True)})


In [8]:
special_tokens = {
    'bos_token': '<|startoftext|>',
    'additional_special_tokens': ['<|speaker-1|>', '<|speaker-2|>', '<|pad|>', '<|mask|>']
}

In [9]:
_ = tokenizer.add_special_tokens(special_tokens)
vocab = tokenizer.get_vocab()

In [10]:
logger.info(tokenizer)

PreTrainedTokenizerFast(name_or_path='EleutherAI/gpt-neo-125M', vocab_size=50257, model_max_len=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'additional_special_tokens': ['<|speaker-1|>', '<|speaker-2|>', '<|pad|>', '<|mask|>']})


#### Load model 

In [11]:
model = transformers.AutoModelForCausalLM.from_pretrained('./../02-train/model')
model.resize_token_embeddings(len(vocab))
device = torch.device('cuda')
model.to(device)
logger.info(next(model.parameters()).device)

cuda:0


In [12]:
_ = model.eval()

#### Evaluate model 

In [13]:
bos_id = vocab['<|startoftext|>']
eos_id = vocab['<|endoftext|>']
speaker_1_id = vocab['<|speaker-1|>']
speaker_2_id = vocab['<|speaker-2|>']
mask = vocab['<|mask|>']

In [14]:
def tokenize(query: str) -> (torch.LongTensor, torch.LongTensor, int):
    # Initialize empty list to store input ids for each turn
    input_ids_turns = []
    
    # Add speaker 1 id to start of query and encode it using the tokenizer
    input_ids = tokenizer.encode(query)
    input_ids = [speaker_1_id] + input_ids
    input_ids_turns.append(input_ids)
    
    
    
    # Add beginning of sequence and end of sequence ids to input_ids, and convert it to a tensor
    input_ids = [bos_id] + list(chain.from_iterable(input_ids_turns)) + [speaker_2_id]
    
    # Determine the speaker of the first turn based on the first speaker id
    start_sp_id = input_ids_turns[0][0]
    # Determine the speaker of the next turn
    next_sp_id = speaker_1_id if start_sp_id == speaker_2_id else speaker_2_id
    
    # Create token type ids for each turn based on the speaker of the turn
    token_type_ids = [[start_sp_id] * len(turn) if h % 2 == 0 else [next_sp_id] * len(turn) for h, turn in enumerate(input_ids_turns)]
    
    
    
    # Add beginning of sequence and end of sequence ids to token_type_ids, and convert it to a tensor
    token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [speaker_2_id]
    
    # Determine the length of the input_ids tensor
    input_len = len(input_ids)
    
    input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
    token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)
    
    return input_ids, token_type_ids, input_len

In [15]:
tokenize('i hate my life')

(tensor([[50257, 50258,    72,  5465,   616,  1204, 50259]], device='cuda:0'),
 tensor([[50258, 50258, 50258, 50258, 50258, 50258, 50259]], device='cuda:0'),
 7)

In [22]:
MAX_LEN = 64
top_p = 0.9

In [23]:
def generate(query: str) -> str:
    output_ids = []
    
    input_ids, token_type_ids, input_len = tokenize(query)
    
    print(input_ids)
    print(token_type_ids)
    print(input_len)
    
    for i in range(input_len, MAX_LEN):
        output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, i-1]
        # Apply the softmax function to the logits
        probs = F.softmax(output, dim=-1)

        sorted_probs, sorted_idxs = torch.sort(probs, descending=True)
        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
 
        idx_remove = cumsum_probs > top_p
        idx_remove[:, 1:] = idx_remove[:, :-1].clone()
        idx_remove[:, 0] = False
        sorted_probs[idx_remove] = 0.0
        sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)
        probs = torch.zeros(output.shape, device=device).scatter_(-1, sorted_idxs, sorted_probs)
        idx = torch.multinomial(probs, num_samples=1)
        idx_item = idx.squeeze(-1).squeeze(-1).item()
        output_ids.append(idx_item)
        if idx_item == eos_id:
            break

        input_ids = torch.cat((input_ids, idx), dim=-1)
        next_type_id = torch.LongTensor([[speaker_2_id]]).to(device)
        token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    #response = tokenizer.decode(output_ids)
    return response

In [26]:
generate('tell me about yourself ?')

tensor([[50257, 50258, 33331,   502,   546,  3511,  5633, 50259]],
       device='cuda:0')
tensor([[50258, 50258, 50258, 50258, 50258, 50258, 50258, 50259]],
       device='cuda:0')
8


'my bad today is just chilling for now'