In [1]:
import mlflow
import transformers

class MyModel(mlflow.pyfunc.PythonModel):
    
    def load_context(self, context):
        import os
        import torch
        from transformers import (
            AutoModelForCausalLM,
            AutoTokenizer,
            BitsAndBytesConfig
        )
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.project_id = os.listdir('/artifacts/mlflow')[0]
        prefix = f"{self.project_id}/version0"
        compute_dtype = getattr(torch, "float16")
        quant_config = BitsAndBytesConfig(load_in_4bit=True,
                                          bnb_4bit_quant_type="nf4",
                                          bnb_4bit_compute_dtype=compute_dtype,
                                          bnb_4bit_use_double_quant=False)

        ft_model_name = "final_merged_checkpoint"
        model_cache = "llama2-model-cache"
        model_tokenizer_path = f"/artifacts/mlflow/{prefix}/{ft_model_name}"
        
        self.model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path,
                                                          cache_dir=f"/artifacts/mlflow/{prefix}/{model_cache}/",
                                                          quantization_config=quant_config,
                                                          device_map="auto"
                                                          #device="0"
                                                          )
        self.model.config.use_cache = False
        self.model.config.pretraining_tp = 1

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path, 
                                                       cache_dir=f"/artifacts/mlflow/{prefix}/{model_cache}/",
                                                       trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "right"
    
    
    def predict(self, context, model_input, params=None):
        """
        This method generates prediction for the given input.
        """
        prompt = model_input["prompt"]

        if prompt is None:
            return 'Please provide a prompt.'
        
        prompt_template = f"<s>[INST] {{dialogue}} [/INST]"
            

        user_input = f"<s>[INST] {prompt} [/INST]"
        
        tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(user_input))
        input_length = len(tokens)
        
        new_tokens = 750
        
        text = f"<s>[INST] {prompt} [/INST]"

        device = "cuda:0"

        inputs = self.tokenizer(text, return_tensors="pt").to(device)

        generation_config = transformers.GenerationConfig(
                    pad_token_id=self.tokenizer.pad_token_id,
                    max_new_tokens = 200
                )

        outputs = self.model.generate(**inputs, generation_config=generation_config)
        llm_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        result = llm_output.replace(f"[INST] {prompt} [/INST]", '')
        return {'text_from_llm': result}


In [2]:
import pandas as pd
import numpy as np
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec, ParamSchema, ParamSpec

# Define input and output schema
input_schema = Schema(
    [
        ColSpec(DataType.string, "prompt"),
    ]
)
output_schema = Schema([ColSpec(DataType.string, "text_from_llm")])

parameters = ParamSchema(
    [       
    ]
)

signature = ModelSignature(inputs=input_schema, outputs=output_schema, params=parameters)


# Define input example
input_example = pd.DataFrame({"prompt": ["What is machine learning?"]})

In [6]:
client = mlflow.MlflowClient()
model_name="llama2-guanaco-sft-bryan"
registered_model = None
try:
    registered_model = client.create_registered_model(model_name)
except:
     registered_model = client.get_registered_model(model_name)

In [7]:
import mlflow
run_id = "6d55536a32874c708ed1743d72c92927"
r = mlflow.get_run(run_id)
r.data.params
r.data.metrics

{'loss': 1.3946,
 'learning_rate': 0.0002,
 'epoch': 1.0,
 'train_runtime': 2060.1701,
 'train_samples_per_second': 0.485,
 'train_steps_per_second': 0.061,
 'total_flos': 8971066649149440.0,
 'train_loss': 1.3884320373535157}

In [8]:
import peft
import trl
import torch
import transformers


from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository

# Get the current base version of torch that is installed, without specific version modifiers
torch_version = torch.__version__.split("+")[0]

# Start an MLflow run context and log the llama2-7B model wrapper along with the param-included signature to
# allow for overriding parameters at inference time
with mlflow.start_run() as run:
    mlflow.log_params(r.data.params)
    mlflow.log_metrics(r.data.metrics)

       
    model_info = mlflow.pyfunc.log_model(
        model_name,
        python_model=MyModel(),
        # NOTE: the artifacts dictionary mapping is critical! This dict is used by the load_context() method in our MyModel() class.
        artifacts={"snapshot": '/mnt/code'},
        pip_requirements=[
            f"torch=={torch_version}",
            f"transformers=={transformers.__version__}",                        
            f"peft=={peft.__version__}",
            f"trl=={trl.__version__}",            
            "einops",
            "sentencepiece",
        ],
        input_example=input_example,
        signature=signature,
    )
    runs_uri = model_info.model_uri
    print(runs_uri)
    # Create a new model version of the RandomForestRegression model from this run
    
    model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)
    mv = client.create_model_version(model_name, model_src, run.info.run_id)
    print("Name: {}".format(mv.name))
    print("Version: {}".format(mv.version))
    print("Description: {}".format(mv.description))
    print("Status: {}".format(mv.status))
    print("Stage: {}".format(mv.current_stage))

Downloading artifacts:   0%|          | 0/25 [00:00<?, ?it/s]

2024/11/28 07:41:02 INFO mlflow.store.artifact.artifact_repo: The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false
2024/11/28 07:41:05 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: llama2-guanaco-sft-bryan, version 1


runs:/defb75b81e63454dbf7f72630bbd073a/llama2-guanaco-sft-bryan
Name: llama2-guanaco-sft-bryan
Version: 1
Description: 
Status: READY
Stage: None


In [9]:
import os
project_id = os.listdir('/artifacts/mlflow')[0]
prefix = f"{project_id}/version0"
ft_model_name = "final_merged_checkpoint"
model_cache = "llama2-model-cache"
model_tokenizer_path = f"/artifacts/mlflow/{prefix}/{ft_model_name}"

In [10]:
model_tokenizer_path

'/artifacts/mlflow/66e3174c384e5c755cbae633/version0/final_merged_checkpoint'