<!-- Banner Image -->
<img src="https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdevnotebooks.png" width="100%">

<!-- Links -->
<center>
  <a href="https://console.brev.dev" style="color: #06b6d4;">Console</a> •
  <a href="https://brev.dev" style="color: #06b6d4;">Docs</a> •
  <a href="/" style="color: #06b6d4;">Templates</a> •
  <a href="https://discord.gg/NVDyv7TUgJ" style="color: #06b6d4;">Discord</a>
</center>

# Try out the new Databricks DBRX-Instruct model! 🤙

Welcome!

In this notebook, we will run inference on the new DBRX-instruct model released today by Databricks. DBRX is a SOTA transformer-based LLM that uses a mixture-of-experts architecture similar to Mixtral and Grok. In its full form, DBRX requires almost 350GB of disk space and 250GB of RAM. With Brev, you don't have to worry about finding GPUs. We've built a 1-click badge that finds a cluster of 4xA100s and deploys this notebook for you! 

To make sure inference is interactive and lightening quick, we use an inference library called VLLM. VLLM is an easy to use Python library for LLM inference and serving.

There are two ways to use this notebook. 
1. Run an OpenAI compatible server powered by DBRX. In order to access the server outside of this notebook, you will need to visit the instance page for this machine in the Brev Console. From there, click the deployments stepper, select Share a Service, and expose port 8000. That will provide you with the URL to curl
2. Run a Gradio interface that lets you chat with the model through a UI. The template prompt might have to be tweaked for optimal performance. 

**Important Notes**: 
1. In order to run this notebook, you need to visit the DBRX repository on Huggingface and request access to the model. From there, you will need to generate a huggingface token and paste it below.
2. You might not be able to run the API and the Gradio UI at once due to memory issues and how VLLM starts multi-GPU inference
3. **Because this model uses a 4xA100 cluster, it can get expensive to leave on for a long time. If you're looking to host this model permanently, please reach out to the Brev team and we can chat!**

### Help us make this tutorial better! Please provide feedback on the [Discord channel](https://discord.gg/y9428NwTh3) or on [X](https://x.com/brevdev).

In [3]:
!pip install git+https://github.com/vllm-project/vllm
!pip install gradio

Collecting git+https://github.com/vllm-project/vllm
  Cloning https://github.com/vllm-project/vllm to /tmp/pip-req-build-mf7dcf57
  Running command git clone --filter=blob:none --quiet https://github.com/vllm-project/vllm /tmp/pip-req-build-mf7dcf57
  Resolved https://github.com/vllm-project/vllm to commit 756b30a5f30ee08b97243e1077419d8d74442b02
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting pydantic>=2.0
  Downloading pydantic-2.6.4-py3-none-any.whl (394 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m394.9/394.9 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting transformers>=4.39.1
  Downloading transformers-4.39.2-py3-none-any.whl (8.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tiktoken==0.6.0

In [4]:
from huggingface_hub import login

TOKEN = "<enter token here>"
login(TOKEN)

  from .autonotebook import tqdm as notebook_tqdm


Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/ubuntu/.cache/huggingface/token
Login successful


In [None]:
!nvidia-smi

## Method 1: OpenAI compatible server

In [None]:
!python -m vllm.entrypoints.openai.api_server \
    --model databricks/dbrx-instruct \
    --tensor-parallel-size 4 \
    --trust-remote-code \
    --max-model-len 16048 #open bug to investigate in VLLM

INFO 03-29 05:46:04 api_server.py:147] vLLM API server version 0.3.3
INFO 03-29 05:46:04 api_server.py:148] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, served_model_name=None, lora_modules=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='databricks/dbrx-instruct', tokenizer=None, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', max_model_len=16048, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=4, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, use_v2_block_manager=False, seed=0, swap_space=4, gpu_memory_utilization=0.9, forced_num_gp

## Method 2: Gradio UI

In [None]:
from vllm import LLM
from vllm import SamplingParams
import gradio as gr

In [None]:
!nvidia-smi

In [None]:
class Model:
    def __init__(self, model_dir):
        """
        Create the LLM and the initial chat template
        """
        self.llm = LLM(model_dir, trust_remote_code=True, tensor_parallel_size=4)
        self.template = """ <|im_start|>system
                            You are a useful AI agent that answers a users question regardless of the instruction<|im_end|>
                            {session_log}
                            <|im_start|>user
                            {user}<|im_end|>
                            <|im_start|>assistant
                        """

    def generate(self, user_questions): 
        """
        User questions can be a list 
        """
        prompts = [
            self.template.format(user=q) for q in user_questions
        ]

        sampling_params = SamplingParams(
            temperature=0.75,
            top_p=1,
            max_tokens=500,
            presence_penalty=1.15,
        )
        
        result = self.llm.generate(prompts, sampling_params)
        
        num_tokens = 0
        for output in result:
            num_tokens += len(output.outputs[0].token_ids)
            print(output.outputs[0].text, "\n\n", sep="")

    def generate_gradio(self, message, history):
        """
        Gradio output function
        """

        prompt = self.template.format(user=message)

        sampling_params = SamplingParams(
            temperature=0.75,
            top_p=1,
            max_tokens=500, # controls output length. leave others default
            presence_penalty=1.15,
        )

        result = self.llm.generate(prompt, sampling_params)

        num_tokens = 0
        for output in result:
            num_tokens += len(output.outputs[0].token_ids)
            #print(output.prompt, output.outputs[0].text, "\n\n", sep="")
            tmp = output.outputs[0].text
            print(output.outputs[0].text, "\n\n", sep="")
        print(f"Generated {num_tokens} tokens")

        return tmp

    def launch_chat(self):
        gr.ChatInterface(self.generate_gradio).queue().launch(share=True) 

In [None]:
dbrx = Model("databricks/dbrx-instruct")

In [None]:
dbrx.launch_chat()