In [None]:
import torch
import numpy as np
import cv2
from torchvision import transforms
from einops import rearrange
from llava_pythia.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava_pythia.mm_utils import tokenizer_image_token, get_model_name_from_path
from llava_pythia.model.builder import load_pretrained_model
from llava_pythia.conversation import conv_templates
from llava_pythia.model.language_model.pythia.llava_pythia import LlavaPythiaConfig
import robosuite as suite
from robosuite.controllers import load_composite_controller_config


def get_image(ts, camera_names, rand_crop_resize=False):
    imgs = [rearrange(ts.observation['images'][cam], 'h w c -> c h w') for cam in camera_names]
    img_tensor = torch.from_numpy(np.stack(imgs) / 255.0).float().cuda().unsqueeze(0)
    if rand_crop_resize:
        print('rand crop resize is used!')
        h, w = img_tensor.shape[-2:]
        ratio = 0.95
        dh, dw = int(h * (1 - ratio) / 2), int(w * (1 - ratio) / 2)
        img_tensor = img_tensor[..., dh:h - dh, dw:w - dw].squeeze(0)
        img_tensor = transforms.Resize((h, w), antialias=True)(img_tensor).unsqueeze(0)
    return img_tensor


def convert_actions(pred_action):
    """
    Convert action from 10D [x, y, z, rot6D(6), gripper] to 7D [x, y, z, axis-angle(3), gripper].
    This prevents mat1 @ mat2 error due to state/action shape mismatch.
    """
    import torch.nn.functional as F

    def rotation_6d_to_matrix(d6):
        a1 = F.normalize(d6[..., 0:3], dim=-1)
        a2 = d6[..., 3:6]
        b2 = F.normalize(a2 - (a1 * a2).sum(-1, keepdim=True) * a1, dim=-1)
        b3 = torch.cross(a1, b2, dim=-1)
        return torch.stack([a1, b2, b3], dim=-2)  # (..., 3, 3)

    def matrix_to_axis_angle(R):
        cos_theta = ((R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]) - 1) / 2
        theta = torch.acos(torch.clamp(cos_theta, -1.0, 1.0))

        wx = R[..., 2, 1] - R[..., 1, 2]
        wy = R[..., 0, 2] - R[..., 2, 0]
        wz = R[..., 1, 0] - R[..., 0, 1]
        axis = torch.stack([wx, wy, wz], dim=-1)
        axis = F.normalize(axis, dim=-1)
        return axis * theta.unsqueeze(-1)


    if isinstance(pred_action, np.ndarray):
        pred_action = torch.from_numpy(pred_action).float()

    if pred_action.shape[-1] == 10:
        pos = pred_action[..., :3]
        rot6d = pred_action[..., 3:9]
        gripper = pred_action[..., 9:]

        R = rotation_6d_to_matrix(rot6d)
        axis_angle = matrix_to_axis_angle(R)
        print("🔍 10D Action:", pred_action)
        print("📦 Position:", pos)
        print("🌀 rot6d:", rot6d)
        print("✊ Gripper:", gripper)
        return torch.cat([pos, axis_angle, gripper], dim=-1).cpu().numpy()
    else:
        return pred_action.cpu().numpy() if torch.is_tensor(pred_action) else pred_action
        


