In [None]:
### **Inference**
import os  
import sys 
import argparse  
import subprocess 
import pickle as pkl 
import collections 
import math 
import time 
from datetime import datetime
import matplotlib.pyplot as plt

import cv2 
import numpy as np 
import taichi as ti 
import torch 
import torch.nn as nn 
import torchvision 
import rclpy 
from rclpy.node import Node 

from threading import Thread
from typing import Tuple, Sequence, Dict, Union, Optional, Callable 

from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 
from diffusers.training_utils import EMAModel 
from diffusers.optimization import get_scheduler 

from datasets.data_utils import * 
from model.noise_pred_net import * 
from model.visual_encoder import * 
from config import * 
from foam_env import *

# 获取当前时间的时间戳
timestamp = datetime.now().strftime("%m%d_%H%M%S")

# Define a function to save ROS bag
def save_rosbag(dir_path: str, filename: str = "", topics: Optional[List[str]] = None) -> subprocess.Popen:
    if topics is None:
        topics = [
            '/autohand_node/cmd_autohand',
            '/diff/xarm/move_joint_cmd',
            '/autohand_node/state_autohand',
            '/xarm/joint_states',
            '/camera/color/image_raw/compressed'
        ]
    
    cmd = ["ros2", "bag", "record", "-o", os.path.join(dir_path, filename)] + topics
    print(f"Executing command: {cmd}")  # Debug: Print the command being executed

    try:
        return subprocess.Popen(cmd, stdout=subprocess.PIPE)
    except subprocess.CalledProcessError as e:
        print(f"ROS bag recording failed with error: {e}")  # Print error message if an exception occurs
        raise
    except Exception as e:
        print(f"An unexpected error occurred: {e}")  # Print unexpected error messages
        raise
    
