### Step 4: Distill knowledge from teacher into pruned students
In this step, we will distill the depth and width pruned models using Knowledge Distillation. We use [NeMo Run](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/quickstart.html) to run the distillation recipe. For usage details of distillation recipe or alternative commandline script, please refer to the [distillation docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).

Let's define the common recipe setup for depth and width pruned model's distillation.

> `NOTE:` For this demonstration, training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps. Please change the recipe for your model and dataset:

In [None]:
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.modelopt.recipes import distillation_recipe
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

In [None]:
# Set path(s) if different:
ROOT_DIR = "/workspace"
TEACHER_MODEL_PATH = f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft/checkpoints/best"
SEQ_LENGTH = 8192
DATA_PATH = f"{ROOT_DIR}/wikitext-data"
DATA_PATHS = {
    "train": [1.0, f"{DATA_PATH}/wikitext_tokenized_train_text_document"],
    "validation": [f"{DATA_PATH}/wikitext_tokenized_test_text_document"],
    "test": [f"{DATA_PATH}/wikitext_tokenized_val_text_document"],
}
INDEX_MAPPING_DIR = f"{DATA_PATH}/index_mappings"

# Change these to accommodate resources:
DEVICES = 8
NODES = 1
TENSOR_PARALLEL_SIZE = DEVICES
PIPELINE_PARALLEL_SIZE = NODES
MICRO_BATCH_SIZE = 4

# Change the fine-tuning recipe for your model and dataset (below values for demonstration purposes):
STEPS = 30
GLOBAL_BATCH_SIZE = 128
LR = 1e-4
MIN_LR = 1e-5
WARMUP_STEPS = 2
LOG_INTERVAL = 1
VAL_INTERVAL = 10
NUM_VAL_BATCHES = 5


def configure_recipe(student_model_path, exp_dir, exp_name):
    # Define the recipe
    recipe = distillation_recipe(
        student_model_path=student_model_path,
        teacher_model_path=TEACHER_MODEL_PATH,
        name=exp_name,
        num_nodes=NODES,
        num_gpus_per_node=DEVICES,
    )
    recipe.resume.restore_config = run.Config(
        RestoreConfig,
        path=student_model_path,
    )
    recipe.log.explicit_log_dir = exp_dir
    recipe.log.ckpt.every_n_train_steps = VAL_INTERVAL
    del recipe.log.ckpt.train_time_interval

    # Change dataset (default is Squad dataset)
    recipe.data = run.Config(
        llm.PreTrainingDataModule,
        paths=DATA_PATHS,
        index_mapping_dir=INDEX_MAPPING_DIR,
        seq_length=SEQ_LENGTH,
        micro_batch_size=MICRO_BATCH_SIZE,
        global_batch_size=GLOBAL_BATCH_SIZE,
    )

    # Set the training parameters if you dont want to use the recipe defaults
    recipe.trainer.max_steps = STEPS
    recipe.trainer.log_every_n_steps = LOG_INTERVAL
    recipe.trainer.val_check_interval = VAL_INTERVAL
    recipe.trainer.limit_val_batches = NUM_VAL_BATCHES
    recipe.trainer.strategy.tensor_model_parallel_size = TENSOR_PARALLEL_SIZE
    recipe.trainer.strategy.pipeline_model_parallel_size = PIPELINE_PARALLEL_SIZE
    recipe.trainer.strategy.sequence_parallel = TENSOR_PARALLEL_SIZE > 1
    recipe.optim.config.lr = LR
    recipe.optim.lr_scheduler.warmup_steps = WARMUP_STEPS
    recipe.optim.lr_scheduler.min_lr = MIN_LR

    return recipe

#### Step 4a: Distilling depth-pruned student
While distilling knowledge from the teacher to depth-pruned model, the `student_model_path` model would be  `/workspace/Llama-3.1-8B-nemo-ft-depth-pruned` as produced by the depth-pruning step in the [pruning](./03_pruning.ipynb) notebook.

In [None]:
student_model_path=f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft-depth-pruned"
exp_name="Llama-3.1-8B-nemo-ft-depth-distilled"
exp_dir=f"{ROOT_DIR}/{exp_name}"

recipe = configure_recipe(student_model_path, exp_dir, exp_name)
print(recipe)

env_vars = {
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
    "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
}
executor = run.LocalExecutor(ntasks_per_node=recipe.trainer.devices, launcher="torchrun", env_vars=env_vars)

In [None]:
run.run(recipe, executor=executor, name=exp_name)

This will create the final distilled model at something like `/workspace/Llama-3.1-8B-nemo-ft-depth-distilled/checkpoints/{model_name}--{val_loss:.2f}-{step}-{consumed_samples}`. Exact path depends on your distillation run. The corresponding tensorboard logs will be saved at `/workspace/Llama-3.1-8B-nemo-ft-depth-distilled/tb_logs`.

> `NOTE:`This script takes at least 35 minutes to run (depends on GPU) and generate the final distilled model.

Here is an image of the validation loss over 30 steps of running distillation:

<img src="./imgs/val_loss_depth_pruned_student_distillation.png" width="400px" alt="Validation Loss plot when using the Depth-pruned model as the student">

#### Step 4b: Distilling width-pruned student
While distilling knowledge from the teacher to width-pruned model, the `student_model_path` model would be  `/workspace/Llama-3.1-8B-nemo-ft-width-pruned` as produced by the width-pruning step in the [pruning](./03_pruning.ipynb) notebook.

In [None]:
student_model_path=f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft-width-pruned"
exp_name="Llama-3.1-8B-nemo-ft-width-distilled"
exp_dir=f"{ROOT_DIR}/{exp_name}"

recipe = configure_recipe(student_model_path, exp_dir, exp_name)
print(recipe)

env_vars = {
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
    "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
}
executor = run.LocalExecutor(ntasks_per_node=recipe.trainer.devices, launcher="torchrun", env_vars=env_vars)

In [None]:
run.run(recipe, executor=executor, name=exp_name)

This will create the final distilled model at something like `/workspace/Llama-3.1-8B-nemo-ft-width-distilled/checkpoints/{model_name}--{val_loss:.2f}-{step}-{consumed_samples}`. Exact path depends on your distillation run. The corresponding tensorboard logs will be saved at `/workspace/Llama-3.1-8B-nemo-ft-width-distilled/tb_logs`.

> `NOTE:`This script takes at least 35 minutes to run (depends on GPU) and generate the final distilled model.

Here is an image of the validation loss over 30 steps of running distillation:

<img src="./imgs/val_loss_width_pruned_student_distillation.png" width="400px" alt="Validation Loss plot when using the width-pruned model as the student">
