# I. Loading/Plot functions

In [3]:
DATASETS = {
    'a': 'dogs',
    'b': 'medical-leaf',
    'c': 'texture-dtd',
    'd': 'birds',
    'e': 'AWA',
    'f': 'plt-net',
    'g': 'resisc',
    'h': 'plt-doc',
    'i': 'airplanes',
    '_': 'ALL'
}

color_dict = {
        'erm': 'blue',
        'jtt': 'red',
        'suby': '#00CC96',  # green
        'subg': '#ff7f0e',  # orange
        'rwy': '#00CC96',
        'rwg': '#ff7f0e',
        'dro': '#DEA0FD'  # purple
    }
    

In [4]:
import wandb
import pandas as pd
import plotly.graph_objects as go
import numpy as np

# Initialize wandb
wandb.login()

# Set your entity and project
entity_name = "aureliengauffre"  # e.g., your username or team name
project_name = "SMA_all_2_best"

# Initialize the wandb API ||client
api = wandb.Api()

# Fetch all runs from the specified project
runs = api.runs(f"{entity_name}/{project_name}")

# Create an empty list to hold data for each run
data = []

# Loop through runs and ext/ract the data you're interested in
for run in runs:
    # Extract both summary metrics and config (hyperparameters) for each run
    run_data = {
        "name": run.name,
        "summary_metrics": run.summary._json_dict,
        "config": run.config,
        # Add any other attributes you're interested in here
    }
    data.append(run_data)

# Convert the list of data to a pandas DataFrame
df = pd.DataFrame(data)

# For summary metrics and config (hyperparameters), expand them into separate columns
df_summary = pd.json_normalize(df['summary_metrics'])
df_config = pd.json_normalize(df['config'])
df = pd.concat([df.drop(['summary_metrics', 'config'], axis=1), df_summary, df_config], axis=1)
df = df.iloc[:, 1:]  # Drop the first column containing the name of the runs (since there is also another name column)

# Now we have a DataFrame `df` with all runs, their summary metrics, and hyperparameters
# print(df.head())  # Print the first few rows of the DataFrame


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: aureliengauffre (use `wandb login --relogin` to force relogin)


In [5]:
df['name']

0            plt-doc
1            plt-doc
2            plt-net
3            plt-net
4            plt-net
            ...     
3561       airplanes
3562       airplanes
3563    medical-leaf
3564       airplanes
3565       airplanes
Name: name, Length: 3566, dtype: object

In [6]:
# Initialize a Plotly figure
def plot_graph_old(df, x_axis, y_axis, dataset_name, methods=None):
    # Default methods if not provided
    
    if methods is None:
        methods = ['erm', 'jtt', 'suby', 'subg', 'rwy', 'rwg', 'dro']
    dashed_methods = ['rwy', 'rwg', 'dro']
    
    if dataset_name is not None:
        df_dataset = df[(df['name'] == dataset_name) ]#& (df['mu'] >= .1) ]
    else :
        df_dataset = df 
    fig = go.Figure() # Initialize a Plotly figure
    for method in methods:
        df_method = df_dataset[df_dataset['method'] == method]
        # Group by x_axis and calculate the mean of y_axis
        df_avg = df_method.groupby(x_axis)[y_axis].mean().reset_index()
        
        # Determine line style based on whether method is in dashed_methods
        line_style = 'dash' if method in dashed_methods else 'solid'
        
        # Add a line to the plot for the current method
        color = color_dict.get(method, 'grey')
        fig.add_trace(go.Scatter(x=df_avg[x_axis], y=df_avg[y_axis], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color)))
    
    # Update the layout
    fig.update_layout(title=f'{y_axis} vs {x_axis} , dataset={dataset_name}',
                      xaxis_title=x_axis,
                      yaxis_title=y_axis,
                      legend_title='Method',
                      width=800,  # Width of the figure in pixels
                      height=600   # Height of the figure in pixels
                     )

    # Show the figure
    fig.show()

# Example usage (assuming df is your DataFrame):
# plot_graph(df, 'K', 'mean_gr


dataset_name = 'AWA' #73sports
y_axis = 'mean_grp_acc_te' #'mean_grp_acc_te'
x_axis = 'K'
plot_graph(df, x_axis, y_axis, dataset_name, methods=None)


