
# MoE Metrics Data Analysis Notebook

How to run this notebook:
- Select the dataset used in the experiments you want to analyze in cell 0 (the cell after this markdown cell).
- Optionally add tags for inclusion or exclusion
- Run all or individual cells.
- If you already created a csv export you can skip the first cell in section 'Expert Activation Heatmaps' (takes a long time to execute)

Dependencies:
- aimStack
- SNS (Seaborn)
- Matplotlib
- PyPlot
- Numpy
- Pandas

In [None]:
"""
Select the dataset for which to run analysis
"""
from typing import Literal

dataset: Literal['mnist', 'cifar10', 'cinic10'] = 'cifar10' # mnist or cifar10 or cinic10
include_tags = [] # Tags of experiments to include, empty list means include all
exclude_tags = [] # Tags of experiments to exclude, empty list means exclude none

bootstrap_reps = 1000 # Number of bootstrap repetitions for confidence intervals
last_runs_to_consider = 10 # Some calculations that compare different architectures need an equal amount of runs for each architecture.

repo_path = '../..' # The (relative) path to the aim repository

In [None]:
"""
Common imports and definitions
This cell contains common imports, definitions and functions for subsequent cells
"""
import seaborn as sns
import pandas as pd
from aim import Repo
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import warnings
from collections import defaultdict
warnings.filterwarnings('default')

RAND_STATE = np.random.RandomState(42)

def common():
    from aim import Repo
    repo = Repo(repo_path)

    # Friendly model names for table display only
    models_friendly_names = {
        'softmoe': 'SoftMoE',
        'noisytopk': 'Top-k',
        'expert_choice': 'Expert Choice',
        'ffn': 'FFN',
        'roundrobin': 'Round Robin',
        'expert_segmentation': 'Expert Segmentation',
        'switch_transformer': 'Switch Transformer'
    }

    # ...likewise for columns
    column_friendly_names = {
        "moe_type": "MoE Type",
        "step": "Step",
        "epoch": "Epoch",
        "accuracy": "Accuracy",
        "profiling_memory_usage": "Memory Usage",
        "profiling_total_cpu_time": "Total CPU Time",
        "loss": "Loss",
        "macs": "MACs",
        "expert_activations": "Expert Activations"
    }

    return repo, models_friendly_names, column_friendly_names

repo, models_friendly_names, column_friendly_names = common()

# Set color palette and style
COLORS = sns.color_palette('colorblind')
COLORS_IDX = [0, 3, 4, 2, 1, 7, 8]
MOE_ARCH_COLOR_DICT = dict(zip(models_friendly_names.keys(), [COLORS[idx] for idx in COLORS_IDX]))
sns.set_style("white")

COMMON_CELLS = True # This just tells subsequent cells that this particular cell has been executed

The next cell queries the aim repo for metrics of interest, then builds two dataframes:
- accuracy_df : used in subsequent analysis steps
- accuracy_df_styled : based on accuracy_df and intended for direct export

In [None]:
"""
Mean, Max and Min metrics data query
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

seq = repo.query_metrics(f"metric.name in ['loss', 'accuracy', 'profiling_total_cpu_time', 'profiling_memory_usage'] \
                         and metric.context.subset == 'val' or ((metric.name == 'macs' or metric.name.startswith('expert_')) and metric.context.subset == 'train') \
                         ")

all_rows = []
df_activations = defaultdict(list)

for metric in seq:
    if dataset not in metric.run.props.tags or set(exclude_tags).intersection(set(metric.run.props.tags)):
        continue
    
    # Find MoE type for this metric in run tags
    model = ""
    for model in models_friendly_names.keys():
        if model in metric.run.props.tags:
            model = model
            break
        
    for x in metric.data:
        step = x[0]
        data = x[1]
        epoch = data[1]
        value = data[0]
        
        if metric.name.startswith('expert_'):
            df_activations[metric.run.hash, step, epoch, model].append(value)
           
        else:
            all_rows.append({
                "hash": metric.run.hash,
                "moe_type": model,
                "context": "val",
                "step": step,
                "epoch": epoch,
                "metric_name": metric.name,
                "value": value
            })


for (hash, step, epoch, model), values in df_activations.items():
    all_rows.append({
        "hash": hash,
        "moe_type": model,
        "context": "train",
        "step": step,
        "epoch": epoch,
        "metric_name": "expert_activations",
        "value": values
    })
df = pd.DataFrame(all_rows)

accuracy_df = df.pivot_table(
    index=['hash', 'moe_type', 'context', 'step', 'epoch'],
    columns='metric_name',
    values='value',
    aggfunc='first'
).reset_index()

accuracy_df.columns.name = None

print(f"Shape: {accuracy_df.shape}")
print(f"Columns: {list(accuracy_df.columns)}")

accuracy_df_styled = accuracy_df.style.format({
    "hash": "{:s}",
    "moe_type": lambda x: models_friendly_names.get(x, x),
    "context": "{:s}",
    "step": "{:.0f}",
    "epoch": "{:.0f}",
    "accuracy": "{:.3f}",
    "profiling_memory_usage": "{:.0f}",
    "profiling_total_cpu_time": "{:.2f}",
    "macs": "{:.0f}",
    "expert_activations": "{:.0f}"
}).hide(axis=0).hide(subset=['hash', 'context'], axis=1).relabel_index([new for old,new in column_friendly_names.items()], axis=1)
accuracy_df

## Group metrics and calculate mean/min/max
The next two cells group the metrics a) by MoE and epoch and b) only by MoE for exporting a summary table.
### Group by MoE and Epoch

In [None]:
"""
Mean, Max and Min metrics table generation
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

