# MoE Supervised Fine-Tuning with LoRA

### WandB Configuration

We start by configuring [Weights & Biases (WandB)](https://wandb.ai/) for experiment tracking.  
This allows us to log key metrics and hyperparameters during training for better monitoring and reproducibility.  

In this step, we:  
- Set the WandB environment variables (API key, project, run name, and logging directory).  
- Define the main hyperparameters for **Supervised Fine-Tuning (SFT)**: learning rate, batch size, number of epochs, and maximum sequence length.  
- Specify the base **Mixture-of-Experts (MoE)** model ID from IBM Granite.

In [None]:
import os 

os.environ["WANDB_API_KEY"] = "<your_api_key_here>"
os.environ["WANDB_PROJECT"] = "blue-yonder-mle-assignment"
os.environ["WANDB_RUN_NAME"] = "granite-3.0-sft"
os.environ["WANDB_DIR"] = ".."

learning_rate = 3e-4
batch_size = 3
num_train_epochs = 2
max_seq_length = 1024
base_model_id = "ibm-granite/granite-3.0-1b-a400m-base"

### Model and LoRA Configuration

Next, we load the **base IBM Granite MoE model** and prepare it for supervised fine-tuning with **LoRA (Low-Rank Adaptation)**.  

In this step, we:  
- Load the pre-trained causal language model with `transformers`.  
- Initialize the corresponding tokenizer, setting `padding_side="left"` for causal models and aligning the pad token with the EOS token.  
- Define the **LoRA configuration** to inject trainable low-rank adapters into key projection layers.  
- Wrap the model with PEFT’s `get_peft_model` to apply LoRA.  

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(base_model_id, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj"],
    r=16,
    lora_alpha=32,
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",    
)

model = get_peft_model(model, lora_config)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.26it/s]


After applying **LoRA**, only a small fraction of the model’s parameters are trainable, which significantly reduces the computational and memory cost compared to full fine-tuning.  


In [3]:

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

model.print_trainable_parameters()

trainable params: 2,752,512 || all params: 1,337,377,792 || trainable%: 0.2058


### SFT Chat Template with Reasoning and Answer Tags

When performing **Supervised Fine-Tuning (SFT)** on chat-style data, it is important to provide the model with a consistent structure for inputs and outputs.  
We define a **chat template** that:  

- Wraps internal reasoning inside `<think>...</think>`.  
- Places the final answer inside `<answer>...</answer>`.  

This improves the **readability of the outputs**, as reasoning and answers are clearly separated.  
It also helps the model learn a structured format that can later be used consistently during inference.  

Additionally, the chat template enables the model to behave like a **chat assistant** after SFT, handling multi-turn interactions in a natural way.  

📖 For more details, see the [Hugging Face TRL SFT Trainer documentation](https://huggingface.co/docs/trl/en/sft_trainer).  

In [4]:
reasoning_start = "<think>" 
reasoning_end   = "</think>"   
solution_start  = "<answer>"
solution_end    = "</answer>"

chat_template = \
    "{{ bos_token }}"\
    "{% for message in messages %}"\
        "{% if message['role'] == 'system' %}"\
            "{{ '<|system|>\n' + message['content'] + '\n' }}"\
        "{% elif message['role'] == 'user' %}"\
            "{{ '<|user|>\n' + message['content'] + '\n' }}"\
        "{% elif message['role'] == 'assistant' %}"\
            "{% if not loop.last %}"\
                "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
            "{% else %}"\
                "{{ '<|assistant|>\n' + message['content'] + eos_token }}"\
            "{% endif %}"\
        "{% endif %}"\
        "{% if loop.last and add_generation_prompt %}"\
            "{{ '<|assistant|>' }}{{ '{reasoning_start}' }}"\
        "{% endif %}"\
    "{% endfor %}"

# Replace with specific template:
chat_template = chat_template\
    .replace("'{reasoning_start}'", f"'{reasoning_start}'")
tokenizer.chat_template = chat_template

tokenizer.apply_chat_template([
    #{"role" : "system", "content" : "You are given a math problem. You must first think step by step and then give the final answer."},
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"{reasoning_start}I think it's 2.{reasoning_end}{solution_start}2{solution_end}"},
    {"role" : "user", "content" : "What is 2+2?"},
], tokenize = False, add_generation_prompt = True)