In [62]:
import plotly.graph_objects as go


def normalize_df(df,metric = 'best_acc_te'):
    # Copy the DataFrame to avoid modifying the original data
    result_df = df.copy()
    
    # Group by 'name', 'K', 'mu', and 'init_seed' and calculate the minimum and maximum accuracy
    grouped = df.groupby(['name', 'K', 'mu', 'init_seed'])[metric]
    min_acc = grouped.transform(np.min)
    max_acc = grouped.transform(np.max)
    
    # Apply the Min-Max normalization formula
    result_df[metric] = (df[metric] - min_acc) / (max_acc - min_acc)
    
    return result_df




def plot_graph(df, x_axis, y_axis, dataset_name, error_bars=None, normalize = False, methods=None,):
    print(dataset_name)
    if methods is None or methods == 'ALL':
        methods = ['erm', 'jtt', 'suby', 'subg', 'rwy', 'rwg', 'dro']
    dashed_methods = ['rwy', 'rwg', 'dro']
    color_dict = {
            'erm': 'blue',
            'jtt': 'red',
            'suby': '#00CC96',  # green
            'subg': '#ff7f0e',  # orange
            'rwy': '#00CC96',
            'rwg': '#ff7f0e',
            'dro': '#DEA0FD'  # purple
        }
    if dataset_name is None or dataset_name == 'ALL': 
        df_dataset = df
    else:
        df_dataset = df[(df['name'] == dataset_name)]
    
    if normalize :
        df_dataset = normalize_df(df_dataset,metric=y_axis)
        
    fig = go.Figure()  # Initialize a Plotly figure
    for method in methods:
        df_method = df_dataset[df_dataset['method'] == method]
        # Group by x_axis and calculate the mean, standard deviation, and count (for standard error calculation)
        stats = df_method.groupby(x_axis)[y_axis].agg(['mean', 'std', 'count']).reset_index()
        
        # Calculate standard error (SEM)
        stats['sem'] = stats['std'] / np.sqrt(stats['count'])
        #print(method, stats) # print the number on which we average, interesting !
        # Determine line style based on whether method is in dashed_methods
        line_style = 'dash' if method in dashed_methods else 'solid'
        
        # Add a line with error bars to the plot for the current method
        color = color_dict.get(method, 'grey')
        if error_bars :
            fig.add_trace(go.Scatter(x=stats[x_axis], y=stats['mean'], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color),
                                 error_y=dict(type='data', array=stats['sem'], visible=True)))
        else :
            fig.add_trace(go.Scatter(x=stats[x_axis], y=stats['mean'], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color)))
    
    # Update the layout
    fig.update_layout(title=f'{y_axis} vs {x_axis}, dataset={dataset_name}',
                      xaxis_title=x_axis,
                      yaxis_title=y_axis,
                      legend_title='Method',
                      width=800,  # Width of the figure in pixels
                      height=800  # Height of the figure in pixels
                     )

    # Show the figure
    
    fig.show()


