In [2]:
from typing import Any, Dict, List

import mlflow, shutil, os

class MLFlowModel(mlflow.pyfunc.PythonModel):
    """Generic deployment model using mlflow"""

    def load_context(self, context: Any):
        """Loads necessary context."""

        import torch
        from transformers import pipeline
        
        model_name = "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2"
        load_in_8bit = False
        load_sharded = False
        use_fast = False

        model_kwargs = {}
        
        if load_in_8bit:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_8bit = True,
                llm_int8_threshold = 3.0,
            )
            model_kwargs["quantization_config"] = quantization_config

        if load_sharded:
            device_map = "auto"
        else:
            device_map = {"": "cuda:0"}

        self.pipeline = pipeline(
            model = model_name,
            torch_dtype = torch.float16,
            trust_remote_code = True,
            use_fast = use_fast,
            device_map = device_map,
            model_kwargs = model_kwargs
        )

        return

    def predict(self, context: Any, model_input: Dict) -> List:
        """Predict function

        Args:
            context: instance containing artifacts that the
                     model can use to perform inference
            model_input: dictionary containing input to evaluate,
                         needs to contain 'input' key

        Returns:
            Single element list containing json encoded as string
        """

        import json
        from collections import defaultdict

        values = defaultdict(list)
        for idx, (k, v) in enumerate(model_input.items()):
            for val in v.values:
                values[k].append(val)

        outputs = self.pipeline(
            values["input"],
            min_new_tokens=2,
            max_new_tokens=256,
            do_sample=False,
            num_beams=2,
            temperature=float(0.3),
            repetition_penalty=float(1.2),
            renormalize_logits=True
        )
        
        d = defaultdict(list)
        for output in outputs:
            for k, v in output[0].items():
                d[k].append(v)

        d = json.dumps(d)
        d = [d]

        return d
    
def setup_mlflow(output_path: str):
    """Setting up mlflow ZIP archive

    Args:
        output_path: path to store output
    """

    mlflow_model = MLFlowModel()

    input_schema = mlflow.types.Schema(
        [mlflow.types.ColSpec(name="input", type=mlflow.types.DataType.string)]
    )

    output_schema = mlflow.types.Schema(
        [mlflow.types.ColSpec(name="output", type=mlflow.types.DataType.string)]
    )

    signature = mlflow.models.signature.ModelSignature(
        inputs=input_schema, outputs=output_schema
    )

    model_path = os.path.join(output_path, "model.mlflow")
    shutil.rmtree(model_path, ignore_errors=True)
    mlflow.pyfunc.save_model(
        path=model_path,
        python_model=mlflow_model,
        signature=signature,
        pip_requirements=[
            "torch==2.0.0",
            "transformers==4.28.1",
            "accelerate==0.18.0",
            "sentencepiece==0.1.96",
            "bitsandbytes==0.38.1"
        ]
    )

    _ = shutil.make_archive(model_path, "zip", model_path)

model_path = setup_mlflow(".")

In [3]:
mlflow_model = mlflow.pyfunc.load_model("model.mlflow")

prompts = [
    "How are you?",
    "What is the capital of the US?"
]

input = {
    "input": prompts
}

ret = mlflow_model.predict(input)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [4]:

import json
for r in json.loads(ret[0])["generated_text"]:
    print(r)

I'm doing well. How about you?
The capital of the United States of America is Washington, D.C.
