# I. Loading/Plot functions

In [23]:
DATASETS = {
    '_': 'ALL',
    'a': 'dogs',
    'b': 'medical-leaf',
    'c': 'texture-dtd',
    'd': 'birds',
    'e': 'AWA',
    'f': 'plt-net',
    'g': 'resisc',
    'h': 'plt-doc',
    'i': 'airplanes',
}

In [24]:
import wandb
import pandas as pd
import plotly.graph_objects as go
import numpy as np

# Initialize wandb
wandb.login()

# Set your entity and project
entity_name = "aureliengauffre"  # e.g., your username or team name
project_name = "SMA_all_2_best"

# Initialize the wandb API ||client
api = wandb.Api()

# Fetch all runs from the specified project
runs = api.runs(f"{entity_name}/{project_name}")

# Create an empty list to hold data for each run
data = []

# Loop through runs and ext/ract the data you're interested in
for run in runs:
    # Extract both summary metrics and config (hyperparameters) for each run
    run_data = {
        "name": run.name,
        "summary_metrics": run.summary._json_dict,
        "config": run.config,
        # Add any other attributes you're interested in here
    }
    data.append(run_data)

# Convert the list of data to a pandas DataFrame
df = pd.DataFrame(data)

# For summary metrics and config (hyperparameters), expand them into separate columns
df_summary = pd.json_normalize(df['summary_metrics'])
df_config = pd.json_normalize(df['config'])
df = pd.concat([df.drop(['summary_metrics', 'config'], axis=1), df_summary, df_config], axis=1)
df = df.iloc[:, 1:]  # Drop the first column containing the name of the runs (since there is also another name column)

# Now we have a DataFrame `df` with all runs, their summary metrics, and hyperparameters
# print(df.head())  # Print the first few rows of the DataFrame


In [25]:
df['name']

0                AWA
1                AWA
2                AWA
3                AWA
4                AWA
            ...     
1330       airplanes
1331       airplanes
1332    medical-leaf
1333       airplanes
1334       airplanes
Name: name, Length: 1335, dtype: object

In [26]:
# Initialize a Plotly figure
def plot_graph_old(df, x_axis, y_axis, dataset_name, methods=None):
    # Default methods if not provided
    
    if methods is None:
        methods = ['erm', 'jtt', 'suby', 'subg', 'rwy', 'rwg', 'dro']
    dashed_methods = ['rwy', 'rwg', 'dro']
    color_dict = {
            'erm': 'blue',
            'jtt': 'red',
            'suby': '#00CC96',#green
            'subg': '#ff7f0e', #orange
            'rwy': '#00CC96',
            'rwg': '#ff7f0e',
            'dro': '#DEA0FD' #pruple
        }
    if dataset_name is not None:
        df_dataset = df[(df['name'] == dataset_name) ]#& (df['mu'] >= .1) ]
    else :
        df_dataset = df 
    fig = go.Figure() # Initialize a Plotly figure
    for method in methods:
        df_method = df_dataset[df_dataset['method'] == method]
        # Group by x_axis and calculate the mean of y_axis
        df_avg = df_method.groupby(x_axis)[y_axis].mean().reset_index()
        
        # Determine line style based on whether method is in dashed_methods
        line_style = 'dash' if method in dashed_methods else 'solid'
        
        # Add a line to the plot for the current method
        color = color_dict.get(method, 'grey')
        fig.add_trace(go.Scatter(x=df_avg[x_axis], y=df_avg[y_axis], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color)))
    
    # Update the layout
    fig.update_layout(title=f'{y_axis} vs {x_axis} , dataset={dataset_name}',
                      xaxis_title=x_axis,
                      yaxis_title=y_axis,
                      legend_title='Method',
                      width=800,  # Width of the figure in pixels
                      height=600   # Height of the figure in pixels
                     )

    # Show the figure
    fig.show()

# Example usage (assuming df is your DataFrame):
# plot_graph(df, 'K', 'mean_gr


dataset_name = 'AWA' #73sports
y_axis = 'mean_grp_acc_te' #'mean_grp_acc_te'
x_axis = 'K'
plot_graph(df, x_axis, y_axis, dataset_name, methods=None)


In [27]:
import plotly.graph_objects as go


def normalize_df(df,metric = 'best_acc_te'):
    # Copy the DataFrame to avoid modifying the original data
    result_df = df.copy()
    
    # Group by 'name', 'K', 'mu', and 'init_seed' and calculate the minimum and maximum accuracy
    grouped = df.groupby(['name', 'K', 'mu', 'init_seed'])[metric]
    min_acc = grouped.transform(np.min)
    max_acc = grouped.transform(np.max)
    
    # Apply the Min-Max normalization formula
    result_df[metric] = (df[metric] - min_acc) / (max_acc - min_acc)
    
    return result_df




