In [None]:
from config import *
import numpy as np
import matplotlib.pyplot as plt


In [None]:
df = load_result_table(project=f'{wandb_username}/fig_1c_temperature_grid_3')
df = load_result_table(project=f'{wandb_username}/distillation__main-fig_4a-fig_10')
df_orig = df.copy()

In [None]:
df['alpha'] = df.data_num_samples_per_class * df.data_num_labels / (df.data_input_dim * df.data_num_labels)
df['alpha_log'] = df.data_num_samples_per_class * np.log(df.data_num_labels) / df.data_input_dim
df['alpha_sqrt'] = df.data_num_samples_per_class * np.sqrt(df.data_num_labels) / df.data_input_dim

In [None]:
c=2

In [None]:
def filter(df, constraints):
    for k, v in constraints.items():
        if k in df.columns:
            df = df[df[k] == v]
    return df

In [None]:

fig, orig_axes = plt.subplots(2, 3, figsize=(12, 8),sharex=True, sharey=True)

constraints = {
    'student_train_frac': 0.8,
    'student_shuffle_input_intra_class': False,
    'data_num_labels': c,
}

a = filter(df, constraints).copy()
a = a.fillna(0)


grouped_df = a.groupby(['alpha', 'student_temperature']).mean(numeric_only=True).reset_index()

properties = [ 'teacher_train_acc', 'student_train_acc','student_test_acc', ]
alphas = np.sort(grouped_df['alpha'].unique())
temps = np.sort(grouped_df['student_temperature'].unique())

A, T = np.meshgrid(alphas, temps)

axes = orig_axes[0]

for i, property_name in enumerate(properties):
    ax = axes[i]
    Z = grouped_df.pivot(index='student_temperature', columns='alpha', values=property_name).reindex(index=temps, columns=alphas)
    c_ = ax.pcolormesh(A, T, Z.values, cmap='Spectral', shading='auto', vmin=0, vmax=1)
    
    ax.set_title(metric_styles[property_name]['label'])
    ax.set_xlabel(r'$\alpha$')
    ax.set_yscale('log')
axes[0].set_ylabel('standard setting\n'+r'temperature $\tau$')


constraints = {
    'student_train_frac': 0.8,
    'student_shuffle_input_intra_class': True,
    'data_num_labels': c,
}
a = filter(df, constraints).copy()
a = a.fillna(0)


grouped_df = a.groupby(['alpha', 'student_temperature']).mean(numeric_only=True).reset_index()

alphas = np.sort(grouped_df['alpha'].unique())
temps = np.sort(grouped_df['student_temperature'].unique())

A, T = np.meshgrid(alphas, temps)

axes = orig_axes[1]

for i, property_name in enumerate(properties):
    ax = axes[i]

    Z = grouped_df.pivot(index='student_temperature', columns='alpha', values=property_name).reindex(index=temps, columns=alphas)
    
    c_ = ax.pcolormesh(A, T, Z.values, cmap='Spectral', shading='auto', vmin=0, vmax=1)
    
    ax.set_title(metric_styles[property_name]['label'])
    ax.set_xlabel(r'$\alpha$')
    ax.set_yscale('log')
cax = fig.add_axes([1.0, 0.15, 0.02, 0.7]) 
fig.colorbar(c_, cax=cax, label='accuracy')
axes[0].set_ylabel('within-class soft label shuffling\n'+r'temperature $\tau$')

plt.tight_layout()
plt.savefig(FIGURE_DIR / f'figure1c_temperature_grid_d=1000_c={c}.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
constraints = {
    'student_train_frac': 0.8,
    'data_num_labels': c,
}

a = filter(df, constraints)
a.fillna(0, inplace=True)

grouped_df = a.groupby(['alpha','student_temperature','student_shuffle_input_intra_class']).mean(numeric_only=True).reset_index()


threshold_results = []

for temp in grouped_df['student_temperature'].unique():
    temp_data = grouped_df[grouped_df['student_temperature'] == temp][grouped_df.student_shuffle_input_intra_class == False]

    alpha_S_id_start = temp_data[temp_data['test_acc']>= 1.0]['alpha'].min()
    alpha_S_id_end = temp_data[temp_data['test_acc'] >= 1.0]['alpha'].max()
    
    alpha_T_label = temp_data[temp_data['teacher_train_acc'] >= 1.0]['alpha'].max()
    
    temp_data = grouped_df[grouped_df['student_temperature'] == temp][grouped_df.student_shuffle_input_intra_class == True]
    alpha_S_logit = temp_data[temp_data['train_acc'] >= 1.0]['alpha'].max()
    
    
    
    threshold_results.append({
        'student_temperature': temp,
        'alpha_S_id_start': alpha_S_id_start,
        'alpha_S_id_end': alpha_S_id_end,
        'alpha_S_logit': alpha_S_logit,
        'alpha_T_label': alpha_T_label,
    })

threshold_df = pd.DataFrame(threshold_results)


In [None]:

thresholds = [
    
    ('alpha_T_label', alpha_t_label_color, None, '-', 2, 6, alpha_t_label_name),
    ('alpha_S_id_start', alpha_s_id_color, None, '-', 2, 6, alpha_s_id_name),
    ('alpha_S_id_end', alpha_s_label_color, None, '--', 2, 6, alpha_s_label_name),
    ('alpha_S_logit', alpha_s_shuffle_label_color, None, '-.', 2, 6, alpha_s_shuffle_label_name),
    
]

threshold_df = threshold_df.sort_values(by='student_temperature')

fig, ax = plt.subplots(figsize=(3, 3))

for threshold, color, marker, linestyle, linewidth, markersize, label in thresholds:

    ax.plot(
        threshold_df[threshold], 
        threshold_df['student_temperature'], 
        label=label,
        color=color, 
        marker=marker,
        linestyle=linestyle,
        linewidth=linewidth,
        markersize=markersize
    )

ax.set_xlabel(r'$\alpha = n/(dc)$')
ax.set_ylabel(r'$\tau$')

ax.set_title(f'$c={c}$')
if c==10:
    #ax.set_xlim(5,40)
    ax.text(0.05, 1.12, f'(B)', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
if c==2:
    #ax.set_xlim(1,4)
    ax.text(0.05, 1.12, f'(A)', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
# Show the plot
plt.tight_layout()
plt.yscale('log')
plt.xlim(0.3,1.2)
plt.ylim(10**-0.5,10**3.5)
plt.grid()
plt.savefig(FIGURE_DIR /f'multinomial_regression_thresholds_alpha_temp_{c=}.pdf', dpi=300)
plt.show()
