In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
from trl import SFTTrainer


dataset = load_dataset("gsm8k", "main")

# Define a function to format the prompt for training
def format_prompt(examples):
    # Format each example as a prompt-response pair
    # The model will be trained to generate the 'answer' given the 'question'
    prompts = []
    for i in range(len(examples['question'])):
        prompt = f"Question: {examples['question'][i].strip()}\nAnswer: {examples['answer'][i].strip()}"
        prompts.append(prompt)
    return {"text": prompts}

# Apply the formatting function to the dataset
formatted_dataset = dataset.map(format_prompt, batched=True)

# Split the dataset into training and evaluation sets (optional, but recommended)
train_dataset = formatted_dataset["train"]
eval_dataset = formatted_dataset["test"]

model_name = "Ashed00/SmolMath-SFT-CoT-AQuA"

# Load the tokenizer and model again, ensuring the pad token is set
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)

# Move the model to the selected device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using device for training: {device}")

# Define training arguments
training_args = TrainingArguments(
    output_dir=f"./SmolMath-SFT-gsm8k-training", # Output directory
    num_train_epochs=3, # Number of training epochs
    per_device_train_batch_size=8, # Batch size per device during training
    save_steps=10_000, # Save checkpoint every X updates steps
    save_total_limit=2, # Limit the total amount of checkpoints
    logging_dir="./logs", # Directory for storing logs
    logging_steps=200,
    learning_rate=1e-5,
    weight_decay=0.01,
    eval_strategy="steps", # Evaluate every X steps
    eval_steps=900,
)


# Use SFTTrainer for supervised fine-tuning
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    args=training_args,

)

# Start training
trainer.train()

trainer.save_model("./SmolMath3-SFT-gsm8k")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from tqdm import tqdm # for a progress bar

# Load the fine-tuned model and tokenizer
model_path = "./SmolMath-SFT-gsm8k"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

# Ensure pad token is set for generation. For some models, especially causal LMs,
# the EOS token is often used as the pad token during batch inference.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Move the model to the selected device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval() # Set model to evaluation mode
print(f"Using device for evaluation: {device}")

# Load the GSM8K test dataset
dataset = load_dataset("gsm8k", "main")
eval_dataset = dataset["test"]

# Function to extract the answer from the model's generation using regex
def extract_answer_from_generation(generation_text):
    # Regex to find '#### <answer>' where <answer> can be any number (integer or float)
    # It also handles potential spaces around '####' and the number.
    match = re.search(r"####\s*(-?\d+(\.\d+)?)", generation_text)
    if match:
        try:
            return float(match.group(1)) # Convert the extracted answer to a float
        except ValueError:
            return None # Return None if conversion fails (shouldn't happen with the regex)
    return None

# Function to extract the ground truth answer
def extract_ground_truth_answer(answer_text):
    # The ground truth answers in GSM8K are already in the "#### <answer>" format
    # We can reuse the same regex for consistency.
    match = re.search(r"####\s*(-?\d+(\.\d+)?)", answer_text)
    if match:
        try:
            return float(match.group(1))
        except ValueError:
            return None
    return None

# Evaluation parameters
batch_size = 16 # Adjust batch size based on your GPU memory
max_new_tokens = 256 # Adjust as needed based on expected answer length

correct_predictions = 0
total_predictions = 0

print("Starting batched evaluation...")

# Create batches manually
for i in tqdm(range(0, len(eval_dataset), batch_size)):
    batch = eval_dataset[i:i+batch_size]
    # Access each example in the batch by its index
    questions = [batch["question"][j] for j in range(len(batch["question"]))]
    ground_truth_answer_texts = [batch["answer"][j] for j in range(len(batch["answer"]))]

    # Format prompts for the current batch
    prompts = [f"Question: {q.strip()}\nAnswer:" for q in questions]

    # Tokenize the batch of prompts
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)

    # Generate answers for the entire batch
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    for j in range(len(prompts)):
        input_len = inputs.input_ids[j].shape[0]
        generated_token_ids = outputs[j, input_len:] # Get only the generated tokens
        generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)

        predicted_answer = extract_answer_from_generation(generated_text)
        ground_truth_answer = extract_ground_truth_answer(ground_truth_answer_texts[j])

        if predicted_answer is not None and ground_truth_answer is not None:
            total_predictions += 1
            if abs(predicted_answer - ground_truth_answer) < 1e-6: # Using a small tolerance for float comparison
                correct_predictions += 1



if total_predictions > 0:
    accuracy = correct_predictions / total_predictions
    print(f"\nEvaluation Complete:")
    print(f"Correct Predictions: {correct_predictions}")
    print(f"Total Predictions: {total_predictions}")
    print(f"Accuracy: {accuracy:.4f}")
else:
    print("No predictions were made or no answers could be extracted for comparison.")

In [3]:
eval_dataset[0]

{'question': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
 'answer': 'Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.\n#### 18'}

In [8]:
input_text = "Question: A canteen requires 62 kgs of wheat for 6 days. How many kgs of wheat will it require for 60 days?\nAnswer: "
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

# Generate text
output = model.generate(input_ids, max_new_tokens=256, num_return_sequences=1, do_sample=False,top_k=50, temperature=0.6, pad_token_id=tokenizer.eos_token_id,)

# Decode and print the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)



In [None]:
print(generated_text)

In [None]:
model.push_to_hub("SmolMath-SFT-gsm8k")
tokenizer.push_to_hub("SmolMath1-SFT-gsm8k")