In [1]:
import wandb
api = wandb.Api()
import pandas as pd
import sys

pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 100)
pd.set_option('display.width', 100)    

def get_data(group, dataset):
    # Project is specified by <entity/project-name>

    algo = group
    performances = {}
    performances['train'] = []
    performances['reachable'] = []
    performances['unreachable'] = []
    # performances['CQL'] = {}
    # performances['DT'] = {}
    # performances['BC'] = {}
    # performances['SUNRISE'] = {}
    # performances['IQL'] = {}
    # performances['SAC-N'] = {}

    filtered_runs = api.runs("gold-ai/ALL_MIXED", filters={"config.group": algo, "config.dataset": dataset})

    # filtered_runs = [run for run in runs if group.lower() == run.config['group'].lower()]
    # filtered_runs = [run for run in filtered_runs if dataset.lower() == run.config['dataset'].lower()]

    progress = 0
    # print(f"progress: {progress}/{len(filtered_runs)}")


    for run in filtered_runs:
        train_data = ""
        reachable_data = ""
        unreachable_data = ""
        samples = 0

        if algo == "CQL":
            train_data = "Train Performance"
            reachable_data = "Reachable Performance"
            unreachable_data = "Unreachable Performance"
            samples = 300
            reachable_relative_step = 0
            unreachable_relative_step = 0
            
        if algo == "BC":
            train_data = "Cumulative Reward.Train"
            reachable_data = "Cumulative Reward.Test_reachable"
            unreachable_data = "Cumulative Reward.Test_unreachable"
            samples = 100
            reachable_relative_step = 0
            unreachable_relative_step = 0


        if algo == "DT":
            train_data = "Cumulative Reward.Train"
            reachable_data = "Cumulative Reward.Test_reachable"
            unreachable_data = "Cumulative Reward.Test_unreachable"
            samples = 100
            reachable_relative_step = 0
            unreachable_relative_step = 0


        if algo == "SUNRISE":
            train_data = "train/reward_mean"
            reachable_data = "eval_reachable/reward_mean"
            unreachable_data = "eval_unreachable/reward_mean"
            samples = 3000
            reachable_relative_step = -2
            unreachable_relative_step = -1

        if algo == "IQL":
            train_data = "Eval Reward"
            reachable_data = "Test Cumulative Reachable Reward"
            unreachable_data = "Test Cumulative Unreachable Reward"
            samples = 1000
            reachable_relative_step = 0
            unreachable_relative_step = 1


        if algo == "SAC-N":
            train_data = "train/reward_mean"
            reachable_data = "eval_reachable/reward_mean"
            unreachable_data = "eval_unreachable/reward_mean"
            samples = 3000
            reachable_relative_step = -2
            unreachable_relative_step = -1


        history = run.history(samples=100000, pandas=True)
        history = history[['_step', train_data, reachable_data, unreachable_data]]
        if not history.empty:
            train_max_row_id = history[train_data].idxmax()
            train_max_row = history.iloc[train_max_row_id]
            reachable_max_row = history.iloc[train_max_row_id + reachable_relative_step]
            unreachable_max_row = history.iloc[train_max_row_id + unreachable_relative_step]
            
            performances['train'].append(train_max_row[train_data])
            performances['reachable'].append(reachable_max_row[reachable_data])
            performances['unreachable'].append(unreachable_max_row[unreachable_data])     
        else:
            print("Empty history for group: ", group, " and dataset: ", dataset)

        

        # # history = run.history(samples=samples, keys=None, x_axis="_step", pandas=(True), stream="default")
        # history = run.history(samples=200, pandas=True)
        # train_performances = [r for r in history[train_data].tolist() if pd.notna(r)] # Why was this needed?
        # # train_performances = history[train_data].tolist()
        # reachable_performances = history[reachable_data].tolist()
        # unreachable_performances = history[unreachable_data].tolist()
        # history = run.history(samples=samples, keys=None, x_axis="_step", pandas=(True), stream="default")

        # max_train_performance = max(train_performances)
        # max_train_performance_step = train_performances.index(max_train_performance)
        # print()
        # print(max_train_performance_step)
        # reachable_performance = -1
        # if pd.isna(reachable_performances[max_train_performance_step]):
        #     print("NAN in reachable performance for max train performance for group: ", group)
        #     for i in range(max_train_performance_step+1, len(reachable_performances)):
        #         if not pd.isna(reachable_performances[i]):
        #             reachable_performance = reachable_performances[i]
        #             break
        # else:
        #     reachable_performance = reachable_performances[max_train_performance_step]

        # unreachable_performance = -1
        # if pd.isna(unreachable_performances[max_train_performance_step]):
        #     print("NAN in unreachable performance for max train performance for group: ", group)
        #     for i in range(max_train_performance_step+1, len(unreachable_performances)):
        #         if not pd.isna(unreachable_performances[i]):
        #             unreachable_performance = unreachable_performances[i]
        #             break
        # else:
        #     unreachable_performance = unreachable_performances[max_train_performance_step]

        # if dataset not in performances[algo]:
        #     performances[algo][dataset] = {}
        #     performances[algo][dataset]['train'] = []
        #     performances[algo][dataset]['reachable'] = []
        #     performances[algo][dataset]['unreachable'] = []
        # performances[algo][dataset]['train'].append(max_train_performance)
        # performances[algo][dataset]['reachable'].append(reachable_performance)
        # performances[algo][dataset]['unreachable'].append(unreachable_performance)

        # performances["train"].append(max_train_performance)
        # performances["reachable"].append(reachable_performance)
        # performances["unreachable"].append(unreachable_performance)

        progress += 1
        # print(f"progress: {progress}/{len(filtered_runs)}")
    # print("\n")
    return performances


