# Distillation Math with Large Language Models
 
### Notebook details
 
This sample demonstrates how to train the selected student model using the teacher model, resulting in the creation of the distilled model.
 
We will use the Meta Llama 3.1 405B Instruct as the teacher model and the Meta Llama 3.1 8B Instruct as the student model.
 
**Note :**
 
- Distillation offering is only available in **West US 3** regions.
- Distillation should only be used for single turn chat completion format.
- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.
- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.

**Prerequisites :**
- Subscribe to the Meta Llama 3.1 405B Instruct and Meta Llama 3.1 8B Instruct, see [how to subscribe your project to the model offering in MS Learn](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio#subscribe-your-project-to-the-model-offering)

# 1. Connect to Azure Machine Learning Workspace

The [workspace](https://docs.microsoft.com/en-us/azure/machine-learning/concept-workspace) is the top-level resource for Azure Machine Learning, providing a centralized place to work with all the artifacts you create when you use Azure Machine Learning. In this section we will connect to the workspace in which the job will be run.

## 1.1. Install the SDK v2

In [None]:
%pip install azure-ai-ml
%pip install azure-identity

%pip install mlflow
%pip install azureml-mlflow
%pip install datasets

## 1.2. Import the required libraries

In [None]:
# import required libraries

import json

from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

from azure.ai.ml import MLClient, Input, Output
from azure.ai.ml.constants import AssetTypes, DataGenerationTaskType, DataGenerationType
from azure.ai.ml.model_customization import distillation, EndpointRequestSettings, PromptSettings
from azure.ai.ml.entities import Data, ServerlessConnection

## 1.3. Configure workspace details and get a handle to the workspace

To connect to a workspace, we need identifier parameters - a subscription, resource group and workspace name. We will use these details in the `MLClient` from `azure.ai.ml` to get a handle to the required workspace. We use the [default azure authentication](https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python) for this tutorial. Check the [configuration notebook](../../configuration.ipynb) for more details on how to configure credentials and connect to a workspace.


### 1.3.1 Prerequisites

An AI Studio project in **West US 3** is required. Please follow [this](https://learn.microsoft.com/azure/ai-studio/how-to/fine-tune-model-llama?tabs=llama-two%2Cchatcompletion#prerequisites) document to setup your AI Studio project

### 1.3.2 AI Studio project settings

Update following cell with the information of the AI Studio project just created.

In [None]:
SUBSCRIPTION_ID = "<SUBSCRIPTION_ID>"
RESOURCE_GROUP = "<RESOURCE_GROUP>"
WORKSPACE_NAME = "<AML_WORKSPACE_NAME>"

In [None]:
try:
    credential = DefaultAzureCredential()
    # Check if given credential can get token successfully.
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    # Fall back to InteractiveBrowserCredential in case DefaultAzureCredential not work
    credential = InteractiveBrowserCredential()

### 1.3.3 Get handle to AI Studio project

In [None]:
ml_client = MLClient(credential, SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE_NAME)

ai_project = ml_client._workspaces.get(ml_client.workspace_name)
ai_project._workspace_id

## 2. Data

### 2.1 Download the dataset from HuggingFace repo

For our example, we download and use the MultiArith dataset (https://huggingface.co/datasets/ChilleD/MultiArith) from HuggingFace.

In [None]:
from datasets import load_dataset

from abc import ABC


class InputDataset(ABC):
    def __init__(self):
        super().__init__()
        (
            self.train_data_file_name,
            self.test_data_file_name,
            self.eval_data_file_name,
        ) = (None, None, None)


class NLIHuggingFaceInputDataset(InputDataset):
    """
    Loads the HuggingFace dataset
    """

    def __init__(self):
        super().__init__()

    def load_hf_dataset(
        self,
        dataset_name,
        train_sample_size=10,
        val_sample_size=10,
        test_sample_size=10,
        train_split_name="train",
        val_split_name="validation",
        test_split_name="test",
    ):
        full_dataset = load_dataset(dataset_name)

        if val_split_name is not None:
            train_data = full_dataset[train_split_name].select(range(train_sample_size))
            val_data = full_dataset[val_split_name].select(range(val_sample_size))
            test_data = full_dataset[test_split_name].select(range(test_sample_size))
        else:
            train_val_data = full_dataset[train_split_name].select(
                range(train_sample_size + val_sample_size)
            )
            train_data = train_val_data.select(range(train_sample_size))
            val_data = train_val_data.select(
                range(train_sample_size, train_sample_size + val_sample_size)
            )
            test_data = full_dataset[test_split_name].select(range(test_sample_size))

        return train_data, val_data, test_data

In [None]:
# We can define train and test sample sizes here. We use a 90-10 split of the training data for validation since there is no validation.
# Note: For math task, no less than 40 entries is the allowed size for training or validation
train_sample_size = 378
val_sample_size = 42

# Sample notebook using the dataset: https://huggingface.co/datasets/ChilleD/MultiArith
dataset_name = "ChilleD/MultiArith"
input_dataset = NLIHuggingFaceInputDataset()

# Note: train_split_name and test_split_name can vary by dataset. They are passed as arguments in load_hf_dataset.
# If val_split_name is None, the below function will split the train set to create the specified sized validation set.
train, val, _ = input_dataset.load_hf_dataset(
    dataset_name=dataset_name,
    train_sample_size=train_sample_size,
    val_sample_size=val_sample_size,
    train_split_name="train",
    val_split_name=None,
)

print("Len of train data sample is " + str(len(train)))
print("Len of validation data sample is " + str(len(val)))

In [None]:
! mkdir -p data

### 2.2 Prepare data to submit for inferencing
The data has now been downloaded and processed in the case that only training data was available and not validation data. In this section we will format the downloaded data to match what is expected in an inferencing request. We will also add a system prompt to instruct the teacher model what kind of labels to generate.

In [None]:
train_data_path = "data/train_multiarith_378.jsonl"
valid_data_path = "data/valid_multiarith_42.jsonl"

system_prompt = "You are an AI assistant that only provides numerical answer to the given math question. \
Do not include reasoning, calculations, answer unit, mathematical operators (+, -, *, /, =), or any other extra words \
in your response. Please ensure your response is solely an integer that answers the question. If the answer is negative, \
include the negative sign; otherwise, do not use any sign."

user_prompt_template = "Question: {question}"

for row in train:
    data = {
        "messages": [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": user_prompt_template.format(question=row["question"]),
            },
        ]
    }

    with open(train_data_path, "a") as f:
        f.write(json.dumps(data) + "\n")

for row in val:
    data = {
        "messages": [
            {"role": "system", "content": system_prompt},
            {
                "role": "user",
                "content": user_prompt_template.format(question=row["question"]),
            },
        ]
    }

    with open(valid_data_path, "a") as f:
        f.write(json.dumps(data) + "\n")

### 2.3 Create Data Input

In [None]:
# Training data defined locally, with local data to be uploaded
train_data = Input(type=AssetTypes.URI_FILE, path=train_data_path)

# If training data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'
# train_data = "azureml:math_train_multi_arith:1"

In [None]:
# Validation data defined locally, with local data to be uploaded
valid_data = Input(type=AssetTypes.URI_FILE, path=valid_data_path)

# If validation data was registered to workspace already, navigate to the Data tab, select the data to use and use the 'Named asset URI'
# train_data = "azureml:math_valid_multi_arith:1"

## 3. Configure and Run the Distillation Job
In this section we will configure and run a Distillation job.

### 3.1 Configure the job through the distillation() factory function

#### distillation() parameters:

The `distillation()` factory function allows user to configure Distillation for the label generation task for the most common scenarios with the following properties.

- `experiment_name` - The name of the Experiment. An Experiment is like a folder with multiple runs in Azure ML Workspace that should be related to the same logical machine learning experiment.
- `data_generation_type` - The type of data generation to perform. Valid options are 'label_generation'.
- `data_generation_task_type` - The kind of data to generation. Valid options include 'NLI', 'NLI_QA', 'CONVERSATION', 'MATH', and 'SUMMARIZATION'.
- `teacher_model_endpoint_connection` - A ServerlessConnection geared towards a MaaS endpoint. Requires the name of the endpoint, the endopoint url, and the api key for the endpoint.
- `student_model` - The student model to train with the synthetic data generated from the teacher model.
- `training_data` - The data to be used for training.
- `validation_data` - The data to be used for validation.
- `name` - The name of the Job/Run. This is an optional property. If not specified, a random name will be generated.


##### Teacher Model Connection
Select the teacher model to use. This requires a MaaS endpoint.
Supported teacher models:
1. Meta-Llama-3.1-405B-Instruct


In [None]:
teacher_model_endpoint_name = "Llama-3-1-405B-Instruct-vum"
teacher_model_endpoint_url = "https://Meta-Llama-3-1-405B-Instruct-vum.westus3.models.ai.azure.com/chat/completions"
teacher_model_api_key = "EXAMPLE_API_KEY"

#### Student Model
Select the student model to use. Supported student models:
1. Meta-Llama-3.1-8B-Instruct

In [None]:
# The model id
student_model = "azureml://registries/azureml-meta/models/Meta-Llama-3.1-8B-Instruct/versions/2"

In [None]:
distillation_job = (
    experiment_name="llama-distillation",
    data_generation_type=DataGenerationType.LABEL_GENERATION,
    data_generation_task_type=DataGenerationTaskType.MATH,
    teacher_model_endpoint_connection=ServerlessConnection(
        name=teacher_model_endpoint_name,
        endpoint=teacher_model_endpoint_url,
        api_key=teacher_model_api_key
    ),
    student_model=student_model,
    training_data=train_data,
    validation_data=valid_data,
    outputs={"registered_model": Output(type="mlflow_model", name="llama-distilled")}
)

### 3.2 Configure the distillation settings

#### set_teacher_model_settings() function parameters:
This is an optional configuration method to configure the settings inference requests will have when submitted to the teacher model endpoint.     
    
- `inference_parameters` - Inference parameters that are applied to inferencing requests. These inference parameters are aligned with parameters allowed by vllm. Currently, the inference parameters that are used by distillation are 'max_tokens', 'temperature', 'top_p', 'frequency_penalty', 'presence_penalty', and 'stop'.

- `endpoint_request_settings` - An EndpointRequestSettings object that adds settings for the inferencing requests sent to the endpoint. Valid endpoint settings include 'min_endpoint_success_ratio' and 'request_batch_size'.
    - `min_endpoint_success_ratio` - The minimum ratio of successful/total inferencing request needed for data generation to be considered successful. Will not proceed if the number of successful/total inferencing requests is below the ratio. Should be between 0 and 1, inclusive. Defaults to 0.7.
    - `request_batch_size` - The number of inferencing requests to send at once to the teacher model endpoint. Defaults to 10.


#### set_prompt_settings() function parameters:
This is an optional configuration method to configure the settings for the system prompt used for the teacher model.

- `prompt_setting` - A PromptSettings object that adds settings that determine what system prompt to use for the teacher model. Valid prompt settings for `MATH` task include 'enable_chain_of_thought'.
    - `enable_chain_of_thought` - The option to leverage Chain of Thought (CoT) reasoning for distillation. CoT leverages step by step reasoning ability of the teacher model to generate more accurate labels.


#### set_finetuning_settings() function parameters:
This is an optional configuration method to configure the settings for finetuning the student model.

- `hyperparameters` - The hyperparameters to use for finetuning.

In [None]:
# Optional settings to use for inferencing requests
distillation_job.set_teacher_model_settings(
    inference_parameters={
        "max_tokens": 200,
        "temperature": 0.8
    },
    endpoint_request_settings=EndpointRequestSettings(
        min_endpoint_success_ratio=0.7,
        request_batch_size=10
    )
)

# Optional settings to use for the system prompt
distillation_job.set_prompt_settings(
    prompt_settings=PromptSettings(
        enable_chain_of_thought=True
    )
)

# Optional settings to use for finetuning the student model
distillation_job.set_finetuning_settings(
    hyperparameters={
        "learning_rate": "0.00002",
        "per_device_train_batch_size": "1",
        "num_train_epochs": "3"
    }
)

### 3.2 Submit the Job
Using the `MLClient` created earlier, we will now run this Command in the workspace.

In [None]:
created_job = ml_client.jobs.create_or_update(distillation_job)

#### Wait Until the Distillation Job Finishes

In [None]:
print(created_job.name)

## 4. Consuming the Distilled Model

Once the above job completes, you should be able to deploy the model and use it for inferencing. To deploy this model, do the following:

1. Go to AI Studio
2. Navigate to the Fine-tuning tab on the left menu
3. In the list of models you see, click on the model which got created from the distillation
4. This should take you to the details page where you can see the model attributes and other details
5. Click on the Deploy button on top of the page
6. Follow the steps to deploy the model