In [0]:
%pip install databricks-agents transformers mlflow torch
%restart_python

In [0]:
import mlflow
import json
import uuid
from transformers import AutoModelForCausalLM, AutoTokenizer

from mlflow.pyfunc import ChatAgent
from mlflow.types.llm import (
    ChatChunk,
    ChatMessage,
    ChatResponse,
    ChatContext,
)
from typing import Any, Generator, Literal, Optional

class Qwen_Pyfunc(ChatAgent):
    def __init__(self):
        self.model_name = "Qwen/Qwen2.5-1.5B-Instruct"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name,
            torch_dtype="auto",
            device_map="auto")
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)

    def load_context(self, context):
        try:
          model_dir = os.path.join(context.artifacts["model"], "qwen")
          self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
          self.model = AutoModelForCausalLM.from_pretrained(model_dir,
              torch_dtype="auto",
              device_map="auto")
        except:
          self.tokenizer = AutoTokenizer.from_pretrained(self.model_name,
              torch_dtype="auto",
              device_map="auto")
          self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
          
    def predict(
        self,
        messages: list[ChatMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatResponse:
        request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages]
        }
        
        messages = request['messages']

        # Apply chat template for tool calling (see Qwen docs)
        text = self.tokenizer.apply_chat_template(
            messages, 
            add_generation_prompt=True, 
            tokenize=False
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)

        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=512
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return ChatAgentResponse(messages=[ChatAgentMessage(
          **{"id": str(uuid.uuid4()), "role": "assistant", "content": response})])

In [0]:
input_example = {'messages':[
  {'role':'user', 'content':'What color is the sky?'}
  ]}

model = Qwen_Pyfunc()
model.load_context(context=None)
model.predict(input_example)

In [0]:
# Log model to MLflow (Databricks)
import mlflow
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

mlflow.set_tracking_uri('databricks')
mlflow.set_registry_uri('databricks-uc')

with mlflow.start_run():
    mlflow.pyfunc.log_model(
        python_model=Qwen_Pyfunc(),
        artifact_path="model",
        registered_model_name='shm.default.qwen25-1p5b-instruct',
        input_example=input_example,
        pip_requirements=["transformers", "mlflow", "torch"]
    )

In [0]:

from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")

endpoint = client.create_endpoint(
    name="shm_qwen_a100",
    config={
        "served_entities": [
            {
                "name": "shm_qwen_a100",
                "entity_name": "shm.default.qwen25-1p5b-instruct",
                "entity_version": "4",
                "workload_type": "GPU_LARGE",
                "workload_size": "Small",
                "scale_to_zero_enabled": True
            }
        ]
    }
)