def plot_graph_all(df, x_axis, y_axis, error_bars=None, normalize=False, methods=None):
    """Unlike the original plot_graph function which calculates and plots error bars
    based on individual datasets directly, this version computes the standard error within each dataset first 
    and then averages these errors across all datasets for each method. This approach provides a generalized 
    view of method performance and variability across different datasets."""
    
    if methods is None or methods == 'ALL':
        methods = ['erm', 'jtt', 'suby', 'subg', 'rwy', 'rwg', 'dro']
    dashed_methods = ['rwy', 'rwg', 'dro']
    
    if normalize:
        df = normalize_df(df, metric=y_axis)

    fig = go.Figure()  # Initialize a Plotly figure
    all_stats = pd.DataFrame()

    # Process each method separately
    for method in methods:
        df_method = df[df['method'] == method]

        # Group data first by 'name' and then by x_axis and compute statistics
        grouped = df_method.groupby(['name', x_axis])
        stats = grouped[y_axis].agg(['mean', 'std', 'count']).reset_index()

        # Calculate standard error within each dataset
        stats['sem'] = stats['std'] / np.sqrt(stats['count'])
        
        # Append results to the all_stats DataFrame for later display
        stats['method'] = method
        all_stats = pd.concat([all_stats, stats], ignore_index=True)
        
        # Now group by x_axis and calculate the mean of the means and the mean of the SEMs
        final_stats = stats.groupby(x_axis).agg({'mean': 'mean', 'sem': 'mean'}).reset_index()

        # Determine line style based on whether method is in dashed_methods
        line_style = 'dash' if method in dashed_methods else 'solid'
        
        # Add a line with error bars to the plot for the current method
        color = color_dict.get(method, 'grey')
        if error_bars:
            fig.add_trace(go.Scatter(x=final_stats[x_axis], y=final_stats['mean'], mode='lines+markers',
                                     name=method, line=dict(dash=line_style, color=color),
                                     error_y=dict(type='data', array=final_stats['sem'], visible=True)))
        else:
            fig.add_trace(go.Scatter(x=final_stats[x_axis], y=final_stats['mean'], mode='lines+markers',
                                     name=method, line=dict(dash=line_style, color=color)))

    # Update the layout
    fig.update_layout(title=f'{y_axis} vs {x_axis}, Average on all dataset (mean and std)',
                      xaxis_title=x_axis,
                      yaxis_title=y_axis,
                      legend_title='Method',
                      width=800,  # Width of the figure in pixels
                      height=800  # Height of the figure in pixels
                     )

    # Show the figure
    fig.show()

    # Display the all_stats DataFrame
    return all_stats.pivot_table(index=[x_axis, 'name'], columns='method', values=['count', 'sem'], aggfunc='first')



# II. Analysis : K

Here are the **plot parameters** to be played with :

In [63]:
ERROR_BARS = True # Wether to plot the error bars
NORMALIZE = False

To analyse the impact of K, we the **value of mu is fixed**:

In [72]:
df_fix_mu = df[df['mu']==.05] # Currently mu in [.05,.1,.2,.4]

df_fix_mu_K2 = df[df['K']==2] 
df_fix_mu_K4 = df[df['K']==4] 
df_fix_mu_K8 = df[df['K']==8] 
df_fix_mu_K12 = df[df['K']==12] 

## a. best_acc_te

In [73]:
x_axis = 'K'
y_axis = 'best_acc_te' #'mean_grp_acc_te'
plot_graph_all(df_fix_mu, x_axis, y_axis,error_bars=ERROR_BARS, normalize=NORMALIZE)


Unnamed: 0_level_0,Unnamed: 1_level_0,count,count,count,count,count,count,count,sem,sem,sem,sem,sem,sem,sem
Unnamed: 0_level_1,method,dro,erm,jtt,rwg,rwy,subg,suby,dro,erm,jtt,rwg,rwy,subg,suby
K,name,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2
2,AWA,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.004747,0.004054,0.007132,0.006483,0.01045,0.001766,0.005871
2,airplanes,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.026757,0.022423,0.031813,0.033598,0.013061,0.010423,0.010402
2,birds,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.009868,0.008473,0.015478,0.010897,0.009312,0.009032,0.017364
2,dogs,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.035725,0.018371,0.020373,0.018918,0.007196,0.017807,0.015543
2,medical-leaf,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.016197,0.01392,0.031995,0.009443,0.01464,0.036265,0.029514
2,plt-doc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.030264,0.016465,0.024901,0.016794,0.007456,0.01677,0.010886
2,plt-net,5.0,5.0,10.0,5.0,5.0,10.0,10.0,0.010854,0.001717,0.006165,0.009055,0.012012,0.003042,0.004424
2,resisc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.002964,0.013566,0.007819,0.017274,0.007656,0.022949,0.010064
2,texture-dtd,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.011744,0.01927,0.01044,0.012917,0.00916,0.045728,0.021916
4,AWA,5.0,5.0,5.0,5.0,5.0,10.0,5.0,0.00319,0.003885,0.005705,0.005085,0.005435,0.004109,0.002976


In [70]:
NORMALIZE

False

In [74]:
x_axis = 'K'
y_axis = 'best_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, normalize=NORMALIZE)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


## b. worst_acc_te

In [75]:
x_axis = 'K'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
plot_graph_all(df_fix_mu, x_axis, y_axis,error_bars=False, normalize=NORMALIZE)

