# Japanese Stable LM Demo
<a target="_blank" href="https://colab.research.google.com/github/mkshing/notebooks/blob/main/stabilityai_japanese_stablelm.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This colab enables to interact the following Japanese Stable LM series.

- [Japanese Stable LM Instruct Alpha 7B](https://huggingface.co/stabilityai/japanese-stablelm-instruct-alpha-7b)
- [Japanese Stable LM 3B-4E1T Instruct](https://huggingface.co/stabilityai/japanese-stablelm-3b-4e1t-instruct)
- [Japanese Stable LM Instruct Gamma 7B](https://huggingface.co/stabilityai/japanese-stablelm-instruct-gamma-7b)

In [None]:
# @title **Setup**
!nvidia-smi
!pip install transformers sentencepiece gradio ftfy 'accelerate>=0.12.0' 'bitsandbytes>=0.31.5' einops

In [None]:
# @title Login HuggingFace
!huggingface-cli login

In [None]:
# @title Load model
import torch
from transformers import AutoTokenizer, LlamaTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

model_id = "stabilityai/japanese-stablelm-3b-4e1t-instruct" # @param ["stabilityai/japanese-stablelm-instruct-alpha-7b", "stabilityai/japanese-stablelm-instruct-gamma-7b", "stabilityai/japanese-stablelm-3b-4e1t-instruct"]

model_kwargs = {"trust_remote_code": True, "device_map": "auto", "low_cpu_mem_usage": True, "torch_dtype": "auto"}

if "alpha" in model_id:
  tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
else:
  tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
model = model.eval().to(device)

In [None]:
# @title **Do the Run!**
# @markdown Try Japanese Stable LM in chat-like UI.
# @markdown <br>**Remark:** this is single-turn inference, i.e., previous contexts are ignored.
import gradio as gr


def build_prompt(user_query, inputs="", sep="\n\n### "):
    sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    p = sys_msg
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": \n"]
    if inputs:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + inputs)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p

@torch.no_grad()
def base_inference_func(prompt, max_new_tokens=128, top_p=0.95, repetition_penalty=1.):
  # print(f"PROMPT:\n{prompt}")
  input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
  output_ids = model.generate(
      input_ids.to(model.device),
      do_sample=True,
      max_new_tokens=max_new_tokens,
      top_p=top_p,
      temperature=1,
      repetition_penalty=repetition_penalty,
  )

  generated = tokenizer.decode(output_ids.tolist()[0][input_ids.size(1):], skip_special_tokens=True).strip()
  # print(f"generated: {generated}")
  return generated


def inference_func(message, chat_history, additional_prompt, max_new_tokens=128, top_p=0.95, repetition_penalty=1.):
  # Infer with prompt without any additional input
  user_inputs = {
      "user_query": message,
      "inputs": additional_prompt,
  }
  prompt = build_prompt(**user_inputs)
  generated = base_inference_func(prompt, max_new_tokens, top_p, repetition_penalty)
  chat_history.append((message, generated))
  return "", chat_history


with gr.Blocks() as demo:
  with gr.Accordion("Configs", open=False):
      if "instruct" in model_id:
        instruction = gr.Textbox(label="instruction",)
      max_new_tokens = gr.Number(value=128, label="max_new_tokens", precision=0)
      top_p = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="top_p")
      repetition_penalty = gr.Slider(0.0, 5.0, value=1.1, step=0.1, label="repetition_penalty")

  if "instruct" in model_id:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")
    msg.submit(inference_func, [msg, chatbot, instruction, max_new_tokens, top_p, repetition_penalty], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)
  else:
    with gr.Row():
      with gr.Column():
        prompt = gr.Textbox(label="prompt")
        button = gr.Button(label="submit")
      with gr.Column():
        out = gr.Textbox(label="generated")
    button.click(base_inference_func, [prompt, max_new_tokens, top_p, repetition_penalty], out)

if __name__ == "__main__":
    demo.launch(debug=True, share=True, show_error=True)