In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import os
import cv2
import uuid
import datetime
import numpy as np
import compress_pickle as cpkl

from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class
from ss_baselines.common.utils import plot_top_down_map

# Helper / tools
from soundspaces.mp3d_utils import CATEGORY_INDEX_MAPPING
def get_category_name(idx):
    assert idx >= 0 and idx <=20, f"Invalid category index number: {idx}"

    for k, v in CATEGORY_INDEX_MAPPING.items():
        if v == idx:
            return k

def get_current_ep_category_label(obs_dict):
    return get_category_name(obs_dict["category"].argmax())

In [None]:
DATASET_DIR_PATH = f"SAVI_Oracle_Dataset_v0"
# DATASET_DIR_PATH = f"SAVI_Oracle_Dataset_v0_10K" # Smaller scale dataset for tests

# Read the dataset statistics file.
dataset_stats_filepath = f"{DATASET_DIR_PATH}/dataset_statistics.bz2"
with open(dataset_stats_filepath, "rb") as f:
    r__dataset_stats = cpkl.load(f)

from pprint import pprint
pprint(r__dataset_stats)

### Plotting some stats about the dataset

In [None]:
# Frequency of category counts
total_episodes = np.sum([v for k, v in r__dataset_stats["category_counts"].items()])
assert total_episodes == r__dataset_stats["total_episodes"], \
    "Category counts not matching episode counts."

category_probs = {k: v / total_episodes for k, v in r__dataset_stats["category_counts"].items()}
# pprint(category_probs)

fig, ax = plt.subplots(1,1, figsize=(6,3), dpi=200)
x=[i for i in range(21)]
x_labels, x_heights = [], []
for k, v in category_probs.items():
    x_labels.append(k)
    x_heights.append(v)
ax.bar(x=x, height=x_heights, tick_label=x_labels)
ax.set_xticklabels(x_labels, rotation=45, ha="right")
fig.suptitle(f"Dset: {DATASET_DIR_PATH} | # steps: {r__dataset_stats['total_steps']} | # eps: {r__dataset_stats['total_episodes']}", fontsize=7)
fig.show()

In [None]:
# Frequency of scenes in the dataset
total_episodes = np.sum([v for k, v in r__dataset_stats["scene_counts"].items()])
assert total_episodes == r__dataset_stats["total_episodes"], \
    "Scene counts not matching episode counts."
scene_probs = {k: v / total_episodes for k, v in r__dataset_stats["scene_counts"].items()}

fig, ax = plt.subplots(1,1, figsize=(6 * 3, 6), dpi=200)
x_labels, x_heights = [], []
for k, v in scene_probs.items():
    x_labels.append(k)
    x_heights.append(v)
n_scenes = len(x_labels)
x = [i for i in range(n_scenes)]

ax.bar(x=x, height=x_heights, tick_label=x_labels)
ax.set_xticklabels(x_labels, rotation=45, ha="right")
fig.suptitle(f"Dset: {DATASET_DIR_PATH} | # steps: {r__dataset_stats['total_steps']} | # eps: {r__dataset_stats['total_episodes']}", fontsize=18)
fig.show()

In [None]:
# # Frequency of the episode lengths
# all_ep_lengths = []
# ep_lengths_dict = {}
# for ep_filename in os.listdir(DATASET_DIR_PATH):
#     if ep_filename == "dataset_statistics.bz2":
#         continue
#     ep_filepath = f"{DATASET_DIR_PATH}/{ep_filename}"
#     with open(ep_filepath, "rb") as f:
#         edd = cpkl.load(f)
    
#     ep_length = edd["ep_length"]
#     if ep_length not in list(ep_lengths_dict.keys()):
#         ep_lengths_dict[ep_length] = 1
#     else:
#         ep_lengths_dict[ep_length] += 1
#     all_ep_lengths.append(ep_length)

# # Histogram of the episodes lengths, note that it is very time costly, since this was not logged during data collection.
# fig, ax = plt.subplots(1,1, figsize=(6 * 3, 6), dpi=200)
# ax.hist(all_ep_lengths, bins=60)

# Extracting trajectories for RSA

### A. C categories, for a given category: N trajs for M rooms

In [None]:
# from IPython.display import clear_output

