In [2]:
from models.action_decoder import FlowMatchingActionExpert

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import numpy as np

# ============================================================
# 0. 설정: 더미 차원 / 배치 크기 (QwenVLAUnified V2 기준)
# ============================================================
BATCH_SIZE = 2
HORIZON    = 8       # action horizon (QwenVLAUnified.horizon)
ACTION_DIM = 7       # (dx,dy,dz,dRx,dRy,dRz,gripper)

SV = 16              # vision token length (더미, 실제 VLM 토큰 수와 무관)
# SS는 실제 파이프라인에서는 쓰지 않고, 센서+로봇 상태는 한 벡터로 합쳐서 들어감

# Qwen2.5-VL-3B hidden size ~ 2048 가정 (코드 기본값과 맞춤)
IMAGE_FEAT_DIM   = 2048   # vl_hidden_size → image_feature_dim
TEXT_GUIDE_DIM   = 2048   # vl_hidden_size → text_guidance_dim

# 센서 + 로봇 상태 통합 차원 (sensor_output_dim(1024) + robot_state_output_dim(1024))
SENSOR_FEAT_DIM  = 2048   # combined_sensor_dim

HIDDEN_DIM       = 1024   # QwenVLAUnified(hidden_dim)

device = "cpu"  # 필요하면 "cuda"로 변경


# ============================================================
# 1. 모델 인스턴스 생성 (FlowMatchingActionExpert V2)
#    └ QwenVLAUnified에서 사용하는 설정과 일치시킴
# ============================================================
flow_model = FlowMatchingActionExpert(
    image_feature_dim = IMAGE_FEAT_DIM,
    text_guidance_dim = TEXT_GUIDE_DIM,
    sensor_dim        = SENSOR_FEAT_DIM,
    action_dim        = ACTION_DIM,
    horizon           = HORIZON,
    hidden_dim        = HIDDEN_DIM,
    nhead             = 8,
    num_decoder_layers= 4,    # Unified에서 기본 4 layer
    time_embed_dim    = 256,
    dropout           = 0.1,
    sigma_min         = 1e-4,
).to(device)

# 센서 토큰을 메모리에 포함 (Unified V2 기본값 True)
flow_model.use_sensor_tokens = True


# ============================================================
# 2. 더미 입력 생성
#    - actions: GT action 시퀀스 (Flow loss용) → (B,H,7)
#    - context_features: VLM image tokens → (B,Sv,2048)
#    - guidance_vector: text guidance → (B,2048)
#    - sensor_features: sensor+robot 통합 벡터 → (B,2048)
# ============================================================
torch.manual_seed(0)

actions = torch.randn(BATCH_SIZE, HORIZON, ACTION_DIM, device=device)
context_features = torch.randn(BATCH_SIZE, SV, IMAGE_FEAT_DIM, device=device)
guidance_vector  = torch.randn(BATCH_SIZE, TEXT_GUIDE_DIM, device=device)

# 실제 Unified 파이프라인과 동일하게 2D 벡터 (B, SENSOR_FEAT_DIM)
sensor_features  = torch.randn(BATCH_SIZE, SENSOR_FEAT_DIM, device=device)

print("=== Dummy Input Shapes ===")
print(f"actions          : {tuple(actions.shape)}   (B,H,A)")
print(f"context_features : {tuple(context_features.shape)}   (B,Sv,D_img)")
print(f"guidance_vector  : {tuple(guidance_vector.shape)}   (B,D_text)")
print(f"sensor_features  : {tuple(sensor_features.shape)}   (B,D_sensor)")
print()