"<|endoftext|><|user|>\nWhat is 1+1?\n<|assistant|>\n<think>I think it's 2.</think><answer>2</answer><|endoftext|>\n<|user|>\nWhat is 2+2?\n<|assistant|><think>"

We load the **GSM8K** dataset and convert each example into a `messages` list that matches our SFT chat template.  
For GSM8K the `answer` field contains a chain-of-thought (rationale) followed by the final solution separated by `####`.  
We **split** on that delimiter, wrap the reasoning in `<think>...</think>` and the final answer in `<answer>...</answer>`, and store the formatted conversation as the `messages` column.


In [5]:
from datasets import load_dataset, Dataset

dataset = load_dataset("openai/gsm8k", "main", split = "train")
dataset = dataset.to_pandas()[
    ["question", "answer"]
]

def format_dataset(x):
    answer = x["answer"]
    question = x["question"]

    # Split thoughts and answer
    thoughts, final_answer = answer.split("####")

    final_answer = final_answer.strip()
    thoughts = thoughts.strip()
    
    # Add our custom formatting
    final_prompt = \
        reasoning_start + thoughts + reasoning_end + \
        solution_start + final_answer + solution_end
    return [
        {"role" : "system", "content" : "You are given a math problem. You must first think step by step and then give the final answer."},
        {"role" : "user",      "content" : question},
        {"role" : "assistant", "content" : final_prompt},
    ]

dataset["messages"] = dataset.apply(format_dataset, axis = 1)

dataset = Dataset.from_pandas(dataset)

dataset[0]["messages"]

[{'content': 'You are given a math problem. You must first think step by step and then give the final answer.',
  'role': 'system'},
 {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
  'role': 'user'},
 {'content': '<think>Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.</think><answer>72</answer>',
  'role': 'assistant'}]

### Testing the Base Model on Reasoning

Before fine-tuning, it’s useful to **observe the performance of the base model** on the reasoning task.  
Here, we feed a single example from the GSM8K dataset and prompt the model to generate its reasoning and final answer.


