## Fine-tuning Mistral-7b in Azure

This Jupyter notebook describes the process of fine-tuning Mistral-7B model, available in Azure Machine Learning's system registry, on Azure GPU compute.

### Step 1: Configuring Environment

Install required Azure AI and Azure Identity Python packages:
```
pip install azure-ai-ml azure-identity mlflow azureml-mlflow
```

In [None]:
# Import required packages
import uuid
import time
import requests
from azure.ai.ml import MLClient
from azure.ai.ml.entities import Model
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.entities import Data
from azure.ai.ml.finetuning import (
    FineTuningTaskType,
    create_finetuning_job
)
from azure.identity import (
    DefaultAzureCredential,
    InteractiveBrowserCredential,
)
from azure.ai.ml.entities import (
    ManagedOnlineEndpoint,
    ManagedOnlineDeployment,
    ProbeSettings,
    OnlineRequestSettings,
)

In [None]:
# Set required variable values
subscription_id = "<your_subscription_id>"
resource_group = "<your_resource_group>"
workspace_name = "<your_workspace_name>"
model_registry = "<your_model_registry_name>"
model_name = "<your_model_name>"
training_dataset_name = "<your_dataset_name>"
training_dataset_file = "<your_dataset_file_name>"
validation_dataset_name = "<your_validation_dataset_name>"
validation_dataset_file = "<your_validation_dataset_file_name>"
dataset_version = "<your_dataset_version>"
job_name = "<your_job_name>"
job_compute = "<your_job_compute>"
endpoint_name = "<your_endpoint_name>"
endpoint_SKU = "<your_endpoint_SKU>"
guid = str(uuid.uuid4())[:8]

In [None]:
# Authenticate with Default Azure Credentials, or fallback to Interactive Browser Credentials
try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    credential = InteractiveBrowserCredential()

In [None]:
# Initialise AML workspace client
workspace_ml_client = MLClient(
    credential = credential,
    subscription_id = subscription_id,
    resource_group_name = resource_group,
    workspace_name = workspace_name,
)

In [None]:
# Initialise AML registry client
registry_ml_client = MLClient(
    credential = credential,
    registry_name = model_registry
)

### Step 2: Defining Source Model

In [None]:
# Retrieve model details from AML Registry
model_to_finetune = registry_ml_client.models.get(
    name = model_name,
    label = "latest"
)

print(f"Model name: {model_to_finetune.name}")
print(f"Model version: {model_to_finetune.version}")
print(f"Model ID: {model_to_finetune.id}")

In [None]:
# Check supported compute SKUs
model_to_finetune.properties["finetune-recommended-sku"]

### Step 3: Preparing Training and Validation Dataset

In [None]:
# Initialise training dataset
try:
    train_data_asset = workspace_ml_client.data.get(
        name = training_dataset_name,
        version = dataset_version
    )
    print(f"Dataset {training_dataset_name} already exists! Will re-use it.")
except:
    print("Creating dataset..\n")
    train_data = Data(
        path = f"./{training_dataset_file}",
        type = AssetTypes.URI_FILE,
        description = "Training dataset",
        name = training_dataset_name,
        version = dataset_version
    )
    train_data_asset = workspace_ml_client.data.create_or_update(train_data)

In [None]:
# Check training dataset details
print(f"Dataset name: {train_data_asset}")

In [None]:
# Initialise validation dataset
try:
    val_data_asset = workspace_ml_client.data.get(
        name = validation_dataset_name,
        version = dataset_version
    )
    print(f"Dataset {validation_dataset_name} already exists! Will re-use it.")
except:
    print("Creating dataset..\n")
    val_data = Data(
        path = f"./{validation_dataset_file}",
        type = AssetTypes.URI_FILE,
        description = "Validation dataset",
        name = validation_dataset_name,
        version = dataset_version
    )
    val_data_asset = workspace_ml_client.data.create_or_update(val_data)

In [None]:
# Check validation dataset details
print(f"Dataset name: {val_data_asset}")

### Step 4: Fine-tuning Model

