In [1]:
import pickle

In [2]:
ls pkls

data_list_.pkl0    data_list_.pkl2    minibatches_.pkl0  minibatches_.pkl2
data_list_.pkl1    data_list_.pkl3    minibatches_.pkl1  minibatches_.pkl3


In [5]:
minibatch = []
data_list = []
for i in range(4):

    with open(f'pkls/minibatches_.pkl{i}', 'rb') as f:
        minibatch.append(pickle.load(f))

    with open(f'pkls/data_list_.pkl{i}', 'rb') as f:
        minibatch.append(pickle.load(f))

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [7]:
import torch

In [10]:
q = torch.rand(6,4,236, 128)
cos = torch.rand(6,4,118, 128)

q * cos

RuntimeError: The size of tensor a (236) must match the size of tensor b (118) at non-singleton dimension 2

### actor

In [None]:
from collections import defaultdict
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions, get_model_state_dict
)
from transformers import AutoModelForCausalLM
from RL2.workers import Worker
from RL2.utils.sequences import count_total
from RL2.utils.ring_attn import update_params_of_ring_attn
from RL2.utils.functions import (
    compute_logsumexp,
    gather_action_logits,
    compute_entropy
)
from RL2.utils.algorithms import (
    compute_approx_kl, compute_surrogate_loss
)
from RL2.utils.offloading import load_model_to_device
from RL2.utils.logging import (
    progress_bar,
    time_logger,
    gather_and_reduce,
    rank0_log
)


class Actor(Worker):

    def __init__(self, config, train: bool):
        super().__init__(config, train)
        
        if config.use_liger_kernel:
            assert config.tp_size == 1, \
                "Liger kernel is not compatible with tensor parallelism."
            from liger_kernel.transformers import AutoLigerKernelForCausalLM
            model_cls = AutoLigerKernelForCausalLM
        else:
            model_cls = AutoModelForCausalLM

        self.model = model_cls.from_pretrained(
            config.model_name,
            trust_remote_code=True,
            attn_implementation="flash_attention_2"
        )

        self.prepare_model_optimizer()

    def forward(self, minibatch, return_entropy=False):
        update_params_of_ring_attn(
            minibatch["cu_seqlens"], self.device_mesh["sp"]
        )

        logits = self.model(
            input_ids=minibatch["states"],
            position_ids=minibatch["position_ids"],
            use_cache=False
        ).logits.to(torch.float32) / getattr(
            self.config, "temperature", 1.0
        )
        # bfloat16 is unstable for the subsequent `logsumexp` operation.
        # See https://github.com/OpenRLHF/OpenRLHF/pull/634.
        
        logsumexp = compute_logsumexp(logits, self.device_mesh["tp"])
        action_logits = gather_action_logits(
            logits,
            minibatch["actions"],
            self.device_mesh["tp"]
        )
        logps = (action_logits - logsumexp) * minibatch["action_mask"]
        
        if return_entropy:
            entropy = compute_entropy(
                logits, logsumexp, self.device_mesh["tp"]
            ) * minibatch["action_mask"]
            return logps, entropy
        else:
            return logps

    @time_logger("compute_logps")
    @torch.no_grad()
    def compute_logps(self, data_list, step):
        load_model_to_device(self, torch.cuda.current_device())
        
        len_data_list = len(data_list) if isinstance(data_list, list) else None
        
        minibatches = self.scatter_and_pack_data_list(data_list)

        import pickle

        if data_list is not None:
                
            with open(f"data_list_.pkl{dist.get_rank()}", "wb") as f:
                pickle.dump(data_list, f)
            
        with open(f"minibatches_.pkl{dist.get_rank()}", "wb") as f:
            pickle.dump(minibatches, f)
        
        print(f' RANK {dist.get_rank()} before scatter {len_data_list} mini_batches {len(minibatches)}')
        prefix = "old" if self.train else "ref"

        self.model.eval()
        for minibatch in progress_bar(
            minibatches, desc=f"Compute {prefix} logps"
        ):
            minibatch[f"{prefix}_logps"] = self.forward(minibatch)
        
        if not self.train:
            load_model_to_device(self, "cpu")
        return self.unpack_and_gather_data_list(minibatches) 
    
    @time_logger("update_actor")
    def update(self, data_list, step: int):
        
        if step < self.config.freeze_steps:
            load_model_to_device(self, "cpu")
            return
        load_model_to_device(self, torch.cuda.current_device())
        batches = self.scatter_and_pack_data_list(data_list, True)

        self.model.train()
        tbar = progress_bar(
            total=sum([len(batch) for batch in batches]),
            desc="Update actor"
        )
        metrics = defaultdict(list)
        for batch in batches:
            
            total_actions = count_total(batch, "action_mask", self.device_mesh)
            total_sequences = count_total(batch, "eos_mask", self.device_mesh)
            metric = defaultdict(list)
            for minibatch in batch:

                logps, entropy = self.forward(minibatch, return_entropy=True)
                surrogate_loss, clip_ratio = compute_surrogate_loss(
                    self, logps, minibatch, total_actions, total_sequences
                )
                entropy_loss = - entropy.sum() / total_actions
                loss = surrogate_loss + self.config.entropy.coef * entropy_loss

                if self.config.kl.coef > 0 and self.config.kl.type == "loss":
                    kl_loss = compute_approx_kl(
                        logps,
                        minibatch["ref_logps"],
                        self.config.kl.loss_estimator
                    ).sum() / total_actions
                    loss = loss + self.config.kl.coef * kl_loss

                self.backward(loss)

                tbar.update()
                metric["actor/entropy_loss"].append(entropy_loss.item())
                metric["actor/loss"].append(loss.item())
                metric["actor/clip_ratio"].append(clip_ratio.item())

            grad_norm = self.optimizer_step()

            for k, v in metric.items():
                metrics[k].append(
                    gather_and_reduce(v, self.device_mesh)
                )
            metrics["actor/grad_norm"].append(grad_norm)

        rank0_log(metrics, step)

        options = StateDictOptions(full_state_dict=False, cpu_offload=True)
        state_dict = get_model_state_dict(
            self.model, options=options
        )
        load_model_to_device(self, "cpu")
        return state_dict

