<a href="https://colab.research.google.com/github/ThomasWarford/falcon_chatbot/blob/experimental/falcon_4bit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip -q install transformers accelerate langchain einops bitsandbytes accelerate gradio

In [None]:
import torch

from typing import List # Used in StopGenerationCriteria class

from transformers import (
    BitsAndBytesConfig, # for setting 4-bit precision
    AutoModelForCausalLM, # loading model
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList,
    pipeline
)

In [None]:

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4", # normal-float 4-bit
    bnb_4bit_use_double_quant=True,
)

In [38]:
model_id = "vilsonrodrigues/falcon-7b-instruct-sharded"

model_4bit = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=quantization_config,
        )

tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
# from https://www.mlexpert.io/prompt-engineering/chatbot-with-local-llm-using-langchain#conversation-chain

class StopGenerationCriteria(StoppingCriteria):
    def __init__(
        self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
    ):
        stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
        self.stop_token_ids = [
            torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
        ]

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
    ) -> bool:
        for stop_ids in self.stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
                return True
        return False

stop_tokens = [["User", ":"], ["AI", ":"]]

stopping_criteria = StoppingCriteriaList(
    [StopGenerationCriteria(stop_tokens, tokenizer, model_4bit.device)]
)

In [53]:
llm_pipeline = pipeline(
        "text-generation",
        model=model_4bit,
        tokenizer=tokenizer,
        use_cache=True,
        device_map="auto",
        max_length=296,
        do_sample=True,
        top_k=1, # always take most likely token (for repeatability)
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        stopping_criteria=stopping_criteria,
)

from langchain.llms import HuggingFacePipeline

llm = HuggingFacePipeline(pipeline=llm_pipeline)

In [75]:
from langchain.schema import BaseOutputParser
import regex as re

class CleanupOutputParser(BaseOutputParser):
    def parse(self, text: str) -> str:
        user_pattern = r"\nUser:"
        text = re.sub(user_pattern, "", text)
        human_pattern = r"\nHuman:"
        text = re.sub(human_pattern, "", text)
        ai_pattern = r"\nAI:"
        return re.sub(ai_pattern, "", text).strip()

    @property
    def _type(self) -> str:
        return "output_parser"

In [76]:
from langchain import PromptTemplate
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory


chatbot_template = """The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.

Current conversation:
{history}
Human: {input}
AI:"""

tech_support_template = """The following is a friendly conversation between a user and an AI helpbot. The AI is tasked with helping the user, who lives in halls at university, connect to the internet. If the AI needs more information to solve the problem, it will ask the user.

Current conversation:
{history}
User: {input}
AI:"""

PROMPT = PromptTemplate(input_variables=["history", "input"], template=tech_support_template)
CONVERSATION = ConversationChain(
    prompt=PROMPT,
    llm=llm,
    output_parser=CleanupOutputParser(),
    verbose=False,
    memory=ConversationBufferMemory(ai_prefix="AI:"),
)

In [None]:
# CONVERSATION.memory.dict()["chat_memory"]["messages"][-1]["content"]

In [None]:
# a = CONVERSATION.memory.chat_memory.messages.pop()
# a

In [78]:
import gradio as gr

def get_response(message, history):

    while history:
        if history[-1][1] in CONVERSATION.memory.dict()["chat_memory"]["messages"][-1]["content"]:
            break;

        # if latest message in history not latest message in memory, undo has been pressed
        CONVERSATION.memory.chat_memory.messages.pop() # Remove AI message
        CONVERSATION.memory.chat_memory.messages.pop() # Remove human message

    else: # if clear button has been pressed
        CONVERSATION.memory.clear()

    return CONVERSATION.predict(input=message)

gr.ChatInterface(
    get_response,
    chatbot=gr.Chatbot(height=300),
    textbox=gr.Textbox(placeholder="Support chat", container=False, scale=7),
    title="Helpbot",
    description="Enter your issue",
    theme="soft",
    retry_btn=None,
    # undo_btn="Delete Previous",
    # clear_btn="Clear",

).launch()

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

