In [None]:
from datasets import load_dataset

In [None]:
dataset = load_dataset("BatsResearch/planetarium")
#check structure
print(dataset)

In [None]:
#split the training set for SFT + GRPO. 80 20
test_data = dataset["test"]
#shuffle training set
train_data = dataset["train"].shuffle(seed=42)

# Compute split sizes (80% GRPO, 20% SFT)
total_train_size = len(train_data)
sft_size = int(0.2 * total_train_size)   # 20% for SFT
grpo_size = total_train_size - sft_size  # 80% for GRPO

# Perform the split
sft_data = train_data.select(range(sft_size))  # First 20% for SFT
grpo_data = train_data.select(range(sft_size, total_train_size))  # Remaining 80% for GRPO

# Print confirmation
print(f"SFT: {len(sft_data)} samples")
print(f"GRPO: {len(grpo_data)} samples")
print(f"Test: {len(test_data)} samples (unchanged)")

Configurations for SFT finetuning on the 20% of training set

In [None]:
#login to huggingface
from huggingface_hub import login
login(token='')

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

# Load model & tokenizer
model_name = "google/gemma-2b-it" 
#perhaps try a different model
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Apply LoRA Configuration (Matching Paper)
lora_config = LoraConfig(
    r=16,               # LoRA rank
    lora_alpha=32,      # Scaling factor
    lora_dropout=0.05,  # LoRA Dropout
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)


In [None]:
sft_training_args = TrainingArguments(
    output_dir="./sft_model",
    per_device_train_batch_size=1,  # Batch Size = 1
    learning_rate=2e-5,  # Learning Rate = 2e-5
    optim="adamw_torch",  # Optimizer = AdamW_torch
    betas=(0.9, 0.999),  # Betas = (0.9, 0.999)
    eps=1e-8,  # Epsilon = 1e-8
    weight_decay=0.01,  # Weight Decay = 0.01
    max_length=1500,  # Max Sequence Length = 1500
    num_train_epochs=3, # epochs, NEED HELP TO ADJUST THIS
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none",
    bf16=True,  # faster?
    gradient_accumulation_steps=1,  # Since batch size is 1
)

In [None]:
# train
sft_trainer = Trainer(
    model=model,
    args=sft_training_args,
    train_dataset=sft_data, 
    tokenizer=tokenizer,
)
sft_trainer.train()


In [None]:
# save model for further use
sft_trainer.save_model("./sft_model")  # Save fine-tuned model
tokenizer.save_pretrained("./sft_model")  # Save tokenizer


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead

# Load the SFT fine-tuned model for GRPO training
sft_model_path = "./sft_model"
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    sft_model_path, device_map="auto", torch_dtype="auto"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)


REWARD FUNCTION

In [None]:
import os
import sys

import planetarium.evaluate
import planetarium

def custom_reward_function(completions, ground_truths, domain_str=None, **kwargs):
    rewards = []
    
    for generated_pddl, ground_truth_pddl in zip(completions, ground_truths):
        parseable, solvable, equivalent = planetarium.evaluate(ground_truth_pddl, generated_pddl)
        if equivalent:
            reward = 1.0  # Correct PDDL -> highest reward
        elif solvable:
            reward = 0.5  # Solvable but incorrect -> somewhere in between reward
        elif parseable:
            reward = 0.2  # low
        else:
            reward = 0.0  # bad
        rewards.append(reward)

    return rewards


In [None]:
## grpo trainer from huggingface

from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

# Load the fine-tuned SFT model (previously saved at ./sft_model)
sft_model_path = "./sft_model"
model = AutoModelForCausalLM.from_pretrained(sft_model_path, device_map="auto", torch_dtype="auto")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)

# Define GRPO Training Configuration
grpo_training_args = GRPOConfig(
    output_dir="./grpo_model", 
    learning_rate=1e-6,
    logging_steps=10,
    per_device_train_batch_size=16,  # Process multiple samples at once
    gradient_accumulation_steps=2,
    max_length=512,  # Ensure max sequence length is within limits
    num_generations=8,  # Generate multiple outputs per query
)


In [None]:
#reshape dataset
print(grpo_data)
#we need to make it into prompt/repsonse pairs
from datasets import Dataset
formatted_grpo_data = Dataset.from_dict({
    "prompts": grpo_data["natural_language"],  # Input NL description
    "ground_truth": grpo_data["problem_pddl"]  # Correct PDDL output
})

print(formatted_grpo_data)

In [None]:
# train ze model

# Initialize trainer
grpo_trainer = GRPOTrainer(
    model="path/to/sft_model", #insert SFT PATH HERE
    reward_funcs=custom_reward_function,
    train_dataset=formatted_grpo_data,
    args=grpo_training_args
)

# Train with GRPO
grpo_trainer.train()

# Save the trained model
grpo_trainer.model.save_pretrained("./grpo_finetuned_model")
grpo_trainer.tokenizer.save_pretrained("./grpo_finetuned_model")


MODEL RESULTS

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load fine-tuned GRPO model
model_path = "./grpo_finetuned_model"
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Generate predictions
def generate_pddl(prompts):
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")
    outputs = model.generate(**inputs, max_length=1500) # dont know this paramter
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)

# Apply to test set
generated_pddls = generate_pddl(test_data["prompts"])



In [None]:
# Evaluate test set performance
parseable_count = 0
solvable_count = 0
correct_count = 0

# Loop through test samples
for generated_pddl, ground_truth_pddl in zip(generated_pddls, test_data["ground_truth"]):
    parseable, solvable, equivalent = planetarium.evaluate(ground_truth_pddl, generated_pddl)
    parseable_count += 1 if parseable else 0
    solvable_count += 1 if solvable else 0
    correct_count += 1 if equivalent else 0

# Compute percentages
total = len(test_data)
parseable_pct = (parseable_count / total) * 100
solvable_pct = (solvable_count / total) * 100
correct_pct = (correct_count / total) * 100

# Print results
print(f"Parseable: {parseable_pct:.2f}%")
print(f"Solvable: {solvable_pct:.2f}%")
print(f"Correct PDDL: {correct_pct:.2f}%")
