In [1]:
import numpy as np

import torch
import torch.nn as nn

import minari

from datasets import Dataset
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from trl import GRPOConfig, GRPOTrainer

  from .autonotebook import tqdm as notebook_tqdm


[2025-06-16 09:46:49,508] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/mnt/data_2/abenechehab/micromamba/envs/rlft4rl/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/mnt/data_2/abenechehab/micromamba/envs/rlft4rl/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status


INFO 06-16 09:46:50 __init__.py:183] Automatically detected platform cuda.


2025-06-16 09:46:50,885	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


### Model

In [2]:
class MLPConfig(PretrainedConfig):
    model_type = "halfcheetah-mlp"

    def __init__(self, input_dim=17, output_dim=6, hidden_sizes=[64, 64], **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_sizes = hidden_sizes

class MLPPolicy(PreTrainedModel):
    config_class = MLPConfig  # enable AutoModel support
    base_model_prefix = "halfcheetah-mlp"

    def __init__(self, config: MLPConfig):
        super().__init__(config)
        layers = []
        dims = [config.input_dim] + config.hidden_sizes + [config.output_dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)
        self.mean = nn.Sequential(nn.Linear(dims[-1], config.output_dim), nn.ReLU())
        self.log_std = nn.Sequential(nn.Linear(dims[-1], config.output_dim), nn.ReLU())
        # For stochastic policy: a trainable log_std parameter per action
        # self.log_std = nn.Parameter(torch.zeros(config.output_dim))

    def forward(self, state, num_logits_to_keep=None):
        # state shape: (batch_size, 17)
        if state.shape[-1] == self.config.output_dim:
            return state
        elif state.shape[-1] == self.config.input_dim:
            state = state.to(self.net[0].weight.device)
            common = self.net(state)  # (batch_size, 6)
            mean = self.mean(common)
            std = torch.exp(self.log_std(common))
            # dist = torch.distributions.Normal(mean, std)
            # action = dist.rsample()     
            # return action
            return mean, std
        else:    
            raise ValueError(
                f"Expected input dimension {self.config.input_dim}, but got {state.shape[-1]}"
            )


    def generate(self, inputs=None, num_return_sequences=1, **kwargs):
        """
        Generate action samples given a batch of inputs (states).

        Args:
            inputs (torch.Tensor): input tensor of shape [batch_size, obs_dim]
            num_return_sequences (int): number of action samples per input

        Returns:
            torch.Tensor: tensor of shape [batch_size * num_return_sequences, action_dim]
        """

        mean, std = self.forward(kwargs["input_ids"].float())
        dist = torch.distributions.Normal(mean, std)

        actions = []
        for _ in range(num_return_sequences):
            sampled = dist.sample()
            actions.append(sampled)

        # shape: [num_return_sequences, batch_size, action_dim] â†’ [batch_size * num_return_sequences, action_dim]
        all_actions = torch.cat(actions, dim=0)
        return all_actions
    

# Register the configuration
AutoConfig.register("halfcheetah-mlp", MLPConfig)
# Register the model
AutoModel.register(MLPConfig, MLPPolicy)

def create_mlp_model(obs_dim, action_dim, hidden_dims=None, **kwargs):
    config = MLPConfig(
        obs_dim=obs_dim, action_dim=action_dim, hidden_dims=hidden_dims, **kwargs
    )
    return MLPPolicy(config)


def load_mlp_model(model_path):
    config = MLPConfig.from_pretrained(model_path)
    return MLPPolicy.from_pretrained(model_path, config=config)

### Tokenizer (dummy)

In [3]:
from typing import Optional, Tuple
from transformers import PreTrainedTokenizer


class MLPTokenizer(PreTrainedTokenizer):
    def __init__(self, obs_dim, **kwargs):
        self.obs_dim = obs_dim
        self.eos_token = 0
        self.eos_token_id = 0
        self.pad_token = 0
        self.pad_token_id = 0
        super().__init__(**kwargs)

    def _tokenize(self, text):
        # extract observation from str and make it array
        obs_list = text.replace("[", "").replace("]", "").split(",")
        obs_list = [float(obs) for obs in obs_list if obs.strip()]
        if len(obs_list) != self.obs_dim:
            raise ValueError(f"Expected {self.obs_dim} observations, got {len(obs_list)}")
        # Convert each element to its own numpy array
        return [str(elem) for elem in obs_list]

    def _convert_token_to_id(self, token):
        return np.array(token).astype(np.float32)

    def get_vocab(self):
        return {}

    @property
    def vocab_size(self):
        return 1  # Dummy value
    
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]:
        return ()


