In [9]:
import numpy as np
import matplotlib.pyplot as plt
from config import *
import matplotlib.colors as mcolors


In [2]:
df = load_result_table(project=f"{wandb_username}/fig_1c_alpha_label_capacity")
df = df[df.state == 'finished']
df_orig = df.copy()

In [3]:
df3 = load_result_table(project=f"{wandb_username}/fig_1c_linear_regression_class_scaling")
df3 = df3[df3.state == 'finished']
df = pd.concat([df,df3], ignore_index=True)

In [4]:
df_orig = df.copy()

In [5]:
df['alpha'] = df.data_num_samples_per_class * df.data_num_labels / (df.data_input_dim * df.data_num_labels)
df['alpha_c'] = df.data_num_samples_per_class * df.data_num_labels / df.data_input_dim
df['alpha_logc'] = df.data_num_samples_per_class * np.log(df.data_num_labels) / df.data_input_dim
df['alpha_sqrtc'] = df.data_num_samples_per_class * np.sqrt(df.data_num_labels) / df.data_input_dim


In [None]:

def truncate_colormap(cmap, minval=0.0, maxval=0.7, n=100):
    new_cmap = mcolors.LinearSegmentedColormap.from_list(
        f"trunc({cmap.name},{minval:.2f},{maxval:.2f})",
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap

In [10]:
constraints = {
    'student_train_frac': 0.55,
    'student_temperature': 10.0,
    'student_shuffle_input_intra_class': False,
}

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a.sort_values(by='alpha')

cmap = truncate_colormap(plt.get_cmap('PuBu'), 0.3, 1.0)

colors = {}
for i, label in enumerate(sorted(a.data_num_labels.unique())):
    colors[label] = cmap(i / len(a.data_num_labels.unique()))

In [None]:

fig, axes = plt.subplots(1,4,figsize=(12,2.5),sharey=True)

fig.subplots_adjust(hspace=0.4, wspace=0.3)

axes =axes.flatten()
ax = axes[0]
constraints = {
    'student_train_frac': 0.55,
    'student_temperature': 10.0,
    'student_shuffle_input_intra_class': False,
}

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a.sort_values(by='alpha')

base_cmap = plt.get_cmap('Blues')
cmap = truncate_colormap(base_cmap, 0.4,1.0)
unique_labels = sorted(a.data_num_labels.unique())
norm = plt.Normalize(vmin=min(unique_labels), vmax=max(unique_labels))
colors = {label: cmap(norm(label)) for label in unique_labels}

for c, g in a.groupby('data_num_labels'):
    g.train_acc = g.train_acc.astype(float)
    g = g.groupby('alpha_logc').mean(numeric_only=True).reset_index()
    ax.plot(g['alpha_logc'], g['teacher_train_acc'], marker='x', c=colors[c])
    
    if c in [2,5]:
        ax.text(g['alpha_logc'].iloc[-1] + 0.01, g['test_acc'].iloc[-1] + 0.01, f'c={c}', fontsize=10, color=colors[c], ha='left', va='bottom')

ax.set_xlabel(r'$\alpha \log c$') 
ax.set_title(metric_styles['teacher_train_acc']['label'])
ax.set_xlim(0, 2)

ax.axvline(0.92, color=alpha_t_label_color, linestyle='--', label=alpha_t_label_name)
ax.legend()

ax = axes[2]
constraints = {
    'student_train_frac': 0.55,
    'student_temperature': 10.0,
    'student_shuffle_input_intra_class': False,
}
a = filter(df, constraints)
a = a.replace('NaN', None)
a = a.sort_values(by='alpha')

base_cmap = plt.get_cmap('Oranges')
cmap = truncate_colormap(base_cmap, 0.4,1.0)
unique_labels = sorted(a.data_num_labels.unique())
norm = plt.Normalize(vmin=min(unique_labels), vmax=max(unique_labels))
colors = {label: cmap(norm(label)) for label in unique_labels}

for c, g in a.groupby('data_num_labels'):
    g.train_acc = g.train_acc.astype(float)
    g = g.groupby('alpha_c').mean(numeric_only=True).reset_index()
    ax.plot(g['alpha_c'], g['test_acc'], marker='x', c=colors[c])

ax.set_xlabel(r'$\alpha c$') 

ax.set_title(metric_styles['student_test_acc']['label'])
ax.axvline(2.0, color=alpha_s_id_color, linestyle='--', label=alpha_s_id_name)
ax.legend()

ax = axes[3]
constraints = {
    'student_train_frac': 0.55,
    'student_temperature': 10.0,
    'student_shuffle_input_intra_class': True,
}
a = filter(df, constraints)
a = a.replace('NaN', None)
a = a.sort_values(by='alpha')

base_cmap = plt.get_cmap('Reds')
cmap = truncate_colormap(base_cmap, 0.4,1.0)
unique_labels = sorted(a.data_num_labels.unique())
norm = plt.Normalize(vmin=min(unique_labels), vmax=max(unique_labels))
colors = {label: cmap(norm(label)) for label in unique_labels}

for c, g in a.groupby('data_num_labels'):
    g.train_acc = g.train_acc.astype(float)
    g = g.groupby('alpha_sqrtc').mean(numeric_only=True).reset_index()
    ax.plot(g['alpha_sqrtc'], g['train_acc'], marker='x', c=colors[c])

ax.set_xlabel(r'$\alpha \sqrt{c}$') 
ax.set_xlim(0, 2)
ax.axvline(0.66, color=alpha_s_shuffle_label_color, linestyle='--', label=alpha_s_shuffle_label_name)
ax.legend()
ax.set_title(metric_styles['student_shuffle_test_acc']['label'])

ax = axes[1]
constraints = {
    'student_train_frac': 0.55,
    'student_temperature': 10.0,
    'student_shuffle_input_intra_class': False,
}
a = filter(df, constraints)
a = a.replace('NaN', None)
a = a.sort_values(by='alpha')

base_cmap = plt.get_cmap('Purples')
cmap = truncate_colormap(base_cmap, 0.4,1.0)
unique_labels = sorted(a.data_num_labels.unique())
norm = plt.Normalize(vmin=min(unique_labels), vmax=max(unique_labels))
colors = {label: cmap(norm(label)) for label in unique_labels}

for c, g in a.groupby('data_num_labels'):
    g.train_acc = g.train_acc.astype(float)
    g = g.groupby('alpha_logc').mean(numeric_only=True).reset_index()
    ax.plot(g['alpha_logc'], g['train_acc'], marker='x', c=colors[c])

ax.set_xlabel(r'$\alpha \log{c}$') 
ax.set_xlim(0, 2)
ax.axvline(0.92, color=alpha_s_label_color, linestyle='--', label=alpha_s_label_name)
ax.legend()
ax.set_title(metric_styles['student_train_acc']['label'])


axes[0].set_ylabel('accuracy')
axes[2].set_xlim(0,10)
for ax, name in zip(axes,["B.1","B.2","B.3", "B.4"]):
    ax.text(0.05, 1.12 , f'({name})', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
    
shared_norm = plt.Normalize(vmin=min(unique_labels), vmax=max(unique_labels))  # reuse from above
grey_cmap = plt.get_cmap('Greys')
sm = plt.cm.ScalarMappable(cmap=grey_cmap, norm=shared_norm)
sm.set_array([])

cbar = fig.colorbar(sm, ax=axes, orientation='vertical', pad=0.02, aspect=30)
cbar.set_label('$c$')
plt.savefig(FIGURE_DIR / 'class_scaling_combined_long.pdf', dpi=300, bbox_inches='tight')
plt.show()
