In [3]:
import wandb
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import concurrent.futures
from tqdm.notebook import tqdm


# Authenticate with W&B
wandb.login(timeout=1024)

# Configure project and other parameters if necessary
project_name = "contrastive_exploration_reward_max"
entity = "pletctj6"

# Retrieve the runs from the project
api = wandb.Api()
runs = api.runs(f"{entity}/{project_name}")

# Initialize a list to store coverage and shannon entropy data
experiments_data = {}

### Downloading the data

* config : config for the run
* history : Time evolution of all the data recorded during the run as columns in a pandas dataframe
* summary : last sample of the data recorded

In [4]:
from envs.config_env import config as config_env

def process_run(run, 
                metrics = {
                            "config": ["exp_name", "env_id", "seed", "keep_extrinsic_reward", "beta_ratio", "use_sigmoid", "feature_extractor"],
                            "history": ["specific/episodic_return", "specific/coverage", "specific/shanon_entropy", "global_step"],
                            "summary": ["specific/episodic_return", "specific/coverage", "specific/shanon_entropy", "global_step"]
                },
                config_env=config_env):
    # Vérification de l'état du run
    # if run.state != "finished":
    #     # print(f"Skipping run {run.name} because it is not finished.")
    #     return None
    ##### CONFIGURATION #####
    config = run.config
    config_metrics = {key: None for key in metrics['config']}
    for key in config_metrics.keys():
        try:
            config_metrics[key] = config.get(key)
        except:
            print(f"Skipping run {run.name} because it doesn't have the necessary data.")
            return None
        
    ##### HISTORY #####
    history = run.history()
    history_metrics = {key: None for key in metrics['history']}
    for key in history_metrics.keys():
        if key in history.columns:
            history_metrics[key] = history[key]
        else:
            print(f"Skipping run {run.name} because it doesn't have the necessary data.")
            return None

    ##### SUMMARY #####
    summary_metrics = {key: None for key in metrics['summary']}
    for key in summary_metrics.keys():
        if key in run.summary:
            summary_metrics[key] = run.summary[key]
        else:
            print(f"Skipping run {run.name} because it doesn't have the necessary data.")
            return None
        
    # Check env id 
    type_id = config_env[run.config.get('env_id')]['type_id']
    return {
        'exp_name': config_metrics['exp_name'],
        'env_name': config_metrics['env_id'],
        'type_id': type_id,
        'seed': config_metrics['seed'],
        'data': {
            'summary_metrics': summary_metrics,
            'history_metrics': history_metrics,
            'config_metrics': config_metrics,
            'config': config
        }
    }

experiments_data = {}
max_workers = 2
# Utilisation de ThreadPoolExecutor pour paralléliser les exécutions de runs
# Spécifiez le nombre de threads avec max_workers, par exemple 4 threads
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
    futures = {executor.submit(process_run, run) for run in runs}
    for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing runs"):
        result = future.result()
        if result is not None:
            exp_name = result['exp_name']
            env_name = result['env_name']
            type_id = result['type_id']
            seed = result['seed']
            data = result['data']
            if exp_name not in experiments_data:
                experiments_data[exp_name] = {}
            if type_id not in experiments_data[exp_name]:
                experiments_data[exp_name][type_id] = {}
            if env_name not in experiments_data[exp_name][type_id]:
                experiments_data[exp_name][type_id][env_name] = {}
            if seed not in experiments_data[exp_name][type_id][env_name]:
                experiments_data[exp_name][type_id][env_name][seed] = data

Processing runs:   0%|          | 0/623 [00:00<?, ?it/s]

Skipping run DMCS-Ball-in-cup-v0__ngu_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__5 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__3 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__1 because it doesn't have the necessary data.
Skipping run Reacher-v4__rnd_ppo__5 because it doesn't have the necessary data.
Skipping run Reacher-v4__rnd_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__rnd_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__rnd_ppo__3 because it doesn't have the necessary data.
Skipping run Reacher-v4__rnd_ppo__1 because it doesn't have the necessary data.
Skipping run DMCS-Cart-k-Pole-v0__icm_ppo__1 because it doesn't have the necessary data.Skipping run HumanoidSt



Skipping run Reacher-v4__icm_ppo__3 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__5 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__icm_ppo__1 because it doesn't have the necessary data.
Skipping run DMCS-Cart-k-Pole-v0__icm_ppo__1 because it doesn't have the necessary data.
Skipping run DMCS-Cart-k-Pole-v0__aux_ppo__1 because it doesn't have the necessary data.Skipping run Reacher-v4__aux_ppo__3 because it doesn't have the necessary data.
Skipping run Reacher-v4__aux_ppo__1 because it doesn't have the necessary data.
Skipping run Reacher-v4__aux_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__aux_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__aux_ppo__5 because it doesn't have the necessary data.