# Define the DiffPolicyInfer class
class DiffPolicyInfer:
    def __init__(self, file_name: str, save_flag: bool, record_flag: bool, ckpt_path: str, save_path: str):
        try:
            rclpy.init()  
            self.num_motors = 23  
            self.init_hand_pos = np.array([0.0, 0.0, 0.0285, -0.0295, 0.3258, 0.0, 0.0, 0.0, 0.1020, 0.0, 0.0, -0.0707, 0.0701, 0.0274, 0.0, 0.0143]) 
            self.init_xarm_pos = np.array([-0.0199, -0.1503, 0.0307, 1.3668, 0.0905, 1.4956, -1.7334])
            self.init_pos = np.concatenate([self.init_hand_pos, self.init_xarm_pos])  # Concatenate hand and XArm initial positions
            vision_encoder = replace_bn_with_gn(get_resnet('resnet18'))
            noise_pred_net = ConditionalUnet1D(
                input_dim=action_dim, 
                global_cond_dim=obs_dim * obs_horizon)
            self.nets = nn.ModuleDict({
                'vision_encoder': vision_encoder, 
                'noise_pred_net': noise_pred_net})
            print('Visual encoder and noise prediction network Initialized.')  # Print a message indicating that pretrained weights are loaded

            self.num_diffusion_iters = 80
            self.noise_scheduler = DDPMScheduler(
                num_train_timesteps=self.num_diffusion_iters,
                beta_schedule='squaredcos_cap_v2',
                clip_sample=True,
                prediction_type='epsilon'
            )
            print('Noise scheduler Initialized.')  

            self.device = torch.device('cuda')  
            self.nets.to(self.device)  

            self.ckpt_path = f"{ckpt_path}/{file_name}.pt" 
            print(f"Full ckpt_path: {self.ckpt_path}")
            if not os.path.isfile(self.ckpt_path): 
                raise FileNotFoundError("Checkpoint file not found!")

            state_dict = torch.load(self.ckpt_path, map_location='cuda')  
            self.ema_nets = self.nets
            self.ema_nets.load_state_dict(state_dict['model_state_dict'])  
            print('Pretrained weights loaded.')   

            self.cur_img = None  
            self.cur_action = None 

            self.foam_env = FoamEnv(enable_foam=True, enable_camera=True, reset_foam=True) 
            foam_env_thread = Thread(target=rclpy.spin, args=(self.foam_env,))
            foam_env_thread.start()
            time.sleep(1.0)

            init_action = self.init_pos  
            print("init_action:",init_action)
            init_obs = self.foam_env.step(init_action)  
            # print("init_obs:", init_obs)
            # print("init_obs['oimage'] shape:", init_obs['oimage'].shape)  # (480, 640, 3)
            # print("init_obs['image'] shape:", init_obs['image'].shape)  # (240, 320, 3)
            # print("init_obs['agent_pos'] shape:", init_obs['agent_pos'].shape)  # (23,)
            # print("init_obs['agent_pos']:", init_obs['agent_pos'])  
            # cv2.imwrite("init_obs_oimage.jpg", init_obs['oimage'])  # correct
            # cv2.imwrite("init_obs_image.jpg", init_obs['image'])  # correct

            self.cur_img = init_obs['image']  # (240, 320, 3)
            # print("init_obs:",init_obs)
            
            # print("cur_img (if ['oimage']) shape:", self.cur_img.shape)  # (480, 640, 3)

            self.cur_action = init_action  # (23,)
            # print("cur_action shape:", self.cur_action.shape) 

            self.obs_deque = collections.deque([init_obs] * obs_horizon, maxlen=obs_horizon)  
            
            self.save_flag = save_flag 
            self.record_flag = record_flag 
            if self.save_flag or self.record_flag: 
                self.save_file_name = f"summer_norm/rosbags/{timestamp}"
                self.save_path = save_path  
            self.in_record = False  
            if self.record_flag:  
                print("Starting ROS bag recording...")  
                self.start_saving(self.save_file_name)  
            
        except Exception as e:
            print(f"Initialization failed: {e}")
            rclpy.shutdown()
            raise

    def finish(self):
        try:
            if self.save_flag or self.record_flag:   
                print("Ending ROS bag recording...")  
                self.end_saving()   
            self.foam_env.reset()  
            time.sleep(1.0)   
        except Exception as e:
            print(f"Error during finish: {e}")  
        finally:
            rclpy.shutdown()  

    def start_saving(self, file_name: str):
        if not self.in_record: 
            print(f"Starting ROS bag recording at {self.save_path}/{file_name}.bag")  
            self.rosbag_p = save_rosbag(self.save_path, file_name + ".bag")  
            self.in_record = True  

    def end_saving(self):
        if self.in_record:  
            print("Terminating ROS bag recording...")  
            self.rosbag_p.terminate()  
            self.in_record = False   

    def step(self):
        """
        Perform one-step inference to generate the next action.

        Utilize observation data to predict and infer actions.

        Process:
        1. Collect the most recent observation data.
        2. Standardize the observation data.
        3. Use the model to predict noise and perform reverse diffusion to generate actions.
        4. De-standardize the generated actions and select actions for a future period.

        Return:
        actions (np.ndarray): The predicted actions.
        """

        B = 1  # batchsize
        images = np.stack([x['image'] for x in self.obs_deque])  
        images = np.moveaxis(images, -1, 1)   
        agent_poses = np.stack([x['agent_pos'] for x in self.obs_deque])   # (2, 23)   
        # print("agent_poses:", agent_poses)  

        nimages = normalize_images(images)  #  (2, 3, 240, 320)
        # nagent_poses = normalize_data(agent_poses, stats=stats)   # (2, 23) 

        nimages = torch.from_numpy(nimages).to(self.device, dtype=torch.float32) 
        nagent_poses = torch.from_numpy(agent_poses).to(self.device, dtype=torch.float32)  
        # print("nimages shape:", nimages.shape)  # torch.Size([2, 3, 240, 320])

        with torch.no_grad():
            image_features = self.ema_nets['vision_encoder'](nimages)  # torch.Size([2, 512])
            obs_features = torch.cat([image_features, nagent_poses], dim=-1)  # torch.Size([2, 535])
            obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)  # torch.Size([1, 1070])  # (B, obs_horizon * obs_dim)

            noisy_action = torch.randn((B, pred_horizon, action_dim), device=self.device)
            naction = noisy_action  # torch.Size([1, 16, 23])

            self.noise_scheduler.set_timesteps(self.num_diffusion_iters)
            for k in self.noise_scheduler.timesteps:
                noise_pred = self.ema_nets['noise_pred_net'](
                    sample=naction,
                    timestep=k,
                    global_cond=obs_cond
                )

                naction = self.noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=naction
                ).prev_sample

        naction = naction.detach().to('cpu').numpy()  
        naction = naction[0]  
        action_pred = unnormalize_data(naction, stats=stats)  
        # action_pred = naction

        start = obs_horizon - 1
        end = start + action_horizon
        actions = action_pred[start:end, :]  
        print("actions:", actions)
        return actions  

    def excute(self, action):
        """
        Execute the action and update the observation data.

        Parameters:
        action (np.ndarray): The action to be executed.
        """

        obs = self.foam_env.step(action)  # Obtain observation data
        self.obs_deque.append(obs)  # Save observation data
        # print('self.obs_deque:',self.obs_deque)
        self.cur_img = obs['image']  # Update current image data
        # self.cur_action = action  # Update current action data
        self.cur_action = obs['agent_pos']  # Update current action data


if __name__ == '__main__':
    file_name = "tennis_nw"
    save_flag = False
    record_flag = False
    ckpt_path = "/home/foamlab/nw/save"
    save_path = "/home/foamlab/nw/save"

    infer = DiffPolicyInfer(
        file_name=file_name,
        save_flag=save_flag,
        record_flag=record_flag,
        ckpt_path=ckpt_path,
        save_path=save_path
    )
    # Main loop

action = []
simplified_action = []
all_xarm_joints = []
all_foamhand_joints = []
total_action_time = 20
time_cnt = 0

while time_cnt < total_action_time:
    action = infer.step()  
    print("# ⬆steps ", time_cnt)  

    for i in range(len(action)):
        infer.excute(action[i])
        simplified_action.append(np.round(action[i], 4))

        # Extract and save the last 7 dimensions (XArm joints)
        xarm_joints = action[i][-7:]
        all_xarm_joints.append(xarm_joints)

        # Extract and save the first 16 dimensions (FoamHand joints)
        foamhand_joints = action[i][:16]
        all_foamhand_joints.append(foamhand_joints)
    
        # Save all XArm joints
        np.savetxt(f'{timestamp}_xarm_joints.txt', all_xarm_joints, fmt='%.4f')
        
        # Save all FoamHand joints
        np.savetxt(f'{timestamp}_hand_joints.txt', all_foamhand_joints, fmt='%.4f')
    
    time_cnt += 1  

