In [14]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import plotly.express as px

In [114]:
def collectVersionData(env, version):
  BASE_DIR = "./.checkpoints"
  results_dir = f'{BASE_DIR}/{env}/{version}'
  
  data = []
  for folder in sorted(os.listdir(results_dir)):
    if 'result.json' in os.listdir(f'{results_dir}/{folder}'):
      data.append(json.load(open(f'{results_dir}/{folder}/result.json')))
      
  return data

In [None]:
def getAllFromDict(keys, listOfDicts):
    data = {}
    for key in keys:
      data[key] =  [d[key] for d in listOfDicts]
    return data

In [177]:
def collectData(env, versions, stats=[]):
  SAMPLER_KEYS = [
    'episode_reward_max',
    'episode_reward_min',
    'episode_reward_mean',
    "episode_len_mean"
    ]
  DQN_LEARNER_KEYS = [
    'mean_q',
    'mean_td_error'
  ]
  PPO_LEARNER_KEYS =[
    "cur_lr",
    "total_loss"
  ]
  TIME_KEYS = [
    "time_this_iter_s",
    "episodes_total"
  ]
  PERF_KEYS = [
    "cpu_util_percent",
    "ram_util_percent"
  ]
  
  history = pd.DataFrame()
  
  for version in versions:
    data = collectVersionData(env, version)
    for epoch in data:
      epoch_data = {
        'epoch': epoch["training_iteration"],
        'environment': env,
        'version': version
      }
      
      sampler = epoch['sampler_results']
      learner = epoch['info']['learner']['default_policy']['learner_stats']
      perf = epoch['perf']
      
      if 'SAMPLER' in stats:
        for key in SAMPLER_KEYS:
          epoch_data[key] = sampler[key]
      if 'DQN_LEARNER' in stats:
        for key in DQN_LEARNER_KEYS:
          epoch_data[key] = learner[key]
      if 'PPO_LEARNER' in stats:
        for key in PPO_LEARNER_KEYS:
          epoch_data[key] = learner[key]
      if 'TIME' in stats:
        for key in TIME_KEYS:
          epoch_data[key] = epoch[key]
      if 'PERF' in stats:
        for key in PERF_KEYS:
          epoch_data[key] = perf[key]
    
      df = pd.DataFrame(epoch_data, index=[len(history)+1])
      history = pd.concat([history,df])
    
  return history
  


In [180]:
def plotResultsVsEpoch(data, env, versions, keys=[]):
  for key in keys:
    fig = px.line(data, x='epoch', y=key, color='version')
    fig.show()
    
def plotResultsVsEpisode(data,env,versions, keys=[]):
    for key in keys:
      fig = px.line(data, x='episodes_total', y=key, color='version')
      fig.show()



In [178]:
ENV = 'CustomOffWorldDockerMonolithDiscreteSim-v0'
VERSIONS = [
    'depth_only_v2',
    'RGB_only_v1',
    'rgbd_v1'
    ]

data = collectData(ENV, VERSIONS,['SAMPLER','PPO_LEARNER','TIME','PERF'])
data

Unnamed: 0,epoch,environment,version,episode_reward_max,episode_reward_min,episode_reward_mean,episode_len_mean,cur_lr,total_loss,time_this_iter_s,episodes_total,cpu_util_percent,ram_util_percent
1,1,CustomOffWorldDockerMonolithDiscreteSim-v0,depth_only_v2,1.0,0.0,0.16667,36.62500,0.00005,-0.00128,198.378625,24,84.88028,77.04613
2,2,CustomOffWorldDockerMonolithDiscreteSim-v0,depth_only_v2,1.0,0.0,0.21739,40.04348,0.00005,-0.00070,201.156595,46,84.80488,79.42648
3,3,CustomOffWorldDockerMonolithDiscreteSim-v0,depth_only_v2,1.0,0.0,0.13636,44.18182,0.00005,-0.01210,279.636126,66,91.60576,89.78471
4,4,CustomOffWorldDockerMonolithDiscreteSim-v0,depth_only_v2,1.0,0.0,0.15730,44.56180,0.00005,0.00074,289.799797,89,91.59807,90.52126
5,5,CustomOffWorldDockerMonolithDiscreteSim-v0,depth_only_v2,1.0,0.0,0.18000,43.23000,0.00005,0.00121,285.488732,113,91.21744,89.59582
...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,96,CustomOffWorldDockerMonolithDiscreteSim-v0,rgbd_v1,1.0,0.0,0.56000,46.58000,0.00005,-0.01660,208.451929,1601,85.15034,70.92416
196,97,CustomOffWorldDockerMonolithDiscreteSim-v0,rgbd_v1,1.0,0.0,0.49000,48.44000,0.00005,-0.02461,210.436869,1623,84.96833,71.22333
197,98,CustomOffWorldDockerMonolithDiscreteSim-v0,rgbd_v1,1.0,0.0,0.49000,46.40000,0.00005,-0.02717,207.933633,1638,85.19226,71.35926
198,99,CustomOffWorldDockerMonolithDiscreteSim-v0,rgbd_v1,1.0,0.0,0.42000,52.26000,0.00005,-0.03035,206.753668,1653,85.05831,71.77966


In [181]:

plotResultsVsEpoch(data, ENV, VERSIONS, ['episode_reward_mean'])
plotResultsVsEpisode(data,ENV, VERSIONS, ['episode_reward_mean'])