In [2]:

from datasets import load_dataset
ds = load_dataset('openai/gsm8k', 'main')
dataset = ds['train']

In [3]:
SYSTEM_PROMPT = """
You are an expert mathematical reasoning assistant.  
Your task is to solve word problems by reasoning step by step before providing the final answer.  
A part of the reasoning has been provided. You need to build on that and provide the final answer.

Follow this format:  
#### reasoning  
(Provide a detailed step-by-step solution, showing intermediate calculations.)  

#### answer  
(State the final numerical answer clearly.)  
"""  

def extract_final_answer(text):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()
def get_prompts(example):
    return {
        'messages': [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": f"{example["question"]}\n #### Reasoning : {example['answer'].split('####')[0].strip()}"},
            {"role": "assistant", "content": f"#### Reasoning : {example['answer'].split('####')[0].strip()}\n#### Answer : {extract_final_answer(example['answer'])}"}
        ]
        }
dataset = dataset.map(get_prompts)
dataset.remove_columns(['question', 'answer'])

Dataset({
    features: ['messages'],
    num_rows: 7473
})

In [4]:
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72',
 'messages': [{'content': '\nYou are an expert mathematical reasoning assistant.  \nYour task is to solve word problems by reasoning step by step before providing the final answer.  \nA part of the reasoning has been provided. You need to build on that and provide the final answer.\n\nFollow this format:  \n#### reasoning  \n(Provide a detailed step-by-step solution, showing intermediate calculations.)  \n\n#### answer  \n(State the final numerical answer clearly.)  \n',
   'role': 'system'},
  {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n #### Reasoning

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
model_name = 'OpenLLM-France/Lucie-7B-Instruct'
output_dir = '/mnt/disk/sft-gsm8k'
ft_model_name = 'Lucie-7B-arithmetic'
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
)


tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=128,
    lora_alpha=32,
    lora_dropout=0.2,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ]
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 310,378,496 || all params: 7,017,336,832 || trainable%: 4.4230


In [6]:
trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=SFTConfig(output_dir=output_dir, save_steps=100000, packing=True, logging_steps=50, run_name="'sft-gsm8k", report_to=["wandb"]),
)

trainer.train()

  trainer = SFTTrainer(
[34m[1mwandb[0m: Currently logged in as: [33mhtagourti[0m ([33mhtagourti-linagora[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss
50,0.6978
100,0.4533
150,0.423
200,0.4076
250,0.402
300,0.3886
350,0.3932
400,0.3853
450,0.3668
500,0.3588


TrainOutput(global_step=1233, training_loss=0.37438613925633735, metrics={'train_runtime': 2323.8956, 'train_samples_per_second': 4.237, 'train_steps_per_second': 0.531, 'total_flos': 4.083937167770911e+17, 'train_loss': 0.37438613925633735, 'epoch': 3.0})