In [None]:
import json
import numpy as np
import time
import os 
import psutil 
from collections import OrderedDict

import torch
from torch.utils.data import DataLoader
 
import robomimic.utils.train_utils as TrainUtils
import robomimic.utils.torch_utils as TorchUtils
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.config import config_factory
from robomimic.algo import algo_factory, RolloutPolicy
from robomimic.utils.log_utils import PrintLogger, DataLogger, flush_warnings
import imageio 

import matplotlib.pyplot as plt

from robomimic.envs.wrappers import EnvWrapper
from copy import deepcopy
import textwrap
import numpy as np
from collections import deque

np.set_printoptions(precision=3, suppress=True)

In [None]:
ckpt_path = "/home/franka_deoxys/data_franka/lift_blue/policy/trans_epoch_140_lift_blue.pth"
assert os.path.exists(ckpt_path)

In [None]:
from copy import deepcopy
def rollout(policy, env, horizon, render=False, video_writer=None, video_skip=5, camera_names=None):
    """
    Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 
    and returns the rollout trajectory.
    Args:
        policy (instance of RolloutPolicy): policy loaded from a checkpoint
        env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
        horizon (int): maximum horizon for the rollout
        render (bool): whether to render rollout on-screen
        video_writer (imageio writer): if provided, use to write rollout to video
        video_skip (int): how often to write video frames
        camera_names (list): determines which camera(s) are used for rendering. Pass more than
            one to output a video with multiple camera views concatenated horizontally.
    Returns:
        stats (dict): some statistics for the rollout - such as return, horizon, and task success
    """
 
    policy.start_episode()
    obs = env.reset()
    state_dict = env.get_state()

    # hack that is necessary for robosuite tasks for deterministic action playback
    obs = env.reset_to(state_dict)

    results = {}
    video_count = 0  # video frame counter
    total_reward = 0.
    try:
        for step_i in range(horizon):
            act = policy(ob=obs)
            next_obs, r, done, _ = env.step(act)
            total_reward += r
            success = env.is_success()["task"]

            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = []
                    for cam_name in camera_names:
                        video_img.append(env.render(mode="rgb_array", height=512, width=512, camera_name=cam_name))
                    video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
                    video_writer.append_data(video_img)
                video_count += 1

            if done or success:
                break

            # update for next iter
            obs = deepcopy(next_obs)
            state_dict = env.get_state()

    except env.rollout_exceptions as e:
        print("WARNING: got rollout exception {}".format(e))
    stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))

    return stats

In [None]:
device = TorchUtils.get_torch_device(try_to_use_cuda=True)

# restore policy
policy, ckpt_dict = FileUtils.policy_from_checkpoint(ckpt_path=ckpt_path, device=device, verbose=False)
ckpt_dict = FileUtils.maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=None)
# algo name and config from model dict
algo_name, _ = FileUtils.algo_name_from_checkpoint(ckpt_dict=ckpt_dict)
config, _ = FileUtils.config_from_checkpoint(algo_name=algo_name, ckpt_dict=ckpt_dict, verbose=False)

In [None]:
config.unlock()
config.train.data="/home/franka_deoxys/data_franka/lift_blue/lift_blue_imgs30.hdf5"
config.lock() 
config.train.data