In [3]:
# for group in tqdm(groups):
#     performances[group] = get_data(group, dataset)

### plot mixed

In [21]:
import matplotlib.pyplot as plt
import plotly.io as pio
import numpy as np
import plotly.graph_objects as go
import os

def plot_mixed(performances, dataset, groups, title, wandb_name):
    metrics = ["train", "reachable", "unreachable"]

    fig = go.Figure()

    x = np.arange(len(groups))  # the label locations
    width = 0.25  # the width of the bars

    mean_values = {metric: [] for metric in metrics}
    stderr_values = {metric: [] for metric in metrics}

    for group in groups:
        for metric in metrics:
            values = np.array(performances[group][metric])
            mean = np.nanmean(values)
            stderr = np.nanstd(values) / np.sqrt(len(values[~np.isnan(values)]))
            mean_values[metric].append(mean)
            stderr_values[metric].append(stderr)

    fig.add_trace(go.Bar(
        x=x - width/1.3, 
        y=mean_values['train'], 
        name='Training', 
        error_y=dict(type='data', array=stderr_values['train']),
        marker_color='#1f77b4'
    ))

    fig.add_trace(go.Bar(
        x=x, 
        y=mean_values['reachable'], 
        name='Reachable', 
        error_y=dict(type='data', array=stderr_values['reachable']),
        marker_color='red'
    ))

    fig.add_trace(go.Bar(
        x=x + width/1.3, 
        y=mean_values['unreachable'], 
        name='Unreachable', 
        error_y=dict(type='data', array=stderr_values['unreachable']),
        marker_color='orange'
    ))

    fig.update_layout(
        title=dict(text="<b>"+title+"</b>", x=0.5,  xanchor='center', y=0.95, font=dict(size=16)),
        xaxis=dict(
            title="Algorithm",
            tickvals=x,
            ticktext=groups,
            titlefont=dict(size=15), tickfont=dict(size=14)
        ),
        yaxis=dict(title='Mean Reward', titlefont=dict(size=15), tickfont=dict(size=14), title_standoff=1,  range=[0, 1.01]),
        barmode='overlay',
        bargap=0, # gap between bars of adjacent location coordinates.
        bargroupgap= 0.6, # gap between bars of the same location coordinate.
        legend=dict(
            # title="<b>Environment</b>",
            orientation="h",
            yanchor="top",
            font=dict(size=12),
            y=1.1,
            xanchor="center",
            bgcolor='rgba(0,0,0,0)',
            x=0.5
        ),
        width=550, height=350,
        margin={'t':60,'l':60,'b':0,'r':25}
    )

    # Log the plot to wandb
    # wandb.init(project="plots", entity="gold-ai", config={"group": "plots", "dataset": dataset}, name=wandb_name)
    # wandb.log({"plot": fig})
    # wandb.finish()

    if not os.path.exists("plots"):
        os.mkdir("plots")

    if not os.path.exists("plots/json"):
        os.mkdir("plots/json")

    if not os.path.exists("plots/html"):
        os.mkdir("plots/html")

    fig.write_image(f"plots/{wandb_name}.png")
    pio.write_json(fig, f"plots/json/{wandb_name}.json")
    fig.write_html(f"plots/html/{wandb_name}.html")