### Data

In [4]:
dataset_id = "mujoco/halfcheetah/medium-v0"
seed = 7
dataset_size = 5

dataset = minari.load_dataset(dataset_id, download=True)
dataset.set_seed(seed=seed)

examples = []
for i, ep in enumerate(dataset):
    for t in range(ep.observations.shape[0] - 1):
        examples.append(
            {
                "prompt": f"{ep.observations[t].tolist()}",
                "state": ep.observations[t].tolist(),
                "action": ep.actions[t].tolist(),
                "reward": float(ep.rewards[t]),
            }
        )
        if i==0 and t==0:
            input_dim = len(ep.observations[t])
            output_dim = len(ep.actions[t])
    if i >= dataset_size:
        break

hf_dataset = Dataset.from_list(examples)

### Custom GRPO Trainer

In [5]:
from typing import Callable, Optional, Union
from datasets import Dataset, IterableDataset

from torch.distributions import Normal

from transformers import (
    PreTrainedModel,
    PreTrainedTokenizerBase,
    TrainerCallback,
)

from transformers.utils import is_peft_available

from trl.models.utils import unwrap_model_for_generation


if is_peft_available():
    from peft import PeftConfig

# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]


In [6]:
class CustomGRPOTrainer(GRPOTrainer):
    def __init__(
        self,
        model: Union[str, PreTrainedModel],
        reward_funcs: Union[RewardFunc, list[RewardFunc]],
        args: GRPOConfig = None,
        train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
        eval_dataset: Optional[
            Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]
        ] = None,
        processing_class: Optional[PreTrainedTokenizerBase] = None,
        reward_processing_classes: Optional[
            Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]
        ] = None,
        callbacks: Optional[list[TrainerCallback]] = None,
        optimizers: tuple[
            Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]
        ] = (None, None),
        peft_config: Optional["PeftConfig"] = None,
    ):
        super().__init__(
            model=model,
            reward_funcs=reward_funcs,
            args=args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            reward_processing_classes=reward_processing_classes,
            callbacks=callbacks,
            optimizers=optimizers,
            peft_config=peft_config
        )
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        if return_outputs:
            raise ValueError("The GRPOTrainer does not support returning outputs")

        device = self.accelerator.device
        prompts = [x["prompt"] for x in inputs]
        # prompts_text = [
        #     maybe_apply_chat_template(example, self.processing_class)["prompt"]
        #     for example in inputs
        # ]
        prompt_inputs = self.processing_class(
            prompts, #prompts_text,
            return_tensors="pt",
            padding=True,
            padding_side="left",
            add_special_tokens=False,
        )
        prompt_inputs = super()._prepare_inputs(prompt_inputs)

        if self.max_prompt_length is not None:
            prompt_inputs["input_ids"] = prompt_inputs["input_ids"][
                :, -self.max_prompt_length :
            ]
            prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][
                :, -self.max_prompt_length :
            ]

        # Generate actions using the model
        with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
            generated_actions = unwrapped_model.generate(
                **prompt_inputs, num_return_sequences=self.num_generations
            )

        # Get the distribution parameters (mean, std) for the generated actions
        def get_gaussian_params_and_logprobs(model, input_ids, actions):
            """
            Get Gaussian distribution parameters and log probabilities for actions.

            Args:
                model: The policy model
                input_ids: Input states (observations)
                actions: Generated actions to compute log probs for

            Returns:
                mean, std, log_probs
            """
            mean, std = model(input_ids.float())  # Get distribution parameters
            # print(f"mean {mean.shape}, std {std.shape}")
            dist = Normal(mean, std)

            # raise ValueError("test")
            # Reshape actions to match batch size if needed
            
            # Handle num_generations > 1 case
            actions = actions.view(-1, mean.shape[0], mean.shape[1])
            # actions = actions.squeeze(1)  # Remove middle dimension if it's 1

            log_probs = dist.log_prob(actions).sum(dim=-1)  # Sum over action dimensions
            return mean, std, log_probs

        # Get current policy distribution parameters and log probabilities
        current_mean, current_std, current_log_probs = get_gaussian_params_and_logprobs(
            model, prompt_inputs["input_ids"], generated_actions
        )

        # Get reference policy distribution parameters and log probabilities
        # with torch.inference_mode():
        if self.ref_model is not None:
            # print(f"using ref model: {self.ref_model}")
            ref_mean, ref_std, ref_log_probs = get_gaussian_params_and_logprobs(
                self.ref_model, prompt_inputs["input_ids"], generated_actions
            )
        else:
            # print("not using ref model")
            with self.accelerator.unwrap_model(model).disable_adapter():
                ref_mean, ref_std, ref_log_probs = get_gaussian_params_and_logprobs(
                    model, prompt_inputs["input_ids"], generated_actions
                )
        ref_log_probs = ref_log_probs.clone()
        ref_mean = ref_mean.clone()
        ref_std = ref_std.clone()

        kl_div = torch.distributions.kl.kl_divergence(
                Normal(current_mean, current_std),
                Normal(ref_mean, ref_std),
        )
        
        # raise ValueError("kl_div shape: ", kl_div.shape)
        kl_div = kl_div.sum(dim=-1)

        # Decode the generated actions (convert back to list format for reward computation)
        if generated_actions.dim() == 1:
            generated_actions = generated_actions.unsqueeze(0)
        completions = [action.cpu().numpy().tolist() for action in generated_actions]

        # Compute the rewards
        prompts_repeated = [
            prompt for prompt in prompts for _ in range(self.num_generations)
        ]

        rewards_per_func = torch.zeros(
            len(prompts_repeated), len(self.reward_funcs), device=device
        )
        for i, (reward_func, reward_processing_class) in enumerate(
            zip(self.reward_funcs, self.reward_processing_classes)
        ):
            # Handle function-based rewards
            reward_kwargs = {
                key: []
                for key in inputs[0].keys()
                if key not in ["prompt", "completion"]
            }
            for key in reward_kwargs:
                for example in inputs:
                    reward_kwargs[key].extend([example[key]] * self.num_generations)
            output_reward_func = reward_func(
                prompts=prompts_repeated, completions=completions, **reward_kwargs
            )
            rewards_per_func[:, i] = torch.tensor(
                output_reward_func, dtype=torch.float32, device=device
            )

        # Sum the rewards from all reward functions
        rewards = rewards_per_func.sum(dim=1)

        # Compute grouped-wise rewards (group by original prompt)
        mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
        std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)

        # Normalize the rewards to compute advantages
        mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
            self.num_generations, dim=0
        )
        std_grouped_rewards = std_grouped_rewards.repeat_interleave(
            self.num_generations, dim=0
        )
        advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

        # Compute the policy gradient loss
        # For continuous actions, we use the log probability ratio directly
        ratio = torch.exp(current_log_probs - ref_log_probs.detach())
        # raise ValueError("ratio shape: ", ratio.shape, " advantages shape: ", advantages.shape, " kl_div shape: ", kl_div.shape)
        policy_loss = -(ratio * advantages.reshape(ratio.shape) - self.beta * kl_div)
        loss = policy_loss.mean()

        # Log the metrics
        action_dim = generated_actions.shape[-1] if generated_actions.dim() > 1 else 1
        self._metrics["action_dimension"].append(action_dim)

        reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                reward_func_name = reward_func.config._name_or_path.split("/")[-1]
            else:
                reward_func_name = reward_func.__name__
            self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())

        self._metrics["reward"].append(
            self.accelerator.gather_for_metrics(rewards).mean().item()
        )
        self._metrics["reward_std"].append(
            self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item()
        )
        self._metrics["kl"].append(
            self.accelerator.gather_for_metrics(kl_div).mean().item()
        )

        return loss