# We need dummy entries because styler.relabel doesn't properly handle levels in multi column indexes
column_friendly_names_2 = {
    "moe_type": "MoE Type",
    "epoch": "Epoch",
    "accuracy": "Accuracy",
    "profiling_memory_usage": "Memory Usage",
    "profiling_total_cpu_time": "Total CPU Time",
    "placeholder_1": "Memory Usage",
    "placeholder_2": "placeholder",
    "placeholder_3": "placeholder",
    "placeholder_4": "Total CPU Time",
    "placeholder_5": "placeholder",
    "placeholder_6": "placeholder",
    "macs": "MACs",
    "loss": "Loss",
    "placeholder_7": "placeholder",
    "placeholder_8": "placeholder",
}

accuracy_grouped = accuracy_df.groupby(["moe_type", "epoch"]).agg({"accuracy": ["mean", "min", "max"], "profiling_memory_usage": ["mean", "min", "max"], "profiling_total_cpu_time": ["mean", "min", "max"], "macs": ["max"], "loss": ["mean", "min", "max"]}).reset_index()
accuracy_grouped
formatter = {
    ("moe_type", ""): lambda x: models_friendly_names.get(x, x),
    ("epoch", ""): "{:.0f}",
    ("accuracy", "mean"): "{:.3f}",
    ("accuracy", "min"): "{:.3f}",
    ("accuracy", "max"): "{:.3f}",
    ("profiling_memory_usage", "mean"): "{:.0f}",
    ("profiling_memory_usage", "min"): "{:.0f}",
    ("profiling_memory_usage", "max"): "{:.0f}",
    ("profiling_total_cpu_time", "mean"): "{:.2f}",
    ("profiling_total_cpu_time", "min"): "{:.2f}",
    ("profiling_total_cpu_time", "max"): "{:.2f}",
    ("macs", ""): "{:.0f}",
    ("loss", "mean"): "{:.3f}",
    ("loss", "min"): "{:.3f}",
    ("loss", "max"): "{:.3f}"
}
accuracy_grouped_styled = accuracy_grouped.style.format(formatter).hide(axis=0).relabel_index(labels=[new for old, new in column_friendly_names_2.items()], axis=1, level=0)
print(accuracy_grouped_styled.to_latex(f"../../report/appendices/model_accuracy_progression_table_{dataset}.tex", environment="longtblr"))
accuracy_grouped

### Group by MoE only

In [None]:
"""
Group by MoE architecture
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 
    
# Only group by moe type for the summary table
column_friendly_names_3 = {key: value for key, value in column_friendly_names_2.items() if key != "epoch"}

accuracy_grouped_2 = accuracy_df.groupby("moe_type").agg({"accuracy": ["mean", "min", "max"], "profiling_memory_usage": ["mean", "min", "max"], "profiling_total_cpu_time": ["mean", "min", "max"], "macs": ["max"], "loss": ["mean", "min", "max"]}).reset_index()
accuracy_grouped_2_styled = accuracy_grouped_2.style.format(formatter).hide(axis=0).relabel_index(labels=[new for old, new in column_friendly_names_3.items()], axis=1, level=0)
accuracy_grouped_2_styled

## Accuracy and Loss Charts

This cell plots and saves accuracy and loss charts. For this, the min, max and mean (based on the exponential moving average) are picked/calculated.

In [None]:
"""
Accuracy and Loss Charts
References: 
- https://stackoverflow.com/questions/76317946/calculate-exponential-moving-average-using-pandas-dataframe
- https://www.geeksforgeeks.org/pandas/how-to-calculate-moving-average-in-a-pandas-dataframe/
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

def exponential_moving_average(data, alpha=0.3):
    if len(data) == 0:
        return data
    cleaned = data.dropna()
    if cleaned.iloc[0] is float('nan'):
        smoothed = [0]
    else:
        smoothed = [cleaned.iloc[0]]
    
    for i in range(1, len(cleaned)):
        smoothed.append(alpha * cleaned.iloc[i] + (1 - alpha) * smoothed[-1])
    return pd.Series(smoothed, index=cleaned.index)

