In [None]:
%load_ext autoreload
%autoreload 2

# 1. 修改model的load, 是base+lora的形式, 仅train lora 1
# 2. dataset不仅返回prompt_ids, 还有pixe values. 2
# 3. forward算logit时, action id, 也得转为embeddings. 3
# 4. 确定轨迹的stream length 4 
# 5. 先拿一个batch的winner, 再batch前向传播, 得到loser轨迹 

In [None]:
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor
from peft import PeftModel, PeftConfig
from src.data_process import TrajectoryDataset
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
import sys
sys.path.append("../..")
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import draccus
import numpy as np
import tqdm
from experiments.robot.robot_utils import get_model
from libero.libero import benchmark
import wandb
from experiments.robot.libero.libero_utils import (
    get_libero_dummy_action,
    get_libero_env,
    get_libero_image,
    quat2axisangle,
    save_rollout_video_CoA,
)
from experiments.robot.openvla_utils import get_processor, get_input
from experiments.robot.robot_utils import (
    DATE_TIME,
    get_action,
    get_CoA,
    get_image_resize_size,
    get_model,
    get_vla_via_lora,
    invert_gripper_action,
    normalize_gripper_action,
    set_seed_everywhere,
)

from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import torch
from torch.utils.data import Dataset, IterableDataset
from transformers import PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin
from trl import DPOConfig, DPOTrainer
from typing import Union

In [None]:
@dataclass
class GenerateConfig:

    # fmt: off
    vla_path: str = "openvla/openvla-7b" 
    # root_dir: str = "/hdd/zijianwang"
    root_dir: str = "/mnt/sda/home/zijianwang"


    #################################################################################################################
    # LoRA parameters
    #################################################################################################################
    use_lora: bool = True
    lora_rank: int = 48
    lora_dropout: float = 0.0
    
    #################################################################################################################
    # Model-specific parameters
    #################################################################################################################
    model_family: str = "openvla"                    # Model family

    dataset_name: str = "libero_10_no_noops"

    pretrained_checkpoint: Union[str, Path] = os.path.join(root_dir, "openvla/FT_res/openvla-7b-finetuned-libero-10+libero_10_no_noops+b4+lr-0.0005+lora-r48+dropout-0.0--image_aug--2025-07-18_19-26-25")     # Pretrained checkpoint path
    lora_path: str = os.path.join(root_dir, "openvla/adapter_tmp_dir/openvla-7b-finetuned-libero-10+libero_10_no_noops+b4+lr-0.0005+lora-r48+dropout-0.0--image_aug--2025-07-18_19-26-25")
    base_vla_path: str = os.path.join(root_dir, "HF_CACHE/openvla-7b-finetuned-libero-10")

    winner_trajectory_path: str = os.path.join(root_dir, "openvla/vla-scripts/DPO/winner_trajectory")

    adapter_tmp_dir: str = os.path.join(root_dir, "openvla/adapter_tmp_dir")
    run_root_dir: str = os.path.join(root_dir, "openvla/DPO_res")

    #################################################################################################################
    load_in_8bit: bool = False                       # (For OpenVLA only) Load with 8-bit quantization
    load_in_4bit: bool = False                       # (For OpenVLA only) Load with 4-bit quantization
    center_crop: bool = True                         # Center crop? (if trained w/ random crop image aug)
    #################################################################################################################
    # Training parameters
    #################################################################################################################
    batch_size: int = 4
    grad_accumulation_steps: int = 1
    learning_rate: float = 0.0005
    max_steps: int = 10000
    dpo_beta: float = 0.1
    #################################################################################################################
    # LIBERO environment-specific parameters
    #################################################################################################################
    task_suite_name: str = "libero_10"          # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
    num_steps_wait: int = 10                         # Number of steps to wait for objects to stabilize in sim
    num_trials_per_task: int = 50                    # Number of rollouts per task
    unnorm_key = task_suite_name
    #################################################################################################################
    # Utils
    #################################################################################################################
    run_id_note: Optional[str] = None                # Extra note to add in run ID for logging
    local_log_dir: str = "./experiments/logs"        # Local directory for eval logs

    use_wandb: bool = False                          # Whether to also log results in Weights & Biases
    wandb_project: str = "YOUR_WANDB_PROJECT"        # Name of W&B project to log to (use default!)
    wandb_entity: str = "YOUR_WANDB_ENTITY"          # Name of entity to log under

    seed: int = 7                                    # Random Seed (for reproducibility)

    device: str = "cuda:2"


    wandb_project = "openvla_CoA_DPO"
    wandb_entity = "15652388600"

    # fmt: on

