In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import asdict, dataclass

from areal.api.agent_args import AgentGRPOConfig, load_expr_config

args = ["--config", "AReaL/examples/lite/configs/sokoban_grpo.yaml"]
config, _ = load_expr_config(args, AgentGRPOConfig)
config: AgentGRPOConfig

In [None]:
config

In [None]:
from areal.utils.network import find_free_ports

SGLANG_PORT, MASTER_PORT = 11451, 14514

SGLANG_HOST = "127.0.0.1"

# Environment variables used by inference/train engines
import os

os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{SGLANG_HOST}:{SGLANG_PORT}"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["LOCAL_RANK"] = str(0)

In [None]:
import subprocess
import sys

# 启动sglang server
from areal.api.cli_args import SGLangConfig
from areal.utils.network import find_free_ports

config.sglang.log_level = "info"
config.sglang.decode_log_interval = 10
sglang_cmd = SGLangConfig.build_cmd(
    config.sglang,
    tp_size=1,
    base_gpu_id=1,
    host=SGLANG_HOST,
    port=SGLANG_PORT,
)
sglang_process = subprocess.Popen(
    sglang_cmd,
    shell=True,
    stdout=sys.stdout,
    stderr=sys.stderr,
)

In [None]:
import asyncio
import functools
import os
import time
import uuid

import colorama
import torch
from tensordict import TensorDict
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from areal.api.cli_args import GenerationHyperparameters
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
    AllocationMode,
    FinetuneSpec,
    LLMRequest,
    WeightUpdateMeta,
)
from areal.api.workflow_api import RolloutWorkflow
from areal.engine.ppo.actor import FSDPPPOActor
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.utils.data import concat_padded_tensors
from areal.utils.device import log_gpu_stats

tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
config.envs

In [None]:
from torchdata.stateful_dataloader import StatefulDataLoader
from areal.dataset.multi_env_dataset import build_env_dataset
rank=0
world_size = 1
train_dataset = build_env_dataset(
        config.envs, split="train", base_seed=config.seed, rank=rank, world_size=world_size
    )
dataloader = StatefulDataLoader(
        train_dataset,
        batch_size=config.train_dataset.batch_size // world_size,
        shuffle=config.train_dataset.shuffle,
        num_workers=config.train_dataset.num_workers,
        collate_fn=lambda x: x,
        drop_last=config.train_dataset.drop_last,
    )


from itertools import cycle

data_generator = cycle(dataloader)

ft_spec = FinetuneSpec(
    total_train_epochs=config.total_train_epochs,
    dataset_size=len(dataloader) * config.train_dataset.batch_size,
    train_batch_size=config.train_dataset.batch_size,
)

x = next(data_generator)


In [None]:
print(x[0])
print(x[0].keys())

In [None]:
os.environ["AREAL_DEBUG_TOKEN_ALIGN"] = "1"

In [None]:
# initialize inference engine
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.workflow.multi_turn_agent_env_workflow import MultiTurnAgentEnvWorkflow
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)
try:
    # TODO: create workflow
    workflow = MultiTurnAgentEnvWorkflow(
        gconfig=GenerationHyperparameters(n_samples=3,max_new_tokens=512),
        tokenizer=tokenizer,
        max_turns=3,
        dump_dir="./test"
    )
    sample_data = next(data_generator)[:2]
    res = rollout.rollout_batch(sample_data, workflow=workflow)
    print(res)
finally:
    rollout.destroy()

In [None]:
workflow = MultiTurnAgentEnvWorkflow(
        gconfig=GenerationHyperparameters(n_samples=3,max_new_tokens=512),
        tokenizer=tokenizer,
        max_turns=3,
        dump_dir="./test"
    )
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)

rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)

weight_update_meta = WeightUpdateMeta.from_fsdp_nccl(
    AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), actor
)

warmup_steps = 1
times = []
for global_step in range(5):
    if global_step >= warmup_steps:
        tik = time.perf_counter()
    batch = rollout.rollout_batch(next(data_generator), workflow=workflow)
    batch = batch.to(actor.device)

    logp = actor.compute_logp(batch)
    batch["prox_logp"] = logp

    actor.compute_advantages(batch)

    stats = actor.ppo_update(batch)
    actor.step_lr_scheduler()

    rollout.pause()
    future = rollout.update_weights(weight_update_meta)
    actor.upload_weights(weight_update_meta)
    future.result()
    torch.cuda.synchronize()
    rollout.resume()

    actor.set_version(global_step + 1)
    rollout.set_version(global_step + 1)
    if global_step >= warmup_steps:
        times.append(time.perf_counter() - tik)
print(times)

In [None]:
workflow = MultiTurnAgentEnvWorkflow(
        gconfig=GenerationHyperparameters(n_samples=3,max_new_tokens=512),
        tokenizer=tokenizer,
        max_turns=3,
        dump_dir="./test"
    )
actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)

rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)

weight_update_meta = WeightUpdateMeta.from_fsdp_nccl(
    AllocationMode.from_str("sglang.d1p1t1+d1p1t1"), actor
)
weight_update_meta.nccl_group_name = "group2"

warmup_steps = 1
times = []
for global_step in range(5):
    if global_step >= warmup_steps:
        tik = time.perf_counter()
    batch = rollout.prepare_batch(dataloader, workflow=workflow)
    batch = batch.to(actor.device)

    logp = actor.compute_logp(batch)
    batch["prox_logp"] = logp

    actor.compute_advantages(batch)

    stats = actor.ppo_update(batch)
    actor.step_lr_scheduler()

    rollout.pause()
    future = rollout.update_weights(weight_update_meta)
    actor.upload_weights(weight_update_meta)
    future.result()
    torch.cuda.synchronize()
    rollout.resume()

    actor.set_version(global_step + 1)
    rollout.set_version(global_step + 1)
    if global_step >= warmup_steps:
        times.append(time.perf_counter() - tik)
print(times)

In [None]:
import signal as signal_module

import psutil


def terminate_process_and_children(pid: int, signal=None):
    if signal is None:
        signal = signal_module.SIGKILL
    if isinstance(signal, str):
        signal = getattr(signal_module, signal)
    try:
        parent = psutil.Process(pid)
        children = parent.children(recursive=True)
        for child in children:
            terminate_process_and_children(child.pid)
        parent.send_signal(signal)
    except psutil.NoSuchProcess:
        pass


terminate_process_and_children(sglang_process.pid)

In [None]:
rollout.destroy()

In [None]:
actor.destroy()