In [None]:
# Define fine-tuning job
finetuning_job = create_finetuning_job(
    name = f"{job_name}-{guid}",
    display_name = f"{job_name}-{guid}",
    experiment_name = f"Finetuning-{model_name}",
    model = model_to_finetune.id,
    task = FineTuningTaskType.TEXT_COMPLETION,
    training_data = train_data_asset.id,
    validation_data = val_data_asset.id,
    output_model_name_prefix = f"{model_name}-finetuned-{guid}",
    compute = job_compute,
    # instance_types = ["Standard_ND96amsr_A100_v4", "Standard_E4s_v3"],
    hyperparameters = {
        "per_device_train_batch_size": "1",
        "learning_rate": "0.00002",
        "num_train_epochs": "1",
    },
)

In [None]:
# Submit fine-tuning job
created_job = workspace_ml_client.jobs.create_or_update(finetuning_job)
workspace_ml_client.jobs.get(created_job.name)

In [None]:
# Monitor fine-tuning job status
status = workspace_ml_client.jobs.get(created_job.name).status

while True:
    status = workspace_ml_client.jobs.get(created_job.name).status
    
    if status in ["Failed", "Completed", "Canceled"]:
        print("Job has finished with status: {0}".format(status))
        break
    else:
        print("Job run is in progress. Checking again in 30 seconds..")
        time.sleep(30)

In [None]:
# Verify fine-tuning job output
registered_model_output = created_job.outputs["registered_model"]

print(f"Finetuning job's output: {registered_model_output}")

In [None]:
# Check registered model
registered_models = workspace_ml_client.models.list()

for model in registered_models:
    if model.name.startswith(model_name):
        registered_model = model
        print(f"Registered fine-tuned model name: {registered_model.name}")

### Step 5: Deploying Finetuned Model to Online Endpoint

In [None]:
# Create online endpoint
endpoint = ManagedOnlineEndpoint(
    name = endpoint_name,
    description = f"Online endpoint for {registered_model.name}",
    auth_mode="key"
)

workspace_ml_client.begin_create_or_update(endpoint).wait()

In [None]:
# Check supported inference SKUs
model_to_finetune.properties["inference-recommended-sku"]

In [None]:
# Get deployable model
deploy_model = workspace_ml_client.models.get(
    name = registered_model.name,
    version = registered_model.latest_version
)

print(f"Deployable model name: {deploy_model.name}")

In [None]:
# Create online deployment
ft_deployment = ManagedOnlineDeployment(
    name = "finetunedmodel",
    endpoint_name = endpoint_name,
    model = deploy_model.id,
    instance_type = endpoint_SKU,
    instance_count = 1,
    liveness_probe = ProbeSettings(initial_delay=600),
    request_settings = OnlineRequestSettings(request_timeout_ms=90000),
)

workspace_ml_client.online_deployments.begin_create_or_update(ft_deployment).wait()

In [None]:
# Allocate traffic to deployment
endpoint.traffic = {
    "finetunedmodel": 100
}

workspace_ml_client.begin_create_or_update(endpoint).result()

In [None]:
# Get endpoint auth key
auth_key = workspace_ml_client.online_endpoints.get_keys(endpoint_name).primary_key

In [None]:
# Get endpoint URL
my_endpoint = workspace_ml_client.online_endpoints.get(name=endpoint_name)
scoring_uri = my_endpoint.scoring_uri
print(f"Endpoint URL: {scoring_uri}")

In [None]:
# Test deployed model
headers = {"Content-Type": "application/json", "Authorization": ("Bearer "+ auth_key)}
url = scoring_uri.replace("/score", "/completions")
prompt = "Summarize the dialog.\n<dialog>: Edward: Rachel, at what time is the meeting..\r\nRachel: At 2pm..\r\nEdward: Ok, see you then\n<summary>: "
payload = {
    "prompt": prompt,
    "temperature": 0,
    "max_tokens": 200,
}
response = requests.post(url, json=payload, headers=headers)

print(f"Response: {response.json()}")

In [None]:
# Beautify response
structured_response = response.json()
print("----------------")
print(f"Prompt used: {prompt}\n")
print(f"Model's response: {structured_response['choices'][0]['text']}")
print("----------------")
print(f"Prompt token count: {structured_response['usage']['prompt_tokens']}")
print(f"Response token count: {structured_response['usage']['completion_tokens']}")
print(f"Total token count: {structured_response['usage']['total_tokens']}")