In [None]:
def setup_model_and_config(cfg: GenerateConfig):
    """Setup and validate configuration, then load the model."""
    assert cfg.pretrained_checkpoint is not None, "cfg.pretrained_checkpoint must not be None!"
    if "image_aug" in cfg.pretrained_checkpoint:
        assert cfg.center_crop, "Expecting `center_crop==True` because model was trained with image augmentations!"
    assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8-bit and 4-bit quantization!"

    # Set random seed
    set_seed_everywhere(cfg.seed)

    cfg.unnorm_key = cfg.task_suite_name

    # Load model
    model = get_model(cfg)
    
    return model

def setup_logging_and_environment(cfg: GenerateConfig, model):
    """Setup logging and LIBERO environment."""
    # [OpenVLA] Check that the model contains the action un-normalization key
    if cfg.model_family == "openvla":
        # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
        # with the suffix "_no_noops" in the dataset name)
        if cfg.unnorm_key not in model.norm_stats and f"{cfg.unnorm_key}_no_noops" in model.norm_stats:
            cfg.unnorm_key = f"{cfg.unnorm_key}_no_noops"
        assert cfg.unnorm_key in model.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"

    # [OpenVLA] Get Hugging Face processor
    processor = None
    if cfg.model_family == "openvla":
        processor = get_processor(cfg)

    # Initialize local logging
    run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}"
    if cfg.run_id_note is not None:
        run_id += f"--{cfg.run_id_note}"
    os.makedirs(cfg.local_log_dir, exist_ok=True)
    local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
    log_file = open(local_log_filepath, "w")
    print(f"Logging to local log file: {local_log_filepath}")

    # Initialize Weights & Biases logging as well
    if cfg.use_wandb:
        wandb.init(
            entity=cfg.wandb_entity,
            project=cfg.wandb_project,
            name=run_id,
        )

    # Initialize LIBERO task suite
    benchmark_dict = benchmark.get_benchmark_dict()
    task_suite = benchmark_dict[cfg.task_suite_name]()
    num_tasks_in_suite = task_suite.n_tasks
    print(f"Task suite: {cfg.task_suite_name}")
    log_file.write(f"Task suite: {cfg.task_suite_name}\n")

    # Get expected image dimensions
    resize_size = get_image_resize_size(cfg)

    return processor, log_file, task_suite, num_tasks_in_suite, resize_size

In [None]:
"""Main function to run the OpenVLA LIBERO inference demo."""
print("[*] Starting OpenVLA LIBERO Inference Demo")

# Initialize configuration
model_cfg = GenerateConfig(device = "cuda:3")

# Setup model and configuration
print("[*] Loading model and setting up configuration...")
# model = setup_model_and_config(cfg)

model = get_vla_via_lora(model_cfg)

In [None]:
# Setup logging and environment
print("[*] Setting up logging and environment...")
processor, log_file, task_suite, num_tasks_in_suite, resize_size = setup_logging_and_environment(model_cfg, model)

In [None]:
ref_config = GenerateConfig(device = "cuda:0")
print(ref_config.device)
ref_model = get_model(ref_config)

In [None]:
# Create dataset instance
dataset = TrajectoryDataset(model_cfg, model_cfg.winner_trajectory_path, model_cfg.task_suite_name, processor, device = model_cfg.device, model = model, img_size = resize_size, stream_length = 3)

# dataset只返回"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6], 
# attention_mask在DataCollatorForPreference生成, labels在dpotrainer.concatenated_forward中生成. 

In [None]:
from trl.trainer.dpo_trainer import DataCollatorForPreference
from torch.utils.data import DataLoader
model_cfg = GenerateConfig()
data_collator = DataCollatorForPreference(pad_token_id = processor.tokenizer.pad_token_id)
train_dataloader = DataLoader(
    dataset,
    batch_size=model_cfg.batch_size,
    shuffle=True,
    collate_fn=data_collator  # 需要返回上述格式的batch
)

In [None]:
for batch in train_dataloader:
    print(batch['chosen_input_ids'])
    break

In [None]:
print(batch.keys())
print(batch['chosen_input_ids'].shape)

In [None]:
import torch
with torch.no_grad():
    pred_test = model.forward(input_ids = batch["prompt_input_ids"], attention_mask = batch["chosen_attention_mask"], pixel_values=batch["pixel_values"])

print(pred_test.keys())

In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm

def compute_log_probs(model, prompt_input_ids, completion_input_ids, prompt_attention_mask, completion_attention_mask, model_device=None):
    """计算模型对completion部分的log概率"""
    # 确保输入数据在正确的device上
    if model_device is None:
        model_device = next(model.parameters()).device
    
    prompt_input_ids = prompt_input_ids.to(model_device)
    completion_input_ids = completion_input_ids.to(model_device)
    prompt_attention_mask = prompt_attention_mask.to(model_device)
    completion_attention_mask = completion_attention_mask.to(model_device)
    
    # 拼接prompt和completion
    input_ids = torch.cat([prompt_input_ids, completion_input_ids], dim=1)
    attention_mask = torch.cat([prompt_attention_mask, completion_attention_mask], dim=1)
    
    # 前向传播
    with torch.no_grad() if model.training == False else torch.enable_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
        logits = outputs.logits
    
    # 计算completion部分的log概率
    prompt_len = prompt_input_ids.shape[1]
    completion_logits = logits[:, prompt_len-1:-1, :]  # 获取completion部分的logits
    completion_labels = completion_input_ids
    
    # 计算每个token的log概率
    log_probs = F.log_softmax(completion_logits, dim=-1)
    # 获取对应label的log概率
    selected_log_probs = torch.gather(log_probs, dim=-1, index=completion_labels.unsqueeze(-1)).squeeze(-1)
    # 对非padding的token求和
    sequence_log_probs = (selected_log_probs * completion_attention_mask).sum(dim=-1)
    
    return sequence_log_probs

def dpo_loss(policy_chosen_logps, policy_rejected_logps, 
             ref_chosen_logps, ref_rejected_logps, beta=0.1, target_device=None):
    """计算DPO损失，处理不同device上的tensor"""
    
    # 确定目标device (通常是policy model的device)
    if target_device is None:
        target_device = policy_chosen_logps.device
    
    # 将所有tensor移动到目标device
    policy_chosen_logps = policy_chosen_logps.to(target_device)
    policy_rejected_logps = policy_rejected_logps.to(target_device)
    ref_chosen_logps = ref_chosen_logps.to(target_device)
    ref_rejected_logps = ref_rejected_logps.to(target_device)
    
    # 计算log比率
    chosen_logratios = policy_chosen_logps - ref_chosen_logps
    rejected_logratios = policy_rejected_logps - ref_rejected_logps
    
    # DPO损失: -log(sigmoid(beta * (chosen_logratios - rejected_logratios)))
    logits = chosen_logratios - rejected_logratios
    losses = -F.logsigmoid(beta * logits)
    
    # 计算奖励（用于监控）
    chosen_rewards = beta * chosen_logratios.detach()
    rejected_rewards = beta * rejected_logratios.detach()
    
    return losses.mean(), chosen_rewards, rejected_rewards

def move_batch_to_device(batch, device):
    """将batch中的所有tensor移动到指定device"""
    moved_batch = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            moved_batch[key] = value.to(device)
        else:
            moved_batch[key] = value
    return moved_batch