class llava_pythia_act_policy:
    def __init__(self, policy_config, data_args=None):
        self.policy_config = policy_config
        self.data_args = data_args
        self._load_policy()

    def _load_policy(self):
        base = self.policy_config["model_base"] if self.policy_config['enable_lora'] else None
        name = get_model_name_from_path(self.policy_config['model_path'])
        path = self.policy_config["model_path"]
        self.tokenizer, self.policy, self.image_processor, self.context_len = load_pretrained_model(path, base, name, False, False)
        self.config = LlavaPythiaConfig.from_pretrained('/'.join(path.split('/')[:-1]), trust_remote_code=True)

    def _expand2square(self, imgs, bg_color):
        b, c, h, w = imgs.shape
        size = max(h, w)
        canvas = np.full((b, size, size, c), bg_color, dtype=np.float32)
        imgs_np = imgs.permute(0, 2, 3, 1).cpu().numpy()
        if h >= w:
            offset = (size - w) // 2
            canvas[:, :h, offset:offset + w, :] = imgs_np
        else:
            offset = (size - h) // 2
            canvas[:, offset:offset + h, :w, :] = imgs_np
        return torch.tensor(canvas).to(dtype=imgs.dtype, device=imgs.device)

    def process_batch_to_llava(self, curr_image, robo_state, raw_lang):
        self.conv = conv_templates[self.policy_config['conv_mode']].copy()
        curr_image = curr_image.squeeze(0) if curr_image.dim() == 5 else curr_image
        img1, img2 = torch.chunk(curr_image, 2, dim=0)

        # ✅ 여기에 states 추가
        states = robo_state.unsqueeze(0) if robo_state.dim() == 1 else robo_state
        print("✅ DEBUG: states shape:", states.shape)
        
        def prep(img):
            img = self._expand2square(img, tuple(self.image_processor.image_mean))
            return self.image_processor.preprocess(img, return_tensors='pt', do_normalize=True, do_rescale=False, do_center_crop=False)['pixel_values'].float().to(self.policy.device)

        image_tensor, image_tensor_r = prep(img1), prep(img2)

        prompt = DEFAULT_IMAGE_TOKEN + '\n' + raw_lang
        if self.policy.config.mm_use_im_start_end:
            prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + raw_lang

        self.conv.append_message(self.conv.roles[0], prompt)
        self.conv.append_message(self.conv.roles[1], None)
        prompt = self.conv.get_prompt() + " <|endoftext|>"

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda().long()
        attn_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
        robot_state_tensor = robo_state.to(self.policy.device, dtype=torch.float32)

        return dict(input_ids=input_ids, attention_mask=attn_mask, images=image_tensor, images_r=image_tensor_r, states=states) # ✅ states=states


class RobosuiteDeployEnv:
    def __init__(self, env_name="Lift", cameras=("sideview", "frontview"), control_freq=20):
        controller_config = load_composite_controller_config(robot="Panda")
        self.env = suite.make(env_name=env_name, robots="Panda", controller_configs=controller_config,
                              has_renderer=False, has_offscreen_renderer=True, render_camera=None,
                              use_object_obs=True, use_camera_obs=True, control_freq=control_freq,
                              camera_names=list(cameras), camera_heights=240, camera_widths=320)
        sim = self.env.sim
        cam_config = {"sideview": ([0.4, 0.8, 1.0], [0.653, 0.271, -0.653, 0.271]),
                      "frontview": ([-0.4, -0.8, 1.0], [0.653, -0.271, 0.653, 0.271])}
        for cam, (pos, quat) in cam_config.items():
            cam_id = sim.model.camera_name2id(cam)
            sim.model.cam_pos[cam_id] = pos
            sim.model.cam_quat[cam_id] = quat
        self.obs = self.env.reset()

    def get_observation(self):
        ts = type("Timestep", (), {})()
        ts.observation = {'images': {cam: self.obs[f'{cam}_image'] for cam in self.env.camera_names}}
        eef = self.obs['robot0_eef_pos'], self.obs['robot0_eef_quat']        
        robot_state = np.concatenate([*eef])
        return ts, robot_state

    def reset(self):
        self.obs = self.env.reset()
        return self.obs

    def step(self, action):
        self.obs, reward, done, info = self.env.step(action)
        return self.obs, reward, done, info

    def render_cameras(self, cameras=("sideview", "frontview"), width=320, height=240):
        self.env.sim.forward()
        return [self.env.sim.render(camera_name=c, width=width, height=height, depth=False, mode="offscreen")[..., ::-1] for c in cameras]