def plot_graph(df, x_axis, y_axis, dataset_name, error_bars=None, normalize = False, methods=None,):
    print(dataset_name)
    if methods is None or methods == 'ALL':
        methods = ['erm', 'jtt', 'suby', 'subg', 'rwy', 'rwg', 'dro']
    dashed_methods = ['rwy', 'rwg', 'dro']
    color_dict = {
            'erm': 'blue',
            'jtt': 'red',
            'suby': '#00CC96',  # green
            'subg': '#ff7f0e',  # orange
            'rwy': '#00CC96',
            'rwg': '#ff7f0e',
            'dro': '#DEA0FD'  # purple
        }
    if dataset_name is None or dataset_name == 'ALL': 
        df_dataset = df
    else:
        df_dataset = df[(df['name'] == dataset_name)]
    
    if normalize :
        df_dataset = normalize_df(df_dataset,metric=y_axis)
        
    fig = go.Figure()  # Initialize a Plotly figure
    for method in methods:
        df_method = df_dataset[df_dataset['method'] == method]
        # Group by x_axis and calculate the mean, standard deviation, and count (for standard error calculation)
        stats = df_method.groupby(x_axis)[y_axis].agg(['mean', 'std', 'count']).reset_index()
        
        # Calculate standard error (SEM)
        stats['sem'] = stats['std'] / np.sqrt(stats['count'])
        #print(method, stats) # print the number on which we average, interesting !
        # Determine line style based on whether method is in dashed_methods
        line_style = 'dash' if method in dashed_methods else 'solid'
        
        # Add a line with error bars to the plot for the current method
        color = color_dict.get(method, 'grey')
        if error_bars :
            fig.add_trace(go.Scatter(x=stats[x_axis], y=stats['mean'], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color),
                                 error_y=dict(type='data', array=stats['sem'], visible=True)))
        else :
            fig.add_trace(go.Scatter(x=stats[x_axis], y=stats['mean'], mode='lines+markers',
                                 name=method, line=dict(dash=line_style, color=color)))
    
    # Update the layout
    fig.update_layout(title=f'{y_axis} vs {x_axis}, dataset={dataset_name}',
                      xaxis_title=x_axis,
                      yaxis_title=y_axis,
                      legend_title='Method',
                      width=800,  # Width of the figure in pixels
                      height=800  # Height of the figure in pixels
                     )

    # Show the figure
    
    fig.show()



# II. Analysis : K

Here are the **plot parameters** to be played with :

In [56]:
df_fix_mu


Unnamed: 0,lr,_step,_timestamp,mean_grp_acc_te,mean_grp_acc_va,worst_grp_acc_te,best_mean_group_acc_va,best_acc_te,minor_grp_acc_va,relative_grp_acc_te,...,split_seed,config_name,hparams_seed,weight_decay,wandb_project,slurm_partition,n_eval_init_seed,slurm_output_dir,num_hparams_seeds,eval_every_n_epochs
0,,,,,,,,,,,...,0,configAemu1-3.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
1,0.000570,19.0,1.713254e+09,93.665728,93.097147,78.553616,0.930971,0.933728,92.900104,0.983747,...,0,configAemu1-3.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
2,0.000570,19.0,1.713253e+09,92.978616,92.615149,77.022654,0.933772,0.934517,92.362559,0.981919,...,0,configAemu1-3.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
3,0.000570,19.0,1.713253e+09,93.593806,93.167243,80.906149,0.931672,0.935158,92.905043,0.982699,...,0,configAemu1-3.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
4,0.000570,19.0,1.713252e+09,93.480665,93.018381,77.669903,0.930184,0.933974,92.740018,0.980630,...,0,configAemu1-3.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1330,0.001118,19.0,1.712853e+09,76.461039,77.933074,57.142857,0.801916,0.800595,64.381786,0.757967,...,0,configAimu1-2.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
1331,0.002647,19.0,1.712853e+09,81.843157,82.533220,50.000000,0.825332,0.812500,65.066440,0.641372,...,0,configAimu1-1.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
1332,0.002283,19.0,1.712853e+09,77.424483,82.954545,52.941176,0.830492,0.833333,65.909091,0.548490,...,0,configAbmu1-0.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4
1333,0.001118,19.0,1.712853e+09,79.270729,79.253768,59.090909,0.807871,0.833333,75.923675,0.939505,...,0,configAimu1-2.yaml,0,0.001,SMA_all_2,,5,slurm_outputs,20,4


In [53]:
ERROR_BARS = True #Wether to plot the error bars
NORMALIZE = False

To analyse the impact of K, we the **value of mu is fixed**:

In [54]:
df_fix_mu = df[df['mu']==.1] # Currently mu in [.05,.1,.2,.4]

## a. best_acc_te

In [55]:
x_axis = 'K'
y_axis = 'best_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, normalize=NORMALIZE)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


## b. worst_acc_te

In [49]:
x_axis = 'K'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


## c. relative_acc

In [50]:
x_axis = 'K'
y_axis = 'relative_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


## d. minor_group_acc

In [51]:
x_axis = 'K'
y_axis = 'minor_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_mu, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


# III. Analysis : mu

Here are the **plot parameters** to be played with :

In [52]:
ERROR_BARS = True #Wether to plot the error bars
NORMALIZE = True

To analyse the impact of mu, we the **value of K is fixed**:

In [40]:
df_fix_K = df[df['K']==4] # Currently K in [2,4,8,12]


## a. best_acc_te

In [41]:
x_axis = 'mu'
y_axis = 'best_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, normalize=NORMALIZE)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


## b. worst_acc_te

In [42]:
x_axis = 'mu'
y_axis = 'worst_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


## c. relative_acc

In [22]:
x_axis = 'mu'
y_axis = 'relative_grp_acc_te' #'mean_grp_acc_te'
for dataset_name in DATASETS.values():
    plot_graph(df_fix_K, x_axis, y_axis, dataset_name,error_bars=ERROR_BARS, methods=None)

ALL


dogs


medical-leaf


texture-dtd


birds


AWA


plt-net


resisc


plt-doc


airplanes


In [61]:
IV. Examples 

Unnamed: 0,K,mean,std,count,sem
0,2,0.844723,0.11193,295,0.006517
1,4,0.830381,0.140285,244,0.008981
2,8,0.795595,0.143057,217,0.009711
3,12,0.883456,0.023724,27,0.004566
