# Model Customization Demo (Distillation)

## Model Distillation with Amazon SageMaker Training Jobs


In [None]:
#! pip install transformers datasets "sagemaker>=2.190.0" --upgrade --quiet
#! pip install transformers boto3 "sagemaker-core==1.0.41" "datasets==4.0.0" "sagemaker>=2.190.0" --upgrade --quiet

In [None]:
import sagemaker
from datasets import load_dataset
import pandas as pd
from transformers import AutoTokenizer
import boto3
import os

sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
from huggingface_hub import login
# Provide hf_token value to models and data
os.environ['hf_token']=""
login(os.environ['hf_token'])


os.environ['WANDB_API_KEY'] = ""

In [None]:
from datasets import load_dataset

dataset = load_dataset('mlabonne/FineTome-100k')

In [None]:
dataset

In [None]:
# save train_dataset to s3 using our SageMaker session
input_path = f's3://{sagemaker_session.default_bucket()}/datasets/distllation_training_job'

# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
dataset["train"].to_json(f"{input_path}/train/dataset.json", orient="records")
sft_dataset_s3_path = f"{input_path}/train/dataset.json"
# ds_train_pref["train"].to_json(f"{input_path}/pref/dataset.json", orient="records")
# perf_dataset_s3_path = f"{input_path}/pref/dataset.json"
print(f"Training data uploaded to:")
print(sft_dataset_s3_path)
print(f"https://s3.console.aws.amazon.com/s3/buckets/{sagemaker_session.default_bucket()}/?region={sagemaker_session.boto_region_name}&prefix={input_path.split('/', 3)[-1]}/")


# ModelTrainer API

In [None]:
from sagemaker.config import load_sagemaker_config

In [None]:
configs = load_sagemaker_config()

In [None]:
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import Compute, SourceCode, InputData, StoppingCondition, CheckpointConfig

env = {}
env["FI_PROVIDER"] = "efa"
env["NCCL_PROTO"] = "simple"
env["NCCL_SOCKET_IFNAME"] = "eth0"
env["NCCL_IB_DISABLE"] = "1"
env["NCCL_DEBUG"] = "WARN"
env["HF_token"] = os.environ['hf_token']
env["WANDB_API_KEY"] = os.environ['WANDB_API_KEY']
env["data_location"] = sft_dataset_s3_path
# MLFlow tracker
tracking_server_arn = "arn:aws:sagemaker:us-east-1:783764584149:mlflow-tracking-server/test"
mlflow_experiment_name = 'distillation'
env["MLFLOW_TRACKING_ARN"] = tracking_server_arn
#env["MLFLOW_EXPERIMENT_NAME"] = mlflow_experiment_name

compute = Compute(
    instance_count=1,
    instance_type= "ml.p4de.24xlarge", #"ml.g6.48xlarge"
    volume_size_in_gb=500,
    keep_alive_period_in_seconds=3600,
)

In [None]:
image_uri = (
    f"658645717510.dkr.ecr.{sagemaker_session.boto_session.region_name}.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121"
)

image_uri

In [None]:
checkpoint_s3_path = f"s3://{bucket_name}/distillation-checkpoints/checkpoints"
checkpoint_s3_path

In [None]:
job_prefix = f"model-distillation-Qwen3-8B-0.6B"

In [None]:
hyperparameters = {
    "dataset_path": "/opt/ml/input/data/dataset",
    "model_dir": "/opt/ml/model",
    "MLFLOW_TRACKING_ARN": tracking_server_arn,
    #"MLFLOW_EXPERIMENT_NAME": mlflow_experiment_name
}

In [None]:
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="run_kd.sh",
)

In [None]:
model_trainer = ModelTrainer(
    training_image=image_uri,
    compute=compute,
    hyperparameters=hyperparameters,
    environment=env,
    source_code=source_code,
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=90000,
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri=f"{checkpoint_s3_path}/{job_prefix}",
    ),
    base_job_name=job_prefix
)

In [None]:
sft_dataset_s3_path

In [None]:
training_data = InputData(
    channel_name="training_dataset",
    data_source=sft_dataset_s3_path,
)

In [None]:
model_trainer.train(input_data_config=[training_data], wait=True)

In [None]:
def get_last_job_name(job_name_prefix):
    sagemaker_client = boto3.client('sagemaker')

    matching_jobs = []
    next_token = None

    while True:
        # Prepare the search parameters
        search_params = {
            'Resource': 'TrainingJob',
            'SearchExpression': {
                'Filters': [
                    {
                        'Name': 'TrainingJobName',
                        'Operator': 'Contains',
                        'Value': job_name_prefix
                    },
                    {
                        'Name': 'TrainingJobStatus',
                        'Operator': 'Equals',
                        'Value': "Completed"
                    }
                ]
            },
            'SortBy': 'CreationTime',
            'SortOrder': 'Descending',
            'MaxResults': 100
        }

        # Add NextToken if we have one
        if next_token:
            search_params['NextToken'] = next_token

        # Make the search request
        search_response = sagemaker_client.search(**search_params)

        # Filter and add matching jobs
        matching_jobs.extend([
            job['TrainingJob']['TrainingJobName'] 
            for job in search_response['Results']
            if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
        ])

        # Check if we have more results to fetch
        next_token = search_response.get('NextToken')
        if not next_token or matching_jobs:  # Stop if we found at least one match or no more results
            break

    if not matching_jobs:
        raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")

    return matching_jobs[0]

In [None]:
#job_prefix = 

In [None]:
job_name = get_last_job_name(job_prefix)

job_name

In [None]:
if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"


In [None]:
model_data