def eval_bc(policy, deploy_env, policy_config, save_episode=True, num_rollouts=1,
            raw_lang=None, n_steps=50, fps=20, camera_names=("sideview", "frontview")):
    
    print("🔍 Combine layer 확인:")
    print(policy.policy.embed_out.combine)
    assert raw_lang is not None

    all_frames = []

    for rollout_idx in range(num_rollouts):
        deploy_env.reset()

        for t in range(n_steps):
            ts, robot_state = deploy_env.get_observation()
            robot_state = robot_state[:7]
            robot_tensor = torch.from_numpy(robot_state).float().cuda()

            print("🤖 robot_state shape:", robot_state.shape)
            print("🤖 robot_state:", robot_state)

            image_tensor = get_image(ts, camera_names)
            batch = policy.process_batch_to_llava(image_tensor, robot_tensor, raw_lang)

            with torch.no_grad():
                print("✅ batch keys:", batch.keys())
                if 'states' in batch:
                    print("✅ batch['states'] shape:", batch['states'].shape)
                else:
                    print("❌ 'states' not found in batch!")

                all_actions = policy.policy(**batch, eval=True)
                action = convert_actions(all_actions[0][0].detach().cpu().numpy())
                _, reward, done, _ = deploy_env.step(action)

                # ✅ 안정적인 프레임 처리
                frames = deploy_env.render_cameras(cameras=camera_names, width=640, height=480)
                cleaned_frames = []

                for f_idx, f in enumerate(frames):
                    print(f"📷 Frame {f_idx}: shape={f.shape}, dtype={f.dtype}, min={f.min()}, max={f.max()}")

                    # NaN / Inf 체크
                    if np.isnan(f).any() or np.isinf(f).any():
                        print(f"🚫 Frame {f_idx} contains NaN or Inf, skipping.")
                        continue

                    # dtype 변환
                    if f.dtype in [np.float32, np.float64]:
                        f = np.clip(f, 0.0, 1.0) * 255.0
                        f = f.astype(np.uint8)
                    elif f.dtype != np.uint8:
                        print(f"⚠️ 예상치 못한 dtype: {f.dtype}, 변환 시도")
                        f = f.astype(np.uint8)

                    # 채널 처리
                    if f.shape[-1] == 4:
                        f = f[..., :3]  # RGBA → RGB
                    elif f.shape[-1] == 1:
                        f = np.repeat(f, 3, axis=-1)  # Gray → RGB
                    elif f.shape[-1] != 3:
                        print(f"❌ 알 수 없는 채널 수: {f.shape[-1]}, 건너뜀")
                        continue

                    # 해상도 맞춤
                    if f.shape[:2] != (480, 640):
                        print(f"📏 해상도 변환: {f.shape[:2]} → (480, 640)")
                        f = cv2.resize(f, (640, 480))

                    # RGB → BGR
                    try:
                        f = cv2.cvtColor(f, cv2.COLOR_RGB2BGR)
                    except Exception as e:
                        print(f"❌ cvtColor 실패: {e}")
                        continue

                    cleaned_frames.append(f)

                # 프레임이 유효할 경우 저장
                if cleaned_frames:
                    try:
                        frame_concat = np.concatenate(cleaned_frames, axis=1)
                        all_frames.append(frame_concat)
                    except Exception as e:
                        print(f"🚨 프레임 연결 실패: {e}")

                if done:
                    break

    # ✅ 영상 저장
    if save_episode and all_frames:
        h, w, _ = all_frames[0].shape
        out = cv2.VideoWriter("rollout2.mp4", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
        for f in all_frames:
            out.write(f)
        out.release()
        print("🎥 rollout.mp4 저장 완료")

[1m[32m[robosuite INFO] [0mLoading controller configuration from: /home/parkjeongsu/anaconda3/envs/tinysuite/lib/python3.10/site-packages/robosuite/controllers/config/robots/default_panda.json (composite_controller_factory.py:121)
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


load llaVA-Pythia MLLM!!!
combine layer: Linear(in_features=519, out_features=512, bias=True)
number of parameters: 7.283150e+07


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


{'device_map': 'cuda', 'torch_dtype': torch.float32}
🔍 Combine layer 확인:
Linear(in_features=519, out_features=512, bias=True)
🤖 robot_state shape: (7,)
🤖 robot_state: [-0.09860907  0.00286695  1.02428365  0.99609503 -0.01026402  0.08766377
 -0.00209969]
✅ DEBUG: states shape: torch.Size([1, 7])
✅ batch keys: dict_keys(['input_ids', 'attention_mask', 'images', 'images_r', 'states'])
✅ batch['states'] shape: torch.Size([1, 7])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 1195, 512])
🔍 global_cond shape (forward): torch.Size([1, 119

  from pkg_resources import resource_filename
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


🎥 rollout_imageio2.mp4 저장 완료