# ============================================================
# 3. FlowMatchingActionExpert 내부 shape 디버깅용 forward
# ============================================================
def debug_flow_forward(
    model: FlowMatchingActionExpert,
    x_t: torch.Tensor,
    t: torch.Tensor,
    context_features: torch.Tensor,
    guidance_vector: torch.Tensor,
    sensor_features: torch.Tensor,
):
    """
    FlowMatchingActionExpert.forward와 거의 동일한 로직을 따르되
    중간 shape들을 전부 출력하는 디버그용 함수.
    """

    print("======== [DEBUG FLOW FORWARD] ========")
    print(f"x_t               : {tuple(x_t.shape)}   (B,H,A)")
    print(f"t                 : {tuple(t.shape)}     (B,)  scalar in [0,1]")
    print(f"context_features  : {tuple(context_features.shape)}   (B,Sv,D_img)")
    print(f"guidance_vector   : {tuple(guidance_vector.shape)}    (B,D_text)")
    print(f"sensor_features   : {tuple(sensor_features.shape)}    (B,D_sensor)")
    print("--------------------------------------")

    B, H, _ = x_t.shape

    # 1) 시간 + 텍스트 가이던스 임베딩
    time_raw = model.sinusoidal_time_embedding(t)        # (B, time_embed_dim)
    print(f"time_raw (sinusoidal) : {tuple(time_raw.shape)}   (B,time_embed_dim={model.time_embed_dim})")

    t_embed = model.time_mlp(time_raw)                  # (B, hidden_dim)
    t_embed = model.time_norm(t_embed)
    print(f"t_embed (after MLP+LN): {tuple(t_embed.shape)}   (B,hidden_dim={model.hidden_dim})")

    guidance_embed = model.text_guidance_proj(guidance_vector)  # (B,hidden_dim)
    guidance_embed = model.guidance_norm(guidance_embed)
    print(f"guidance_embed        : {tuple(guidance_embed.shape)}   (B,hidden_dim)")

    cond_embed = guidance_embed + t_embed               # (B,hidden_dim)
    print(f"cond_embed (guide+time): {tuple(cond_embed.shape)}   (B,hidden_dim)")
    print("--------------------------------------")

    # 2) 디코더 입력 토큰(tgt) + 포지셔널
    tgt = model.action_embed(x_t)                       # (B,H,hidden_dim)
    print(f"tgt (after action_embed): {tuple(tgt.shape)}   (B,H,hidden_dim)")

    tgt = tgt + model.tgt_pos[:, :H]                    # (1,H,D) broadcast
    print(f"tgt (after tgt_pos)      : {tuple(tgt.shape)}   (B,H,hidden_dim)")

    # 3) 메모리(비전 + 센서) 구성
    assert context_features.dim() == 3, "context_features는 (B,Sv,D_img) 형태여야 함"
    vision_mem = model.context_proj(context_features)   # (B,Sv,hidden_dim)
    print(f"vision_mem (proj)       : {tuple(vision_mem.shape)}   (B,Sv,hidden_dim)")

    # 타입 임베딩: 0 = vision
    vision_mem = vision_mem + model.token_type_embed.weight[0].view(1,1,-1)
    print(f"vision_mem (+type 0)    : {tuple(vision_mem.shape)}")

    if model.use_sensor_tokens:
        assert sensor_features is not None, "use_sensor_tokens=True이면 sensor_features 필요"

        # Unified 파이프라인처럼 2D 벡터 (B,D) → (B,1,D) 토큰으로 확장
        if sensor_features.dim() == 2:
            sensor_tok = model.sensor_proj(sensor_features).unsqueeze(1)  # (B,1,D)
        else:
            sensor_tok = model.sensor_proj(sensor_features)               # (B,Ss,D)
        print(f"sensor_tok (after proj): {tuple(sensor_tok.shape)}")

        # 타입 임베딩: 1 = sensor
        sensor_tok = sensor_tok + model.token_type_embed.weight[1].view(1,1,-1)
        print(f"sensor_tok (+type 1)   : {tuple(sensor_tok.shape)}")

        Sv = vision_mem.size(1)
        Ss = sensor_tok.size(1)
        # 포지셔널 임베딩
        vision_mem = vision_mem + model.mem_pos[:, :Sv]
        if Ss <= model.mem_pos.size(1):
            sensor_tok = sensor_tok + model.mem_pos[:, :Ss]
        print(f"vision_mem (+mem_pos)  : {tuple(vision_mem.shape)}")
        print(f"sensor_tok (+mem_pos)  : {tuple(sensor_tok.shape)}")

        memory = torch.cat([vision_mem, sensor_tok], dim=1)  # (B,Sv+Ss,D)
    else:
        Sv = vision_mem.size(1)
        vision_mem = vision_mem + model.mem_pos[:, :Sv]
        memory = vision_mem

    print(f"memory (vision+sensor) : {tuple(memory.shape)}   (B,S,hidden_dim)")
    print("--------------------------------------")

    # 4) cond를 tgt에 주입
    conditioned_tgt = tgt + cond_embed.unsqueeze(1)     # (B,H,D)
    print(f"conditioned_tgt        : {tuple(conditioned_tgt.shape)}")

    # 5) causal mask 생성
    tgt_mask = None
    if model.causal_self_attn:
        H_ = conditioned_tgt.size(1)
        tgt_mask = torch.triu(torch.ones(H_, H_, device=conditioned_tgt.device), diagonal=1).bool()
        print(f"tgt_mask              : {tuple(tgt_mask.shape)}   (H,H) causal")

    # 6) ModulatedDecoderLayer 스택 적용
    x = conditioned_tgt
    for layer_idx, layer in enumerate(model.mod_layers):
        print(f"\n--- ModulatedDecoderLayer {layer_idx} ---")
        print(f" input x     : {tuple(x.shape)}")
        x = layer(x, memory, cond_embed, tgt_mask=tgt_mask, memory_key_padding_mask=None)
        print(f" output x    : {tuple(x.shape)}")

    decoder_output = x
    print("\nDecoder output         :", tuple(decoder_output.shape))

    # 7) 출력 헤드 (velocity)
    velocity = model.output_head(decoder_output)
    print("velocity (final output):", tuple(velocity.shape), "  (B,H,A)")
    print("=========================================\n")

    return velocity


