In [None]:
import numpy as np
import matplotlib
import json
import pandas as pd
import os
from collections import defaultdict
from nltk.tokenize import word_tokenize


In [None]:
def get_num_tokens(text):
    return len(word_tokenize(text))


In [None]:
splits = ["train", "valid_seen", "valid_unseen"]
task_fields = ["task_type", "focus_object", "base_object", "dest_object", "scene"]
data_path = "../tars/alfred/data/json_2.1.0"
stats_dict = defaultdict(lambda: [])

for split in splits:
    task_dirs = os.listdir("{}/{}".format(data_path, split))
    print("{} ({} param sets)".format(split, len(task_dirs)))
    for i, task_dir in enumerate(task_dirs):
        if i % 500 == 0:
            print(i)
        task_values = task_dir.split("-")
        if len(task_values) < 5:
            print(task_values)

        for trial_dir in os.listdir("{}/{}/{}".format(data_path, split, task_dir)):
            stats_dict["split"].append(split)
            stats_dict["task_id"].append(trial_dir)

            for j, field in enumerate(task_fields):
                stats_dict[field].append(task_values[j])

            traj_data_file = open("{}/{}/{}/{}/traj_data.json".format(data_path, split, task_dir, trial_dir))
            traj_data = json.load(traj_data_file)
            num_steps_list = []
            num_step_tokens_list = []
            num_task_tokens_list = []

            for directive in traj_data["turk_annotations"]["anns"]:
                num_steps_list.append(len(directive["high_descs"]))
                num_step_tokens_list.append(sum([get_num_tokens(desc) for desc in directive["high_descs"]]))
                num_task_tokens_list.append(get_num_tokens(directive["task_desc"]))

            stats_dict["steps"].append(round(np.mean(num_steps_list), 2))
            stats_dict["total_steps_toks"].append(round(np.mean(num_step_tokens_list), 2))
            stats_dict["task_toks"].append(round(np.mean(num_task_tokens_list), 2))
            stats_dict["images"].append(len(traj_data["images"]))
            stats_dict["actions"].append(len(traj_data["plan"]["low_actions"]))
            nav_count = 0
            interact_count = 0

            for action in traj_data["plan"]["low_actions"]:
                if "mask" in action["discrete_action"]["args"].keys():
                    interact_count += 1
                else:
                    nav_count += 1

            stats_dict["nav_actions"].append(nav_count)
            stats_dict["interact_actions"].append(interact_count)


stats_df = pd.DataFrame(stats_dict)
        
  

In [None]:
stats_df