<a href="https://colab.research.google.com/github/AnDDoanf/LLM-repo/blob/master/smartLLMChain_mistral7b_llama2_langchain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install langchain langchain_community transformers langchain-huggingface bitsandbytes accelerate

In [2]:
# Use a pipeline as a high-level helper
from transformers import pipeline, AutoTokenizer
from langchain_huggingface import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnableSequence
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import transformers
from torch import cuda, bfloat16
import torch

def build_mll(model_name, prompt):
    device = torch.device('cuda')
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                    bnb_4bit_use_double_quant=True,
                                    bnb_4bit_quant_type="nf4",
                                    bnb_4bit_compute_dtype=bfloat16,
                                    )

    model = AutoModelForCausalLM.from_pretrained(model_name,
                                                 quantization_config=bnb_config,
                                                )
    text_generation_pipeline = pipeline(model=model,
                                        tokenizer=tokenizer,
                                        task="text-generation",
                                        temperature=0.2,
                                        repetition_penalty=1.1,
                                        return_full_text=True,
                                        do_sample=True,
                                        max_new_tokens=1000,
                                        )

    mll = HuggingFacePipeline(pipeline=text_generation_pipeline)

    return RunnableSequence(prompt|mll)

In [3]:
# Define Constructors

def ideation_constructor(model_names):
    prompt_template = """
    ### [INST]
    Instruction: Provide the information needed to response the request.
    Here is context to help:

    {context}

    ### REQUEST:
    {question}

    [/INST]
    """

    prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=prompt_template,
    )
    ideation_chains = [build_mll(model_name, prompt) for model_name in model_names]
    return ideation_chains

def critique_constructor(model_name):
    critique_template = PromptTemplate.from_template(
        """
        Here are some ideas:

        {ideas}

        Critique each step in both ideas and select the best idea.
        """
    )
    critique_chain = build_mll(model_name, critique_template)
    return critique_chain

def resolve_constructor(model_name):
    resolve_template = PromptTemplate.from_template(
        """
        Here is the best idea:

        {best_idea}

        Improve upon this best idea and provide a final version.
        """
    )
    resolve_sequence = build_mll(model_name, resolve_template)
    return resolve_sequence

class SmartChainLLM():
  def __init__(self):
    self.model_names = ["meta-llama/Llama-2-7b-chat-hf", "mistralai/Mistral-7B-Instruct-v0.3"]

  def construct_chains(self) -> None:
    self.ideation_chains = ideation_constructor(self.model_names)
    self.critique_chain = critique_constructor(self.model_names[0])
    self.resolve_sequence = resolve_constructor(self.model_names[0])

  def ideation_step(self, user_input):
    return [chain.invoke(user_input) for chain in self.ideation_chains]

  def critique_step(self, ideas):
      return self.critique_chain.invoke({'ideas':ideas})

  def resolve_step(self, best_idea):
      return self.resolve_sequence.invoke({'best_idea':best_idea})

  def run(self):
    with torch.no_grad():
      while True:
        question = input("Enter your request: ")
        if question == "exit" or question == "quit":
          break
        context = input("Enter your context to help the bot: ")
        user_input = {'question':question, 'context':context}
        ideas = self.ideation_step(user_input)
        print(ideas)
        best_idea = self.critique_step(ideas)
        print(best_idea)
        final_output = self.resolve_step(best_idea)
        print(final_output)

In [4]:
smart_chain = SmartChainLLM()
smart_chain.construct_chains()
# smart_chain.run()

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
smart_chain.run()

Enter your request: How to perform best on internship?
Enter your context to help the bot: I'm doing internship on NLP and LLM
["\n    ### [INST]\n    Instruction: Provide the information needed to response the request.\n    Here is context to help:\n\n    I'm doing internship on NLP and LLM\n\n    ### REQUEST:\n    How to perform best on internship?\n\n    [/INST]\n    1. Set clear goals for yourself: Before starting your internship, define what you want to achieve during this experience. This could be related to specific skills or knowledge you want to gain, projects you want to work on, or even networking opportunities you want to take advantage of. Having clear goals will help you stay focused and motivated throughout the internship.\n2. Be proactive: Don't wait for tasks to come to you - actively seek out opportunities to learn and contribute. Show enthusiasm and initiative by offering to help with projects or tasks that align with your goals.\n3. Learn from mistakes: Recognize th