### Imports

In [1]:
import warnings 
warnings.filterwarnings('ignore')
import sys
import os
from tqdm import tqdm
# Import with TPU if we can
os.environ["JAX_PLATFORMS"] = "cpu"
from notebooks.icvf_helper import *
import matplotlib.pyplot as plt
import matplotlib.animation as ani
from IPython.display import HTML


2023-10-27 18:31:38.165788: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-27 18:31:38.165872: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-27 18:31:38.165892: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### Utility Functions and Visualizations

In [2]:
def get_simple_tk6_trajs(num_from_each=5):
    files = [
        'gs://multi-robot-bucket/data/robonetv2/toykitchen_simple_0530/toykitchen6/put_cucumber_in_orange_pot_simple/train/out.npy',
        'gs://multi-robot-bucket/data/robonetv2/toykitchen_simple_0530/toykitchen6/put_sweet_potato_on_plate_simple/train/out.npy',
        'gs://multi-robot-bucket/data/robonetv2/toykitchen_simple_0530/toykitchen6/take_croissant_out_of_colander_simple/train/out.npy'
    ]
    trajs = []
    for file in files:
        with tf.io.gfile.GFile(file, 'rb') as f:
            trajs.append(np.load(f, allow_pickle=True)[:num_from_each])
    trajs = np.concatenate(trajs)
    return trajs

def get_discounted_rtg(rews, gamma=0.99):
    rtg = 0
    rtgs = []
    for r in rews[::-1]:
        rtg = r + gamma * rtg
        rtgs.append(rtg)
    return rtgs[::-1]

def get_agent(encoder_checkpoint):
    return prep_learner(pretrained_encoder_path=encoder_checkpoint, encoder_type="resnetv2-50-1")

def get_values(unreplicated_agent, trajs, goal_traj_idxs, lim=20):
    all_values = []
    for i, traj in enumerate(tqdm.tqdm((trajs[:lim]))):
        batch = build_batch(target_traj=traj, goal='custom', custom_goal_traj_idxs=goal_traj_idxs[i])
        values = get_value_metrics(unreplicated_agent, batch)
        all_values.append(values)
    return trajs, all_values

def plot_dynamic_value_traj(img_traj, values, goal_idxs, reward_func=None):
    if reward_func is not None:
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(14, 3.5))  
        ax4.set_title('Reward Trajectory')
        ax4.set_xlabel('Index')
        ax4.set_ylabel('Returns')
        for idx in goal_idxs:
            ax4.axvline(x=idx, color='r')
        ax4.set_xlim(0, len(img_traj) - 1)
        ax4.plot(reward_func, label='Reward')
        ax4.plot(get_discounted_rtg(reward_func), label='TD Targets (RTG)')
        ax4.legend()
    else:
        # Create a figure with two subplots
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 3.5))
    ax3.axis('off')
    ax3.set_title('Goal Image')
    ax2.set_title('Image Trajectory')
    # Initialize an empty line for the plot
    line = ax1.plot([], [])[0]
    ax1.set_title('Value Trajectory')
    ax1.set_xlabel('Index')
    ax1.set_ylabel('Value')
    ax1.set_xlim(0, len(img_traj) - 1)
    ax1.set_ylim(-15, 15)
    # plot vertical bars at goal idxs
    for idx in goal_idxs:
        ax1.axvline(x=idx, color='r')
    # resize ax1 plot
    box = ax1.get_position()
    
    # Function to update the plot and video frame
    def update(num, line, data):
        line.set_data(data[..., :num])
        ax2.clear()
        ax2.axis('off')
        ax2.imshow(img_traj[num])
        ax3.clear()
        ax3.axis('off')
        ax3.imshow(img_traj[goal_idxs[num]])

    # Create an animation
    data = np.array([range(len(img_traj)), values])
    a = ani.FuncAnimation(fig, update, frames=len(img_traj), fargs=(line, data))
    # a.save('dynamic_value_traj.gif', fps=3)
    # plt.show()
    return HTML(a.to_html5_video())

### Load Model Checkpoint

In [3]:
%%capture
checkpoints = {
    'no-tk6': '/nfs/nfs2/users/nrdas/experiment_output/test_icvf_resnetv2-50-1_multilinear_all_but_tk6_20231024_235013/checkpoint500000'
}
encoder_checkpoint = checkpoints['no-tk6']
unreplicated_agent = get_agent(encoder_checkpoint)

In [10]:
def compute_subgoals_offline(final_value_traj, mode='uniform', goal_density=6):
    if mode == 'gradient':
        raise NotImplementedError
    elif mode == 'uniform':
        i_list = list(range(0, len(final_value_traj), goal_density))
        i_list.append(len(final_value_traj)-1)
        return i_list[1:]
    else:
        raise NotImplementedError

