In [None]:
from config import *
import torch
from pathlib import Path
import pandas as pd
import wandb
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from src.style import c_muted

In [None]:
df = load_result_table(f'{wandb_username}/neurips2025-one_hidden_layer-ablations', samples=4000)

In [None]:
df_orig = df.copy()

In [None]:
figs, axes = plt.subplots(1, 2, figsize=(10, 3.5), sharex=True, sharey=True)


ax = axes[0]

constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}
k = filter(df, constraints)
k = k.replace('NaN', None)
k = k[k.student_soft_label_filtering.isna()]
k = k[k.student_soft_label_treatment.isna()]
g_mean = k.groupby('data_num_samples_per_class').mean(numeric_only=True).reset_index()
g_sem = k.groupby('data_num_samples_per_class').sem(numeric_only=True).reset_index()

x = g_mean.data_num_samples_per_class / g_mean.data_input_dim
y = g_mean.test_acc
yerr = g_sem.test_acc

ax.errorbar(x, y, yerr=yerr, label='normal', capsize=3,color='black', linestyle='dotted')


for r in [['cut_tail', 0.01], ['cut_tail', 0.1], ['cut_tail', 0.5]]:
    k = filter(df, constraints)
    k = k.replace('NaN', None)
    k = k[k.student_soft_label_treatment.apply(str) == str(r)]
    k = k[k.student_soft_label_filtering.isna()]
    g_mean = k.groupby('data_num_samples_per_class').mean(numeric_only=True).reset_index()
    g_sem = k.groupby('data_num_samples_per_class').sem(numeric_only=True).reset_index()

    x = g_mean.data_num_samples_per_class / g_mean.data_input_dim
    y = g_mean.test_acc
    yerr = g_sem.test_acc

    ax.errorbar(x, y, yerr=yerr, label=str(int(r[1]*100)), capsize=3)
ax.set_xlabel(r'$\alpha = N/d$')
ax.set_ylabel(metric_styles['student_test_acc']['label'])
ax.legend(title='zero-out smallest\n$k$ soft label values')


ax = axes[1]

ax = plt.gca()
shared_constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

# normal
constraints = shared_constraints.copy()
k = filter(df, constraints)
k = k.replace('NaN', None)
k = k[k.student_soft_label_filtering.isna()]
k = k[k.student_soft_label_treatment.isna()]
g_mean = k.groupby('data_num_samples_per_class').mean(numeric_only=True).reset_index()
g_sem = k.groupby('data_num_samples_per_class').sem(numeric_only=True).reset_index()

x = g_mean.data_num_samples_per_class / g_mean.data_input_dim
y = g_mean.test_acc_zeros
yerr = g_sem.test_acc_zeros
ax.errorbar(x, y, yerr=yerr, label='normal', capsize=3, color='black', linestyle='dotted')

# remove_labels
constraints = shared_constraints.copy()
k = filter(df, constraints)
k = k.replace('NaN', None)
k = k[k.student_soft_label_filtering.apply(str) == "['remove_labels', [0]]"]
k = k[k.student_soft_label_treatment.isna()]
g_mean = k.groupby('data_num_samples_per_class').mean(numeric_only=True).reset_index()
g_sem = k.groupby('data_num_samples_per_class').sem(numeric_only=True).reset_index()

x = g_mean.data_num_samples_per_class / g_mean.data_input_dim
y = g_mean.test_acc_zeros
yerr = g_sem.test_acc_zeros
ax.errorbar(x, y, yerr=yerr, capsize=3,linestyle='dotted',color='tab:blue')
y = g_mean.test_acc
yerr = g_sem.test_acc
ax.errorbar(x, y, yerr=yerr, capsize=3,color='tab:blue')

# remove_logits
constraints = shared_constraints.copy()
k = filter(df, constraints)
k = k.replace('NaN', None)
k = k[k.student_soft_label_treatment.apply(str) == "['remove_logits', [0]]"]
k = k[k.student_soft_label_filtering.isna()]
g_mean = k.groupby('data_num_samples_per_class').mean(numeric_only=True).reset_index()
g_sem = k.groupby('data_num_samples_per_class').sem(numeric_only=True).reset_index()

x = g_mean.data_num_samples_per_class / g_mean.data_input_dim
y = g_mean.test_acc_zeros
yerr = g_sem.test_acc_zeros
ax.errorbar(x, y, yerr=yerr, capsize=3,linestyle='dotted',color='tab:orange')
y = g_mean.test_acc
yerr = g_sem.test_acc
ax.errorbar(x, y, yerr=yerr, capsize=3,color='tab:orange')

ax.plot([],[],label='remove class $c$ samples from training',color='tab:blue')
ax.plot([],[],label='remove class $c$ soft label entry',color='tab:orange')
ax.plot([],[],label=metric_styles['student_test_acc']['label'],color='black')
ax.plot([],[],label=metric_styles['student_test_acc_zeros']['label'],color='black',linestyle='dotted')
ax.set_ylabel('accuracy')
ax.set_xlabel(r'$\alpha = N/d$')
ax.legend()
for i, name in zip(range(4),"AB"):
    if name == 'A':
        axes[i].text(-0.1, 0.95, f'({name})', transform=axes[i].transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
    else:
        axes[i].text(-0.05, 0.95, f'({name})', transform=axes[i].transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
plt.savefig(FIGURE_DIR / 'logit_ablations.pdf', bbox_inches='tight', dpi=300)