models = [model for model in accuracy_grouped['moe_type'].unique() if model]
for model in models:
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    individual_runs = accuracy_df[accuracy_df['moe_type'] == model]
    
    for i, hash_val in enumerate(individual_runs['hash'].unique()):
        run_data = individual_runs[individual_runs['hash'] == hash_val].sort_values('epoch')
        
        if len(run_data) > 0:
            smoothed_accuracy = exponential_moving_average(run_data['accuracy'], alpha=0.3) 
            sns.lineplot(x=run_data['epoch'], y=smoothed_accuracy, color='lightgray', linewidth=2.0, ax=ax)
           
    
    subset = accuracy_grouped[accuracy_grouped['moe_type'] == model].sort_values('epoch')
    
    if len(subset) > 0:
        # Smooth aggregate data
        smoothed_mean = exponential_moving_average(subset['accuracy']['mean'], alpha=0.3)
        smoothed_min = exponential_moving_average(subset['accuracy']['min'], alpha=0.3)
        smoothed_max = exponential_moving_average(subset['accuracy']['max'], alpha=0.3)
        loss_mean = exponential_moving_average(subset['loss']['mean'], alpha=0.3)
        
        # Accuracy Mean line
        sns.lineplot(x=subset['epoch'], y=smoothed_mean, label='Accuracy Mean', linewidth=3, ax=ax)

        # Accuracy Min/max lines
        sns.lineplot(x=subset['epoch'], y=smoothed_min, color='gray', linestyle='--', alpha=0.8, linewidth=2.0, label='Accuracy Min', ax=ax)
        sns.lineplot(x=subset['epoch'], y=smoothed_max, color='gray', linestyle=':', alpha=0.8, linewidth=2.0, label='Accuracy Max', ax=ax)

        # Loss Mean line
        sns.lineplot(x=subset['epoch'], y=loss_mean, label='Loss Mean', linewidth=3, ax=ax)

    friendly_name = models_friendly_names.get(model, model)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_ylim(0.0, 2.5)
    
    ax.legend(fontsize=12, loc='lower right', fancybox=True)
    sns.despine(ax=ax)
    
    plt.tight_layout()
    safe_filename = model.replace(' ', '_').lower()
    plt.savefig(f'../../report/charts/accuracy_{safe_filename}_{dataset}.png', dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    plt.close()

## Hardware Usage Charts

The next two cells plot and save the hardware usage charts.

### Memory Usage

In [None]:
"""
Plot Memory Usage
References:
- https://seaborn.pydata.org/generated/seaborn.lineplot.html
- https://seaborn.pydata.org/examples/timeseries_facets.html
- https://stackoverflow.com/questions/62667158/how-do-i-increase-the-line-thickness-for-sns-lineplot
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

memory_data = []
for idx, model in enumerate(models):
    individual_runs = accuracy_df[accuracy_df['moe_type'] == model]
    for _, row in individual_runs.iterrows():
        memory_data.append({
            'epoch': row['epoch'],
            'memory_usage': row['profiling_memory_usage'],
            'moe_type': model,
            'friendly_name': models_friendly_names.get(model, model),
            'hash': row['hash']
        })

memory_df = pd.DataFrame(memory_data)

# Calculate mean for each model/epoch combination
mean_data = memory_df.groupby(['friendly_name', 'epoch'])['memory_usage'].mean().reset_index()
mean_data['line_type'] = 'Mean'

g = sns.FacetGrid(memory_df, col='friendly_name', col_wrap=3, height=4, aspect=1.2, sharey=False)

# Plot individual runs
g.map_dataframe(sns.lineplot, x='epoch', y='memory_usage', units='hash', estimator=None, alpha=0.6, linewidth=1.5, marker='o', markersize=3)

# Overlay mean lines to make it stand out
def add_mean_line(data, **kwargs):
    color = MOE_ARCH_COLOR_DICT.get(data.moe_type.iloc[0], 'gray')
    mean_by_epoch = data.groupby('epoch')['memory_usage'].mean()
    sns.lineplot(mean_by_epoch, label='Mean', linewidth=3, color=color)

g.map_dataframe(add_mean_line)
  
# Set individual run colors to gray
for ax in g.axes.flatten():
    for line in ax.lines[:-1]:  # Set line colors for all lines except the mean line
        line.set_color('gray')

g.set_axis_labels('Epoch', 'Memory (MB)')
g.set_titles('{col_name}')

axes = g.axes.flatten()
axes[0].legend(['Individual Runs', 'Mean'], loc='center right')

for ax in axes:
    ax.grid(True, alpha=0.3)

sns.despine(left=True, bottom=True)
plt.tight_layout()
plt.savefig(f'../../report/charts/memory_usage_{dataset}.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

### CPU Time

In [None]:
"""
Plot Memory Usage
References:
- https://seaborn.pydata.org/generated/seaborn.lineplot.html
- https://seaborn.pydata.org/examples/timeseries_facets.html
- https://stackoverflow.com/questions/62667158/how-do-i-increase-the-line-thickness-for-sns-lineplot
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

cpu_data = []
for idx, model in enumerate(models):
    individual_runs = accuracy_df[accuracy_df['moe_type'] == model]
    for _, row in individual_runs.iterrows():
        cpu_data.append({
            'epoch': row['epoch'],
            'cpu_usage': row['profiling_total_cpu_time'],
            'moe_type': model,
            'friendly_name': models_friendly_names.get(model, model),
            'hash': row['hash']
        })

cpu_df = pd.DataFrame(cpu_data)

# Calculate mean
cpu_mean_data = cpu_df.groupby(['friendly_name', 'epoch'])['cpu_usage'].mean().reset_index()
cpu_mean_data['line_type'] = 'Mean'

g = sns.FacetGrid(cpu_df, col='friendly_name', col_wrap=3, height=4, aspect=1.2, sharey=False)

# Plot individual runs
g.map_dataframe(sns.lineplot, x='epoch', y='cpu_usage', units='hash', estimator=None, alpha=0.6, linewidth=1.5, marker='o', markersize=3)

# Overlay mean lines to make it stand out
def add_cpu_mean_line(data, **kwargs):
    color = MOE_ARCH_COLOR_DICT.get(data.moe_type.iloc[0], 'gray')
    mean_by_epoch = data.groupby('epoch')['cpu_usage'].mean()
    sns.lineplot(mean_by_epoch, label='Mean', linewidth=3, color=color)

g.map_dataframe(add_cpu_mean_line)

# Set individual run colors to gray
for ax in g.axes.flatten():
    for line in ax.lines[:-1]:  # Set line colors for all lines except the mean line
        line.set_color('gray')

g.set_axis_labels('Epoch', 'CPU Time (sec.)')
g.set_titles('{col_name}')

axes = g.axes.flatten()
axes[0].legend(['Individual Runs', 'Mean'], loc='center right')

for ax in axes:
    ax.grid(True, alpha=0.3)

sns.despine(left=True, bottom=True)
plt.tight_layout()
plt.savefig(f'../../report/charts/cpu_usage_{dataset}.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

## Expert Activation Heatmaps

The following cells plot and save heatmaps for each models expert activations. For this, a separate query with the aim repo is run. Since the query can take a long time to execute the results are saved in the working directory as a csv file for subsequent use and experimentation.

The expert segmentation architecture's data is plotted separateley because it has double the amount of experts, hence a single plot for all architectures would be harder to read.

In [None]:
"""
Query for Expert Activations
Note: THIS IS A VERY SLOW OPERATION - Hence why the data is written to disk for subsequent usage
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

from aim.storage.context import Context

all_rows = []
ac_df = pd.DataFrame()
for t in repo.iter_runs():
    m = t.metrics()
    for metric in m:
        if metric.name.startswith('expert_'):
            ac_row = {}
            ac_row['hash'] = t.hash 
            ac_row['tags'] = t.props.tags
            ac_row['name'] = metric.name
            ac_row['values'] = metric.values.values_list()
            ac_row['context'] = metric.context.to_dict()
            ac_df = pd.concat([ac_df, pd.DataFrame([ac_row])], axis=0)
    

ac_df.to_csv(f'expert_activations_{dataset}.csv', index=False)


In [None]:
"""
Plot Expert Activations Heatmap
References:
- https://www.sciencedirect.com/topics/engineering/shannon-entropy
- https://stackoverflow.com/questions/49973537/shannons-entropy-on-an-array-containing-zeros
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

import ast
from scipy.stats import entropy

# Calculate Shannon entropy for each model (column)
def shannon_entropy_normalized(activations):
    # Remove zero values as they do not contribute to entropy
    
    if isinstance(activations, list):
        non_zero_values = np.array(activations)
        non_zero_values = non_zero_values[non_zero_values > 0]
        if len(non_zero_values) <= 1:
            return 0.0
    elif isinstance(activations, np.ndarray):
        non_zero_values = activations[activations > 0]
        if len(non_zero_values) <= 1:
            return 0.0
    else:
        return 0.0
    
    probabilities = non_zero_values / non_zero_values.sum()
    shannon_ent = entropy(probabilities, base=2)
    
    # Normalize by max possible entropy (log2(n) where n is number of non-zero experts)
    # See report for formula and references
    max_entropy = np.log2(len(non_zero_values))
    normalized_entropy = shannon_ent / max_entropy if max_entropy > 0 else 0
    
    return normalized_entropy

df = pd.read_csv(f'expert_activations_{dataset}.csv', index_col=0)
df_eval = df[df['context'].apply(lambda x: eval(x)['subset']) == 'val']

# Add model name from tags
df_eval['model'] = df_eval['tags'].apply(lambda x: next((tag for tag in eval(x) if tag in models_friendly_names.keys()), 'None')).copy()
df_eval = df_eval.dropna(subset=['model'])
df_eval['values'] = df_eval['values'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) 
df_eval = df_eval.explode('values')
df_eval['values'] = df_eval['values'].astype(int)

df_grouped_before_aggregate = df_eval.groupby(['model', 'name', 'hash'])['values'].max().reset_index()
df_grouped = df_grouped_before_aggregate.groupby(['model', 'name'])['values'].mean().reset_index()

# Pivot
df_pivot = df_grouped.pivot(index='name', columns='model', values='values')
df_pivot = df_pivot.fillna(0)
df_pivot = df_pivot.sort_index(axis=0, ascending=False)

for i in range(2): # Create a separate expert_segmentation heatmap. expert_segmentation experiments have 16, instead of 8 experts, which makes a single heatmap containing all experts harder to read.
    if i == 0:
        df_pivot_filtered = df_pivot.drop(columns=['expert_segmentation', 'ffn'], errors='ignore')
    else:
        df_pivot_filtered = df_pivot.loc[:, ['expert_segmentation']]
        
    df_pivot_filtered = df_pivot_filtered[df_pivot_filtered.sum(axis=1) > 0]
        
    # Calculate totals and entropy for each model
    totals_row = df_pivot_filtered.sum(axis=0)
    entropy_row = df_pivot_filtered.apply(shannon_entropy_normalized, axis=0)
    
    # Calculate relative weights
    df_proportions = df_pivot_filtered.div(df_pivot_filtered.sum(axis=0), axis=1)
    weight_data = []
    for expert_idx in df_pivot_filtered.index:
        for model in df_pivot_filtered.columns:
            relative_weight = df_proportions.loc[expert_idx, model]
            activation_count = df_pivot_filtered.loc[expert_idx, model]
            
            weight_data.append({
                'name': expert_idx,
                'model': model,
                'relative_weight': relative_weight,  
                'activation_count': activation_count,
                'model_entropy': entropy_row[model]
            })
    
    df_weight_long = pd.DataFrame(weight_data) # Relplot accepts only long format data
    
    sns.set_theme(style="whitegrid")
    cmap = sns.color_palette("viridis", as_cmap=True)
    
    g = sns.relplot(data=df_weight_long, 
                    x="model", 
                    y="name",
                    hue="relative_weight", 
                    size="relative_weight",  
                    palette=cmap, 
                    sizes=(100, 1000),  
                    height=8,
                    kind="scatter",
                    aspect=1.2)

    entropy_info_str = " | ".join([
        f"{col}: Entropy={ent:.3f}" 
        for col, ent in zip(entropy_row.index, entropy_row.values)
    ])

    # Remove the SNS legend and instead use the Matplotlib legend because labelspacing does not seem to work in SNS
    g.despine(left=True, bottom=True)
    g.set_axis_labels("Model", "Expert")
    g.legend.remove()
    plt.legend(labelspacing=1.8, loc='center right', fancybox=False, bbox_to_anchor=(1.35, 0.5), title="Relative Weight", title_fontsize='12', fontsize='11', frameon=False)
    plt.tight_layout()
    
    if i == 0:
        plt.savefig(f'../../report/charts/expert_activations_heatmap_{dataset}.png', dpi=300, bbox_inches='tight')
    else:
        plt.savefig(f'../../report/charts/expert_segmentation_heatmap_{dataset}.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

## Shannon Equiprobability

### Aggregate Shannon Equiprobability

This cell calculates the Shannon Equiprobabilty for each model after the final training run.

In [None]:
"""
Print Shannon Equiprobability Table
"""
try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 


print(f"Shannon Entropy {dataset}")
print("=" * 50)
for model in df_pivot.columns:
    ent = shannon_entropy_normalized(df_pivot[model])
    total = df_pivot[model].sum()
    active_experts = (df_pivot[model] > 0).sum()
    print(f"{model:20s}: {ent:.3f} (Total: {total:.0f}, Total activated experts: {active_experts})")

### Accuracy vs Final Shannon Equiprobability

This cell calculates and plots the equiprobability and relates it to the models accuracy at various stages of training (see code comments for how to configure this). 

In [None]:
"""
Plot Equiprobability and Accuracy 
Plotted for each run and epoch for models with weighted activations
References: 
- https://realpython.com/numpy-scipy-pandas-correlation-python/
"""

try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

from scipy.stats import pearsonr, spearmanr

table = accuracy_df[['hash', 'epoch', 'accuracy', 'expert_activations']].copy()
table['shannon'] = table.apply(lambda row: shannon_entropy_normalized(row['expert_activations']), axis=1)
table.drop(columns=['expert_activations'], inplace=True)
table = table.groupby(['hash', 'epoch']).agg({'accuracy': 'first', 'shannon': 'first'}).reset_index()

table = table[(table['shannon'] != 0) & (table['shannon'] != 1)]
table_shannon = table[table['epoch'] == 20] # Use Shannon from this epoch
table_accuracy = table[table['epoch'] == 20] # Use accuracy from this epoch
table = pd.concat([table_shannon, table_accuracy], axis=0)  

def correlation_text(corr):
    abs_corr = abs(float(corr))
    if abs_corr < 0.1:
        return "negligible"
    elif abs_corr < 0.3:
        return "weak"
    elif abs_corr < 0.5:
        return "moderate"
    elif abs_corr < 0.7:
        return "strong"
    else:
        return "very strong"

print(f"Shannon entropy range: {table['shannon'].min():.4f} to {table['shannon'].max():.4f}")
print(f"Accuracy range: {table['accuracy'].min():.4f} to {table['accuracy'].max():.4f}")

shannon_vals = table['shannon'].astype(float).values
accuracy_vals = table['accuracy'].astype(float).values

pearson_corr, pearson_p = pearsonr(shannon_vals, accuracy_vals)
spearman_corr, spearman_p = spearmanr(shannon_vals, accuracy_vals)

print("\nCorrelation:")
print("=" * 50)
print(f"Pearson correlation: {pearson_corr:.4f} (p-value: {pearson_p:.4f})")
print(f"Spearman correlation: {spearman_corr:.4f} (p-value: {spearman_p:.4f})")

print(f"Pearson correlation strength: {correlation_text(pearson_corr)}")
print(f"Spearman correlation strength: {correlation_text(spearman_corr)}")

plt.figure(figsize=(10, 6))
plt.scatter(shannon_vals, accuracy_vals, alpha=0.6, s=30)
plt.xlabel('Shannon Entropy')
plt.ylabel('Accuracy')
plt.title(f'Pearson r = {pearson_corr:.3f}, Spearman r = {spearman_corr:.3f}')
plt.grid(True, alpha=0.3)

if abs(pearson_corr) > 0.1:
    z = np.polyfit(shannon_vals, accuracy_vals, 1)
    p = np.poly1d(z)
    plt.plot(shannon_vals, p(shannon_vals), "r--", alpha=0.8, linewidth=2, label=f'Trend line (r={pearson_corr:.3f})')
    plt.legend()

plt.tight_layout()
save_path = f'../../report/charts/accuracy_vs_shannon_entropy_{dataset}_2020.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)
plt.show()


## Stratified Bootstrap CI

This cell calculates the stratified bootstrap intervals for every models accuracies.

Note: The code in the next three cells is taken from Agarwal et al. (2022) with minimal adjustments._

In [None]:
"""
Prepare Score Distributions
References: 
- https://seaborn.pydata.org/generated/seaborn.boxplot.html
- https://seaborn.pydata.org/generated/seaborn.violinplot.html
- https://stat20.berkeley.edu/fall-2024/3-generalization/09-bootstrapping/notes.html
- https://arxiv.org/pdf/2108.13264#page=5.73
- https://colab.research.google.com/drive/1a0pSD-1tWhMmeJeeoyZM1A-HCW3yf1xR?usp=sharing#scrollTo=LnNa7O9xCuKA 
"""

from networkx import is_empty


try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

accuracy_val_end = accuracy_df[accuracy_df['epoch'].isin(range(21))]
scores = accuracy_val_end[accuracy_val_end['context'] == 'val'] \
    .sort_values(by=['hash', 'epoch']) \
    .drop(columns=['context', 'step', 'epoch', 'loss', 'profiling_total_cpu_time', 'profiling_memory_usage']) \
    .dropna(subset=['accuracy'])

score_data_dict = {
    key: np.array([]) for key in accuracy_val_end['moe_type'].unique()
}

for key in score_data_dict.keys():
    nda = []
    for hash in scores[scores['moe_type'] == key]['hash'].unique():
        mid = scores[(scores['moe_type'] == key) & (scores['hash'] == hash)].drop(columns=['moe_type', 'hash'])
        if len(mid['accuracy'].values) == 20: # Only include runs that completed all 20 epochs
            nda.append(mid['accuracy'].values) # runs x tasks
    score_data_dict[key] = np.array(nda).astype(float)

score_data_dict



In [None]:
"""
Plot Interval Estimates
References: 
- https://seaborn.pydata.org/generated/seaborn.boxplot.html
- https://seaborn.pydata.org/generated/seaborn.violinplot.html
- https://stat20.berkeley.edu/fall-2024/3-generalization/09-bootstrapping/notes.html
- https://arxiv.org/pdf/2108.13264#page=5.73
- https://colab.research.google.com/drive/1a0pSD-1tWhMmeJeeoyZM1A-HCW3yf1xR?usp=sharing#scrollTo=LnNa7O9xCuKA 
"""

try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

from rliable import library as rly
from rliable import metrics
from rliable import plot_utils
from matplotlib.ticker import MaxNLocator

# Matplotlib params
from matplotlib import rcParams
from matplotlib import rc

rcParams['legend.loc'] = 'best'
rcParams['pdf.fonttype'] = 42
rcParams['ps.fonttype'] = 42

rc('text', usetex=False)

def set_axes(ax, xlim, ylim, xlabel, ylabel):
  ax.set_xlim(xlim)
  ax.set_ylim(ylim)
  ax.set_xlabel(xlabel, labelpad=14)
  ax.set_ylabel(ylabel, labelpad=14)
 
def set_ticks(ax, xticks, xticklabels, yticks, yticklabels):
  ax.set_xticks(xticks)
  ax.set_xticklabels(xticklabels)
  ax.set_yticks(yticks)
  ax.set_yticklabels(yticklabels)

def decorate_axis(ax, wrect=10, hrect=10, labelsize='large'):
  # Hide the right and top spines
  ax.spines['right'].set_visible(False)
  ax.spines['top'].set_visible(False)
  ax.spines['left'].set_linewidth(2)
  ax.spines['bottom'].set_linewidth(2)
  ax.tick_params(length=0.1, width=0.1, labelsize=labelsize)
  ax.spines['left'].set_position(('outward', hrect))
  ax.spines['bottom'].set_position(('outward', wrect))

def score_normalization(res_dict, min_scores, max_scores):
  norm_scores = {}
  for game, scores in res_dict.items():
    norm_scores[game] = (scores - min_scores[game])/(max_scores[game] - min_scores[game])
  return norm_scores

def convert_to_matrix(score_dict):
   keys = sorted(list(score_dict.keys()))
   return np.stack([score_dict[k] for k in keys], axis=1)

def plot_score_hist(score_matrix, bins=20, figsize=(28, 14), fontsize='xx-large', N=6, extra_row=1, names=None):
  num_tasks = score_matrix.shape[1]
  N1 = (num_tasks // N) + extra_row
  fig, ax = plt.subplots(nrows=N1, ncols=N, figsize=figsize)
  for i in range(N):
    for j in range(N1):
      idx = j * N + i
      if idx < num_tasks:
        ax[j, i].set_title(names[idx], fontsize=fontsize)
        sns.histplot(score_matrix[:, idx], bins=bins, ax=ax[j,i], kde=True)
      else:
        ax[j, i].axis('off')
      decorate_axis(ax[j, i], wrect=5, hrect=5, labelsize='xx-large')
      ax[j, i].xaxis.set_major_locator(plt.MaxNLocator(4))
      if idx % N == 0:
        ax[j, i].set_ylabel('Count', size=fontsize)
      else:
        ax[j, i].yaxis.label.set_visible(False)
      ax[j, i].grid(axis='y', alpha=0.1)
  return fig

StratifiedBootstrap = rly.StratifiedBootstrap

IQM = lambda x: metrics.aggregate_iqm(x) # Interquartile Mean
OG = lambda x: metrics.aggregate_optimality_gap(x, 0.99) # Optimality Gap
MEAN = lambda x: metrics.aggregate_mean(x)
MEDIAN = lambda x: metrics.aggregate_median(x)

moe_arch_score_dict = {key: val[:last_runs_to_consider] for key, val in score_data_dict.items()}

def subsample_scores(score_dict, n=5, replace=False):
  subsampled_dict = {}
  total_samples = len(score_dict[list(score_dict.keys())[0]])
  for game, scores in score_dict.items():
    indices = np.random.choice(range(total_samples), size=n, replace=replace)
    subsampled_dict[game] = scores[indices]
  return subsampled_dict

def subsample_scores_mat(score_mat, num_samples=5, replace=False):
  total_samples, num_games = score_mat.shape
  subsampled_scores = np.empty((num_samples, num_games))
  for i in range(num_games):
    indices = np.random.choice(total_samples, size=num_samples, replace=replace)
    subsampled_scores[:, i] = score_mat[indices, i]
  return subsampled_scores

def subsample_seeds(score_mat, num_samples=5, replace=False):
  indices = np.random.choice(score_mat.shape[0], size=num_samples, replace=replace)
  return score_mat[indices]

def batch_subsample_seeds(score_mat, num_samples=5, batch_size=100, replace=False):
  indices = [
    np.random.choice(score_mat.shape[0], size=num_samples, replace=replace)
    for _ in range(batch_size)
  ]
  return (score_mat[idx] for idx in indices)

def subsample_scores_mat_with_replacement(score_mat, num_samples=5):
  total_samples, num_games = score_mat.shape
  indices = np.random.choice(total_samples, size=(num_samples, num_games), replace=True)
  col_indices =  np.expand_dims(np.arange(num_games), axis=0)
  col_indices = np.repeat(col_indices, num_samples, axis=0)
  subsampled_scores = score_mat[indices, col_indices]
  return subsampled_scores

SIZES = [3, 5, 10, 25, 50, 100]

def calc_aggregate_fn(score_data, num_samples=5, total_n=20000, aggregate_fn=MEDIAN, replace=False):
  subsampled_scores = batch_subsample_seeds(score_data, num_samples, batch_size=total_n, replace=replace)
  aggregates = [aggregate_fn(scores) for scores in subsampled_scores]
  return np.array(aggregates)

def calculate_aggregate_varying_sizes(score_matrix, aggregate_fn, total_n=20000, sizes=None, replace=False):
  agg_dict = {}
  if sizes is None:
    sizes = SIZES
  for size in sizes:
    agg_dict[n] = calc_aggregate_fn(score_matrix, num_samples=size, aggregate_fn=aggregate_fn, total_n=total_n, replace=replace)
    print('Mean Aggregate: {}'.format(np.mean(agg_dict[n])))
  return agg_dict

def CI(bootstrap_dist, stat_val=None, alpha=0.05, is_pivotal=False):
    """
    Get the bootstrap confidence interval for a given distribution.
    Args:
      bootstrap_distribution: numpy array of bootstrap results.
      stat_val: The overall statistic that this method is attempting to
        calculate error bars for. Default is None.
      alpha: The alpha value for the confidence intervals.
      is_pivotal: if true, use the pivotal (reverse percentile) method. 
        If false, use the percentile method.
    Returns:
      (low, high): The lower and upper limit for `alpha` x 100% CIs.
      val: The median value of the bootstrap distribution if `stat_val` is None
        else `stat_val`.
    """
    # Adapted from https://pypi.org/project/bootstrapped
    if is_pivotal:
      assert stat_val is not None, 'Please pass the statistic for a pivotal'
      'confidence interval' 
      low = 2 * stat_val - np.percentile(bootstrap_dist, 100 * (1 - alpha / 2.))
      val = stat_val
      high = 2 * stat_val - np.percentile(bootstrap_dist, 100 * (alpha / 2.))
    else:
      low = np.percentile(bootstrap_dist, 100 * (alpha / 2.))
      val = np.percentile(bootstrap_dist, 50)
      high = np.percentile(bootstrap_dist, 100 * (1 - alpha / 2.))
    return (low, high), val

aggregate_func = lambda x: np.array([MEDIAN(x), IQM(x), MEAN(x), OG(x)])
aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates(moe_arch_score_dict, aggregate_func, task_bootstrap=False, reps=bootstrap_reps, random_state=RAND_STATE)

algorithms = list(moe_arch_score_dict.keys())
fig, axes = plot_utils.plot_interval_estimates(
    aggregate_scores, 
    aggregate_interval_estimates,
    metric_names = ['Median', 'IQM', 'Mean', 'Optimality Gap'],
    algorithms=algorithms,
    colors=MOE_ARCH_COLOR_DICT,
    xlabel_y_coordinate=-0.16,
    xlabel='Normalized Score')

plt.savefig(f'../../report/charts/bootstrap_interval_estimates_{dataset}.png', dpi=300, bbox_inches='tight')

## Probability of Improvement

Note: The following code is taken from Agarwal et al. (2022) with minimal adjustments

References:
- https://stat20.berkeley.edu/fall-2024/3-generalization/09-bootstrapping/notes.html
- https://arxiv.org/pdf/2108.13264#page=5.73
- https://colab.research.google.com/drive/1a0pSD-1tWhMmeJeeoyZM1A-HCW3yf1xR?usp=sharing#scrollTo=LnNa7O9xCuKA 

In [None]:
"""
Plot Probability of Improvement
References: 
- https://seaborn.pydata.org/generated/seaborn.boxplot.html
- https://seaborn.pydata.org/generated/seaborn.violinplot.html
- https://stat20.berkeley.edu/fall-2024/3-generalization/09-bootstrapping/notes.html
- https://arxiv.org/pdf/2108.13264#page=5.73
- https://colab.research.google.com/drive/1a0pSD-1tWhMmeJeeoyZM1A-HCW3yf1xR?usp=sharing#scrollTo=LnNa7O9xCuKA 
"""

try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])")

for algorithm in algorithms:
    all_pairs =  {}
    for alg in (algorithms):
        if alg == algorithm:
            continue
        pair_name = f'{algorithm}:{alg}'
        all_pairs[pair_name] = (moe_arch_score_dict[algorithm], moe_arch_score_dict[alg])
    
    probabilities, probability_cis = {}, {}
    probabilities, probability_cis = rly.get_interval_estimates(all_pairs, metrics.probability_of_improvement, reps=bootstrap_reps, random_state=RAND_STATE)

    fig, ax = plt.subplots(figsize=(4, 3))
    h = 0.6
    algorithm_labels = []

    for i, (alg_pair, prob) in enumerate(probabilities.items()):
        _, alg1 = alg_pair.split(':')
        algorithm_labels.append(alg1)
        (l, u) = probability_cis[alg_pair]
        ax.barh(y=i, width=u-l, height=h, left=l, color=MOE_ARCH_COLOR_DICT[alg1], alpha=0.75)
        ax.vlines(x=prob, ymin=i-7.5 * h/16, ymax=i+(6*h/16), color='k', alpha=0.85)
        ax.set_yticks(range(len(algorithm_labels)))
        ax.set_yticklabels(algorithm_labels)

    ax.set_title(fr'P({algorithm} > $Y$)', size='xx-large')
    plot_utils._annotate_and_decorate_axis(ax, labelsize='xx-large', ticklabelsize='xx-large')
    ax.set_ylabel(r'Algorithm $Y$', size='xx-large')
    ax.xaxis.set_major_locator(MaxNLocator(4))
    fig.subplots_adjust(wspace=0.25, hspace=0.45)
    
    fig.savefig(f'../../report/charts/probability_of_improvement_{algorithm}_{dataset}.png', dpi=300, bbox_inches='tight')

## Hardware Usage

This cell calculates the hardware usage and relates it to each models accuracy.

MACs, Memory, CPU and Accuracy each are plotted with their own vertical bar. For each experiment run a line is drawn connecting the respective points on the bars.

In [None]:
"""
Plot Hardware Usage
Reference: The visual presentation of the data in this section is heavily inspired by the aimStack UI's metrics explorer, though no code was copied.
"""

try:
    if COMMON_CELLS:
        pass
except NameError:
    from IPython.display import Javascript
    Javascript("Jupyter.notebook.execute_cells([0,1])") 

from sklearn.preprocessing import MinMaxScaler

parallel_data = []
for _, row in accuracy_df.iterrows():
    model = row['moe_type']
    friendly_name = models_friendly_names.get(model, model)
    
    parallel_data.append({
        'model': friendly_name,
        'moe_type': model,
        'hash': row['hash'],
        'memory_usage': row['profiling_memory_usage'],
        'cpu_time': row['profiling_total_cpu_time'],
        'epoch': row['epoch'],
        'accuracy': row['accuracy'],
        'macs': row['macs']
    })

parallel_df = pd.DataFrame(parallel_data)
parallel_df = parallel_df.fillna(method='ffill') # Upfill missing macs values as they are only tracked for the first epoch

# Filter to final epoch only
parallel_df_final = parallel_df[parallel_df['epoch'] == parallel_df['epoch'].max()].copy()

# Set up x positions for metrics
x_positions = [0, 1, 2, 3]
x_labels = ['MACs', 'Memory Usage\n(MB)', 'CPU Time\n(seconds)', 'Accuracy']
models = [model for model in parallel_df['moe_type'].unique() if model]
n_models = len(models)

# Grid dimensions
cols = 2
rows = (n_models + cols - 1) // cols  # Ceiling division

fig, axes = plt.subplots(rows, cols, figsize=(15, 4 * rows))
if rows == 1:
    axes = axes.reshape(1, -1)
axes = axes.flatten()

# Plot each model in its own subplot
for idx, model in enumerate(models):
    ax = axes[idx]

    # Normalize to make visualizations nicer
    scaler = MinMaxScaler()
    metrics_to_plot = ['macs', 'memory_usage', 'cpu_time', 'accuracy'] 
    model_data = parallel_df_final[parallel_df_final['moe_type'] == model].copy()
    model_data[metrics_to_plot] = scaler.fit_transform(model_data[metrics_to_plot])

    color = MOE_ARCH_COLOR_DICT.get(model, 'gray')
    friendly_name = models_friendly_names.get(model, model)
    
    # Plot individual runs for this model
    for _, row in model_data.iterrows():
        y_values = [row['macs'], row['memory_usage'], row['cpu_time'], row['accuracy']]
        ax.plot(x_positions, y_values, color='lightgray')
    
    # Calculate and plot mean line
    mean_values = [
        model_data['macs'].mean(),
        model_data['memory_usage'].mean(), 
        model_data['cpu_time'].mean(),
        model_data['accuracy'].mean()
    ]

    ax.plot(x_positions, mean_values, 'o--', color=color, linewidth=3, markersize=8, alpha=1.0)

    # Customize subplot
    ax.set_xticks(x_positions)
    ax.set_xticklabels(x_labels, fontsize=10)
    ax.set_title(friendly_name, fontsize=12, y=1.05)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0, 1)
    
    # Add vertical reference lines at each axis
    for x_pos in x_positions:
        ax.axvline(x=x_pos, color='lightgray', alpha=0.5, linestyle='-', linewidth=0.8)
    
    macs_min = parallel_df_final['macs'].min()
    macs_max = parallel_df_final['macs'].max()
    ax.text(-0.15, 0, f'{macs_min:.0f}', transform=ax.transData, ha='right', va='center', fontsize=8, color='gray')
    ax.text(-0.15, 1, f'{macs_max:.0f}', transform=ax.transData, ha='right', va='center', fontsize=8, color='gray')
    
    memory_min = parallel_df_final['memory_usage'].min()
    memory_max = parallel_df_final['memory_usage'].max()
    ax.text(1, 0, f'{memory_min:.0f}', transform=ax.transData, ha='center', va='top', fontsize=8, color='gray', rotation=0)
    ax.text(1, 1, f'{memory_max:.0f}', transform=ax.transData, ha='center', va='bottom', fontsize=8, color='gray', rotation=0)

    cpu_min = parallel_df_final['cpu_time'].min()
    cpu_max = parallel_df_final['cpu_time'].max()
    ax.text(2.15, 0, f'{cpu_min:.1f}', transform=ax.transData, ha='left', va='center', fontsize=8, color='gray')
    ax.text(2.15, 1, f'{cpu_max:.1f}', transform=ax.transData, ha='left', va='center', fontsize=8, color='gray')

    accuracy_min = parallel_df_final['accuracy'].min()
    accuracy_max = parallel_df_final['accuracy'].max()
    ax.text(3.15, 0, f'{accuracy_min:.1f}', transform=ax.transData, ha='left', va='center', fontsize=8, color='gray')
    ax.text(3.15, 1, f'{accuracy_max:.1f}', transform=ax.transData, ha='left', va='center', fontsize=8, color='gray')

    # Add intermediate scale markers
    for i, (min_val, max_val, pos, suffix) in enumerate([
        (macs_min, macs_max, 0, ''),
        (memory_min, memory_max, 1, ''),
        (cpu_min, cpu_max, 2, ''),
        (accuracy_min, accuracy_max, 3, '')
    ]):
        # Add 25%, 50%, 75% markers
        for y_pos in [0.25, 0.5, 0.75]:
            val = min_val + (max_val - min_val) * y_pos
            
            # Small tick marks
            ax.plot([pos - 0.02, pos + 0.02], [y_pos, y_pos], color='gray', alpha=0.6, linewidth=0.5)
            ax.set_yticklabels([])

            # We'll show a value label only in the vertical middle
            if y_pos == 0.5:
                if pos == 0:  # MACs - left side
                    ax.text(pos - 0.08, y_pos, f'{val:.0f}', ha='right', va='center', fontsize=7, color='gray')
                elif pos == 1:  # Memory - below
                    ax.text(pos, y_pos - 0.05, f'{val:.0f}{suffix}', ha='center', va='top', fontsize=7, color='gray')
                elif pos == 2:  # CPU - below
                    ax.text(pos + 0.08, y_pos, f'{val:.1f}{suffix}', ha='center', va='center', fontsize=7, color='gray')
                elif pos == 3:  # Accuracy - right side
                    ax.text(pos + 0.08, y_pos, f'{val:.1f}{suffix}', ha='left', va='center', fontsize=7, color='gray')

# Hide unused subplots
for idx in range(n_models, len(axes)): 
    axes[idx].set_visible(False)

legend_elements = [
    Line2D([0], [0], color='gray', linewidth=1.5, alpha=0.6, label='Individual Runs'),
    Line2D([0], [0], color='black', linewidth=3, linestyle='--', label='Mean')
]
fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.95, 1.0), fontsize=11)

sns.despine(left=True, bottom=True, fig=fig)
plt.tight_layout()
plt.subplots_adjust(right=0.95)  

plt.savefig(f'../../report/charts/hardware_usage_comparison_{dataset}.png', dpi=300, bbox_inches='tight', facecolor='white')
