In [205]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [206]:
def read_json(path):
    f = open(path)
    return json.loads(f.read())

In [270]:
# Dataset stats plots

dataset = read_json("../data/stats/pick_and_place_human_demos.json")
print(dataset.keys())

dict_keys(['episode_length', 'action_frequency'])


In [271]:
def plot_histogram(dataset, key, path, title, y_label, x_label):
    fig_dims = (8, 7)
    sns.set(font_scale = 1)
    fig, ax = plt.subplots(figsize=fig_dims)

    plot = sns.histplot(
        dataset[key],
        bins=35,
        color='#99CCFF',
        ax=ax,
    )
    for i, bar in enumerate(plot.patches):
        bar.set_linewidth(0.5)
        bar.set_edgecolor("w")
        # bar.set_color('#5899DB')

    plt.gca().set(title="")
    plot.set_xlabel(x_label, fontsize=15)
    plot.set_ylabel(y_label, fontsize=15)

    ax.spines['bottom'].set_color('0.1')
    ax.spines['top'].set_color('1')
    ax.spines['right'].set_color('0.1')
    ax.spines['left'].set_color('0.1')

    plt.savefig(path, dpi=400, pad_inches=0.1, transparent=True)
    plt.clf()
    fig.clf()


def plot_barplot(dataset, key, path, title, y_label, x_label):
    colors = ['#5899DB', '#808080',  '#67AB9F', '#F19C99', '#FFB366', '#99CCFF', '#D49555',  '#97D077',  '#A680B8']
    keys = list(dataset[key].keys())
    values = list(map(int, list(dataset[key].values())))
    fig_dims = (7, 7)
#     if "object" in path:
#         fig_dims = (6, 6)
    fig, ax = plt.subplots(figsize=fig_dims)
    
    sns.set(font_scale = 1)
    
    plot = sns.barplot(x=keys, y=values)
    for i, bar in enumerate(ax.patches):
        bar.set_color(colors[i])
    
    for item in plot.get_xticklabels():
        item.set_rotation(45)

    ticks = [0, 200000, 400000, 600000, 800000, 1000000, 1200000, 1400000]
    print(max(values))
    print(values)
    
    if max(values) > 1000000:
        print("lim 1")
        ticks = [0, 200000, 400000, 600000, 800000, 1000000, 1200000, 1400000]
        plot.set_ylim(0, 1450000)

    if max(values) > 2000000:
        print("lim 1")
        ticks = [0, 1000000, 2000000, 3000000, 4000000, 5000000]
        plot.set_ylim(0, 5150000)
    

    if max(values) < 1000000:
        print("lim 2")
        ticks = [0, 200000, 400000, 600000, 800000, 1000000]
        plot.set_ylim(0, 905000)

    if max(values) < 700000:
        print("lim 3")
        ticks = [0, 50000, 100000, 150000, 200000, 250000]
        plot.set_ylim(0, 270000)
    
    plot.set_yticklabels(ticks, size=10)

    ax.spines['bottom'].set_color('0.1')
    ax.spines['top'].set_color('1')
    ax.spines['right'].set_color('0.1')
    ax.spines['left'].set_color('0.1')

    plt.gca().set(title="")
    plot.set_xlabel(x_label, fontsize=15)
    plot.set_ylabel(y_label, fontsize=15)
    plt.savefig(path, dpi=400, bbox_inches="tight", pad_inches=0.1, transparent=True)
    plt.clf()
    fig.clf()

In [272]:
plot_histogram(
    dataset,
    key="episode_length",
    path="figures/pick_and_place_episode_length_distrib.jpg",
    title="Episode length histogram",
    y_label="Number of episodes",
    x_label="Episode length"
)

<Figure size 576x504 with 0 Axes>

In [273]:
plot_barplot(
    dataset,
    key="action_frequency",
    path="figures/pick_and_place_action_distrib.jpg",
    title="Action frequency histogram",
    x_label="Actions",
    y_label="Num actions"
)

4211570
[23942, 1628141, 300040, 2168225, 2290425, 4211570, 372666, 29159, 142972]
lim 1
lim 1




<Figure size 504x504 with 0 Axes>

In [274]:
paths = ["objectnav_human_demos.json", "objectnav_s_path.json", "pick_and_place_s_path.json"]

for path in paths:
    output_path = "figures/{}_episode_length_distrib.jpg".format(path.split(".")[0])
    file_path = "../data/stats/{}".format(path)
    dataset = read_json(file_path)
    if path == "objectnav_s_path.json":
        dataset["action_frequency"].update({
            "LOOK_DOWN": 0,
            "LOOK_UP": 0
        })

    plot_histogram(
        dataset,
        key="episode_length",
        path=output_path,
        title="Episode length histogram",
        y_label="",
        x_label="Episode length"
    )
    output_path = "figures/{}_action_distrib.jpg".format(path.split(".")[0])
    print(dataset["action_frequency"].keys())
    plot_barplot(
        dataset,
        key="action_frequency",
        path=output_path,
        title="Action frequency histogram",
        x_label="Actions",
        y_label=""
    )

dict_keys(['STOP', 'TURN_RIGHT', 'MOVE_FORWARD', 'TURN_LEFT', 'LOOK_DOWN', 'LOOK_UP'])
737833
[8618, 187160, 737833, 182100, 6454, 5098]
lim 2




dict_keys(['STOP', 'TURN_RIGHT', 'MOVE_FORWARD', 'TURN_LEFT', 'LOOK_DOWN', 'LOOK_UP'])
229159
[7378, 54949, 229159, 56687, 0, 0]
lim 2
lim 3




dict_keys(['TURN_LEFT', 'MOVE_FORWARD', 'TURN_RIGHT', 'LOOK_DOWN', 'GRAB_RELEASE', 'LOOK_UP', 'STOP'])
1297121
[717347, 1297121, 665223, 271569, 19700, 119976, 9251]
lim 1




<Figure size 576x504 with 0 Axes>

<Figure size 504x504 with 0 Axes>

<Figure size 576x504 with 0 Axes>

<Figure size 504x504 with 0 Axes>

<Figure size 576x504 with 0 Axes>

<Figure size 504x504 with 0 Axes>