In [None]:
class FrameStackWrapper(EnvWrapper):
    """
    Wrapper for frame stacking observations during rollouts. The agent
    receives a sequence of past observations instead of a single observation
    when it calls @env.reset, @env.reset_to, or @env.step in the rollout loop.
    """
    def __init__(self, env, num_frames):
        """
        Args:
            env (EnvBase instance): The environment to wrap.
            num_frames (int): number of past observations (including current observation)
                to stack together. Must be greater than 1 (otherwise this wrapper would
                be a no-op).
        """
        assert num_frames > 1, "error: FrameStackWrapper must have num_frames > 1 but got num_frames of {}".format(num_frames)

        super(FrameStackWrapper, self).__init__(env=env)
        self.num_frames = num_frames

        # keep track of last @num_frames observations for each obs key
        self.obs_history = None

    def _get_initial_obs_history(self, init_obs):
        """
        Helper method to get observation history from the initial observation, by
        repeating it.

        Returns:
            obs_history (dict): a deque for each observation key, with an extra
                leading dimension of 1 for each key (for easy concatenation later)
        """
        obs_history = {}
        for k in init_obs:
            obs_history[k] = deque(
                [init_obs[k][None] for _ in range(self.num_frames)], 
                maxlen=self.num_frames,
            )
        return obs_history

    def _get_stacked_obs_from_history(self):
        """
        Helper method to convert internal variable @self.obs_history to a 
        stacked observation where each key is a numpy array with leading dimension
        @self.num_frames.
        """
        # concatenate all frames per key so we return a numpy array per key
        return { k : np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history }

    def cache_obs_history(self):
        self.obs_history_cache = deepcopy(self.obs_history)

    def uncache_obs_history(self):
        self.obs_history = self.obs_history_cache
        self.obs_history_cache = None

    def reset(self):
        """
        Modify to return frame stacked observation which is @self.num_frames copies of 
        the initial observation.

        Returns:
            obs_stacked (dict): each observation key in original observation now has
                leading shape @self.num_frames and consists of the previous @self.num_frames
                observations
        """
        obs = self.env.reset()
        self.timestep = 0  # always zero regardless of timestep type
        self.update_obs(obs, reset=True)
        self.obs_history = self._get_initial_obs_history(init_obs=obs)
        return self._get_stacked_obs_from_history()

    def reset_to(self, state):
        """
        Modify to return frame stacked observation which is @self.num_frames copies of 
        the initial observation.

        Returns:
            obs_stacked (dict): each observation key in original observation now has
                leading shape @self.num_frames and consists of the previous @self.num_frames
                observations
        """
        obs = self.env.reset_to(state)
        self.timestep = 0  # always zero regardless of timestep type
        self.update_obs(obs, reset=True)
        self.obs_history = self._get_initial_obs_history(init_obs=obs)
        return self._get_stacked_obs_from_history()

    def step(self, action):
        """
        Modify to update the internal frame history and return frame stacked observation,
        which will have leading dimension @self.num_frames for each key.

        Args:
            action (np.array): action to take

        Returns:
            obs_stacked (dict): each observation key in original observation now has
                leading shape @self.num_frames and consists of the previous @self.num_frames
                observations
            reward (float): reward for this step
            done (bool): whether the task is done
            info (dict): extra information
        """
        obs, r, done, info = self.env.step(action)
        self.update_obs(obs, action=action, reset=False)
        # update frame history
        for k in obs:
            # make sure to have leading dim of 1 for easy concatenation
            self.obs_history[k].append(obs[k][None])
        obs_ret = self._get_stacked_obs_from_history()
        return obs_ret, r, done, info

    def update_obs(self, obs, action=None, reset=False):
        obs["timesteps"] = np.array([self.timestep])
        
        if reset:
            obs["actions"] = np.zeros(self.env.action_dimension)
        else:
            self.timestep += 1
            obs["actions"] = action[: self.env.action_dimension]

    def _to_string(self):
        """Info to pretty print."""
        return "num_frames={}".format(self.num_frames)

In [None]:
ObsUtils.initialize_obs_utils_with_config(config)

env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=config.train.data)
shape_meta = FileUtils.get_shape_metadata_from_dataset(
    dataset_path=config.train.data,
    all_obs_keys=config.all_obs_keys,
    verbose=True
)

 

In [None]:
# env = EnvUtils.wrap_env_from_config(env0, config=config) # apply environment warpper, if applicable
env = FrameStackWrapper(env0, num_frames=config.train.frame_stack)

In [None]:
rs=[]
for i in range(5):
    stats=rollout(policy, env, 500)
    rs.append(stats['Return'])
    print('i=',i,' stats=', stats)
np.mean(rs)

### one episode

In [None]:
policy.start_episode()
obs = env.reset()
state_dict = env.get_state()

# hack that is necessary for robosuite tasks for deterministic action playback
obs = env.reset_to(state_dict)
horizon = 500

In [None]:
obs['agentview_image'].shape

In [None]:
for step_i in range(horizon):
    act = policy(ob=obs)
    next_obs, r, done, _ = env.step(act)
    success = env.is_success()["task"]
    if done or success:
        break 
    # update for next iter
    obs = deepcopy(next_obs) 

print('success: ', success, step_i)

### one episode on the original unwrapped env

