In [1]:
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
T5Tokenizer,
T5ForConditionalGeneration,
DataCollatorWithPadding)

from peft import (get_peft_config, 
get_peft_model, 
PromptTuningInit, 
PromptTuningConfig, 
TaskType, 
PeftType,
PeftModel,
PeftConfig)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
base_model = "google/flan-t5-large"
tokenizer = T5Tokenizer.from_pretrained(base_model)

# Load the pre trained model
pre_trained_model_path = "pre-trained-model"
config = PeftConfig.from_pretrained(pre_trained_model_path)
model = T5ForConditionalGeneration.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, pre_trained_model_path)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
def predict(message, history, prompt):
    message = prompt + message
    history_transformer_format = history + [[message, ""]]
    messages = "".join(["".join(["\n<human>:"+item[0], "\n<bot>:"+item[1]])
                for item in history_transformer_format])
    model_inputs = tokenizer([messages], return_tensors="pt")
    response = model.generate(input_ids = model_inputs.input_ids, max_new_tokens = 300)
    output = tokenizer.decode(response[0], skip_special_tokens=True)
    return output

In [7]:
prompt = gr.Textbox("Let's first prepare relevant information and make a plan. Execute the plan, ensuring accurate numerical calculation and logical consistency. Present the answer step-by-step.", label="System Prompt")
gr.ChatInterface(
    predict,
    chatbot=gr.Chatbot(height=500),
    textbox=gr.Textbox(placeholder="Go on, ask me all your math doubts", container=False, scale=7),
    title="Math AId",
    description="Math",
    theme="soft",
    retry_btn=None,
    undo_btn="Delete Previous",
    clear_btn="Clear",
    additional_inputs=[prompt]
).launch(share=False)

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




In [5]:
gr.close_all()