# 训练循环
def train_dpo(model, ref_model, train_dataloader, cfg, if_not_demo = False):
    """DPO训练主循环，支持不同device上的模型"""

    # Configure Unique Experiment ID & Log Directory
    exp_id = (
        f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
        f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
        f"+lr-{cfg.learning_rate}"
    )
    if cfg.use_lora:
        exp_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
    import datetime
    exp_id += f"--{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    if cfg.run_id_note is not None:
        exp_id += f"--{cfg.run_id_note}"

    run_dir, adapter_dir = os.path.join(cfg.run_root_dir, exp_id), os.path.join(cfg.adapter_tmp_dir, exp_id)
    if if_not_demo:
        wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=f"dpo+{exp_id}")
    
    # 获取模型所在的device
    policy_device = next(model.parameters()).device
    ref_device = next(ref_model.parameters()).device
    
    print(f"Policy model device: {policy_device}")
    print(f"Reference model device: {ref_device}")
    
    # 设置优化器
    optimizer = AdamW(model.parameters(), lr=cfg.learning_rate)
    
    # 确保参考模型不更新
    ref_model.eval()
    for param in ref_model.parameters():
        param.requires_grad = False

    from collections import deque
    # Deque to store recent train metrics (used for computing smoothened metrics for gradient accumulation)
    recent_losses = deque(maxlen=cfg.grad_accumulation_steps)
    recent_accuracies = deque(maxlen=cfg.grad_accumulation_steps)
    recent_rewards_margin = deque(maxlen=cfg.grad_accumulation_steps)
    recent_chosen_logps = deque(maxlen=cfg.grad_accumulation_steps)
    recent_rejected_logps = deque(maxlen=cfg.grad_accumulation_steps)
    recent_chosen_rewards = deque(maxlen=cfg.grad_accumulation_steps)
    recent_rejected_rewards = deque(maxlen=cfg.grad_accumulation_steps)
    
    # 训练循环
    with tqdm(total = cfg.max_steps, leave=False) as progress:
        model.train()
        
        for batch_idx, batch in enumerate(train_dataloader):
            # 将batch移动到policy model的device
            policy_batch = move_batch_to_device(batch, policy_device)
            
            # 1. 计算策略模型的log概率 (在policy_device上)
            policy_chosen_logps = compute_log_probs(
                model, 
                policy_batch['prompt_input_ids'], 
                policy_batch['chosen_input_ids'], 
                policy_batch['prompt_attention_mask'], 
                policy_batch['chosen_attention_mask'],
                model_device = policy_device
            )
            policy_rejected_logps = compute_log_probs(
                model, 
                policy_batch['prompt_input_ids'], 
                policy_batch['rejected_input_ids'],
                policy_batch['prompt_attention_mask'], 
                policy_batch['rejected_attention_mask'],
                model_device = policy_device
            )
            
            # 2. 计算参考模型的log概率 (在ref_device上)
            with torch.no_grad():
                # 将batch移动到ref model的device
                ref_batch = move_batch_to_device(batch, ref_device)
                
                ref_chosen_logps = compute_log_probs(
                    ref_model, 
                    ref_batch['prompt_input_ids'], 
                    ref_batch['chosen_input_ids'],
                    ref_batch['prompt_attention_mask'], 
                    ref_batch['chosen_attention_mask'],
                    model_device = ref_device
                )
                ref_rejected_logps = compute_log_probs(
                    ref_model, 
                    ref_batch['prompt_input_ids'], 
                    ref_batch['rejected_input_ids'],
                    ref_batch['prompt_attention_mask'], 
                    ref_batch['rejected_attention_mask'],
                    model_device = ref_device
                )
            
            # 3. 计算DPO损失 (在policy_device上)
            loss, chosen_rewards, rejected_rewards = dpo_loss(
                policy_chosen_logps, policy_rejected_logps,
                ref_chosen_logps, ref_rejected_logps,
                beta = cfg.dpo_beta, target_device = policy_device
            )
            
            # 4. 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            # 梯度裁剪（可选）
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0)
            optimizer.step()
            progress.update()

            # 5. 计算指标
            accuracy = (chosen_rewards > rejected_rewards).float().mean()
            reward_margin = (chosen_rewards - rejected_rewards).mean()

            recent_losses.append(loss.item())
            recent_accuracies.append(accuracy.item())
            recent_rewards_margin.append(reward_margin.item())
            recent_chosen_logps.append(policy_chosen_logps.mean().item())
            recent_rejected_logps.append(policy_rejected_logps.mean().item())
            recent_chosen_rewards.append(chosen_rewards.mean().item())
            recent_rejected_rewards.append(rejected_rewards.mean().item())

            # Compute smoothened train metrics
            #   =>> Equal to current step metrics when not using gradient accumulation
            #   =>> Otherwise, equal to the average of metrics observed over micro-batches used for gradient accumulation
            smoothened_loss = sum(recent_losses) / len(recent_losses)
            smoothened_accuracy = sum(recent_accuracies) / len(recent_accuracies)
            smoothened_rewards_margin = sum(recent_rewards_margin) / len(recent_rewards_margin)
            smoothened_chosen_logps = sum(recent_chosen_logps) / len(recent_chosen_logps)
            smoothened_rejected_logps = sum(recent_rejected_logps) / len(recent_rejected_logps)
            smoothened_chosen_rewards = sum(recent_chosen_rewards) / len(recent_chosen_rewards)
            smoothened_rejected_rewards = sum(recent_rejected_rewards) / len(recent_rejected_rewards)

            # Compute gradient step index
            gradient_step_idx = batch_idx // cfg.grad_accumulation_steps

            if gradient_step_idx % 1 == 0:
                if if_not_demo:
                    wandb.log({
                        "loss": smoothened_loss,
                        "accuracy": smoothened_accuracy,
                        "reward_margin": smoothened_rewards_margin,
                        "chosen_logps": smoothened_chosen_logps,
                        "rejected_logps": smoothened_rejected_logps,
                        "chosen_rewards": smoothened_chosen_rewards,
                        "rejected_rewards": smoothened_rejected_rewards
                    },
                    step=gradient_step_idx,
                    )
            
            if gradient_step_idx % 10 == 0:
                model.save_pretrained(adapter_dir)
                print(f"Saved adapter to {adapter_dir}, batch_idx: {batch_idx}")
                
            # 定期清理缓存（如果使用GPU）
            if gradient_step_idx % 10 == 0:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    


In [None]:
# 使用示例，处理不同device的情况
if __name__ == "__main__":
    # 示例：将ref_model放在CPU上以节省GPU内存
    
    # 或者在多GPU环境下，将模型放在不同GPU上
    # model = model.to('cuda:0')      # 策略模型在GPU 0
    # ref_model = ref_model.to('cuda:1')  # 参考模型在GPU 1

    model_cfg = GenerateConfig(max_steps = 10000, grad_accumulation_steps = 1, dpo_beta = 0.1)
    # 开始训练
    train_dpo(model, ref_model, train_dataloader, model_cfg, if_not_demo = False)

# ---------------