In [None]:
def stacked_get_init(init_obs, num_frames):
    obs_history = {}
    for k in init_obs:
        obs_history[k] = deque([init_obs[k][None] for _ in range(num_frames)], maxlen=num_frames,)
    obs = { k : np.concatenate(obs_history[k], axis=0) for k in obs_history }
    return obs_history, obs 

def stacked_add_new(obs_history, new_obs):
    for k in new_obs:
        if 'timesteps' in k or 'actions' in k: continue
        obs_history[k].append(new_obs[k][None])

    obs= { k : np.concatenate(obs_history[k], axis=0) for k in obs_history }
    return obs_history, obs 

In [None]:
policy.start_episode()
obs = env0.reset()
state_dict = env0.get_state()


all_obs=[]
all_actions=[]

# hack that is necessary for robosuite tasks for deterministic action playback
init_obs = env0.reset_to(state_dict)
horizon = 500
num_frames=10

all_obs.append(init_obs)

init_obs['agentview_image'].shape

In [None]:
obs_history, obs = stacked_get_init(init_obs, num_frames)
obs['agentview_image'].shape

In [None]:
for step_i in range(horizon):
    act = policy(ob=obs)
    all_actions.append(act)

    next_obs, r, done, info =  env0.step(act)

    success = env.is_success()["task"]
    if done or success:
        break 

    all_obs.append(next_obs)
    obs_history, obs = stacked_add_new(obs_history, next_obs)
     
print('success: ', success, step_i)

In [None]:
all_actions=np.stack(all_actions)
all_actions.shape  , len(all_obs)

In [None]:
all_obss={key: [] for key in all_obs[0].keys()}

for i in range(len(all_obs)):
    for key in all_obs[i].keys(): 
        if 'image' in key:
            all_obs[i][key]=all_obs[i][key].transpose(2, 1, 0)
        all_obss[key].append(all_obs[i][key][None])

# for key in next_obs.keys():
#     if 'image' in key:
#         next_obs[key]=next_obs[key].transpose(2, 1, 0)

for key in all_obss.keys():
    all_obss[key]=np.concatenate(all_obss[key], axis=0)

In [None]:
for key in all_obss.keys():
    print(key, all_obss[key].shape)

### inference test on the rollout

In [None]:
obss=all_obss
action_org = all_actions
T = action_org.shape[0]

action_org.shape , T 

In [None]:
obs_0={k:obss[k][0] for k in obss.keys()}
for key in obs_0.keys():
    if 'image' in key:
        obs_0[key]=obs_0[key].transpose(2, 1, 0)

init_obs = obs_0
obs_history, obs = stacked_get_init(init_obs, num_frames)
obs['agentview_image'].shape

In [None]:
action_pred=[]
for t in range(T):
    next_obs={k:obss[k][t] for k in obss.keys()}
    for key in next_obs.keys():
        if 'image' in key:
            next_obs[key]=next_obs[key].transpose(2, 1, 0)

    if t==0:
        obs_history, obs = stacked_get_init(next_obs, num_frames)
    else:
        obs_history, obs = stacked_add_new(obs_history, next_obs)
    act = policy(ob=obs)
    action_pred.append(act)
action_pred = np.stack(action_pred)

In [None]:
action_pred.shape

In [None]:
plt.plot(action_org, color='blue')
plt.plot(action_pred, color='red')
plt.show()

### Inference on offline trajectory

In [None]:
trainset, validset = TrainUtils.load_data_for_training(
    config, obs_keys=shape_meta["all_obs_keys"])
len(trainset.demos)

In [None]:
demo_name = 'demo_1'
demo_name = 'demo_33'
demo=trainset.hdf5_file['data'][demo_name]
action_org=demo['actions']
obss = demo['obs']
T = action_org.shape[0]

action_org.shape , T 

In [None]:
obs_0={k:obss[k][0] for k in obss.keys()}
for key in obs_0.keys():
    if 'image' in key:
        obs_0[key]=obs_0[key].transpose(2, 1, 0)

init_obs = obs_0
obs_history, obs = stacked_get_init(init_obs, num_frames)
obs['agentview_image'].shape

In [None]:
action_pred=[]
for t in range(T):
    next_obs={k:obss[k][t] for k in obss.keys()}
    for key in next_obs.keys():
        if 'image' in key:
            next_obs[key]=next_obs[key].transpose(2, 1, 0)

    if t==0:
        obs_history, obs = stacked_get_init(next_obs, num_frames)
    else:
        obs_history, obs = stacked_add_new(obs_history, next_obs)
    act = policy(ob=obs)
    action_pred.append(act)
