### Step 2: Fine-tune the teacher on the dataset

We fine-tune the unpruned model on our dataset to correct the distribution shift from the original dataset the model was trained on. We use [NeMo Run](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/quickstart.html) to run the fine-tuning recipe.

According to the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that without correcting for this distribution shift, the teacher provides suboptimal guidance on the dataset during distillation.

> `NOTE:` For this demonstration, training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps. Please change the fine-tuning recipe for your model and dataset:

In [None]:
import nemo_run as run
from nemo.collections import llm

Let's define the recipe and executor for running it. We will use the `torchrun` launcher for this but you can use the Slurm launcher as well for multi-node runs.

In [None]:
# Set path(s) if different:
ROOT_DIR = "/workspace"
MODEL_PATH = f"{ROOT_DIR}/Llama-3.1-8B-nemo"
SEQ_LENGTH = 8192
EXP_NAME = "Llama-3.1-8B-nemo-ft"
EXP_DIR = f"{ROOT_DIR}/{EXP_NAME}"
DATA_PATH = f"{ROOT_DIR}/wikitext-data"
DATA_PATHS = {
    "train": [1.0, f"{DATA_PATH}/wikitext_tokenized_train_text_document"],
    "validation": [f"{DATA_PATH}/wikitext_tokenized_test_text_document"],
    "test": [f"{DATA_PATH}/wikitext_tokenized_val_text_document"],
}
INDEX_MAPPING_DIR = f"{DATA_PATH}/index_mappings"

# Change these to accommodate resources:
DEVICES = 8
NODES = 1
TENSOR_PARALLEL_SIZE = DEVICES
PIPELINE_PARALLEL_SIZE = NODES
MICRO_BATCH_SIZE = 4

# Change the fine-tuning recipe for your model and dataset (below values for demonstration purposes):
STEPS = 30
GLOBAL_BATCH_SIZE = 128
LR = 1e-4
MIN_LR = 1e-5
WARMUP_STEPS = 2
LOG_INTERVAL = 1
VAL_INTERVAL = 10
NUM_VAL_BATCHES = 5


def configure_recipe():
    # Define the recipe
    recipe = llm.llama31_8b.finetune_recipe(
        num_nodes=NODES,
        num_gpus_per_node=DEVICES,
        peft_scheme=None,  # Full finetuning
        seq_length=SEQ_LENGTH,
    )
    recipe.resume.restore_config.path = MODEL_PATH
    recipe.log.explicit_log_dir = EXP_DIR
    recipe.log.ckpt.every_n_train_steps = VAL_INTERVAL

    # Change dataset (default is Squad dataset)
    recipe.data = run.Config(
        llm.PreTrainingDataModule,
        paths=DATA_PATHS,
        index_mapping_dir=INDEX_MAPPING_DIR,
        seq_length=SEQ_LENGTH,
        micro_batch_size=MICRO_BATCH_SIZE,
        global_batch_size=GLOBAL_BATCH_SIZE,
    )

    # Set the training parameters if you dont want to use the recipe defaults
    recipe.trainer.max_steps = STEPS
    recipe.trainer.log_every_n_steps = LOG_INTERVAL
    recipe.trainer.val_check_interval = VAL_INTERVAL
    recipe.trainer.limit_val_batches = NUM_VAL_BATCHES
    recipe.trainer.strategy.tensor_model_parallel_size = TENSOR_PARALLEL_SIZE
    recipe.trainer.strategy.pipeline_model_parallel_size = PIPELINE_PARALLEL_SIZE
    recipe.trainer.strategy.sequence_parallel = TENSOR_PARALLEL_SIZE > 1
    recipe.optim.config.lr = LR
    recipe.optim.lr_scheduler.warmup_steps = WARMUP_STEPS
    recipe.optim.lr_scheduler.min_lr = MIN_LR

    return recipe


recipe = configure_recipe()
print(recipe)
env_vars = {
    "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",  # Disable caching NCCL communication buffer memory
    "NCCL_NVLS_ENABLE": "0",  # Disable NVLink SHARP to save memory
}
executor = run.LocalExecutor(ntasks_per_node=recipe.trainer.devices, launcher="torchrun", env_vars=env_vars)

Let's run the recipe. This is expected to take at least 20 minutes to run on 8x 80GB H100 GPUs (may vary depending on GPU and recipe).

In [None]:
run.run(recipe, executor=executor, name=EXP_NAME)

This will create save topk fine-tuned teacher models at `/workspace/Llama-3.1-8B-nemo-ft/checkpoints/{model_name}--{val_loss:.2f}-{step}-{consumed_samples}`. Let's rename the one with lowest `val_loss` to to `/workspace/Llama-3.1-8B-nemo-ft/checkpoints/best` so its easier to find. We'll use this later.

In [None]:
# NOTE: Rename path based on your training run
!mv "{ROOT_DIR}/Llama-3.1-8B-nemo-ft/checkpoints/model_name=0--val_loss=2.03-step=29-consumed_samples=3840.0" "{ROOT_DIR}/Llama-3.1-8B-nemo-ft/checkpoints/best"