In [23]:
from tqdm.notebook import trange, tqdm
groups = ["BC", "DT", "CQL", "IQL", "SUNRISE", "SAC-N"]
datasets_interested_in = ['0, 25', '50', '0, 25, 50', '0, 25, 50, 75', '0', '0, 50', '25', '25, 50', '0, 75', '0, 0, 0, 0, 0', '0, 0, 0, 25, 50', '0, 0, 0, 50, 75', '0, 50, 100', '0, 25, 50, 75, 100']
# datasets_interested_in = ["0, 25, 50, 75, 100"]

def generate_title(dataset):
    letters = []
    for letter in dataset.split(", "):
        if letter == "0":
            letters.append("Expert")
        elif letter == "25":
            letters.append("Advanced")
        elif letter == "50":
            letters.append("Medium")
        elif letter == "75":
            letters.append("Basic")
        elif letter == "100":
            letters.append("Random")
    return "+".join(letters)

for dataset in tqdm(datasets_interested_in, desc="Datasets", position=0):
    tqdm.write(f"Dataset: {dataset}")
    performances = {}
    title = generate_title(dataset)
    for group in tqdm(groups, desc="Groups", position=1, leave=False):
        tqdm.write(f"Group: {group}")
        performances[group] = get_data(group, dataset)
    plot_mixed(performances, dataset, groups, title, f"{dataset}-{title}")


Datasets:   0%|          | 0/14 [00:00<?, ?it/s]

Dataset: 0, 25


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 50


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 25, 50


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 25, 50, 75


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 50


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 25


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 25, 50


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 75


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 0, 0, 0, 0


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 0, 0, 25, 50


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 0, 0, 50, 75


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 50, 100


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N
Dataset: 0, 25, 50, 75, 100


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Group: BC
Group: DT
Group: CQL
Group: IQL
Group: SUNRISE
Group: SAC-N


### plot for presentation

In [59]:
import matplotlib.pyplot as plt
import plotly.io as pio
import numpy as np
import plotly.graph_objects as go
import os
from tqdm.notebook import tqdm
from plotly.colors import sample_colorscale, sequential