# ============================================================
# 4. OT-CFM까지 포함한 전체 forward + loss 디버깅
# ============================================================
with torch.no_grad():
    # (1) 스케일 정규화 → flow 브리지 생성
    actions_n = actions / flow_model.action_scale  # (B,H,A)
    print("actions_n (normalized) :", tuple(actions_n.shape))

    x_t, u_t, t_scalar = flow_model.flow.compute_flow_and_target(actions_n)
    print("x_t (bridge state)     :", tuple(x_t.shape), "   (B,H,A)")
    print("u_t (target velocity)  :", tuple(u_t.shape), "   (B,H,A)")
    print("t_scalar               :", tuple(t_scalar.shape), "   (B,) scalar in [0,1]")
    print()

    # (2) 디버그 forward (내부 shape 전부 출력)
    v_pred = debug_flow_forward(
        flow_model,
        x_t,
        t_scalar,
        context_features,
        guidance_vector,
        sensor_features,
    )

    # (3) 실제 loss 계산과 동일한 형태로 loss 확인
    lam = (1.0 - t_scalar).clamp_(min=0.0)
    lam = lam.view(lam.size(0), -1).mean(dim=1)      # (B,)
    lam = lam.view(-1, 1, 1).clamp(min=flow_model.min_lambda)

    loss_per_elem = torch.nn.functional.smooth_l1_loss(v_pred, u_t, reduction='none')
    weighted = lam * loss_per_elem
    loss = weighted.mean(dim=(1,2)).mean()

    print("v_pred shape           :", tuple(v_pred.shape))
    print("Smooth L1 loss (scalar):", float(loss))


# ============================================================
# 5. 샘플링 경로(dim 확인)
# ============================================================
with torch.no_grad():
    sampled_actions = flow_model.sample(
        context_features=context_features,
        guidance_vector=guidance_vector,
        sensor_features=sensor_features,
        batch_size=BATCH_SIZE,
        num_steps=6,
        method='rk4',
    )
    print("\n=== Sampling ===")
    print("sampled_actions :", tuple(sampled_actions.shape), "   (B,H,A)")


✅ FlowMatchingActionExpert V2 (Cross-Attention + ModulatedDecoder) 초기화 완료
   4개의 ModulatedDecoderLayer 사용
=== Dummy Input Shapes ===
actions          : (2, 8, 7)   (B,H,A)
context_features : (2, 16, 2048)   (B,Sv,D_img)
guidance_vector  : (2, 2048)   (B,D_text)
sensor_features  : (2, 2048)   (B,D_sensor)

actions_n (normalized) : (2, 8, 7)
x_t (bridge state)     : (2, 8, 7)    (B,H,A)
u_t (target velocity)  : (2, 8, 7)    (B,H,A)
t_scalar               : (2,)    (B,) scalar in [0,1]

x_t               : (2, 8, 7)   (B,H,A)
t                 : (2,)     (B,)  scalar in [0,1]
context_features  : (2, 16, 2048)   (B,Sv,D_img)
guidance_vector   : (2, 2048)    (B,D_text)
sensor_features   : (2, 2048)    (B,D_sensor)
--------------------------------------
time_raw (sinusoidal) : (2, 256)   (B,time_embed_dim=256)
t_embed (after MLP+LN): (2, 1024)   (B,hidden_dim=1024)
guidance_embed        : (2, 1024)   (B,hidden_dim)
cond_embed (guide+time): (2, 1024)   (B,hidden_dim)
-------------------------