def gen_idxs(subgoal_idxs, trajlen):
    idxs = []
    for i in range(trajlen):
        for subgoal_idx in subgoal_idxs:
            if subgoal_idx >= i:
                idxs.append(subgoal_idx)
                break
    return idxs

def reward_design(subgoal_idxs, icvf_traj, mode='progress', alpha=1, beta=1, gamma=2):
    if mode == 'vanilla':
        rews = icvf_traj[:-1]
        rews[-1] += gamma
    elif mode == 'progress':
        # reward of transition is change in value
        rews = []
        curr_goal_idx = 0
        curr_norm_factor = np.abs(icvf_traj[subgoal_idxs[0]] - icvf_traj[1])
        for i in range(len(icvf_traj)-1):
            if i == subgoal_idxs[curr_goal_idx]:
                curr_goal_idx += 1
                curr_norm_factor = np.abs(icvf_traj[subgoal_idxs[curr_goal_idx]] - icvf_traj[i + 1])
                r = beta
                r += alpha * (icvf_traj[i+1] - icvf_traj[i]) / curr_norm_factor
            elif i+1 == subgoal_idxs[-1]:
                r = gamma
                r += alpha * (icvf_traj[i+1] - icvf_traj[i]) / curr_norm_factor
            else:
                r = alpha * (icvf_traj[i+1] - icvf_traj[i]) / curr_norm_factor
            rews.append(r)
    elif mode=='optimal':
        rews = []
        for i in range(len(icvf_traj)-1):
            r = 0
            if i in subgoal_idxs:
                r += beta
            elif i+1 == subgoal_idxs[-1]:
                r += gamma
            rews.append(r)
    else:
        raise NotImplementedError
    return rews

### Run model on trajectories from ToyKitchen6 (held-out from model). Use final goal

In [5]:
trajectories = get_simple_tk6_trajs()
trajlens = [len(t['observations']) for t in trajectories]
goal_t_idxs = [[-1]*l for l in trajlens]
trajs, all_values = get_values(
    unreplicated_agent=unreplicated_agent, 
    trajs=trajectories,
    goal_traj_idxs=goal_t_idxs,
    lim=1)

100%|██████████| 1/1 [00:34<00:00, 34.05s/it]


### Current System for Shaping

In [6]:
idx = 0
t = trajs[idx]
v = all_values[idx]
t_img = []
end_idx = len(t['observations']) - 1
for iter_num, obs in enumerate(t['observations']):
    t_img.append(obs['images0'])
plot_dynamic_value_traj(t_img, list(v['v']), goal_idxs=goal_t_idxs[idx],
                        reward_func=reward_design([end_idx]*len(v['v']), list(v['v']), mode='vanilla'))

### Run model on trajectories from ToyKitchen6 (held-out from model). Use custom goal indices

In [7]:
subgoals = [compute_subgoals_offline(v['v']) for v in all_values] # using final goal values
subgoal_trajidxs = [gen_idxs(s, len(t['observations'])) for s, t in zip(subgoals, trajs)]
print("using goal-idxs for 0th as", subgoals[0])
trajs, all_values = get_values(
    unreplicated_agent=unreplicated_agent, 
    trajs=trajectories,
    goal_traj_idxs=subgoal_trajidxs,
    lim=1)

using goal-idxs for 0th as [6, 12, 18, 22]


100%|██████████| 1/1 [00:01<00:00,  1.42s/it]


### ICVF Subgoal Values and Rewards (ICVF subgoal TD)

In [17]:
idx = 0
t = trajs[idx]
v = all_values[idx]
# TODO: add optimal overlay
t_img = []
for iter_num, obs in enumerate(t['observations']):
    t_img.append(obs['images0'])
plot_dynamic_value_traj(t_img, list(v['v']), goal_idxs=subgoal_trajidxs[idx],
                        reward_func=reward_design(subgoals[idx], list(v['v']), gamma=2))

### Raw Subgoal Values and Rewards (optimal subgoal TD)

In [9]:
opt_rews = reward_design(subgoals[idx], list(v['v']), mode='optimal', beta=0)
opt_values = get_discounted_rtg(opt_rews + [0])
plot_dynamic_value_traj(t_img, list(opt_values), goal_idxs=subgoal_trajidxs[idx], reward_func=opt_rews)

### Raw Sparse-Reward-based Values (optimal TD)

In [10]:
end_idx = len(v['v']) - 1
opt_rews = reward_design([end_idx], list(v['v']), mode='optimal')
opt_values = get_discounted_rtg(opt_rews + [0])
plot_dynamic_value_traj(t_img, list(opt_values), goal_idxs=[end_idx]*len(v['v']), reward_func=opt_rews)

In [None]:
# TODO: make reward -1's and then a 10 at the end