### Step 3: Prune the fine-tuned teacher model to create a pruned student
In this step, we will explore two methods to prune the fine-tuned teacher model - depth and width pruning. Refer to the [README.md](./README.md) to decide which pruning techniques you would like to explore. We use [NeMo Run](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/quickstart.html) to run the pruning recipe. For usage details of pruning recipe or alternative commandline script, please refer to the [pruning docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/pruning/pruning.html).

Let's define the common recipe setup for depth or width pruning.

In [None]:
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.modelopt import PruningConfig
from nemo.collections.llm.modelopt.recipes import prune_recipe

In [None]:
# Set path(s) if different:
ROOT_DIR = "/workspace"
TEACHER_MODEL_PATH = f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft/checkpoints/best"
SEQ_LENGTH = 8192
DATA_PATH = f"{ROOT_DIR}/wikitext-data"
DATA_PATHS = {
    "train": [1.0, f"{DATA_PATH}/wikitext_tokenized_train_text_document"],
    "validation": [f"{DATA_PATH}/wikitext_tokenized_test_text_document"],
    "test": [f"{DATA_PATH}/wikitext_tokenized_val_text_document"],
}
INDEX_MAPPING_DIR = f"{DATA_PATH}/index_mappings"

# Change these to accommodate resources:
DEVICES = 8
NODES = 1
TENSOR_PARALLEL_SIZE = 1  # Pruning only supports tensor parallelism 1
PIPELINE_PARALLEL_SIZE = DEVICES
MICRO_BATCH_SIZE = 4

# Reduce this number to speed up the pruning process but may result in a slightly worse pruned model
# Not used if pruning_config.drop_layers is set
NUM_TRAIN_SAMPLES = 1024


def configure_recipe(pruning_config: PruningConfig, save_path: str):
    # Define the recipe
    recipe = prune_recipe(
        nemo_checkpoint=TEACHER_MODEL_PATH,
        save_path=save_path,
    )

    # Change dataset (default is Squad dataset)
    recipe.data = run.Config(
        llm.PreTrainingDataModule,
        paths=DATA_PATHS,
        index_mapping_dir=INDEX_MAPPING_DIR,
        seq_length=SEQ_LENGTH,
        micro_batch_size=MICRO_BATCH_SIZE,
        global_batch_size=MICRO_BATCH_SIZE,  # Global batch size has no effect on pruning
    )

    recipe.devices = DEVICES
    recipe.num_nodes = NODES
    recipe.tp_size = TENSOR_PARALLEL_SIZE
    recipe.pp_size = PIPELINE_PARALLEL_SIZE
    recipe.legacy_ckpt = True  # For compatibility with newer versions of TransformerEngine
    recipe.num_train_samples = NUM_TRAIN_SAMPLES

    for k, v in pruning_config.__dict__.items():
        if v is not None:
            setattr(recipe.pruning_config, k, v)

    return recipe

#### Step 3a: Using depth-pruning 
To depth-prune, we will trim the layers 16-31 (leaving 1-15 and 32) in the finetined teacher model. Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. 

In [None]:
pruning_config = PruningConfig(
    # To drop specific layers
    drop_layers=[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
    # To drop layers automatically based on the cosine similarity
    # target_num_layers=16,
)
save_path = f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft-depth-pruned"

recipe = configure_recipe(pruning_config=pruning_config, save_path=save_path)
print(recipe)

env_vars = {
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
    "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
}
executor = run.LocalExecutor(ntasks_per_node=recipe.devices, launcher="torchrun", env_vars=env_vars)

In [None]:
run.run(recipe, executor=executor, name="depth_pruning")

Running this script will save the depth-pruned model to your workspace at `/workspace/Llama-3.1-8B-nemo-ft-depth-pruned`.

#### Step 3b: Using width-pruning 
To width-prune, we will trim the ffn_hidden_size to 9216 and hidden_size to 3072. We can also trim the `num_attention_heads` and `num_query_groups` if needed.

In [None]:
pruning_config = PruningConfig(
    target_ffn_hidden_size=9216,
    target_hidden_size=3072,
    target_num_attention_heads=None,
    target_num_query_groups=None,
)
save_path = f"{ROOT_DIR}/Llama-3.1-8B-nemo-ft-width-pruned"

recipe = configure_recipe(pruning_config=pruning_config, save_path=save_path)
print(recipe)

env_vars = {
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
    "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
}
executor = run.LocalExecutor(ntasks_per_node=recipe.devices, launcher="torchrun", env_vars=env_vars)

In [None]:
run.run(recipe, executor=executor, name="width_pruning")

Running this script will save the width-pruned model to your workspace at `/workspace/Llama-3.1-8B-nemo-ft-width-pruned`.

Now that we have the depth and width pruned models, we can distill them from the finetuned teacher model in next step.