In [None]:
import numpy as np
import matplotlib
import json
import pandas as pd
import os
from collections import defaultdict
from nltk.tokenize import word_tokenize


In [None]:
def get_num_tokens(text):
    return len(word_tokenize(text))

In [None]:
# Build metadata table for generating dataset statistics
splits = ["train", "valid_seen", "valid_unseen"]
task_fields = ["task_type", "focus_object", "base_object", "dest_object", "scene"]
data_path = "../tars/alfred/data/json_2.1.0"
stats_dict = defaultdict(lambda: [])

for split in splits:
    task_dirs = os.listdir("{}/{}".format(data_path, split))
    print("{} ({} param sets)".format(split, len(task_dirs)))
    for i, task_dir in enumerate(task_dirs):
        if i % 500 == 0:
            print(i)
        task_values = task_dir.split("-")
        if len(task_values) < 5:
            print(task_values)

        for trial_dir in os.listdir("{}/{}/{}".format(data_path, split, task_dir)):
            stats_dict["split"].append(split)
            stats_dict["task_id"].append(trial_dir)

            for j, field in enumerate(task_fields):
                stats_dict[field].append(task_values[j])

            traj_data_file = open("{}/{}/{}/{}/traj_data.json".format(data_path, split, task_dir, trial_dir))
            traj_data = json.load(traj_data_file)
            num_steps_list = []
            num_step_tokens_list = []
            num_task_tokens_list = []

            for directive in traj_data["turk_annotations"]["anns"]:
                num_steps_list.append(len(directive["high_descs"]))
                num_step_tokens_list.append(sum([get_num_tokens(desc) for desc in directive["high_descs"]]))
                num_task_tokens_list.append(get_num_tokens(directive["task_desc"]))

            stats_dict["steps"].append(np.mean(num_steps_list))
            stats_dict["total_steps_toks"].append(np.mean(num_step_tokens_list))
            stats_dict["task_toks"].append(np.mean(num_task_tokens_list))
            stats_dict["images"].append(len(traj_data["images"]))
            stats_dict["actions"].append(len(traj_data["plan"]["low_actions"]))
            nav_count = 0
            interact_count = 0

            for action in traj_data["plan"]["low_actions"]:
                if "mask" in action["discrete_action"]["args"].keys():
                    interact_count += 1
                else:
                    nav_count += 1
                    
            stats_dict["nav_actions"].append(nav_count)
            stats_dict["interact_actions"].append(interact_count)


stats_df = pd.DataFrame(stats_dict)
    

In [None]:
# Derive some additional columns
stats_df["toks/step"] = stats_df["total_steps_toks"] / stats_df["steps"]
stats_df["actions/step"] = stats_df["actions"] / stats_df["steps"]
stats_df["images/action"] = stats_df["images"] / stats_df["actions"]
stats_df["nav/interact"] = stats_df["nav_actions"] / stats_df["interact_actions"]

stats_df = stats_df.round(2)


In [None]:
stats_df

In [None]:
# Plot frequency of a categorical field. Useful for task_type and maybe objects.
def plot_freq(col, **kwargs):
    splits = ["train", "valid_seen", "valid_unseen"]
    freq_cols = {split:stats_df[stats_df["split"] == split][col].value_counts(normalize=True) for split in splits}    
    pd.concat(freq_cols, axis=1).plot.bar(xlabel=col, ylabel="Relative Frequency", **kwargs)

# Plot histogram of a quantitative field. 
def plot_hist(col, **kwargs):
    splits = ["train", "valid_seen", "valid_unseen"]
    axes = stats_df.hist(col, by="split", sharex=True, **kwargs)
    for ax in axes.reshape(-1):
        ax.tick_params(axis="x", which="both", labelbottom=True)
        ax.set_xlabel(col, visible=True)
        ax.set_ylabel("Frequency")

In [None]:
plot_freq("task_type", figsize=(8, 4))

In [None]:
plot_freq("focus_object", figsize=(20, 4))

In [None]:
plot_freq("dest_object", figsize=(12, 4))

In [None]:
plot_hist("steps", figsize=(12, 6))

In [None]:
plot_hist("actions", figsize=(12, 6), bins=16)