Skipping run Reacher-v4__apt_ppo__5 because it doesn't have the necessary data.Skipping run DMCS-Cart-k-Pole-v0__aux_ppo__1 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__1 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__3 because it doesn't have the necessary data.




Skipping run Reacher-v4__apt_ppo__5 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__4 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__1 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__2 because it doesn't have the necessary data.
Skipping run Reacher-v4__apt_ppo__3 because it doesn't have the necessary data.
Skipping run DMCS-Finger-v0__apt_ppo__2 because it doesn't have the necessary data.
Skipping run DMCS-Finger-v0__apt_ppo__1 because it doesn't have the necessary data.
Skipping run DMCS-Finger-v0__apt_ppo__3 because it doesn't have the necessary data.
Skipping run DMCS-Finger-v0__apt_ppo__4 because it doesn't have the necessary data.Skipping run DMCS-Finger-v0__apt_ppo__5 because it doesn't have the necessary data.
Skipping run DMCS-Ball-in-cup-v0__apt_ppo__2 because it doesn't have the necessary data.Skipping run DMCS-Ball-in-cup-v0__apt_ppo__5 because it doesn't have the necessary data.
Skip

#### Check data

In [30]:
experiments_data['apt_ppo']['mujoco']['HalfCheetah-v3'][3]['history_metrics']['specific/episodic_return']
# # check nan in data
# is_nan = experiments_data['apt_ppo']['mujoco']['HalfCheetah-v3'][3]['history_metrics']['specific/episodic_return'].isnull().values.any()
# print(is_nan)

0             NaN
1             NaN
2     -644.028198
3     -485.806061
4     -561.165405
          ...    
495           NaN
496   -481.006317
497           NaN
498   -460.569122
499   -488.707458
Name: specific/episodic_return, Length: 500, dtype: float64

### Learning Curve

#### Change name

In [14]:
def traductor_exp(exp_name:str):
    # split name by _
    list_words = exp_name.split("_")
    if 'v1' in list_words: 
        if 'kl' in list_words:
            name = "V1KL"
        elif 'lipshitz' in list_words:
            name = "V1W"
    elif 'v2' in list_words:
        if 'kl' in list_words:
            name = "V2KL"
        elif 'lipshitz' in list_words:
            name = "V2W"
    else:
        # remove _ppo at the end
        name = list_words[0]
    # Upper case
    name = name.upper()
    return name


##### list 

In [20]:
# algo
list_algos = list(experiments_data.keys())
# type
list_type = list(experiments_data[list_algos[0]].keys())
# env
list_env = []
for i in range(len(list_type)):
    list_env += list(list(experiments_data[list_algos[0]][list_type[i]].keys()))

print(list_algos)
print(list_type)
print(list_env)

['ppo', 'rnd_ppo', 'ngu_ppo', 'icm_ppo', 'aux_ppo', 'apt_ppo']
['robotics', 'dmcs', 'mujoco']
['FetchSlide-v2', 'FetchReach-v1', 'FetchPush-v2', 'DMCS-Ball-in-cup-v0', 'DMCS-Cart-k-Pole-v0', 'DMCS-Acrobot-v0', 'DMCS-Finger-v0', 'DMCS-Fish-v0', 'HalfCheetah-v3', 'Hopper-v3', 'Ant-v3', 'Walker2d-v3', 'Humanoid-v3', 'HumanoidStandup-v4', 'Reacher-v4', 'Swimmer-v3']


In [None]:
import pandas as pd
import numpy as np
import os
from tabulate import tabulate

def concat_time_serie(experiments_data, 
                        keys=['specific/coverage', 'specific/coverage_mu', 'specific/shanon_entropy', 'specific/shannon_entropy_mu'], 
                        type_id_default = None, 
                        env_name_default = None):
    key_data = {}
    for exp_name in experiments_data.keys():
        for type_id in experiments_data[exp_name].keys():
            if type_id_default is not None and type_id not in type_id_default:
                continue
            for env_name in experiments_data[exp_name][type_id].keys():
                if env_name_default is not None and env_name not in env_name_default:
                    continue
                for seed in experiments_data[exp_name][type_id][env_name].keys():
                    run_data = experiments_data[exp_name][type_id][env_name][seed]
                    for key in keys:
                        if key not in key_data:
                            key_data[key] = []
                        metric = run_data['history_metrics'][key]
                        key_data[key].append({
                            'exp_name': traductor_exp(exp_name),
                            'env_name': env_name,
                            'seed': seed,
                            key: metric
                        })
    key_df = {}
    for key in keys:
        key_df[key] = pd.DataFrame(key_data[key])
    return key_df 

key_df = concat_time_serie(experiments_data, type_id_default=[type_id], env_name_default=list(experiments_data[list_algos[0]][type_id].keys()))