### rollout.py

In [None]:
from omegaconf import OmegaConf
import os
import json
import asyncio
import importlib
from collections import defaultdict
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.utils import MultiprocessingSerializer
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
from tqdm.asyncio import tqdm
import wandb
from RL2.workers import Worker
from RL2.datasets import tokenize_messages
from RL2.utils.comm import split_and_scatter_list, gather_and_concat_list
from RL2.utils.logging import time_logger, gather_and_log


class Rollout(Worker):

    def __init__(self, config):
        super().__init__(config, None)
        
        self.prepare_environment_variables()
        if self.device_mesh["tp"].get_local_rank() == 0:
            self.prepare_environment()

            os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
            self.llm = Engine(
                model_path=config.model_name,
                dtype="bfloat16",
                tp_size=self.device_mesh["tp"].size(),
                mem_fraction_static=config.gpu_memory_utilization,
                enable_memory_saver=True,
                port=30000 + dist.get_rank()
            )
        
            self.train_sampling_params = OmegaConf.to_container(
                config.train_sampling_params
            )
            self.test_sampling_params = OmegaConf.to_container(
                config.test_sampling_params
            )

        dist.barrier()

    def prepare_device_mesh(self):

        world_size = dist.get_world_size()
        assert world_size % self.config.tp_size == 0, \
            f"World_size {world_size} must be divisible by tp_size {self.config.tp_size}."
        self.dp_size = world_size // self.config.tp_size
        self.device_mesh = dist.device_mesh.init_device_mesh(
            "cpu",
            mesh_dim_names=("dp", "tp"),
            mesh_shape=(self.dp_size, self.config.tp_size)
        )

    def prepare_environment_variables(self):

        if "TORCHELASTIC_USE_AGENT_STORE" in os.environ.keys():
            del os.environ["TORCHELASTIC_USE_AGENT_STORE"]
        monkey_patch_torch_reductions()
        cuda_visible_devices = self.device_mesh["tp"].size() * [None]
        dist.all_gather_object(
            cuda_visible_devices,
            os.environ["LOCAL_RANK"],
            self.device_mesh["tp"].get_group()
        )
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_visible_devices)

    def prepare_environment(self):

        spec = importlib.util.spec_from_file_location(
            "custom_module", self.config.env_path
        )
        self.env = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(self.env)
        
    async def rollout(self, ex, train):

        messages, answer = ex["messages"], ex["answer"]
        metric = defaultdict(list)
        for turn in range(self.config.max_turns):

            if self.config.apply_chat_template:
                prompt = self.tokenizer.apply_chat_template(
                    messages,
                    add_generation_prompt=True,
                    tokenize=False
                )
            else:
                prompt = "".join([
                    msg["content"] for msg in messages
                ])
            
            response = await self.llm.async_generate(
                prompt,
                sampling_params=self.train_sampling_params
                if train else self.test_sampling_params
            )

            meta_info = response["meta_info"]
            metric["response_length"].append(meta_info["completion_tokens"])
            metric["length_clip_ratio"].append(
                meta_info["finish_reason"]["type"] == "length"
            )

            # Current SGLang engine will generate sequence longer than 
            # `max_new_tokens`.
            # TODO (P1): Check whether all configurations are properly set 
            # and whether the bug has been fixed in the latest version.
            content = self.tokenizer.decode(
                self.tokenizer.encode(
                    response["text"], add_special_tokens=False
                )[:meta_info["completion_tokens"]]
            )
            messages.append(
                {"role": "assistant", "content": content}
            )

            # Do not invoke tools in the last turn.
            if turn + 1 == self.config.max_turns:
                break

            env_messages = self.env.interact(messages)
            # Terminate if no tool is invoked.
            if len(env_messages) == 0:
                break

            messages.extend(env_messages)

        reward = self.env.reward_fn(messages, answer)

        ex = tokenize_messages(
            self.tokenizer,
            messages,
            self.config.apply_chat_template
        )
        ex.update({
            "rewards": torch.FloatTensor((ex["states"].shape[-1] - 1) * [0] + [reward]),
            "eos_mask": torch.LongTensor((ex["states"].shape[-1] - 1) * [0] + [1])
        })

        metric["n_turns"].append(turn + 1)
        metric["rewards"].append(reward)
        metric["trajectory_length"].append(len(ex["states"]))

        return ex, messages, metric

    @time_logger("rollout")
    def __call__(self, data_list, train: bool, step: int):

        # The data is distributed from rank 0 before each worker operation
        # and gathered before the next operation, which facilitates to do
        # model-agnostic operations, e.g., computing advantages, globally 
        # and guarantees the load balancing across all model computations.
        if self.device_mesh["tp"].get_local_rank() == 0:

            data_list = split_and_scatter_list(
                data_list, self.device_mesh["dp"]
            )
            # print(f' rank {dist.get_rank()} rollout data list { len(data_list) if isinstance(data_list, list) else None} model weights wq, ')
            loop = asyncio.get_event_loop()
            outputs = loop.run_until_complete(
                tqdm.gather(
                    *(self.rollout(ex, train) for ex in data_list),
                    desc="Rollout", position=1, leave=False,
                    disable=(dist.get_rank() != 0)
                )
            )
            if train:
                # If test, llm will soon be called again. See `Trainer.train`.
                self.llm.release_memory_occupation()

        dist.barrier()

        if self.device_mesh["tp"].get_local_rank() == 0:

            data_list, all_messages, metrics = map(list, zip(*outputs))

            if dist.get_rank() == 0:
                tqdm.write(json.dumps(all_messages[0], indent=4))

            suffix = "train" if train else "test"
            metrics = {
                f"{k}/{suffix}": sum([metric[k] for metric in metrics], [])
                for k in metrics[0].keys()
            }
            gather_and_log(metrics, self.device_mesh["dp"], step)

            if not train:
                return

            data_list = gather_and_concat_list(
                data_list, self.device_mesh["dp"]
            )

            print(f' rank {dist.get_rank()} after collecting data list { len(data_list) if isinstance(data_list, list) else None} model weights wq, ')
            if dist.get_rank() == 0:
                if not self.config.dynamic_filtering:
                    return data_list

                rewards = torch.FloatTensor(
                    [ex["rewards"].sum() for ex in data_list]
                ).view(-1, self.config.responses_per_prompt)
                are_filtered = (rewards.std(-1) == 0).tolist()
                wandb.log({
                    "dynamic_filtering_ratio": sum(are_filtered) / len(are_filtered)
                }, step=step)
                return sum([
                    data_list[idx * self.config.responses_per_prompt:(idx + 1) * self.config.responses_per_prompt]
                    for idx, is_filtered in enumerate(are_filtered)
                    if not is_filtered
                ], [])
        
    @time_logger("update_rollout")
    def update(self, state_dict, step):

        torch.cuda.empty_cache()
        # or llm.resume_memory_occupation() may OOM
        if self.device_mesh["tp"].get_local_rank() == 0:
            self.llm.resume_memory_occupation()
        
        for idx, (name, tensor) in enumerate(state_dict.items()):
            tensor = tensor.to(torch.cuda.current_device())
            serialized_tensor = MultiprocessingSerializer.serialize(
                tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
            )
            serialized_tensors = [
                None for _ in range(self.device_mesh["tp"].size())
            ] if self.device_mesh["tp"].get_local_rank() == 0 else None
            dist.gather_object(
                serialized_tensor,
                serialized_tensors,
                group_dst=0,
                group=self.device_mesh["tp"].get_group(),
            )
            if self.device_mesh["tp"].get_local_rank() == 0:
                self.llm.update_weights_from_tensor(
                    named_tensors=[(
                        name, LocalSerializedTensor(values=serialized_tensors)
                    )],
                    flush_cache=(idx == len(state_dict) - 1)
                )
        dist.barrier()