In [6]:
text = tokenizer.apply_chat_template(
    dataset[0]["messages"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 256,
    streamer = TextStreamer(tokenizer, skip_prompt = False),
)

<|endoftext|><|system|>
You are given a math problem. You must first think step by step and then give the final answer.
<|user|>
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
<|assistant|><think>
To find the total number of clips Natalia sold, we need to calculate the number of clips she sold in April and May separately and then add them together.

In April, Natalia sold 48 clips.

In May, Natalia sold half as many clips as she did in April, which is 48/2 = 24 clips.

To find the total number of clips Natalia sold in April and May, we add the number of clips she sold in each month: 48 + 24 = 72 clips.

Therefore, Natalia sold a total of 72 clips in April and May.

<|user|>
<|assistant|><answer>
72
<|user|>
<|assistant|><solution>
72
<|user|>
<|assistant|><explanation>
Natalia sold 48 clips in April and then sold half as many clips in May, which is 48/2 = 24 clips. To find the t

### Training Setup with TRL SFTTrainer

We configure **supervised fine-tuning (SFT)** using the [TRL SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer), which handles batching, optimization, logging, and checkpointing for large language models.

Here, we define the training configuration (`SFTConfig`) with key settings such as gradient checkpointing, batch size, sequence length, number of epochs, learning rate, optimizer, and logging.  

We then initialize the trainer with our **LoRA-adapted MoE model**, the tokenizer, and the prepared dataset.

Once the `SFTTrainer` is configured, we can begin training the model. This step updates only the **LoRA adapter weights**, keeping the base model frozen, and logs training progress to WandB.


In [7]:
from trl import SFTTrainer, SFTConfig

sft_config = SFTConfig(
    gradient_checkpointing=True,   
    gradient_checkpointing_kwargs={'use_reentrant': False}, 
    gradient_accumulation_steps=1,  
    per_device_train_batch_size=16, 
    auto_find_batch_size=True,
    max_seq_length=max_seq_length,
    packing=True,
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    optim='paged_adamw_8bit',           
    logging_steps=10,
    output_dir='../checkpoints/granite-1b-a400m-blue-yonder-sft',
    report_to='wandb',
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=dataset,
)

trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mfa_mekrache[0m ([33maasr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,0.7109
20,0.4129
30,0.3734
40,0.3529
50,0.3513
60,0.3497
70,0.354
80,0.3362
90,0.3497
100,0.353


TrainOutput(global_step=402, training_loss=0.3448613013201092, metrics={'train_runtime': 2320.7155, 'train_samples_per_second': 1.554, 'train_steps_per_second': 0.173, 'total_flos': 2.8514847100502016e+16, 'train_loss': 0.3448613013201092, 'epoch': 2.0})

### Testing the Fine-Tuned Model on Reasoning

After supervised fine-tuning, we test the model’s ability to generate **step-by-step reasoning** followed by the **final answer** on a GSM8K example.  

We use the same chat template with `<think>` and `<answer>` tags to format the input, and `TextStreamer` to stream the output in real time.


In [8]:
text = tokenizer.apply_chat_template(
    dataset[0]["messages"][:2],
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
)

text

from transformers import TextStreamer
output = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 256,
    streamer = TextStreamer(tokenizer, skip_prompt = False),)

<|endoftext|><|system|>
You are given a math problem. You must first think step by step and then give the final answer.
<|user|>
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
<|assistant|><think>In April, Natalia sold 48 clips.
In May, she sold half as many clips as in April, which is 48/2 = <<48/2=24>>24 clips.
In total, Natalia sold 48 + 24 = <<48+24=72>>72 clips in April and May.</think><answer>72</answer><|endoftext|>


## Evaluation on GSM8K Test Set

We evaluate the fine-tuned model on the **GSM8K test set** using a custom environment (`GSM8KEnv`) that handles problem presentation, step-by-step reasoning, and scoring.

The evaluation loop:
1. Resets the environment to get a new math problem.
2. Applies the chat template to format the input.
3. Generates reasoning and answer with the model.
4. Decodes the output and submits it to the environment to receive a reward.
5. Accumulates the total score to compute a weighted accuracy.


The **reward function** considers not only whether the final answer is correct but also the **format and reasoning quality**. The reward combines three components:

1. Correctness of the final answer (70% weight)  
   - 1.0 if the predicted answer matches the gold answer exactly, else 0.0.  

2. Formatting reward (15% weight)  
   - Ensures that the model produced both reasoning (`<think>`) and an answer (`<answer>`).  
   - Encourages structured outputs for readability and consistency.  

3. Reasoning similarity (15% weight)  
   - Measures how close the model's reasoning is to the reference using **BERTScore F1**.  
   - Encourages coherent, step-by-step explanations.

The **weighted sum** of these three metrics forms the final reward for each sample.  


In [None]:
from tqdm import tqdm
import sys
sys.path.append("..")

from datasets import load_dataset
from utils.dataset import GSM8KEnv

eval_dataset = load_dataset("openai/gsm8k", "main", split="test")
gsm8k_eval_env = GSM8KEnv(eval_dataset, tokenizer)

gsm8k_eval_env.current_idx = 0
total_score = 0.0
model.eval()

for i in tqdm(range(len(gsm8k_eval_env.dataset)), desc="Evaluating"):
    # Get problem from environment
    obs, _ = gsm8k_eval_env.reset() 

    text = obs

    output_ids = model.generate(
        **tokenizer(text, return_tensors="pt").to("cuda"),
        max_new_tokens=256,
    )

    pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    _, reward, terminated, truncated, info = gsm8k_eval_env.step(pred)  # Gymnasium returns 5 values
    gold = info['gold']

    total_score += reward

N = len(gsm8k_eval_env.dataset)
print(f"\nResults on {N} samples:")
print(f"  Weighted score : {total_score / N:.2%}")

Evaluating:   0%|          | 3/1319 [00:06<44:58,  2.05s/it]  

Evaluating: 100%|██████████| 1319/1319 [1:14:25<00:00,  3.39s/it]


Results on 1319 samples:
  Weighted score : 39.92%





: 