### Reward fn

In [13]:
def reward_fn(completions, state, action, reward, **kwargs):
    rewards = []
    for i, _ in enumerate(completions):
        rewards.append(reward[i])
    return rewards

def BC_reward_fn(completions, state, action, reward, **kwargs):
    rewards = []
    for i, _ in enumerate(completions):
        rewards.append(-np.linalg.norm(np.array(completions[i]) - np.array(action[i]), ord=2))
    return rewards


### Training

In [None]:
# Configure training arguments
training_args = GRPOConfig(
    output_dir="../../models/halfcheetah-mlp-grpo",
    learning_rate=1e-4,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    num_train_epochs=5,
    beta=0.01,  # KL penalty coefficient
    num_generations=4,  # number of actions to sample per state (for group advantage)
    max_steps=10000,
)

# Create your model
use_llm = False
if not use_llm:
    model = create_mlp_model(
        obs_dim=input_dim, action_dim=output_dim, hidden_dims=[256, 256]
    )
    # Create tokenizer (if required)
    tokenizer = MLPTokenizer(obs_dim=input_dim)
else:
    llm = "Qwen/Qwen3-0.6B"
    model = AutoModelForCausalLM.from_pretrained(llm)
    tokenizer = AutoTokenizer.from_pretrained(llm)