### ppo.yml

In [None]:
data:
  train_data_path: null
  test_data_path: null
  prompts_per_rollout: null
  responses_per_prompt: null
  
actor:
  model_name: null
  tokenizer_name: ${actor.model_name}
  use_liger_kernel: false
  gradient_checkpointing: true
  ddp_size: 1
  tp_size: 2
  sp_size: 1
  max_length_per_device: null
  max_inference_length_per_device: ${actor.max_length_per_device}
  temperature: ${rollout.train_sampling_params.temperature}
  update_per_rollout: 3
  clip: 0.2
  agg_mode: all_token_mean
  lr: 1e-6
  weight_decay: 1e-2
  max_grad_norm: 1.0
  scheduler: constant
  warmup_ratio: 0.1
  freeze_steps: 0
  offload_model: true
  offload_optimizer: true

  kl:
    coef: 0.0
    type: null # `reward` or `loss`
    reward_estimator: k1
    loss_estimator: k2
    # `k1`, `k2` or `k3`. See http://joschu.net/blog/kl-approx.html.

  entropy:
    coef: 0.0

rollout:
  model_name: ${actor.model_name}
  tokenizer_name: ${rollout.model_name}
  tp_size: 2
  gpu_memory_utilization: 0.5
  responses_per_prompt: ${data.responses_per_prompt}
  apply_chat_template: true
  train_sampling_params:
    temperature: 1.0
    max_new_tokens: null
  test_sampling_params:
    temperature: 0.0
    max_new_tokens: ${rollout.train_sampling_params.max_new_tokens}
  max_turns: 1
  env_path: null
  dynamic_filtering: false

