<h2>Importing Libraries</h2>

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
import warnings
warnings.filterwarnings("ignore")

  torch.utils._pytree._register_pytree_node(


<h2>Loading Original & PEFT Model </h2>

In [2]:
# Loading original model
model_name = "ybelkada/falcon-7b-sharded-bf16"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
# Loading PEFT model
PEFT_MODEL = "/teamspace/studios/this_studio/falcon-7b-sharded-bf16-finetuned-treccast"

config = PeftConfig.from_pretrained(PEFT_MODEL)
peft_base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)

peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

<h2>Creating Gradio Chatbot using our Fine-Tuned Model</h2>

In [5]:
import gradio as gr

# Initialize PEFT model and tokenizer
#def init_model_and_tokenizer(PEFT_MODEL):
    #tokenizer = AutoTokenizer.from_pretrained(PEFT_MODEL)
    #model = AutoModelForCausalLM.from_pretrained(PEFT_MODEL, torch_dtype=torch.bfloat16).to('cuda:0' if torch.cuda.is_available() else 'cpu')
    #return model, tokenizer

# Custom LLM chain for PEFT model
class CustomLLM:
    def __init__(self, peft_model, peft_tokenizer):
        self.peft_model = peft_model
        self.peft_tokenizer = peft_tokenizer
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

    def run(self, prompt):
        peft_encoding = self.peft_tokenizer(prompt, return_tensors="pt").to(self.device)
        peft_outputs = self.peft_model.generate(
            input_ids=peft_encoding.input_ids,
            generation_config=GenerationConfig(
                max_new_tokens=256,
                pad_token_id=self.peft_tokenizer.eos_token_id,
                eos_token_id=self.peft_tokenizer.eos_token_id,
                attention_mask=peft_encoding.attention_mask,
                temperature=0.6,
                top_p=0.7,
                repetition_penalty=1.2,
                num_return_sequences=1,
            ),
        )
        peft_text_output = self.peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)
        return post_process_chat(peft_text_output)

# Post-processing function for chatbot response
def post_process_chat(bot_message):
    bot_message = bot_message.strip()
    return bot_message

# Function to handle chat interaction and maintain history
def chat_with_history(query, chat_history):
    chat_history.append(("User", query))
    prompt = f"User: {query}\nBot:"
    bot_message = llm_chain.run(prompt)
    chat_history.append(("Bot", bot_message))
    return chat_history, chat_history


# Initialize LLM chain
llm_chain = CustomLLM(peft_model, peft_tokenizer)

# Initialize Gradio interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    query = gr.Textbox(label="Type your query here", placeholder="Enter your query...")
    clear = gr.Button("Clear Chat History")

    chat_history = []

    def respond(query, chat_history):
        chat_history, response = chat_with_history(query, chat_history)
        return chat_history, chat_history

    query.submit(respond, inputs=[query, chatbot], outputs=[chatbot, chatbot])
    clear.click(lambda: [], None, chatbot, queue=False)

demo.launch()


Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