Unnamed: 0_level_0,Unnamed: 1_level_0,count,count,count,count,count,count,count,sem,sem,sem,sem,sem,sem,sem
Unnamed: 0_level_1,method,dro,erm,jtt,rwg,rwy,subg,suby,dro,erm,jtt,rwg,rwy,subg,suby
K,name,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2
2,AWA,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.285351,7.81577,4.474268,8.141688,4.872016,1.386377,6.253393
2,airplanes,5.0,5.0,5.0,5.0,5.0,5.0,5.0,9.015927,4.307786,0.927458,6.662022,5.712818,4.630114,1.496619
2,birds,5.0,5.0,5.0,5.0,5.0,5.0,5.0,5.909254,14.278238,2.968569,10.256449,7.664581,2.671972,5.154132
2,dogs,5.0,5.0,5.0,5.0,5.0,5.0,5.0,1.269381,2.21881,1.538159,2.567449,1.059739,4.139578,1.644408
2,medical-leaf,5.0,5.0,5.0,5.0,5.0,5.0,5.0,1.176471,1.176471,1.440876,1.440876,2.352941,5.3464,4.822336
2,plt-doc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,2.113973,0.827737,2.847008,5.062968,2.510024,2.963089,0.437409
2,plt-net,5.0,5.0,10.0,5.0,5.0,10.0,10.0,2.684134,0.995766,2.001805,2.996579,5.101449,1.226934,1.131341
2,resisc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,1.452186,1.26527,1.385051,0.968904,2.220294,7.880341,2.270785
2,texture-dtd,5.0,5.0,5.0,5.0,5.0,5.0,5.0,8.888889,3.664141,3.685139,7.954345,6.04765,5.241101,2.421611
4,AWA,5.0,5.0,5.0,5.0,5.0,10.0,5.0,4.160248,3.901398,6.215691,4.692563,3.96046,1.488967,4.094905


In [13]:
x_axis = 'K'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


## c. relative_acc

In [76]:
x_axis = 'K'
y_axis = 'relative_grp_acc_te' #'mean_grp_acc_te'
plot_graph_all(df_fix_mu, x_axis, y_axis,error_bars=False, normalize=NORMALIZE)

Unnamed: 0_level_0,Unnamed: 1_level_0,count,count,count,count,count,count,count,sem,sem,sem,sem,sem,sem,sem
Unnamed: 0_level_1,method,dro,erm,jtt,rwg,rwy,subg,suby,dro,erm,jtt,rwg,rwy,subg,suby
K,name,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2
2,AWA,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.023009,0.032909,0.019411,0.038138,0.025972,0.012687,0.027103
2,airplanes,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.053216,0.02826,0.007513,0.077661,0.047595,0.084258,0.026313
2,birds,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.023893,0.081425,0.008835,0.048309,0.038195,0.017649,0.034727
2,dogs,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.016514,0.011094,0.01787,0.011521,0.014715,0.067297,0.015656
2,medical-leaf,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.019066,0.012334,0.044998,0.018169,0.028794,0.066516,0.044626
2,plt-doc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.027198,0.035198,0.034602,0.062536,0.052096,0.108599,0.023398
2,plt-net,5.0,5.0,10.0,5.0,5.0,10.0,10.0,0.009724,0.005037,0.009984,0.012773,0.038393,0.018111,0.009785
2,resisc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.020655,0.035743,0.091196,0.039923,0.014712,0.109069,0.01983
2,texture-dtd,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.036957,0.024296,0.019642,0.042898,0.02686,0.111077,0.029593
4,AWA,5.0,5.0,5.0,5.0,5.0,10.0,5.0,0.007931,0.009687,0.006736,0.006937,0.01113,0.004982,0.005812


In [14]:
x_axis = 'K'
y_axis = 'relative_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


## d. minor_group_acc

In [15]:
x_axis = 'K'
y_axis = 'minor_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


# III. Analysis : mu

Here are the **plot parameters** to be played with :

In [16]:
ERROR_BARS = True #Wether to plot the error bars
NORMALIZE = False

To analyse the impact of mu, we the **value of K is fixed**:

In [17]:
df_fix_K = df[df['K']==4] # Currently K in [2,4,8,12]


## a. best_acc_te

In [18]:
x_axis = 'mu'
y_axis = 'best_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, normalize=NORMALIZE)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


## b. worst_acc_te