action_pred = np.stack(action_pred)

In [None]:
action_pred.shape

In [None]:
plt.plot(action_org, color='blue')
plt.plot(action_pred, color='red')
plt.show()

### transformer inference analysis

In [None]:
policy.start_episode()
obs = env.reset()
state_dict = env.get_state()

In [None]:
# hack that is necessary for robosuite tasks for deterministic action playback
obs = env.reset_to(state_dict)

results = {}
video_count = 0  # video frame counter
total_reward = 0.

In [None]:
obs.keys(), obs['agentview_image'].shape, obs['robot0_eef_pos'].shape

In [None]:
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
act.shape

In [None]:
for k in obs.keys():
    print(k, obs[k].shape)

In [None]:
obs['actions']

In [None]:
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
act 

In [None]:
obs, r, done, _ = env.step(act)
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
act 

In [None]:
obs, r, done, _ = env.step(act)
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
obs, r, done, _ = env.step(act)
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
obs, r, done, _ = env.step(act)
obs['robot0_eef_pos']

In [None]:
act = policy(ob=obs)
obs, r, done, _ = env.step(act)
obs['robot0_eef_pos']

### inference on training trajectory

In [None]:
trainset, validset = TrainUtils.load_data_for_training(
    config, obs_keys=shape_meta["all_obs_keys"])
len(trainset.demos)

In [None]:
demo=trainset.hdf5_file['data']['demo_1']
demo.keys()

In [None]:
action_org=demo['actions']
action_org.shape

In [None]:
obss=demo['obs']
N=obss['robot0_eye_in_hand_image'].shape[0]
N, obss.keys()

In [None]:
def get_trans_obs(obs10, t):
    
    if obs10 is None and t>0:
        print('please provide obs10 for t>0')
        return None

    obs_t={k:np.expand_dims(obss[k][t], axis=0) for k in obss.keys()}
    for key in obs_t.keys():
        if 'image' in key:
            obs_t[key]=obs_t[key].transpose(0, 3, 1, 2)
     
    if obs10 is None and t==0: 
        obs10={k:np.repeat(obs_t[k], 10, axis=0) for k in obs_t.keys()}        #initial repeated 10 times
    else: 
        obs10={k:np.concatenate([obs10[k][1:], obs_t[k]], axis=0) for k in obs10.keys()}
    
    obs10=TensorUtils.to_tensor(obs10)
    return obs10

In [None]:
t=0
# obs_t={k:np.expand_dims(obss[k][t], axis=0) for k in obss.keys()}
# obs10={k:np.repeat(obs_t[k], 10, axis=0) for k in obs_t.keys()}        #initial repeated 10 times

obs10=get_trans_obs(None, 0)
for key in obs10.keys():
    print(key, obs10[key].shape) 
  

In [None]:
obs10['robot0_eef_pos']

In [None]:
obs10=get_trans_obs(obs10, 1)

In [None]:
obs10['robot0_eef_pos']

In [None]:
obs10=get_trans_obs(obs10, 2)
obs10['robot0_eef_pos']

In [None]:
obs10=get_trans_obs(obs10, 3)
obs10['robot0_eef_pos']

In [None]:
act = policy(ob=obs10)

In [None]:
act

In [None]:
obss=demo['obs']
N=obss['robot0_eye_in_hand_image'].shape[0] 
obs10=get_trans_obs(None, 0)

action_pred=[]
obs10=None
for t in range(N):
    obs10=get_trans_obs(obs10, t)
    act = policy(ob=obs10)
    action_pred.append(act)

action_pred=np.vstack(action_pred)
action_pred.shape

In [None]:
# obs10['robot0_eef_pos']
action_pred[0]

In [None]:
plt.plot(action_org, color='blue')
plt.plot(action_pred, color='red')
plt.show()

In [None]:
# plot the actions

import matplotlib.pyplot as plt

# plt.plot(action_org[:,0], label='org')
# plt.plot(action_pred[:,0], label='pred')
# plt.legend()
# plt.show()

#plot all the channels

for i in range(7):
    plt.plot(action_org[:,i], label='org')
    plt.plot(action_pred[:,i], label='pred')
    plt.legend()
    plt.show()