# Distillation for Text Summarization with Large Language Models

### Notebook details

This sample demonstrates how to train the selected student model using the teacher model, for the text summarization task.

We will use the Meta Llama 3.1 405B Instruct as the teacher model and the Meta Llama 3.1 8B Instruct as the student model.

**Note :**

- Distillation offering is only available in **West US 3** regions.
- Distillation inputs should be in the chat completion format:

  > {"messages": [ \
  > &nbsp;&nbsp;{"role": "system", "content": "Instructions for summarization"}, \
  > &nbsp;&nbsp;{"role": "user", "content": "Text to summarize"} \
  > ]}

- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.
- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.

**Prerequisites :**

- Subscribe to the Meta Llama 3.1 405B Instruct and Meta Llama 3.1 8B Instruct, see [how to subscribe your project to the model offering in MS Learn](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio#subscribe-your-project-to-the-model-offering)


### Install SDK v2 and verify the imports


In [None]:
# %pip install requests
# %pip install datasets
# %pip install azure-ai-ml
# %pip install azure-identity
# %pip install tqdm


import base64
import json
import os
import uuid

import datasets
import requests
from azure.ai.ml import Input, MLClient
from azure.ai.ml.constants import AssetTypes
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.entities import Data, ServerlessEndpoint
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from tqdm import tqdm

### Downloading the dataset

For this task we will use the [griffin/chain_of_density](https://huggingface.co/datasets/griffin/chain_of_density) dataset. This dataset consists of 1000 news articles which we will use to train and test our endpoints.

We will begin by downloading the dataset and preparing the data in chat completion format. AzureML expects both train and validation datasets for distillation. We will reserve some samples to test the distilled model. Hence we will split the data into 3 parts:

| Split      | Size |
| ---------- | ---- |
| Train      | 500  |
| Validation | 400  |
| Test       | 100  |


In [None]:
# Download the dataset
dataset = datasets.load_dataset("griffin/chain_of_density", name="unannotated")["train"]

dataset = dataset.shuffle(seed=41)

train_n = 500
valid_n = 400
test_n = 100

train_dataset = dataset.select(range(train_n))
valid_dataset = dataset.select(range(train_n, train_n + valid_n))
test_dataset = dataset.select(range(train_n + valid_n, train_n + valid_n + test_n))

# Keep assertions to guard against data changes in upstream
assert len(train_dataset) == train_n
assert len(valid_dataset) == valid_n
assert len(test_dataset) == test_n

In [None]:
root_dir = "./data/"
if not os.path.exists(root_dir):
    os.mkdir(root_dir)

src_train_data_file_name = root_dir + "src_train_" + str(train_n) + ".txt"
src_valid_data_file_name = root_dir + "src_valid_" + str(valid_n) + ".txt"
src_test_data_file_name = root_dir + "src_test_" + str(test_n) + ".txt"


def write_to_rawfile(data, fname):
    with open(fname, "w") as file:
        for row in data:
            file.write(row["article"] + "\n")


# Write raw data files
write_to_rawfile(train_dataset, src_train_data_file_name)
write_to_rawfile(valid_dataset, src_valid_data_file_name)
write_to_rawfile(test_dataset, src_test_data_file_name)

In [None]:
# Prepare data in chat completion format
cc_train_data_file_name = root_dir + "cc_train_" + str(train_n) + ".jsonl"
cc_valid_data_file_name = root_dir + "cc_valid_" + str(valid_n) + ".jsonl"
cc_test_data_file_name = root_dir + "cc_test_" + str(test_n) + ".jsonl"

SYSTEM_PROMPT = "You will generate concise, entity-dense summary of the given article. Only generate the summary text. Do not exceed 80 words."


def prepare_in_cc_format(filename):
    dataset = []
    with open(filename) as f:
        for line in f.readlines():
            dataset.append(
                {
                    "messages": [
                        {
                            "role": "system",
                            "content": SYSTEM_PROMPT,
                        },
                        {
                            "role": "user",
                            "content": f"Article: {line}",
                        },
                    ]
                }
            )
    return dataset


train_cc_data = prepare_in_cc_format(src_train_data_file_name)
valid_cc_data = prepare_in_cc_format(src_valid_data_file_name)
test_cc_data = prepare_in_cc_format(src_test_data_file_name)


def write_to_ccfile(fname, data):
    with open(fname, "w") as file:
        for row in data:
            file.write(json.dumps(row) + "\n")


write_to_ccfile(cc_train_data_file_name, train_cc_data)
write_to_ccfile(cc_valid_data_file_name, valid_cc_data)
write_to_ccfile(cc_test_data_file_name, test_cc_data)

### Creating the connection to AzureML and obtaining registries

From the AzureML portal ([https://portal.azure.com](https://portal.azure.com/#home)) obtain the subscripton, workspace and resource group info. A workspace in **West US 3** is required. Populate the below cell with the info. We will use this to obtain the workspace.

We will also obtain the registries which hold the meta models (azureml-meta) and the distillation pipeline (azureml).


In [None]:
SUBSCRIPTION_ID = "75703df0-38f9-4e2e-8328-45f6fc810286"
RESOURCE_GROUP_NAME = "rg-sasumai"
WORKSPACE_NAME = "sasum-westus3-ws"

In [None]:
# Connect to Azure and get workspace and registries
try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception as ex:
    credential = InteractiveBrowserCredential()

# Refers to the current account to be used for experiments
workspace_client = MLClient(
    credential,
    subscription_id=SUBSCRIPTION_ID,
    resource_group_name=RESOURCE_GROUP_NAME,
    workspace_name=WORKSPACE_NAME,
)

# Regsitry which holds Meta models
model_registry_client = MLClient(credential, registry_name="azureml-meta")

# Regsitry which holds the Distillation component
distillation_registry_client = MLClient(credential, registry_name="maas-test-registry")

### Upload the prepared data to your workspace

We will upload the train and validation files created in previous cells to our workspace. These will be consumed by our distillation pipeline. If data assets with the same name already exist in our workspace, the version will be updated. The latest version is used by default.


In [None]:
# Upload the data to Azure workspace
train_data_name = "distillation-summ-exp-train"
valid_data_name = "distillation-summ-exp-valid"

train_data = workspace_client.data.create_or_update(
    Data(
        path=cc_train_data_file_name,
        type=AssetTypes.URI_FILE,
        description="Training dataset",
        name=train_data_name,
    )
)

valid_data = workspace_client.data.create_or_update(
    Data(
        path=cc_valid_data_file_name,
        type=AssetTypes.URI_FILE,
        description="Validation dataset",
        name=valid_data_name,
    )
)

### Chain of density

Chain-of-density is a new way to operate distillation in which the system prompt for the teacher is replaced by Azure with an enhanced prompt. This produces better quality of outputs from the teacher thus improving accuracy of distilled model.


In [None]:
# Set this to false if you would like to use your own prompt anyway
ENABLE_CHAIN_OF_DENSITY = True

### Configure distillation params

Parameters commonly changed are teacher_model_max_new_tokens, teacher_model_temperature and teacher_model_top_p. Meta-Llama-3.1-8B, Meta-Llama-3.1-8B-Instruct, Meta-Llama-3.1-70B and Meta-Llama-3.1-70B-Instruct are supported as student models.


In [None]:
# Configure distillation teacher and student models, obtain the distillation component
TEACHER_MODEL_ENDPOINT_NAME = "Meta-Llama-3-1-405B-Instruct-efu"
STUDENT_MODEL_NAME = "Meta-Llama-3.1-8B-Instruct"
STUDENT_MODEL_VERSION = 1

student_model = model_registry_client.models.get(
    STUDENT_MODEL_NAME, STUDENT_MODEL_VERSION
)

DISTILLATION_PIPELINE_NAME = "oss_distillation_pipeline"
distillation_pipeline_component = distillation_registry_client.components.get(
    name=DISTILLATION_PIPELINE_NAME, version="0.0.6.testv7"
)


@pipeline
def distillation_pipeline(
    teacher_model_endpoint_name,
    system_properties,
    input_finetune_model,
    registered_model_name,
    train_file_data_asset,
    valid_file_data_asset,
):
    oss_distillation = distillation_pipeline_component(
        teacher_model_endpoint_name=teacher_model_endpoint_name,
        teacher_model_max_new_tokens=(1024 if ENABLE_CHAIN_OF_DENSITY else 256),
        teacher_model_temperature=0,
        teacher_model_top_p=1,
        enable_chain_of_density=ENABLE_CHAIN_OF_DENSITY,
        train_file_path=train_file_data_asset,
        validation_file_path=valid_file_data_asset,
        # Finetune
        mlflow_model_path=input_finetune_model,
        model_asset_id=student_model.id,
        registered_model_name=registered_model_name,
        system_properties=system_properties,
        ## hyperparams
        learning_rate=2e-5,
        per_device_train_batch_size=1,
        num_train_epochs=5,
        data_generation_task_type="SUMMARIZATION",
    )

    return {"output_model": oss_distillation.outputs.output_model}

In [None]:
# Configure system properties for the job, these help make the model appear in AI Studio
system_properties = {
    "finetune_oss": "True",
    "PipelineType": "Finetune",
    "azureml.PipelineType": "Finetune",
    "model_asset_id": student_model.id,
}

json_str = json.dumps(system_properties).replace(" ", "")

system_properties_b64_encoded = base64.b64encode(json_str.encode("utf-8")).decode(
    "utf-8"
)

In [None]:
# Configure distillation parameters
short_id = str(uuid.uuid4())[:8]
train_file_path_input = Input(type="uri_file", path=train_data.path)
valid_file_path_input = Input(type="uri_file", path=valid_data.path)
input_finetune_model = Input(type="mlflow_model", path=student_model.id)
experiment_name = f"distillation-summarization".replace(".", "-")
# do not use underscores in the name, that's unsupported
registered_model_name = "my-summ-model-" + short_id

distillation_job = distillation_pipeline(
    teacher_model_endpoint_name=TEACHER_MODEL_ENDPOINT_NAME,
    system_properties=system_properties_b64_encoded,
    input_finetune_model=input_finetune_model,
    registered_model_name=registered_model_name,
    train_file_data_asset=train_file_path_input,
    valid_file_data_asset=valid_file_path_input,
)

distillation_job.display_name = f"distillation-summarization"
distillation_job.experiment_name = experiment_name
distillation_job.settings.default_compute_type = "serverless"
distillation_job.continue_on_step_failure = False

In [None]:
# Submit pipeline job to workspace
dst_job = workspace_client.jobs.create_or_update(distillation_job)
workspace_client.jobs.stream(dst_job.name)

### Create a serverless endpoint to consume the model

Our distillation job is now complete, we will deploy the model from it as a serverless endpoint


In [None]:
# Create the model url for registered endpoint
# The version is 1 for the first run
rg_model_vs = 1

workspace = workspace_client.workspaces.get(WORKSPACE_NAME)
rg_model_asset_id = (
    "azureml://locations/"
    f"{workspace.location}"
    "/workspaces/"
    f"{workspace._workspace_id}"
    "/models/"
    f"{registered_model_name}"
    "/versions/"
    f"{rg_model_vs}"
)

# Create serverless endpoint - names must be unique
serverless_endpoint_name = "my-endpoint-" + short_id

serverless_endpoint = ServerlessEndpoint(
    name=serverless_endpoint_name,
    model_id=rg_model_asset_id,
)

created_endpoint = workspace_client.serverless_endpoints.begin_create_or_update(
    serverless_endpoint
).result()

### Inferencing the model on the test file _(optional)_

We will evaluate the test dataset on the distilled endpoint and store it as a file. We will then evaluate the output.


In [None]:
DISTLLED_ENDPOINT_URL = created_endpoint.scoring_uri + "/v1/chat/completions"
DISTLLED_ENDPOINT_KEY = workspace_client.serverless_endpoints.get_keys(
    created_endpoint.name
).primary_key


op_test_data_file_name = root_dir + "op_test_" + str(test_n) + ".jsonl"

auth_headers = {
    "Content-Type": "application/json",
    "Authorization": "Bearer " + DISTLLED_ENDPOINT_KEY,
}


def get_summary_from_llm(messages):
    data = {
        "messages": messages,
        "temperature": 0,
        "top_p": 1.0,
        "max_new_tokens": 256,
    }
    response = requests.post(
        url=DISTLLED_ENDPOINT_URL, headers=auth_headers, data=json.dumps(data)
    )

    try:
        response_dict = json.loads(response.text)
        res = response_dict["choices"][0]["message"]["content"]
    except Exception as e:
        print(response.text, e)
        res = "error"
    return res


summaries = []
for item in tqdm(test_cc_data):
    summary = get_summary_from_llm(item["messages"])
    summaries.append({"article": item["messages"][1]["content"], "summary": summary})

with open(op_test_data_file_name, "w") as file:
    for summary in summaries:
        file.write(json.dumps(summary) + "\n")

### Evaluating the model output

We will score the output on a modified version of the entity-density metric. We capture the entities in the summary and retain only the relevant entities. A relevant entity is one that appears in the original article. This is divided by the size of the summary in tokens to adjust for overly verbose outputs.


In [None]:
# %pip install spacy
# %pip install nltk

import spacy
import nltk

nltk.download("wordnet")
nltk.download("punkt")
nltk.download("punkt_tab")
nlp = spacy.load("en_core_web_sm")


def calculate_density(article, summary):
    tokens = nltk.word_tokenize(summary)
    num_tokens = len(tokens)

    article_ents = [ent.text for ent in nlp(article).ents]
    doc_ents = [ent.text for ent in nlp(summary).ents]
    common_ents = [ent for ent in doc_ents if ent in article_ents]
    num_entities = len(common_ents)

    density = num_entities / num_tokens
    return density, num_tokens


SKIP_ERROR = True
density_sum = 0
token_sum = 0
sample_cnt = 0
for row in summaries:
    article = row["article"]
    summary = row["summary"]
    if summary == "error" and SKIP_ERROR:
        continue
    density, num_tokens = calculate_density(article=article, summary=summary)
    density_sum += density
    token_sum += num_tokens
    sample_cnt += 1

average_density = round(density_sum / sample_cnt, 5)
average_token_length = round(token_sum / sample_cnt, 5)
print(
    f"The mean density is {average_density} with average token length of {average_token_length} accross {sample_cnt} samples"
)