## Chatbot using RevDict and Llama 2 Chat Experiments

In [1]:
import string
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

[nltk_data] Downloading package punkt to /home/CE/julians/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
  from .autonotebook import tqdm as notebook_tqdm


### Load model

In [5]:
id = 6
model_name = 'meta-llama/Llama-2-7b-chat-hf'
access_token = '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'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=access_token, device_map={'':id})

Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.17s/it]


In [6]:
# check current device
if next(model.parameters()).is_cuda:
    print("Model is on GPU")
    print(f"Model is on Cuda {torch.cuda.current_device()}")
else:
    print("Model is on CPU")


Model is on GPU
Model is on Cuda 0


### Chatbot 

In [8]:
def get_current_prompt(message, chat_history, system_prompt):
    ''' 
    This function creates the current prompt by considering the latest user input (message), the chat history and the system prompt
    '''
    texts = [f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "]

    for user_input, response in chat_history:
        texts.append(f"{user_input.strip()} [/INST] {response.strip()} </s><s> [INST] ")
    
    texts.append(f"{message.strip()} [/INST]")

    current_prompt = " ".join(texts)

    return current_prompt

def get_response(prompt):
    ''' 
    This function returns the llama 2 response based on the current prompt.
    '''
    inputs = tokenizer([prompt], return_tensors='pt').to("cuda:6")
    output = model.generate(**inputs, max_new_tokens=20, temperature=.75, early_stopping=True, 
            )
    chatbot_response = tokenizer.decode(output[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True).strip()

    # cut off last sentence if it doesn't terminate with a punctuation symbol 
    sentences = sent_tokenize(chatbot_response) 
    if len(sentences) > 1 and sentences[-1][-1] not in string.punctuation:
        sentences.pop()
    chatbot_response = ' '.join(sentences)

    return chatbot_response

def main(system_prompt):
    '''
    This function is called to initialize the conversation.
    '''
    history = []

    while True:
        user_input = input()
        print(f"User: {user_input}")
        if user_input in ["exit", "thank you", "bye"]:
            break

        prompt = get_current_prompt(user_input, history, system_prompt)
        response = get_response(prompt)
        print(f"Chatbot: {response}")
        time.sleep(3)
        history = history + [(user_input, response)]

In [9]:
# define a system prompt
# this is just a draft; the target 'instrument' will be replaced by the output of the reverse dictionary. Optionally, also the description is included  
system_prompt = """As an assistance chatbot, your task is to help the user find a word they are looking for.

You know the word the user is looking for. The word is 'instrument'

You help the user guess this word, so you are like a teacher. Never say this word to the user unless they guessed it. Always give short answers."""

In [7]:
# example user input
user_input = "I need help to find the word."
prompt = get_current_prompt(user_input, [], system_prompt)
response = get_response(prompt)
response

"Great! Let me help you find the word you're thinking of."

In [10]:
# start the conversation
# NB: there is some bug within my loop that causes the output not the be printed immediately
main(system_prompt)

User: 
Chatbot: Great! I'm here to help you find the word you're thinking of.
User: I need help to find the word.
Chatbot: Sure, I'd be happy to help!
User: Can you provide synonyms of my target?
Chatbot: Of course! The word you are looking for is "instrument".
User: exit
