In [1]:
from config import *
import numpy as np
import matplotlib.pyplot as plt


In [2]:
df = load_result_table(project=f'{wandb_username}/neurips2025-apdx_1_mlr_temperature_c=10')
df_orig = df.copy()

In [3]:
df['alpha'] = df.data_num_samples_per_class * df.data_num_labels / (df.data_input_dim * df.data_num_labels)
df['alpha_log'] = df.data_num_samples_per_class * np.log(df.data_num_labels) / df.data_input_dim
df['alpha_sqrt'] = df.data_num_samples_per_class * np.sqrt(df.data_num_labels) / df.data_input_dim

In [4]:
c=10

In [5]:

fig, orig_axes = plt.subplots(2, 3, figsize=(12, 8),sharex=True, sharey=True)

constraints = {
    'student_train_frac': 0.8,
    'student_shuffle_input_intra_class': False,
    'data_num_labels': c,
}

a = filter(df, constraints).copy()
a = a.fillna(0)


grouped_df = a.groupby(['alpha', 'student_temperature']).mean(numeric_only=True).reset_index()

properties = [ 'teacher_train_acc', 'student_train_acc','student_test_acc', ]
alphas = np.sort(grouped_df['alpha'].unique())
temps = np.sort(grouped_df['student_temperature'].unique())

A, T = np.meshgrid(alphas, temps)

axes = orig_axes[0]

for i, property_name in enumerate(properties):
    ax = axes[i]
    Z = grouped_df.pivot(index='student_temperature', columns='alpha', values=property_name).reindex(index=temps, columns=alphas)
    c_ = ax.pcolormesh(A, T, Z.values, cmap='Spectral', shading='auto', vmin=0, vmax=1)
    
    ax.set_title(metric_styles[property_name]['label'])
    ax.set_xlabel(r'$\alpha$')
    ax.set_yscale('log')
axes[0].set_ylabel('standard setting\n'+r'temperature $\tau$')


constraints = {
    'student_train_frac': 0.8,
    'student_shuffle_input_intra_class': True,
    'data_num_labels': c,
}
a = filter(df, constraints).copy()
a = a.fillna(0)


grouped_df = a.groupby(['alpha', 'student_temperature']).mean(numeric_only=True).reset_index()

alphas = np.sort(grouped_df['alpha'].unique())
temps = np.sort(grouped_df['student_temperature'].unique())

A, T = np.meshgrid(alphas, temps)

axes = orig_axes[1]

for i, property_name in enumerate(properties):
    ax = axes[i]

    Z = grouped_df.pivot(index='student_temperature', columns='alpha', values=property_name).reindex(index=temps, columns=alphas)
    
    c_ = ax.pcolormesh(A, T, Z.values, cmap='Spectral', shading='auto', vmin=0, vmax=1)
    
    ax.set_title(metric_styles[property_name]['label'])
    ax.set_xlabel(r'$\alpha$')
    ax.set_yscale('log')
cax = fig.add_axes([1.0, 0.15, 0.02, 0.7]) 
fig.colorbar(c_, cax=cax, label='accuracy')
axes[0].set_ylabel('within-class soft label shuffling\n'+r'temperature $\tau$')

plt.tight_layout()
plt.savefig(FIGURE_DIR / f'figure1c_temperature_grid_d=1000_c={c}.png', dpi=300, bbox_inches='tight')
plt.show()


In [6]:
constraints = {
    'student_train_frac': 0.8,
    'data_num_labels': c,
}

a = filter(df, constraints)
a.fillna(0, inplace=True)

grouped_df = a.groupby(['alpha','student_temperature','student_shuffle_input_intra_class']).mean(numeric_only=True).reset_index()


threshold_results = []

for temp in grouped_df['student_temperature'].unique():
    temp_data = grouped_df[grouped_df['student_temperature'] == temp][grouped_df.student_shuffle_input_intra_class == False]

    alpha_S_id_start = temp_data[temp_data['test_acc']>= 1.0]['alpha'].min()
    alpha_S_id_end = temp_data[temp_data['test_acc'] >= 1.0]['alpha'].max()
    
    alpha_T_label = temp_data[temp_data['teacher_train_acc'] >= 1.0]['alpha'].max()
    
    temp_data = grouped_df[grouped_df['student_temperature'] == temp][grouped_df.student_shuffle_input_intra_class == True]
    alpha_S_logit = temp_data[temp_data['train_acc'] >= 1.0]['alpha'].max()
    
    
    
    threshold_results.append({
        'student_temperature': temp,
        'alpha_S_id_start': alpha_S_id_start,
        'alpha_S_id_end': alpha_S_id_end,
        'alpha_S_logit': alpha_S_logit,
        'alpha_T_label': alpha_T_label,
    })

threshold_df = pd.DataFrame(threshold_results)


In [9]:
from pathlib import Path
import wandb
import pandas as pd
from tqdm.notebook import tqdm

FIGURE_DIR = Path("figures/neurips2025")
FIGURE_DIR.mkdir(parents=True, exist_ok=True)

wandb_username = "feeds"

alpha_t_label_color = "#00FF00"
alpha_t_label_name = r'$\alpha^T_{label}$'