def plot_for_presentation(performances, datasets, groups, title):
    """only 3 datasets possible"""
    fig = go.Figure()

    x = np.arange(len(groups))  # the label locations
    width = 0.25  # the width of the bars

    mean_values = {}
    mean_values['reachable'] = {}
    mean_values['unreachable'] = {}
    mean_values["reachable"] = {dataset: [] for dataset in datasets}
    mean_values["unreachable"] = {dataset: [] for dataset in datasets}
    # stderr_values = {dataset: [] for dataset in datasets}

    for dataset in datasets:
        for group in groups:
            reachable_values = np.array(performances[group][dataset]["reachable"])
            unreachable_values = np.array(performances[group][dataset]["unreachable"])
            reachable_mean = np.nanmean(reachable_values)
            unreachable_mean = np.nanmean(unreachable_values)
            mean_values['reachable'][dataset].append(reachable_mean)
            mean_values['unreachable'][dataset].append(unreachable_mean)

    for i, metric in enumerate(["reachable", "unreachable"]):
        color_scale = sequential.Mint
        color1 = sample_colorscale(color_scale, 0.25)[0]
        color2 = sample_colorscale(color_scale, 0.75)[0]

        # Convert RGB to HEX format for Plotly
        # color1 = rgb_to_hex(color1)
        # color2 = rgb_to_hex(color2)

        fig.add_trace(go.Bar(
            x=x - width/1.3, 
            y=mean_values[metric][datasets[0]],
            name=f'{dataset[0]}', 
            # error_y=dict(type='data', array=stderr_values['train']),
            marker_color=color1 if i == 0 else color2
        ))
    for i, metric in enumerate(["reachable", "unreachable"]):
        color_scale = sequential.Purp
        color1 = sample_colorscale(color_scale, 0.25)[0]
        color2 = sample_colorscale(color_scale, 0.75)[0]

        # Convert RGB to HEX format for Plotly
        # color1 = rgb_to_hex(color1)
        # color2 = rgb_to_hex(color2)
        fig.add_trace(go.Bar(
            x=x, 
            y=mean_values[metric][datasets[1]], 
            name=f'{dataset}', 
            # error_y=dict(type='data', array=stderr_values['reachable']),
            marker_color=color1 if i == 0 else color2
        ))
    for i, metric in enumerate(["reachable", "unreachable"]):
        color_scale = sequential.Peach
        color1 = sample_colorscale(color_scale, 0.25)[0]
        color2 = sample_colorscale(color_scale, 0.75)[0]

        # Convert RGB to HEX format for Plotly
        # color1 = rgb_to_hex(color1)
        # color2 = rgb_to_hex(color2)
        fig.add_trace(go.Bar(
            x=x + width/1.3, 
            y=mean_values[metric][datasets[2]],
            name=f'{dataset[2]}', 
            # error_y=dict(type='data', array=stderr_values['unreachable']),
            marker_color=color1 if i == 0 else color2
        ))

    fig.update_layout(
        title=dict(text="<b>"+title+"</b>", x=0.8, y=0.95, font=dict(size=16)),
        xaxis=dict(
            title="Algorithm",
            tickvals=x,
            ticktext=groups,
            titlefont=dict(size=15), tickfont=dict(size=14)
        ),
        yaxis=dict(title='Mean Reward', titlefont=dict(size=15), tickfont=dict(size=14), title_standoff=1),
        barmode='overlay',
        bargap=0, # gap between bars of adjacent location coordinates.
        bargroupgap= 0.6, # gap between bars of the same location coordinate.
        # legend=dict(
        #     # title="<b>Environment</b>",
        #     orientation="h",
        #     yanchor="top",
        #     font=dict(size=12),
        #     y=1.1,
        #     xanchor="center",
        #     bgcolor='rgba(0,0,0,0)',
        #     x=0.5
        # ),
        width=550, height=350,
        margin={'t':60,'l':60,'b':0,'r':25}
    )

    # Log the plot to wandb
    # wandb.init(project="plots", entity="gold-ai", config={"group": "plots", "dataset": dataset}, name=wandb_name)
    # wandb.log({"plot": fig})
    # wandb.finish()

    if not os.path.exists("presentation_plots"):
        os.mkdir("presentation_plots")

    if not os.path.exists("presentation_plots/json"):
        os.mkdir("presentation_plots/json")

    if not os.path.exists("presentation_plots/html"):
        os.mkdir("presentation_plots/html")

    fig.write_image(f"presentation_plots/{title}.png")
    pio.write_json(fig, f"presentation_plots/json/{title}.json")
    fig.write_html(f"presentation_plots/html/{title}.html")

    fig.show()



In [51]:
# groups = ["BC", "DT", "CQL", "IQL", "SUNRISE", "SAC-N"]
# datasets = ["0", "0, 50", "0, 25, 50, 75, 100"]
# performances = {}
# for group in tqdm(groups, desc="Groups", position=0):
#     performances[group] = {}
#     for dataset in tqdm(datasets, desc="Datasets", position=1):
#         performances[group][dataset] = {}
#         data = get_data(group, dataset)
#         performances[group][dataset]["reachable"] = data["reachable"]
#         performances[group][dataset]["unreachable"] = data["unreachable"]


Groups:   0%|          | 0/6 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