# # Start byreading all the episodes in 
# M = 3 # number fo scenes / rooms, for one category
# N = 4 # number of trajs. per scenes / rooms, for one category
# CATEGORIES_OF_INTEREST = [
#     "chair",
#     "picture",
#     # "table",
#     # "cushion",
#     # "cabinet",
#     # "plant"
# ]
# C = len(CATEGORIES_OF_INTEREST)

# trajs_scenes_cat = {
#     k: {} for k in CATEGORIES_OF_INTEREST
# }

# n_selected_trajs = 0

# ep_filenames = os.listdir(DATASET_DIR_PATH)
# if "dataset_statistics.bz2" in ep_filenames:
#     ep_filenames.remove("dataset_statistics.bz2")
# ep_filenames_iterator = iter(ep_filenames)

# while n_selected_trajs < C * N * M:
#     ep_filename = next(ep_filenames_iterator)

#     ep_filepath = f"{DATASET_DIR_PATH}/{ep_filename}"
#     with open(ep_filepath, "rb") as f:
#         edd = cpkl.load(f)

#     ep_length = edd["ep_length"]
#     ep_category = edd["category_name"]
#     ep_scene = edd["scene_id"]

#     # Skip if the category does not match
#     if ep_category not in CATEGORIES_OF_INTEREST:
#         continue

#     if ep_scene not in trajs_scenes_cat[ep_category].keys():
#         # First time seeing the scene: add it to the dict, along with the new traj.
#         if len(trajs_scenes_cat[ep_category]) < M:
#             # Only add it if we don't have enough scenes yet.
#             trajs_scenes_cat[ep_category][ep_scene] = [
#                 {
#                     "ep_filename": ep_filename,
#                     "edd": edd
#                 }
#             ]
#             n_selected_trajs += 1
#     else:
#         # The scene was already seen once; check if we need more, and append accordingly
#         if len(trajs_scenes_cat[ep_category][ep_scene]) < N:
#             trajs_scenes_cat[ep_category][ep_scene].append({
#                 "ep_filename": ep_filename,
#                 "edd": edd
#             })
#             n_selected_trajs += 1

#     if n_selected_trajs < N * M:
#         clear_output(wait=True)
    
#     print("### --------------------------------------------------- ###")
#     print(f"### # selected traj: {n_selected_trajs} for \"{ep_category}\"")
#     for k, v in trajs_scenes_cat[ep_category].items():
#         print(f"\t{k}: {len(v)}")
#     print("### --------------------------------------------------- ###")
#     print("")

### B. Similar to A, but make sure we use the same rooms for each category

In [None]:
# from IPython.display import clear_output

# # Start byreading all the episodes in 
# M = 3 # number fo scenes / rooms, for one category
# N = 2 # number of trajs. per scenes / rooms, for one category
# CATEGORIES_OF_INTEREST = [
#     "chair",
#     "picture",
#     # "table",
#     # "cushion",
#     # "cabinet",
#     # "plant"
# ]
# C = len(CATEGORIES_OF_INTEREST)

# trajs_scenes_cat = {
#     k: {} for k in CATEGORIES_OF_INTEREST
# }

# n_selected_trajs = 0

# ep_filenames = os.listdir(DATASET_DIR_PATH)
# if "dataset_statistics.bz2" in ep_filenames:
#     ep_filenames.remove("dataset_statistics.bz2")
# ep_filenames_iterator = iter(ep_filenames)

# scenes_of_interest = [] # To make sure we have the same scenes for each category

# while n_selected_trajs < C * N * M:
#     ep_filename = next(ep_filenames_iterator)

#     ep_filepath = f"{DATASET_DIR_PATH}/{ep_filename}"
#     with open(ep_filepath, "rb") as f:
#         edd = cpkl.load(f)

#     ep_length = edd["ep_length"]
#     ep_category = edd["category_name"]
#     ep_scene = edd["scene_id"]

#     # Skip if the category does not match
#     if ep_category not in CATEGORIES_OF_INTEREST:
#         continue

#     # Track which scenes' trajectories will be saved.
#     # We want the same scenes for each category
#     if len(scenes_of_interest) < M and (ep_scene not in scenes_of_interest):
#         scenes_of_interest.append(ep_scene)
    
#     if ep_scene not in scenes_of_interest:
#         continue
    
