In [None]:

import copy
import os.path
import pickle
from os.path import join

import cv2
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from PIL import ImageDraw
from matplotlib import cm

script_path = "/home/shady/code/openpi/scripts"
os.environ["PYTHONPATH"] += f":{script_path}"
import jax.experimental
import jax.numpy as jnp
from PIL import Image

from openpi import EXP_DATA_PATH
from openpi.models.pi0 import Pi0

script_path = "/home/shady/code/openpi/scripts"
os.environ["PYTHONPATH"] += f":{script_path}"
from scripts.text_latent import Checkpoint, Args, _model, _config, maybe_download, create_dataloader


def restore_img(img) -> Image:
    float_img = np.asarray(img + 1, dtype=np.float64) / 2 * 255
    int_img = float_img.astype(np.uint8)
    img = Image.fromarray(int_img)
    return img


def normalize(vectors):
    return vectors / jnp.linalg.norm(vectors, axis=-1, keepdims=True)


def patch_idx_to_hw(patch_idx, num_patch=16, patch_size=14):
    """Convert patch index to height and width."""
    h = patch_idx // num_patch
    w = patch_idx % num_patch
    return int(h * patch_size), int(w * patch_size)


def draw_box_on_image(image, top_left_corner, box_width=14, box_height=14, color='red', thickness=1):
    # Create a Draw object to modify the image
    draw = ImageDraw.Draw(image)

    # Extract coordinates and calculate the bounding box
    top = top_left_corner[0]
    left = top_left_corner[1]
    right = left + box_width
    bottom = top + box_height

    bounding_box = (left, top, right, bottom)

    # Draw the rectangle
    draw.rectangle(bounding_box, outline=color, width=thickness)

    return image


In [None]:
# load text latents for libero goal task
text_latent_path = join(EXP_DATA_PATH, "pi0")
libero_goal_task_id = [i for i in range(10, 40)]
text_latents = {}
for id in libero_goal_task_id:
    file_path = join(text_latent_path, f"avg_states_{id}_{id + 1}_frame_0_119.pkl")
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    text_latents[id] = data["hidden_states_avg"]
# load training data for libero goal task
policy = "pi0"
args = Args()
args.policy = Checkpoint(config=f"{policy}_libero", dir="s3://openpi-assets/checkpoints/{}_libero".format(policy))
train_config = _config.get_config(args.policy.config)
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
ckpt_dir = maybe_download(args.policy.dir)

# load model
model: Pi0 = train_config.model.load(_model.restore_params(ckpt_dir / "params", dtype=jnp.bfloat16))
embed_prefix_jit = jax.jit(model.embed_prefix)


def embed_prefix(observation):
    observation = _model.preprocess_observation(None, observation, train=False)
    return embed_prefix_jit(observation)


In [None]:
all_obs = []
for task_id in range(10, 40):
    task_range = (task_id, task_id + 1)
    episode_to_use_for_collection = 1
    data_loader, dataset_meta = create_dataloader(train_config, data_config, ckpt_dir, task_range, 1,
                                                  episode_to_use_for_collection, True)
    obs = [_model.Observation.from_dict(element) for element in list(iter(data_loader))]
    base_images = [obs[i].images["base_0_rgb"][0] for i in range(len(obs))]
    print(f"Task {task_id}: {dataset_meta.tasks[task_id]}")


    def decode(obs):
        if not isinstance(obs, list):
            obs = obs.tokenized_prompt[0].tolist()
        return data_loader.torch_loader.dataset._transform.transforms[-1].tokenizer._tokenizer.decode(obs)


    prompt = obs[0].tokenized_prompt[0]
    decoded = [decode([token.item()]) for token in prompt]
    prompt = decoded[:obs[0].tokenized_prompt_mask.sum()]
    print(f"Tokenized prompt: {prompt}")
    print(f"Token index: {prompt[:obs[0].tokenized_prompt_mask.sum()]}")

    # check embedding same
    prompt_len = len(decoded[:obs[0].tokenized_prompt_mask.sum()])
    all_obs.append((obs, len(prompt), prompt, len(obs)))


In [None]:
# test
def get_max_sim_patch(task_start, task_id, all_obs, obs_idx=0):
    task_obs = copy.deepcopy(all_obs[task_id - task_start])
    obs = task_obs[0]
    prompt_len = task_obs[1]
    episode_similarity = jnp.zeros((18, 256), dtype=jnp.float32)

    # slightly average to remove noise
    # for obs_idx in tqdm.tqdm(range(70, 71)):
    embeddings, _, _ = embed_prefix(obs[obs_idx])  # test embedding prefix

    image_embedding = embeddings[0, :256]
    text_latent = text_latents[task_id][:, 256 * 3 + 3:256 * 3 + prompt_len].sum(-2)
    normed_image_embedding = normalize(image_embedding)
    normed_text_latent = normalize(text_latent)

    # embeddings, _, _ = embed_prefix(obs[obs_idx])  # test embedding prefix
    # text_embedding = embeddings[0, prompt_len - 2]
    # image_latent = text_latents[task_id][:, :256]
    # normed_image_embedding = normalize(image_latent)
    # normed_text_latent = normalize(text_embedding)

    similarity = jnp.dot(normed_text_latent, normed_image_embedding.T)
    episode_similarity += similarity

    patch_episode_similarity = jnp.max(episode_similarity[1:], axis=0)
    # max_patch_indices = jnp.argsort(patch_episode_similarity)[-35:]
    # original_img = restore_img(copy.deepcopy(base_images[0]))
    # for max_patch_index in max_patch_indices:
    #     top_left_coordinate = patch_idx_to_hw(max_patch_index)
    #     draw_box_on_image(original_img, top_left_coordinate)

    float_img = np.asarray(task_obs[0][obs_idx].images["base_0_rgb"][0] + 1, dtype=np.float64) / 2 * 255
    original_image = float_img.astype(np.uint8)
    coorrelate_map = np.asarray(patch_episode_similarity).reshape(16, 16)
    norm_attention_map = (coorrelate_map - coorrelate_map.min()) / (coorrelate_map.max() - coorrelate_map.min())
    resized_attention_map = cv2.resize(norm_attention_map, (224, 224), interpolation=cv2.INTER_LINEAR)
    heatmap_rgba = cm.jet(resized_attention_map)
    heatmap_bgr = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
    alpha = 0.3
    blended_image = cv2.addWeighted(heatmap_bgr, alpha, original_image, 1 - alpha, 0)
    return Image.fromarray(blended_image)
    #
    # return patch_episode_similarity, base_images


def plot_dynamics(task_id):
    task_obs = all_obs[task_id - 10]
    print(f"Task {task_id}: {task_obs[2]}")
    images = [get_max_sim_patch(10, task_id, all_obs, obs_idx=t) for t in range(0, task_obs[-1], 5)]

    def show_image(index):
        plt.figure(figsize=(5, 5))
        plt.imshow(images[index])
        plt.axis('off')
        plt.show()

    slider = widgets.IntSlider(value=0, min=0, max=len(images) - 1, step=1, description='Image:')
    widgets.interact(show_image, index=slider)

In [None]:
plot_dynamics(22)

In [None]:
plot_dynamics(29)

In [None]:
plot_dynamics(23)

In [None]:
plot_dynamics(27)