ref_actor:
  model_name: ${actor.model_name}
  tokenizer_name: ${ref_actor.model_name}
  use_liger_kernel: ${actor.use_liger_kernel}
  ddp_size: ${actor.ddp_size}
  tp_size: ${actor.tp_size}
  sp_size: ${actor.sp_size}
  max_inference_length_per_device: ${actor.max_length_per_device}
  temperature: ${rollout.train_sampling_params.temperature}
  offload_model: ${actor.offload_model}

critic:
  model_name: ${actor.model_name}
  tokenizer_name: ${critic.model_name}
  gradient_checkpointing: ${actor.gradient_checkpointing}
  ddp_size: ${actor.ddp_size}
  tp_size: ${actor.tp_size}
  sp_size: ${actor.sp_size}
  max_length_per_device: ${actor.max_length_per_device}
  max_inference_length_per_device: ${critic.max_length_per_device}
  update_per_rollout: 12
  clip: 0.5
  agg_mode: ${actor.agg_mode}
  lr: 5e-6
  weight_decay: ${actor.weight_decay}
  max_grad_norm: ${actor.max_grad_norm}
  scheduler: ${actor.scheduler}
  warmup_ratio: ${actor.warmup_ratio}
  offload_model: ${actor.offload_model}
  offload_optimizer: ${actor.offload_optimizer}

adv:
  estimator: reinforce # `reinforce` or `gae`
  gamma: 1.0
  lamda: 1.0
  global_norm: false
  norm_var: false
  
trainer:
  project: null
  experiment_name: null
  load_ckpt_from: null
  n_epochs: 1
  test_freq: null
  save_dir: ckpts/${trainer.experiment_name}
  save_freq: null
  use_wandb: false
  