#     # Make sure the scene is part of the
#     if ep_scene not in trajs_scenes_cat[ep_category].keys():
#         # First time seeing the scene: add it to the dict, along with the new traj.
#         if len(trajs_scenes_cat[ep_category]) < M:
#             # Only add it if we don't have enough scenes yet.
#             trajs_scenes_cat[ep_category][ep_scene] = [
#                 {
#                     "ep_filename": ep_filename,
#                     "edd": edd
#                 }
#             ]
#             n_selected_trajs += 1
#     else:
#         # The scene was already seen once; check if we need more, and append accordingly
#         if len(trajs_scenes_cat[ep_category][ep_scene]) < N:
#             trajs_scenes_cat[ep_category][ep_scene].append({
#                 "ep_filename": ep_filename,
#                 "edd": edd
#             })
#             n_selected_trajs += 1

#     if n_selected_trajs < N * M:
#         clear_output(wait=True)
    
#     print("### --------------------------------------------------- ###")
#     print(f"### # selected traj: {n_selected_trajs} for \"{ep_category}\"")
#     for k, v in trajs_scenes_cat[ep_category].items():
#         print(f"\t{k}: {len(v)}")
#     print("### --------------------------------------------------- ###")
#     print("")

In [None]:
# # Saving the filtered trajectories data
# # trajs_scenes_cat["chair"] # Check the content
# C = len(CATEGORIES_OF_INTEREST)
# analysis_trajs_filename = f"analysis_trajs_C_{C}_M_{M}_N_{N}.bz2"; print(analysis_trajs_filename)
# # Uncomment for actual saving
# with open(analysis_trajs_filename, "wb") as f:
#     cpkl.dump(trajs_scenes_cat, f)

## Investigating filter scenes from the dataset

In [None]:
analysis_trajs_filename = "analysis_trajs_C_6_M_5_N_5.bz2"

In [None]:
# Read the filtred trajectories data
with open(analysis_trajs_filename, "rb") as f:
    analysis_trajs_dict = cpkl.load(f)

In [None]:
# Structured as follows:
# - category_name
#   - scene_id
#       [{}, {}, {} ...] with N items (episode data)b

analysis_trajs_dict.keys()

In [None]:
## Inspecting the analysis data trajectories
n_plotted_trajs = 0
for i, (cat_name, cat_scenes_trajs) in enumerate(analysis_trajs_dict.items()):
    for scene_id, scene_trajs in cat_scenes_trajs.items():
        for traj_dict in scene_trajs:
            NN = 4
            fig, axes = plt.subplots(1, NN * 2 + 1, figsize=((NN * 2 + 1) *  6, 6), dpi=200)
            fig.set_facecolor("white")

            ep_filename = traj_dict["ep_filename"]
            edd = traj_dict["edd"]

            truncated_obs_list = edd["obs_list"]["rgb"][:NN] + edd["obs_list"]["rgb"][-NN:]
            truncated_top_down_map_list = edd["info_list"][:NN] + edd["info_list"][-NN:]
            # Plot NN * 2 steps + top down map
            # for t, rgb_obs in enumerate(truncated_obs_list):
            #     axes[t].imshow(rgb_obs)
            #     axes[t].tick_params(axis="both", which="both", bottom=False, top=False, left=False, labelleft=False, labelbottom=False)
            # top_down_map_img = plot_top_down_map(edd["info_list"][1])
            # axes[-1].imshow(top_down_map_img)
            # axes[-1].tick_params(axis="both", which="both", bottom=False, top=False, left=False, labelleft=False, labelbottom=False)

            # Plot NN * 2 top_down_map + one rgb_obs
            for t, info_dict in enumerate(truncated_top_down_map_list):
                top_down_map_img = plot_top_down_map(info_dict)
                axes[t].imshow(top_down_map_img)
                axes[t].tick_params(axis="both", which="both", bottom=False, top=False, left=False, labelleft=False, labelbottom=False)
            axes[-1].tick_params(axis="both", which="both", bottom=False, top=False, left=False, labelleft=False, labelbottom=False)
            axes[-1].imshow(truncated_obs_list[-1])

            fig.suptitle(f"Scene: {scene_id} | Target: {cat_name} | Ep. Len.: {edd['ep_length']}", fontsize=32)

            n_plotted_trajs += 1

            if n_plotted_trajs >= 10:
                break
        
        if n_plotted_trajs >= 10:
            break