Datasets:   0%|          | 0/3 [00:00<?, ?it/s]

In [60]:
plot_for_presentation(performances, datasets, groups, "Expert vs Medium vs All")

### read figures and change name

In [13]:
# import plotly.graph_objs as go
# import plotly.io as pio

# fig = pio.read_json('plots/json/0, 25, 50, 75, 100.json')
# fig.update_layout(
#     title='Updated Title',
#     xaxis=dict(title='Updated X-axis Title'),
#     yaxis=dict(title='Updated Y-axis Title')
# )
# fig.show()

### plot expert

In [294]:
wandb.finish()
import plotly.graph_objs as go
import plotly.io as pio
import numpy as np

groups = ["BC", "DT", "CQL", "IQL", "SUNRISE", "SAC-N"]
metrics = ['train', 'reachable', 'unreachable']
title = "Performance on the Expert Dataset"

mean_values = {metric: [] for metric in metrics}
stderr_values = {metric: [] for metric in metrics}

for group in groups:
    for metric in metrics:
        values = np.array(performances[group][metric])
        mean = np.nanmean(values)
        stderr = np.nanstd(values) / np.sqrt(len(values[~np.isnan(values)]))
        mean_values[metric].append(mean)
        stderr_values[metric].append(stderr)

# Create bars with error bars using Plotly
fig = go.Figure()

# Define patterns for each metric
metric_patterns = {
    'train': '',
    'reachable': 'x',
    'unreachable': '.'
}

# Define colors for each group
group_colors = {
    'CQL': 'blue',
    'DT': 'green',
    'BC': 'purple',
    'IQL': 'orange',
    'SUNRISE': 'red',
    'SAC-N': 'cyan'
}

x = np.arange(len(groups))  # the label locations
width = 0.25  # the width of the bars


fig.add_trace(go.Bar(
    x=x - width/1.3, 
    y=mean_values['train'], 
    name='Training', 
    error_y=dict(type='data', array=stderr_values['train']),
    marker_color='#1f77b4'
))

fig.add_trace(go.Bar(
    x=x, 
    y=mean_values['reachable'], 
    name='Reachable', 
    error_y=dict(type='data', array=stderr_values['reachable']),
    marker_color='red'
))

fig.add_trace(go.Bar(
    x=x + width/1.3, 
    y=mean_values['unreachable'], 
    name='Unreachable', 
    error_y=dict(type='data', array=stderr_values['unreachable']),
    marker_color='orange'
))

## for doing bar plots like in the paper
# show_legend = True
# for i, group in enumerate(groups):
#     fig.add_trace(go.Bar(
#         x=[x[i] - width/1.7], 
#         y=[mean_values['train'][i]], 
#         name=f'Train' if show_legend else None, 
#         error_y=dict(type='data', array=[stderr_values['train'][i]]),
#         marker_color=group_colors[group],
#         marker_pattern_shape=metric_patterns['train'],
#         showlegend=show_legend
#     ))
#     fig.add_trace(go.Bar(
#         x=[x[i]], 
#         y=[mean_values['reachable'][i]], 
#         name=f'Reachable' if show_legend else None, 
#         error_y=dict(type='data', array=[stderr_values['reachable'][i]]),
#         marker_color=group_colors[group],
#         marker_pattern_shape=metric_patterns['reachable'],
#         showlegend=show_legend
#     ))
#     fig.add_trace(go.Bar(
#         x=[x[i] + width/1.7], 
#         y=[mean_values['unreachable'][i]], 
#         name=f'Unreachable' if show_legend else None, 
#         error_y=dict(type='data', array=[stderr_values['unreachable'][i]]),
#         marker_color=group_colors[group],
#         marker_pattern_shape=metric_patterns['unreachable'],
#         showlegend=show_legend
#     ))

#     show_legend = False