trainer = CustomGRPOTrainer(
    model=model,  # our MLP wrapped as PreTrainedModel
    processing_class=tokenizer,  # dummy tokenizer
    train_dataset=hf_dataset,  # dataset of states (and optionally a 'prompt' col)
    reward_funcs=BC_reward_fn,  # a function returning list of rewards
    args=training_args,
)


In [16]:
trainer.train()

Step,Training Loss
500,-0.1448
1000,-0.1487
1500,-0.1472
2000,-0.1472
2500,-0.1469
3000,-0.1501
3500,-0.1509
4000,-0.1543
4500,-0.1527
5000,-0.155


TrainOutput(global_step=10000, training_loss=-0.15066619873046874, metrics={'train_runtime': 809.6634, 'train_samples_per_second': 395.226, 'train_steps_per_second': 12.351, 'total_flos': 0.0, 'train_loss': -0.15066619873046874})

In [11]:
%debug

> [32m/tmp/ipykernel_19357/4179726797.py[39m([92m5[39m)[36mreward_fn[39m[34m()[39m
[32m      2[39m     rewards = []
[32m      3[39m     [38;5;28;01mfor[39;00m i, _ [38;5;28;01min[39;00m enumerate(completions):
[32m      4[39m         rewards.append(reward[i])
[32m----> 5[39m     [38;5;28;01mraise[39;00m NotImplementedError([33m"reward"[39m)
[32m      6[39m     [38;5;28;01mreturn[39;00m rewards

64
15.539128748929897
15.539128748929897
[15.539128748929897, 15.539128748929897, 15.539128748929897, 15.539128748929897, -0.7794300646493412]
*** AttributeError: 'list' object has no attribute 'shape'
64
15.539128748929897
[[0.9796957969665527, 0.6177719831466675, 0.9426043033599854, 0.976677417755127, -0.7610951066017151, 0.5502623319625854], [0.9796957969665527, 0.6177719831466675, 0.9426043033599854, 0.976677417755127, -0.7610951066017151, 0.5502623319625854], [0.9796957969665527, 0.6177719831466675, 0.9426043033599854, 0.976677417755127, -0.7610951066017151, 0.55