# Distillation with Large Language Models
 
### Notebook details
 
This sample demonstrates how to train the selected student model using the teacher model, resulting in the creation of the distilled model.
 
We will use the Meta Llama 3.1 405B Instruct as the teacher model and the Meta Llama 3.1 8B Instruct as the student model.
 
**Note :**
 
- Distillation offering is only available in **West US 3** regions.
- Distillation should only be used for single turn chat completion format.
- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.
- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.
- Distllation is currently supported only for Natural Language Inference (NLI) task, which is a standard task in benchmarking for Natural Language Understanding.

**Prerequisites :**
- Subscribe to the Meta Llama 3.1 405B Instruct and Meta Llama 3.1 8B Instruct, see [how to subscribe your project to the model offering in MS Learn](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio#subscribe-your-project-to-the-model-offering)

## Install the SDK v2

In [None]:
%pip install azure-ai-ml
%pip install azure-identity
%pip install azure-core
%pip install azure-ai-inference

%pip install mlflow
%pip install azureml-mlflow
%pip install datasets

## Import the required libraries

In [None]:
# import required libraries

import base64
import json
import os
import uuid

from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage
from azure.ai.ml import Input, MLClient
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities import Data, ServerlessEndpoint
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.core.exceptions import ResourceNotFoundError

## Prerequisites

An AI Studio project in **West US 3** is required. Please follow [this](https://learn.microsoft.com/azure/ai-studio/how-to/fine-tune-model-llama?tabs=llama-two%2Cchatcompletion#prerequisites) document to setup your AI Studio project

## AI Studio project settings

Update following cell with the information of the AI Studio project just created.

In [None]:
SUBSCRIPTION_ID = "<SUBSCRIPTION_ID>"
RESOURCE_GROUP = "<RESOURCE_GROUP>"
WORKSPACE_NAME = "<AML_WORKSPACE_NAME>"

## Configure credential

We are using `DefaultAzureCredential` to get access to workspace. 
`DefaultAzureCredential` should be capable of handling most Azure SDK authentication scenarios. 

try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

In [None]:
try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

## Get handle to AI Studio project

In [None]:
ml_client = MLClient(credential, SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE_NAME)

ai_project = ml_client._workspaces.get(ml_client.workspace_name)
ai_project._workspace_id

## Pick a teacher model

We support **Meta-Llama-3.1-405B-Instruct** as the teacher model. 

In [None]:
# We will reuse or create a serverless endpoint
TEACHER_MODEL_NAME = "Meta-Llama-3.1-405B-Instruct"
TEACHER_MODEL_ENDPOINT_NAME = "Meta-Llama-3-1-405B-Instruct-vum"

mlclient_azureml_meta = MLClient(credential, registry_name="azureml-meta")
try:
    ml_client.serverless_endpoints.get(TEACHER_MODEL_ENDPOINT_NAME)
except ResourceNotFoundError:
    # create the endpoint
    teacher_model_id = (
        "azureml://registries/azureml-meta/models/Meta-Llama-3.1-405B-Instruct"
    )
    teacher_endpoint = ServerlessEndpoint(
        name=TEACHER_MODEL_ENDPOINT_NAME,
        model_id=teacher_model_id,
    )
    ml_client.begin_create_or_update(teacher_endpoint).result()

## Pick a student model

We will use **Meta-Llama-3.1-8B-Instruct** as student model. We only support chat completion models that are available for PayGo finetuning in Azure AI Studio.

In [None]:
STUDENT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct"
STUDENT_MODEL_VERSION = 1

# retrieve student model from model registry
student_model = mlclient_azureml_meta.models.get(
    STUDENT_MODEL_NAME, version=STUDENT_MODEL_VERSION
)

print(
    "\n\nUsing model name: {0}, version: {1}, id: {2} for fine tuning".format(
        student_model.name, student_model.version, student_model.id
    )
)

## Download the dataset from HuggingFace repo

For our example, we download and use the ConjNLI dataset (https://huggingface.co/datasets/cestwc/conjnli) from HuggingFace.

In [None]:
from datasets import load_dataset

from abc import ABC


class InputDataset(ABC):
    def __init__(self):
        super().__init__()
        (
            self.train_data_file_name,
            self.test_data_file_name,
            self.eval_data_file_name,
        ) = (None, None, None)


class NLIHuggingFaceInputDataset(InputDataset):
    """
    Loads the HuggingFace dataset
    """

    def __init__(self):
        super().__init__()

    def load_hf_dataset(
        self,
        dataset_name,
        train_sample_size=10,
        val_sample_size=10,
        test_sample_size=10,
        train_split_name="train",
        val_split_name="validation",
        test_split_name="test",
    ):
        full_dataset = load_dataset(dataset_name)

        if val_split_name is not None:
            train_data = full_dataset[train_split_name].select(range(train_sample_size))
            val_data = full_dataset[val_split_name].select(range(val_sample_size))
            test_data = full_dataset[test_split_name].select(range(test_sample_size))
        else:
            train_val_data = full_dataset[train_split_name].select(
                range(train_sample_size + val_sample_size)
            )
            train_data = train_val_data.select(range(train_sample_size))
            val_data = train_val_data.select(
                range(train_sample_size, train_sample_size + val_sample_size)
            )
            test_data = full_dataset[test_split_name].select(range(test_sample_size))

        return train_data, val_data, test_data

In [None]:
# We can define train and test sample sizes here. Validation size is kept same as test sample size
train_sample_size = 512
val_sample_size = 256

# Sample notebook using the dataset: https://huggingface.co/datasets/cestwc/conjnli
dataset_name = "cestwc/conjnli"
input_dataset = NLIHuggingFaceInputDataset()

# Note: train_split_name and test_split_name can vary by dataset. They are passed as arguments in load_hf_dataset.
# If val_split_name is None, the below function will split the train set to create the specified sized validation set.
train, val, _ = input_dataset.load_hf_dataset(
    dataset_name=dataset_name,
    train_sample_size=train_sample_size,
    val_sample_size=val_sample_size,
    train_split_name="adversarial",
    val_split_name=None,
)

print("Len of train data sample is " + str(len(train)))
print("Len of validation data sample is " + str(len(val)))

In [None]:
!mkdir data

In [None]:
train_data_path = "data/train_conjnli_512.jsonl"
valid_data_path = "data/valid_conjnli_256.jsonl"

with open(train_data_path, "w+") as f:
    for row in train:
        data = {"messages": []}
        data["messages"].append(
            {
                "role": "system",
                "content": "You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.",
            }
        )
        data["messages"].append(
            {
                "role": "user",
                "content": "Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\n\nPremise: "
                + row["premise"]
                + "\nHypothesis: "
                + row["hypothesis"],
            }
        )
        f.write(json.dumps(data) + "\n")

with open(valid_data_path, "w+") as f:
    for row in val:
        data = {"messages": []}
        data["messages"].append(
            {
                "role": "system",
                "content": "You are a helpful assistant. Your output should only be one of the three labels: 'entailment', 'contradiction', or 'neutral'.",
            }
        )
        data["messages"].append(
            {
                "role": "user",
                "content": "Given the following two texts, your task is to determine the logical relationship between them. The first text is the 'premise' and the second text is the 'hypothesis'. The relationship should be labeled as one of the following: 'entailment' if the premise entails the hypothesis, 'contradiction' if the premise contradicts the hypothesis, or 'neutral' if the premise neither entails nor contradicts the hypothesis.\n\nPremise: "
                + row["premise"]
                + "\nHypothesis: "
                + row["hypothesis"],
            }
        )
        f.write(json.dumps(data) + "\n")

## Prepare data inputs



In [None]:
train_data = None
train_data_name = "nli_train_70"

train_data = ml_client.data.create_or_update(
    Data(
        path=train_data_path,
        type=AssetTypes.URI_FILE,
        description="Training dataset",
        name=train_data_name,
    )
)

train_data_asset_id = f"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{train_data.name}/versions/{train_data.version}"
train_data_asset_id

In [None]:
valid_data = None
valid_data_name = "nli_valid_70"

valid_data = ml_client.data.create_or_update(
    Data(
        path=valid_data_path,
        type=AssetTypes.URI_FILE,
        description="validation dataset",
        name=valid_data_name,
    )
)

valid_data_asset_id = f"azureml://locations/{ai_project.location}/workspaces/{ai_project._workspace_id}/data/{valid_data.name}/versions/{valid_data.version}"
valid_data_asset_id

## Distillation strategy settings

We provide the option to leverage Chain of Thought (CoT) reasoning for distillation. CoT leverages step by step reasoning ability of the teacher model to generate more accurate labels.

In [None]:
ENABLE_CHAIN_OF_THOUGHT = "True"

## Configure distillation

In [None]:
mlclient_azureml = MLClient(credential, registry_name="azureml")

In [None]:
distillation_pipeline_name = "oss_distillation_pipeline"
distillation_pipeline_component = mlclient_azureml.components.get(
    name=distillation_pipeline_name
)

In [None]:
@pipeline
def distillation_pipeline(
    teacher_model_endpoint_name: str,
    enable_chain_of_thought: str,
    system_properties: str,
    input_finetune_model: Input,
    train_file_path: Input,
    validation_file_path: Input = None,
):
    oss_distillation = distillation_pipeline_component(
        teacher_model_endpoint_name=teacher_model_endpoint_name,
        enable_chain_of_thought=enable_chain_of_thought,
        train_file_path=train_file_path,
        validation_file_path=validation_file_path,
        # Finetune
        mlflow_model_path=input_finetune_model,
        model_asset_id=student_model.id,
        system_properties=system_properties,
        ## hyperparams
        learning_rate=0.00002,
        per_device_train_batch_size=1,
        num_train_epochs=3,
        data_generation_task_type="NLI",
    )

    return {"output_model": oss_distillation.outputs.output_model}

In [None]:
system_properties = {
    "finetune_oss": "True",
    "model_asset_id": student_model.id,
    "PipelineType": "Finetune",
    "azureml.PipelineType": "Finetune",
    "azureml.ModelName": student_model.name,
    "azureml.original_model_id": student_model.id,
    "azureml.trainingData.assetId": train_data_asset_id,
}

json_str = json.dumps(system_properties).replace(" ", "")

system_properties_b64_encoded = base64.b64encode(json_str.encode("utf-8")).decode(
    "utf-8"
)
print(f"System properties => {system_properties_b64_encoded}")

In [None]:
train_file_path_input = Input(type="uri_file", path=train_data.path)
validation_file_path_input = Input(type="uri_file", path=valid_data.path)
input_finetune_model = Input(type="mlflow_model", path=student_model.id)
experiment_name = f"distillation-{TEACHER_MODEL_NAME}".replace(".", "-")

finetuning_job = distillation_pipeline(
    teacher_model_endpoint_name=TEACHER_MODEL_ENDPOINT_NAME,
    enable_chain_of_thought=ENABLE_CHAIN_OF_THOUGHT,
    system_properties=system_properties_b64_encoded,
    input_finetune_model=input_finetune_model,
    train_file_path=train_file_path_input,
    validation_file_path=validation_file_path_input,
)

finetuning_job.properties.update(system_properties)
print(f"job property: {finetuning_job.properties}")

# pipeline_job.identity = UserIdentityConfiguration()
finetuning_job.display_name = f"finetune-{student_model.name}"
finetuning_job.experiment_name = experiment_name
finetuning_job.settings.default_compute_type = "serverless"
finetuning_job.continue_on_step_failure = False
# pipeline_job.settings.force_rerun = True

## Submit pipeline job

In [None]:
# Submit pipeline job to workspace
ft_job = ml_client.jobs.create_or_update(finetuning_job)
print(f"Submitted job, progress available at {ft_job.studio_url}")

## Create a serverless endpoint to consume the model (optional)

In [None]:
# Wait for the job to complete
ml_client.jobs.stream(ft_job.name)
registered_model_name = ml_client.jobs.get(ft_job.name).properties[
    "registered_ft_model_name"
]

In [None]:
# Create the model url for registered endpoint
rg_model_vs = ml_client.models.get(registered_model_name, label="latest")._version

rg_model_asset_id = (
    "azureml://locations/"
    f"{ai_project.location}"
    "/workspaces/"
    f"{ai_project._workspace_id}"
    "/models/"
    f"{registered_model_name}"
    "/versions/"
    f"{rg_model_vs}"
)

In [None]:
# Create serverless endpoint - names must be unique, we will use suffix of the model
short_id = registered_model_name[-9:]
serverless_endpoint_name = "my-endpoint-" + short_id

serverless_endpoint = ServerlessEndpoint(
    name=serverless_endpoint_name,
    model_id=rg_model_asset_id,
)

created_endpoint = ml_client.serverless_endpoints.begin_create_or_update(
    serverless_endpoint
).result()

## Sample inference against the deployed endpoint (optional)

In [None]:
url = created_endpoint.scoring_uri
key = ml_client.serverless_endpoints.get_keys(created_endpoint.name).primary_key
model = ChatCompletionsClient(
    endpoint=url,
    credential=AzureKeyCredential(key),
)

response = model.complete(
    messages=[
        SystemMessage(
            content="You are a helpful assistant. Your output should only be one of the five choices: 'A', 'B', 'C', 'D', or 'E'."
        ),
        UserMessage(
            content="Answer the following multiple-choice question by selecting the correct option.\n\nQuestion: Can you name a good reason for attending school?\nAnswer Choices:\n(A) get smart\n(B) boredom\n(C) colds and flu\n(D) taking tests\n(E) spend time"
        ),
    ],
)

print(response.choices[0].message.content)

## Cleanup endpoints created (optional)

Endpoint deployments are chargeable and incurr costs on the subscription. Optionally clean up the endpoints after finishing experiments

In [None]:
_ = ml_client.serverless_endpoints.begin_delete(TEACHER_MODEL_ENDPOINT_NAME)
_ = ml_client.serverless_endpoints.begin_delete(serverless_endpoint_name)