# Update layout
fig.update_layout(
    title=dict(text="<b>"+title+"</b>", x=0.8, y=0.95, font=dict(size=16)),
    xaxis=dict(
        title="Algorithm",
        tickvals=x,
        ticktext=groups,
        titlefont=dict(size=15), tickfont=dict(size=14)
    ),
    yaxis=dict(title='Mean Reward', titlefont=dict(size=15), tickfont=dict(size=14), title_standoff=1),
    barmode='overlay',
    bargap=0, # gap between bars of adjacent location coordinates.
    bargroupgap= 0.6, # gap between bars of the same location coordinate.
    legend=dict(
        # title="<b>Environment</b>",
        orientation="h",
        yanchor="top",
        font=dict(size=12),
        y=1.1,
        xanchor="center",
        bgcolor='rgba(0,0,0,0)',
        x=0.5
    ),
    width=550, height=350,
    margin={'t':60,'l':60,'b':0,'r':25}
)

# Log the plot to wandb
# wandb.init(project="plots", entity="gold-ai", config={"group": "plots", "dataset": dataset})
# wandb.log({"plot": fig})
# wandb.finish()

# Display the plot
fig.show()

### not sure where this came from XD

In [None]:
x=0.05),
    barmode='overlay',
    bargap=0, # gap between bars of adjacent location coordinates.
    bargroupgap= 0.6, # gap between bars of the same location coordinate.
    legend=dict(
        # title="<b>Environment</b>",
        orientation="h",
        yanchor="top",
        font=dict(size=11),
        y=1.1,
        xanchor="center",
        bgcolor='rgba(0,0,0,0)',
        x=0.5
    ),
    width=550, height=350,
    margin={'t':60,'l':20,'b':7,'r':25}
)

# Log the plot to wandb
# wandb.init(project="plots", entity="gold-ai", config={"group": "plots", "dataset": dataset})
# wandb.log({"plot": fig})
# wandb.finish()

# Display the plot
fig.show()

### Plotting with matplotlib (not well compatible with wandb)

In [None]:
# from enum import Enum
# import matplotlib.pyplot as plt
# import numpy as np
# wandb.finish()

# class Group(float, Enum):
#     CQL = 1
#     DT = 2
#     BC = 3
#     IQL = 4
#     SUNRISE = 5
#     SAC_N = 6

# group_enum = {group: Group[group] for group in groups}

# metrics = ['train', 'reachable', 'unreachable']

# mean_values = {metric: [] for metric in metrics}
# stderr_values = {metric: [] for metric in metrics}

# for group in groups:
#     for metric in metrics:
#         values = np.array(performances[group][metric])
#         mean = np.nanmean(values)
#         stderr = np.nanstd(values) / np.sqrt(len(values[~np.isnan(values)]))
#         mean_values[metric].append(mean)
#         stderr_values[metric].append(stderr)

# x = np.arange(len(groups))  # the label locations
# width = 0.25  # the width of the bars

# fig, ax = plt.subplots(figsize=(7, 4))

# # Create bars with error bars
# rects1 = ax.bar(x - width/3, mean_values['train'], width, yerr=stderr_values['train'], label='Train', color='cyan')
# rects2 = ax.bar(x, mean_values['reachable'], width, yerr=stderr_values['reachable'], label='Reachable', color='red')
# rects3 = ax.bar(x + width/3, mean_values['unreachable'], width, yerr=stderr_values['unreachable'], label='Unreachable', color='orange')

# # Add some text for labels, title and custom x-axis tick labels, etc.
# ax.set_xlabel('Group')
# ax.set_ylabel('Scores')
# ax.set_title('Performance Metrics by Group')
# ax.set_xticks(x)
# ax.set_xticklabels(group_enum)
# ax.legend()

# # Function to add labels on bars
# def add_labels(rects):
#     for rect in rects:
#         height = rect.get_height()
#         if not np.isnan(height):
#             ax.annotate(f'{height:.2f}',
#                         xy=(rect.get_x() + rect.get_width() / 2, height),
#                         xytext=(15, 0),  # 3 points vertical offset
#                         textcoords="offset points",
#                         ha='center', va='bottom')

# add_labels(rects1)
# add_labels(rects2)
# add_labels(rects3)

# fig.tight_layout()
# plt.show()

# wandb.init(project="plots", entity="gold-ai", config={"group": "plots", "dataset": dataset})
# wandb.log({"plot": plt})

# wandb.finish()