In [19]:
x_axis = 'K'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
plot_graph_all(df_fix_mu, x_axis, y_axis,error_bars=ERROR_BARS, normalize=NORMALIZE)

Unnamed: 0_level_0,Unnamed: 1_level_0,count,count,count,count,count,count,count,sem,sem,sem,sem,sem,sem,sem
Unnamed: 0_level_1,method,dro,erm,jtt,rwg,rwy,subg,suby,dro,erm,jtt,rwg,rwy,subg,suby
K,name,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2
2,AWA,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.092018,0.132032,0.127909,0.109882,0.094693,0.0,0.126117
2,airplanes,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.153852,0.07589,0.013539,0.105813,0.091734,0.080541,0.0
2,birds,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.148002,0.181808,0.132811,0.194167,0.178965,0.147394,0.140133
2,dogs,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.024087,0.03069,0.025268,0.049911,0.027393,0.0,0.016629
2,medical-leaf,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.028583,0.01447,0.034319,0.030321,0.036364,0.0,0.104876
2,plt-doc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.035873,0.026982,0.051883,0.078977,0.035866,0.0,0.007857
2,plt-net,5.0,5.0,10.0,5.0,5.0,10.0,10.0,0.057312,0.026696,0.032402,0.058839,0.112816,0.020325,0.022271
2,resisc,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.126936,0.02956,0.043077,0.142499,0.112168,0.1785,0.09784
2,texture-dtd,5.0,5.0,5.0,5.0,5.0,5.0,5.0,0.139816,0.055556,0.088696,0.113709,0.079679,0.112279,0.036679
4,AWA,5.0,5.0,5.0,5.0,5.0,10.0,5.0,0.108243,0.077551,0.086022,0.117733,0.065932,0.021874,0.069323


In [20]:
x_axis = 'mu'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


## c. relative_acc

In [21]:
x_axis = 'mu'
y_axis = 'relative_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


ALL


# IV. Examples 

## a. Dataset difficulty :

In [22]:
def plot_violin(df, metric, category):
    """
    Creates and displays a violin plot for the given DataFrame.
    Parameters:
    - df: pandas DataFrame containing the data to plot.
    - metric: str, the name of the column in df containing the values to plot.
    - category: str, the name of the column in df representing different categories to separate the violins.

    """
    fig = go.Figure()
    categories = df[category].unique()
    for cat in categories:
        cat_data = df[df[category] == cat][metric]
        fig.add_trace(go.Violin(y=cat_data, name=cat, box_visible=True, meanline_visible=True))

    fig.update_layout(title=f"{metric} by {category} for dataset",
                      yaxis_title=metric,
                      legend_title=category)

    # Show the figure
    fig.show()

In [23]:
plot_violin(df_fix_mu,'best_acc_te','method')

In [24]:
plot_violin(df_fix_mu_K2,'worst_grp_acc_te','method')

In [25]:
plot_violin(df_fix_mu_K12,'worst_grp_acc_te','method')

In [26]:
plot_violin(df_fix_mu,'worst_grp_acc_te','method')

In [27]:
plot_violin(df_fix_mu,'best_acc_te','name')

In [53]:
plot_violin(df_fix_mu_K12,'best_acc_te','name')

In [29]:
plot_violin(df_fix_mu,'worst_grp_acc_te','name')

In [30]:
plot_violin(df_fix_mu_K12,'best_acc_te','name')

In [54]:
df_1 = df_fix_mu[df_fix_mu['name'] == 'dogs']
df_1 = df_1[df_1['K'] == 12]
df_1 = df_1[df_1['method'] == 'erm']
df_1['best_acc_te']

2844    0.827342
2847    0.820479
2848    0.813617
2849    0.828649
2850    0.828105
Name: best_acc_te, dtype: float64

In [56]:
df_2 = df_fix_mu[df_fix_mu['name'] == 'medical-leaf']
df_2 = df_2[df_2['K'] == 12]
df_2 = df_2[df_2['method'] == 'erm']
df_2['best_acc_te']

2516    0.910802
2517    0.883642
2518    0.916049
2519    0.899691
2520    0.904630
Name: best_acc_te, dtype: float64

In [33]:
df_2['best_acc_te'].std()/(5)**.5 # We check the computation from previous error bar in graphs

0.013566206179819724