alpha_s_id_color = "#0000FF"
alpha_s_id_name = r'$\alpha^S_{id}$'

alpha_s_label_color = "#FF00FF"
alpha_s_label_name = r'$\alpha^S_{label}$'

alpha_s_id_color_up = 'blue'
alpha_s_id_color_low = 'blue'
alpha_s_id_name_up = r'$\alpha^S_{\leq id}$'
alpha_s_id_name_low = r'$\alpha^S_{\geq id}$'

alpha_s_shuffle_label_color = "#FF0000"
alpha_s_shuffle_label_name = r'$\alpha^{S-shuffle}_{label}$'

metric_styles = {
    'student_test_acc': {
        'color': 'tab:orange',
        'linestyle': '-',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{\mathrm{test}}$',
    },
    'student_train_acc': {
        'color': 'tab:purple',
        'linestyle': '--',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{\mathrm{train}}$',
    },
    'student_test_acc_orig': {
        'color': 'tab:green',
        'linestyle': '-',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{\mathrm{val}}$',
    },
    'teacher_train_acc': {
        'color': 'tab:blue',
        'linestyle': '--',
        'label': r'$\mathrm{acc}^{\mathrm{T}}_{\star}$',
    },
    'teacher_val_acc': {
        'color': 'tab:green',
        'linestyle': '--',
        'label': r'$\mathrm{acc}^{\mathrm{T}}_{val}$',
    },
    'student_val_acc': {
        'color': 'tab:green',
        'linestyle': '--',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{val}$',
    },
    'student_test_acc_zeros': {
        'color': 'tab:green',
        'linestyle': 'dashed',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{c=0}$',
    },
    'student_shuffle_test_acc': {
        'color': 'tab:red',
        'linestyle': 'dashed',
        'label': r'$\mathrm{acc}^{\mathrm{S-shuffle}}_{test}$',
    },
    'match_teacher_test_acc': {
        'color': 'tab:blue',
        'linestyle': 'dashed',
        'label': r'$\mathrm{acc}^{\mathrm{S}}_{match-T}$',
    },
    'train_mse': {
        'label': r'$\mathrm{mse}(f^*(\mathcal{D}^S_{train}))$',
    },
    'test_mse': {
        'label': r'$\mathrm{mse}(f^*(\mathcal{D}^S_{test}))$',
    }
}

api = wandb.Api()
def load_result_table(project,samples=2000, load_hist=False):
    # Project is specified by <entity/project-name>
    runs = api.runs(project)

    summary_list = []
    for run in tqdm(runs): 

            res = {**run.summary._json_dict,
                **{k: v for k,v in run.config.items()
                },'name':run.name,'entity': run.entity, 'project': run.project, 'state': run._state, 'run_id': run.id}

            if load_hist:
                history = run.history(samples=samples)  # Adjust `samples` for the number of steps
                res['test_accuracy_hist'] = history["test_acc"]
                res['train_accuracy_hist'] = history["train_acc"]
                res['train_loss_hist'] = history["train_loss"]  
                res['test_loss_hist'] = history["test_loss"]
                res['test_accuracy_orig_hist'] = history["test_acc_orig"]
                res['test_accuracy_zeros_hist'] = history["test_acc_zeros"]

            summary_list.append(res)

    df = pd.DataFrame(summary_list)
    return df

def filter(df, constraints):
    for k, v in constraints.items():
        if k in df.columns:
            df = df[df[k] == v]
    return df

def print_df(a):
    for c in a.columns:
        try:
            if len(a[c].unique()) < 10:
                print(c,a[c].unique())
        except:
            pass

In [13]:

thresholds = [
    
    ('alpha_T_label', alpha_t_label_color, None, '-', 2, 6, alpha_t_label_name),
    ('alpha_S_id_start', alpha_s_id_color, None, '-', 2, 6, alpha_s_id_name),
    ('alpha_S_id_end', alpha_s_label_color, None, '--', 2, 6, alpha_s_label_name),
    ('alpha_S_logit', alpha_s_shuffle_label_color, None, '-.', 2, 6, alpha_s_shuffle_label_name),
    
]

threshold_df = threshold_df.sort_values(by='student_temperature')

fig, ax = plt.subplots(figsize=(3, 3))

for threshold, color, marker, linestyle, linewidth, markersize, label in thresholds:

    ax.plot(
        threshold_df[threshold], 
        threshold_df['student_temperature'], 
        label=label,
        color=color, 
        marker=marker,
        linestyle=linestyle,
        linewidth=linewidth,
        markersize=markersize
    )

ax.set_xlabel(r'$\alpha = n/(dc)$')
ax.set_ylabel(r'$\tau$')

ax.set_title(f'$c={c}$')
if c==10:
    #ax.set_xlim(5,40)
    ax.text(0.05, 1.12, f'(B)', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
if c==2:
    #ax.set_xlim(1,4)
    ax.text(0.05, 1.12, f'(A)', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
# Show the plot
plt.tight_layout()
plt.yscale('log')
plt.xlim(0.05,0.5)
#plt.ylim(10**-0.5,10**3.5)
plt.grid()
plt.savefig(FIGURE_DIR /f'multinomial_regression_thresholds_alpha_temp_{c=}.pdf', dpi=300)
plt.show()
