# Learning Goals

## Optimizing Hugging Face Models with Supervised Fine-Tuning (SFT) in NeMo 2.0

NeMo 2.0 now allows users to perform SFT and PEFT using Hugging Face (HF) LLMs. This notebook demonstrates how to perform SFT with Hugging Face LLMs to make the models more performant on a specific task with. NeMo 2.0 utilizes HF's auto classes to download and load HF's transformer models, and wraps these models to turn them into lightning modules in order to perform tasks such as SFT and PEFT with NeMo 2.0.

[AutoModel](https://huggingface.co/docs/transformers/en/model_doc/auto) is the generic model class that will be instantiated as one of the model classes of the library when created with the from_pretrained() class method. There are many AutoModel classes in HF and each of them covers a specific group of transformer model architectures. AutoModel class loads mainly the base transformer model that converts embeddings to hidden states where a specific AutoModel class such as AutoModelForCausalLM has a causal language modeling head on top of the base model.

In this notebook, we will focus on the models that can be loaded using the HF's `AutoModelForCausalLM` class.

## Data
We will use [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) dataset which is a reading comprehension dataset, consisting of questions and answers pairs.

## Step 1. Import Modules and Prepare the Dataset

In [None]:
import tempfile
from functools import partial

import fiddle as fdl
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.callbacks import JitConfig, JitTransform

We will be using SquadDataModule that NeMo 2.0. provides. This data module extends the `FineTuningDataModule`, so that it has access to existing data handling logic including the packed sequences. 

In [None]:
class SquadDataModuleWithPthDataloader(llm.SquadDataModule):
    """Creates a squad dataset with a PT dataloader"""

    def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader:
        return DataLoader(
            dataset,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers,
            collate_fn=dataset.collate_fn,
            batch_size=self.micro_batch_size,
            **kwargs,
        )


def squad(tokenizer, mbs=1, gbs=2) -> pl.LightningDataModule:
    """Instantiates a SquadDataModuleWithPthDataloader and return it

    Args:
        tokenizer (AutoTokenizer): the tokenizer to use

    Returns:
        pl.LightningDataModule: the dataset to train with.
    """
    return SquadDataModuleWithPthDataloader(
        tokenizer=tokenizer,
        seq_length=512,
        micro_batch_size=mbs,
        global_batch_size=gbs,
        num_workers=0,
        dataset_kwargs={
            "sanity_check_dist_workers": False,
            "get_attention_mask_from_fusion": True,
        },
    )

Now, we will set some variables including the HF model name, maximum steps, number of GPUs, etc.

In [None]:
model_name = "meta-llama/Llama-3.2-1B" # HF model name. This can be the path of the downloaded model as well.
strategy = "auto" # Distributed training strategy such as DDP, FSDP2, etc.
devices = 1 # Number of GPUs.
max_steps = 100 # Number of steps in the training loop.
accelerator = "gpu"
wandb_project = None
use_torch_jit = False # torch jit can be enabled.
ckpt_folder="/opt/checkpoints/automodel_experiments/" # Path for saving the checkpoint.

In [None]:
wandb = None
if wandb_project is not None:
    model = '_'.join(args.model.split('/')[-2:])
    wandb = WandbLogger(
        project=args.wandb_project,
        name=f'{model}_dev{args.devices}_strat_{args.strategy}',
    )

callbacks = []
if use_torch_jit:
    jit_config = JitConfig(use_torch=True, torch_kwargs={'dynamic': False}, use_thunder=False)
    callbacks = [JitTransform(jit_config)]

callbacks.append(
    nl.ModelCheckpoint(
        every_n_train_steps=max_steps // 2,
        dirpath=ckpt_folder,
    )
)

if strategy == 'fsdp2':
    astrategy = nl.FSDP2Strategy(data_parallel_size=devices, tensor_parallel_size=1)

if __name__ == '__main__':
    llm.api.finetune(
        model=llm.HFAutoModelForCausalLM(model_name=model_name),
        data=squad(llm.HFAutoModelForCausalLM.configure_tokenizer(model_name), gbs=devices),
        trainer=nl.Trainer(
            devices=devices,
            max_steps=max_steps,
            accelerator="gpu",
            strategy=strategy,
            log_every_n_steps=1,
            limit_val_batches=0.0,
            num_sanity_val_steps=0,
            accumulate_grad_batches=1,
            gradient_clip_val=1.0,
            use_distributed_sampler=False,
            logger=wandb,
            callbacks=callbacks,
            precision="bf16",
        ),
        optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
        log=None,
    )