In [5]:
import pandas as pd
import numpy as np
from config import * 
import numpy as np
import matplotlib as mpl
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.pyplot as plt

In [6]:
def get_phase(row):
    if row['teacher_train_acc'] >= 0.99:
        if row['student_train_acc'] >= 0.99:
            if row['student_test_acc'] >= 0.99:
                return 'structure'
            else:
                return 'perfect_mem'
        else:
            return 'imperfect_mem'
    else:
        if row['student_test_acc'] <= 0.1:
            return 'teacher_cannot_fit_student_cannot_fit'
        else:
            return 'teacher_cannot_fit'

In [12]:
project = f"{wandb_username}/fig_3_one_hidden_layer"
#project = f"{wandb_username}/distillation__main-fig_5"
df = load_result_table(project=project)
df = df[df.state == 'finished']
df_orig = df.copy()

In [13]:
project = f"{wandb_username}/fig_3_one_hidden_layer_shuffled"
df = load_result_table(project=project)
df = df[df.state == 'finished']
df = pd.concat([df_orig, df], ignore_index=True)

In [8]:
commands = {
    'Dtrain' : '$\mathcal{D}_{train}^S$',
    'Dtest' : '$\mathcal{D}_{test}^S$',
    'Dstar': '$D_*^T$',
}

In [11]:
constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_hidden_dim': 500,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a[a.data_num_samples_per_class != 100]
print(sorted(a.data_num_samples_per_class.unique()))
for c in a.columns:
    try:
        if len(a[c].unique()) < 10:
            print(c,a[c].unique())
    except:
        pass
    
phase_map = {
    'structure': 0,
    'perfect_mem': 1,
    'imperfect_mem': 4,
    'teacher_cannot_fit': 2,
    'teacher_cannot_fit_student_cannot_fit': 3
}

phase_map_name = {
    'structure': 'generalize teacher',
    'perfect_mem': 'perfect logit mem.',
    'imperfect_mem': 'imperfect logit mem.',
    'teacher_cannot_fit': 'teacher cannot mem.',
    'teacher_cannot_fit_student_cannot_fit': 'teacher cannot mem. student cannot mem.'
}


a = a.groupby(['student_train_frac','data_num_samples_per_class']).mean(numeric_only=True).reset_index()
a_std = a.groupby(['student_train_frac','data_num_samples_per_class']).std(numeric_only=True).reset_index()
a['phase'] = a.apply(get_phase, axis=1)
a['alpha'] = a['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim']
a['rho'] = a['student_train_frac']

inv_phase_map = {v: k for k, v in phase_map.items()}
a['phase'] = a['phase'].map(phase_map)

X = sorted(a['alpha'].unique())
Y = sorted(a['rho'].unique())

wa = 2.5
wp= 3
fig, axes = plt.subplots(2, 4, figsize=(22/2*0.8+22/2*0.8/3, 4*2*0.8))
axes = axes.flatten()

for i in range(5, 8):
    axes[i].sharex(axes[4])
    axes[i].sharey(axes[4])
    
plt.subplots_adjust(wspace=0.04)

# ==== Plot 1: Phase Diagram ====
ax = axes[0]
phase_grid = a.pivot(index='rho', columns='alpha', values='phase').reindex(index=Y, columns=X)
Z = phase_grid.values

phase_cmap = ListedColormap(["#2ca02c", "#1f77b4", "#ff7f0e", "#d62728", "#a3c1da"])
bounds = [0, 1, 2, 3, 4, 5]
norm = BoundaryNorm(bounds, phase_cmap.N)

pc = ax.pcolor(X, Y, Z, cmap=phase_cmap, norm=norm, shading='auto')
phase_centers = a.groupby('phase')[['alpha', 'rho']].mean()

ax.set_xlabel(r"$\alpha$")
ax.set_ylabel(r"$\rho$")
ax.set_xlim(0,240)
ax.set_ylim(0,1)
ax.set_aspect(240)
ax.set_title(f"phenomena")

ax = axes[1]
metric = 'teacher_train_acc'
acc_grid = a.pivot(index='rho', columns='alpha', values=metric).reindex(index=Y, columns=X)
Z = acc_grid.values
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
pc = ax.pcolor(X, Y, Z, cmap='Spectral', vmin=0, vmax=1, shading='auto')
cax = inset_axes(axes[1], width="8%", height="35%", loc="lower left")  # Horizontal colorbar at the bottom
fig.colorbar(pc, cax=cax)
ax.set_title(f"{metric_styles[metric]['label']}")


ax.set_xlabel(r"$\alpha$")
ax.set_ylabel(r"$\rho$")
ax.set_xlim(0,240)
ax.set_ylim(0,1)
ax.set_aspect(240)

ax = axes[2]
metric = 'student_train_acc'
acc_grid = a.pivot(index='rho', columns='alpha', values=metric).reindex(index=Y, columns=X)
Z = acc_grid.values
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
pc = ax.pcolor(X, Y, Z, cmap='Spectral', vmin=0, vmax=1, shading='auto')
cax = inset_axes(axes[1], width="8%", height="35%", loc="lower left")  # Horizontal colorbar at the bottom
fig.colorbar(pc, cax=cax)
ax.set_title(f"{metric_styles[metric]['label']}")


ax.set_xlabel(r"$\alpha$")
ax.set_ylabel(r"$\rho$")
ax.set_xlim(0,240)
ax.set_ylim(0,1)
ax.set_aspect(240)

# ==== Plot 2: Student Test Accuracy ====
ax = axes[3]
metric = 'student_test_acc'
acc_grid = a.pivot(index='rho', columns='alpha', values=metric).reindex(index=Y, columns=X)
Z = acc_grid.values
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
pc = ax.pcolor(X, Y, Z, cmap='Spectral', vmin=0, vmax=1, shading='auto')
cax = inset_axes(axes[1], width="8%", height="35%", loc="lower left")  # Horizontal colorbar at the bottom
fig.colorbar(pc, cax=cax)
ax.set_title(f"{metric_styles[metric]['label']}")


ax.set_xlabel(r"$\alpha$")
ax.set_ylabel(r"$\rho$")
ax.set_xlim(0,240)
ax.set_ylim(0,1)
ax.set_aspect(240)

constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_hidden_dim': 500,
    'student_train_frac': 0.65,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

project = f"{wandb_username}/distillation__main-fig_5"

runs = api.runs(project)

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a[a.teacher_early_stopping != True]

a.student_test_acc = a.student_test_acc.apply(lambda x: x if not isinstance(x, str) else None)
a.student_test_acc_orig = a.student_test_acc_orig.apply(lambda x: x if not isinstance(x, str) else None)
a.teacher_train_acc = a.teacher_train_acc.apply(lambda x: x if not isinstance(x, str) else None)

num_samples_per_class = [250,500,1000,2000]
orig_axes = axes
axes=axes[4:]
j = 0
titles = [
    f'S memorizes {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S approximates {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S learns T\n{metric_styles["student_test_acc"]["label"]}=100%',
    f'{metric_styles["teacher_train_acc"]["label"]}< 100%\n{metric_styles["student_train_acc"]["label"]} ~ {metric_styles["student_test_acc"]["label"]}'    
]

for ax, num_samples, title in zip(axes, num_samples_per_class, titles):
    
    
    
    b = a[a['data_num_samples_per_class'] == num_samples][:5]

    if len(b) == 0:
        print('skip')
        continue

    ax.set_title(title)
    j += 1
    for _ in [1,2,3]:
        orig_axes[_].scatter(
            b['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim'],
            b['student_train_frac'], edgecolors='white', s=100, facecolors='none')
    
    for _, row in b.iterrows():
        
        orig_axes[0].text(
        row['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim'],
        row['student_train_frac'],
        j,
        ha='center',
        va='center',
        fontsize=8,
        color='black',
        bbox=dict(facecolor='white', edgecolor='black', boxstyle='circle,pad=0.3')
    )
        ax.text(
    0.1, 0.9, 
    j,
    ha='left',
    va='top',
    fontsize=8,
    color='black',
    transform=ax.transAxes,
    bbox=dict(facecolor='white', edgecolor='black', boxstyle='circle,pad=0.3')
)

    
    
    all_train_accs = []
    all_test_accs = []
    for i in range(5):
        run_data = b.iloc[i]
        run = api.run(f"{project}/{run_data.run_id}")
        history = run.history() 
        test_acc = history["test_acc"].dropna()
        train_acc = history["train_acc"].dropna()
        test_acc_orig = history["test_acc_orig"]
        test_acc_zeros = history["test_acc_zeros"]
        train_loss = history["train_loss"]
        test_loss = history["test_loss"]
        all_train_accs.append(train_acc)
        all_test_accs.append(test_acc)
        ax.plot(test_acc.values, color=metric_styles['student_test_acc']['color'], lw=0.5 )
        ax.plot(train_acc.values, color=metric_styles['student_train_acc']['color'], lw=0.5)
        
    ax.axhline(run_data.student_test_acc_orig, color=metric_styles['student_test_acc_orig']['color'], linestyle=metric_styles['student_test_acc_orig']['linestyle'], label =metric_styles['student_test_acc_orig']['label'])
    ax.axhline(run_data.teacher_train_acc, color=metric_styles['teacher_train_acc']['color'], linestyle=metric_styles['teacher_train_acc']['linestyle'], label=metric_styles['teacher_train_acc']['label'])
    ax.axhline(0, color='black',lw=0.5)
   
    ax.text(
    0.95,  
    0.95,  
    f"$n = {num_samples*100:,}$",  
    ha='right', 
    va='top',  
    fontsize=10,  
    color='black',  
    bbox=dict(facecolor='white', alpha=0.7, edgecolor='black', boxstyle='round,pad=0.3'), 
    transform=ax.transAxes
)
    
    min_len = min([len(x) for x in all_train_accs])
    all_train_accs = [x.values[:min_len] for x in all_train_accs]
    all_train_accs = np.array(all_train_accs)
    ax.plot(all_train_accs.mean(axis=0), color=metric_styles['student_train_acc']['color'], lw=2, label=metric_styles['student_train_acc']['label'], linestyle=metric_styles['student_train_acc']['linestyle'])

    ax.fill_between(range(min_len), all_train_accs.mean(axis=0) - all_train_accs.std(axis=0), all_train_accs.mean(axis=0) + all_train_accs.std(axis=0), color=metric_styles['student_train_acc']['color'], alpha=0.2)
    
    min_len = min([len(x) for x in all_test_accs])
    all_test_accs = [x.values[:min_len] for x in all_test_accs]
    all_test_accs = np.array(all_test_accs)
    ax.plot(all_test_accs.mean(axis=0), color=metric_styles['student_test_acc']['color'], lw=2, label=metric_styles['student_test_acc']['label'], linestyle=metric_styles['student_test_acc']['linestyle'])
    

    ax.set_xlabel('$t$')
    ax.set_xlim(0,200)
axes[0].set_ylabel('accuracy')
axes[0].legend()
for axes, name in zip(orig_axes,["A","B.1","B.2","B.3","C.1","C.2","C.3","C.4"]):
    axes.text(0.05, 1.12 if "C" not in name else 1.22, f'({name})', transform=axes.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')


plt.tight_layout()
plt.savefig(FIGURE_DIR / "1-hidden-layer_overview.pdf", dpi=300)
plt.show()  


In [18]:


fig, axes = plt.subplots(1, 4, figsize=(22/2*0.8+22/2*0.8/3+1, 4*2*0.8/2),sharey=True)
axes = axes.flatten()


constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_hidden_dim': 500,
    'student_train_frac': 0.65,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

project = f"{wandb_username}/fig_3_one_hidden_layer"

runs = api.runs(project)

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a[a.teacher_early_stopping != True]

a.student_test_acc = a.student_test_acc.apply(lambda x: x if not isinstance(x, str) else None)
a.student_test_acc_orig = a.student_test_acc_orig.apply(lambda x: x if not isinstance(x, str) else None)
a.teacher_train_acc = a.teacher_train_acc.apply(lambda x: x if not isinstance(x, str) else None)

num_samples_per_class = [250,500,1000,2000]

j = 0
titles = [
    f'S memorizes {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S approximates {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S learns T\n{metric_styles["student_test_acc"]["label"]}=100%',
    f'{metric_styles["teacher_train_acc"]["label"]}< 100%\n{metric_styles["student_train_acc"]["label"]} ~ {metric_styles["student_test_acc"]["label"]}'    
]

for ax, num_samples, title in zip(axes, num_samples_per_class, titles):
    
    b = a[a['data_num_samples_per_class'] == num_samples][:5]

    if len(b) == 0:
        print('skip')
        continue
    # mark with star on the first plot
    ax.set_title(title)
    j += 1
    for _ in [1,2,3]:
        orig_axes[_].scatter(
            b['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim'],
            b['student_train_frac'], edgecolors='white', s=100, facecolors='none')
    
    

    all_train_accs = []
    all_test_accs = []
    for i in range(5):
        run_data = b.iloc[i]
        run = api.run(f"{project}/{run_data.run_id}")
        history = run.history()  # Adjust `samples` for the number of steps
        # get test and train accuracies
        test_acc = history["test_acc"].dropna()
        train_acc = history["train_acc"].dropna()
        test_acc_orig = history["test_acc_orig"]
        test_acc_zeros = history["test_acc_zeros"]
        train_loss = history["train_loss"].dropna()
        test_loss = history["test_loss"].dropna()
        all_train_accs.append(train_loss)
        all_test_accs.append(test_loss)
        
        # plot train and test accuracies
        ax.plot(test_loss.values, color=metric_styles['student_test_acc']['color'], lw=1 )
        ax.plot(train_loss.values, color=metric_styles['student_train_acc']['color'], lw=1)
        
    
    min_len = min([len(x) for x in all_train_accs])
    all_train_accs = [x.values[:min_len] for x in all_train_accs]
    all_train_accs = np.array(all_train_accs)
    ax.plot(all_train_accs.mean(axis=0), color=metric_styles['student_train_acc']['color'], lw=2, label=metric_styles['student_train_acc']['label'].replace('acc','loss'), linestyle=metric_styles['student_train_acc']['linestyle'])
    # also plot the std as shaded area
    ax.fill_between(range(min_len), all_train_accs.mean(axis=0) - all_train_accs.std(axis=0), all_train_accs.mean(axis=0) + all_train_accs.std(axis=0), color=metric_styles['student_train_acc']['color'], alpha=0.2)
    
    min_len = min([len(x) for x in all_test_accs])
    all_test_accs = [x.values[:min_len] for x in all_test_accs]
    all_test_accs = np.array(all_test_accs)
    ax.plot(all_test_accs.mean(axis=0), color=metric_styles['student_test_acc']['color'], lw=2, label=metric_styles['student_test_acc']['label'].replace('acc','loss'), linestyle=metric_styles['student_test_acc']['linestyle'])
    
    ax.set_xlabel('$t$')
    ax.set_xlim(0,200)
axes[0].set_ylabel('cross entropy loss')
axes[0].legend()
plt.savefig(FIGURE_DIR / "1-hidden-layer_losses.pdf", dpi=300,bbox_inches='tight')
plt.show()  


In [None]:
fig, axes = plt.subplots(1, 4, figsize=(22/2*0.8+22/2*0.8/3+1, 4*2*0.8/2),sharey=True)
axes = axes.flatten()


constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_hidden_dim': 500,
    'student_train_frac': 0.65,
    'student_shuffle_input_intra_class': True,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

project = f"{wandb_username}/fig_3_one_hidden_layer_shuffled"

runs = api.runs(project)

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a[a.teacher_early_stopping != True]

a.student_test_acc = a.student_test_acc.apply(lambda x: x if not isinstance(x, str) else None)
a.student_test_acc_orig = a.student_test_acc_orig.apply(lambda x: x if not isinstance(x, str) else None)
a.teacher_train_acc = a.teacher_train_acc.apply(lambda x: x if not isinstance(x, str) else None)

num_samples_per_class = [250,500,1000,2000]

j = 0
titles = [
    f'S memorizes {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S approximates {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S learns T\n{metric_styles["student_test_acc"]["label"]}=100%',
    f'{metric_styles["teacher_train_acc"]["label"]}< 100%\n{metric_styles["student_train_acc"]["label"]} ~ {metric_styles["student_test_acc"]["label"]}'    
]

for ax, num_samples, title in zip(axes, num_samples_per_class, titles):
    
    b = a[a['data_num_samples_per_class'] == num_samples][:5]

    if len(b) == 0:
        print('skip')
        continue
    ax.set_title(title)
    j += 1
    for _ in [1,2,3]:
        orig_axes[_].scatter(
            b['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim'],
            b['student_train_frac'], edgecolors='white', s=100, facecolors='none')
    
    

    all_train_accs = []
    all_test_accs = []
    for i in range(2):
        run_data = b.iloc[i]
        run = api.run(f"{project}/{run_data.run_id}")
        history = run.history() 
        test_acc = history["test_acc"].dropna()
        train_acc = history["train_acc"].dropna()
        test_acc_orig = history["test_acc_orig"]
        test_acc_zeros = history["test_acc_zeros"]
        train_loss = history["train_loss"].dropna()
        test_loss = history["test_loss"].dropna()
        all_train_accs.append(train_acc)
        all_test_accs.append(test_acc)
        
        #ax.plot(test_acc.values, color='tab:red', lw=1 )
        ax.plot(train_acc.values, color='tab:red', lw=1)
        
    
    min_len = min([len(x) for x in all_train_accs])
    all_train_accs = [x.values[:min_len] for x in all_train_accs]
    all_train_accs = np.array(all_train_accs)
    ax.plot(all_train_accs.mean(axis=0), color='tab:red', lw=2, label=metric_styles['student_shuffle_test_acc']['label'], linestyle=metric_styles['student_shuffle_test_acc']['linestyle'])
    ax.fill_between(range(min_len), all_train_accs.mean(axis=0) - all_train_accs.std(axis=0), all_train_accs.mean(axis=0) + all_train_accs.std(axis=0), color=metric_styles['student_train_acc']['color'], alpha=0.2)
    
    ax.set_xlabel('$t$')
    ax.set_xlim(0,200)
    
    
constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_hidden_dim': 500,
    'student_train_frac': 0.65,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

project = f"{wandb_username}/fig_3_one_hidden_layer"

runs = api.runs(project)

a = filter(df, constraints)
a = a.replace('NaN', None)
a = a[a.teacher_early_stopping != True]

a.student_test_acc = a.student_test_acc.apply(lambda x: x if not isinstance(x, str) else None)
a.student_test_acc_orig = a.student_test_acc_orig.apply(lambda x: x if not isinstance(x, str) else None)
a.teacher_train_acc = a.teacher_train_acc.apply(lambda x: x if not isinstance(x, str) else None)

num_samples_per_class = [250,500,1000,2000]

j = 0
titles = [
    f'S memorizes {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S approximates {commands["Dtrain"]}\n{metric_styles["student_test_acc"]["label"]}>{metric_styles["student_test_acc_orig"]["label"]}',
    f'S learns T\n{metric_styles["student_test_acc"]["label"]}=100%',
    f'{metric_styles["teacher_train_acc"]["label"]}< 100%\n{metric_styles["student_train_acc"]["label"]} ~ {metric_styles["student_test_acc"]["label"]}'    
]

for ax, num_samples, title in zip(axes, num_samples_per_class, titles):
    
    b = a[a['data_num_samples_per_class'] == num_samples][:5]

    if len(b) == 0:
        print('skip')
        continue
    ax.set_title(title)
    j += 1
    for _ in [1,2,3]:
        orig_axes[_].scatter(
            b['data_num_samples_per_class'] * constraints['data_num_labels'] / constraints['data_input_dim'],
            b['student_train_frac'], edgecolors='white', s=100, facecolors='none')
    
    

    all_train_accs = []
    all_test_accs = []
    for i in range(5):
        run_data = b.iloc[i]
        run = api.run(f"{project}/{run_data.run_id}")
        history = run.history() 
        test_acc = history["test_acc"].dropna()
        train_acc = history["train_acc"].dropna()
        test_acc_orig = history["test_acc_orig"]
        test_acc_zeros = history["test_acc_zeros"]
        train_loss = history["train_loss"].dropna()
        test_loss = history["test_loss"].dropna()
        all_train_accs.append(train_acc)
        all_test_accs.append(test_acc)
        
        ax.plot(test_acc.values, color=metric_styles['student_test_acc']['color'], lw=1 )
        ax.plot(train_acc.values, color=metric_styles['student_train_acc']['color'], lw=1)
        
    
    min_len = min([len(x) for x in all_train_accs])
    all_train_accs = [x.values[:min_len] for x in all_train_accs]
    all_train_accs = np.array(all_train_accs)
    ax.plot(all_train_accs.mean(axis=0), color=metric_styles['student_train_acc']['color'], lw=2, label=metric_styles['student_train_acc']['label'], linestyle=metric_styles['student_train_acc']['linestyle'])
    ax.fill_between(range(min_len), all_train_accs.mean(axis=0) - all_train_accs.std(axis=0), all_train_accs.mean(axis=0) + all_train_accs.std(axis=0), color=metric_styles['student_train_acc']['color'], alpha=0.2)
    
    min_len = min([len(x) for x in all_test_accs])
    all_test_accs = [x.values[:min_len] for x in all_test_accs]
    all_test_accs = np.array(all_test_accs)
    ax.plot(all_test_accs.mean(axis=0), color=metric_styles['student_test_acc']['color'], lw=2, label=metric_styles['student_test_acc']['label'], linestyle=metric_styles['student_test_acc']['linestyle'])
    
    ax.set_xlabel('$t$')
    ax.set_xlim(0,200)
    
axes[0].set_ylabel('accuracy')
axes[0].legend()
plt.savefig(FIGURE_DIR / "1-hidden-layer_accuracy_shuffle.pdf", dpi=300,bbox_inches='tight')
plt.show()  


In [25]:
constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_hidden_dim': 500,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
    'teacher_num_epochs': 1000,
    
}

a = filter(df, constraints)
a = a.replace('NaN', None)

a.student_test_acc = a.student_test_acc.apply(lambda x: x if not isinstance(x, str) else None)
a.student_test_acc_orig = a.student_test_acc_orig.apply(lambda x: x if not isinstance(x, str) else None)
a.teacher_train_acc = a.teacher_train_acc.apply(lambda x: x if not isinstance(x, str) else None)

fig, axes = plt.subplots(1, 2, figsize=(6*0.97, 3*0.97),sharey=True, sharex=True)
num_samples_per_class = [250,1000]
metrics = ['student_test_acc',
    'student_train_acc',
    'student_test_acc_orig',
    'teacher_train_acc',]
for ax, num_samples in zip(axes, num_samples_per_class):
    
    b = a[a['data_num_samples_per_class'] == num_samples]
    # only 10 runs per student_train_frac when there are more available
    b = b.groupby(['student_train_frac']).head(10).reset_index()
    b_sem = b.groupby(['student_train_frac']).std(numeric_only=True).reset_index()
    b_mean = b.groupby(['student_train_frac']).mean(numeric_only=True).reset_index()
    
    for metric in metrics:
        ax.errorbar(
            b_mean['student_train_frac'],
            b_mean[metric],
            yerr=b_sem[metric],
            label=metric_styles[metric]['label'],
            color=metric_styles[metric]['color'],
            linestyle=metric_styles[metric]['linestyle'],
            capsize=2,
            lw=2
        )


    ax.set_title(r"$n "+ f" = {num_samples * constraints['data_num_labels']:,}$")
    ax.set_ylim(-0.05, 1.05)
    ax.set_xlabel(r'$\rho$')
axes[0].set_ylabel('accuracy')
axes[0].legend(ncols=2)
#fig.supxlabel(r'fraction of memorized data available to student', fontsize=10)
for i, name in zip(range(4),"AB"):
    axes[i].text(0.05, 1.12, f'({name})', transform=axes[i].transAxes, fontsize=12, va='top', ha='right',fontweight='bold')
plt.tight_layout()
plt.savefig(FIGURE_DIR / "1-hidden-layer-rho-curves.pdf", dpi=300, bbox_inches='tight') 
plt.show()

In [39]:
constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'data_num_samples_per_class': 250,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
    'teacher_hidden_dim': 500,
}

k = filter(df, constraints)
k = k.replace('NaN', None)

hidden_dims = k['student_hidden_dim'].unique()

norm = mpl.colors.LogNorm(vmin=min(hidden_dims), vmax=max(hidden_dims))
cmap = plt.get_cmap('rainbow') 
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  

plt.figure(figsize=(5, 3))
ax = plt.gca()
num_samples_per_class = [250, 500]
metric = 'student_test_acc'
plt.plot([], [], color='black',label=metric_styles[metric]['label'])
print(sorted(k.student_hidden_dim.unique()))
for student_hidden_dim, a in k.groupby('student_hidden_dim'):


    b = a[a.student_hidden_dim == student_hidden_dim]
    b = b.groupby(['student_train_frac']).head(5).reset_index()
    b_sem = b.groupby(['student_train_frac']).sem(numeric_only=True).reset_index()
    b_mean = b.groupby(['student_train_frac']).mean(numeric_only=True).reset_index()


    ax.errorbar(
        b_mean['student_train_frac'],
        b_mean[metric],
        yerr=b_sem[metric],
        color=cmap(norm(student_hidden_dim)),  
        linestyle=metric_styles[metric]['linestyle'],
        capsize=2,
        lw = 1.5 if student_hidden_dim == 500 else 1.0,
    )
    
    if student_hidden_dim == 500:
        ax.plot(
            b_mean['student_train_frac'],
            b_mean['student_test_acc_orig'],
            label=metric_styles['student_test_acc_orig']['label'],
            color=metric_styles['student_test_acc_orig']['color'],
            linestyle='dotted'
        )


cbar = plt.colorbar(sm, ax=ax)
cbar.set_label(r'student hidden size $p$', rotation=270, labelpad=15)

plt.xlabel(r'$\rho$')
plt.ylabel('accuracy')
plt.axhline(0, color='black', lw=0.5)
plt.legend()
plt.tight_layout()
plt.savefig(FIGURE_DIR / "1-hidden-layer-mem-more-hidden-layers.pdf", dpi=300, bbox_inches='tight')
plt.show()


In [27]:
constraints = {
    'student_temperature': 20.0,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_hidden_dim': 500,
    'student_train_frac': 0.45,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

k = filter(df, constraints)
k = k.replace('NaN', None)

colors = { c: list(c_muted.values())[i] for i, c in enumerate(k['data_num_labels'].unique()) }

for data_num_labels, a in k.groupby('data_num_labels'):
    a = a.groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
    x = a['data_num_samples_per_class'] * a['data_num_labels'] / a['data_input_dim'] / a['data_num_labels'] * a['data_num_labels']#* np.log(a['data_num_labels'])
    plt.plot(x, a['student_test_acc'], linestyle='solid',marker='x',  c=colors[data_num_labels], label=data_num_labels)
    plt.plot(x, a['student_train_acc'], linestyle='dashed',marker='x', c=colors[data_num_labels])
    plt.plot(x, a['teacher_train_acc'], c='grey',marker='x')

plt.xlabel(r'$\alpha \cdot c = n/d$')
plt.ylabel('accuracy')
plt.plot([],[],linestyle='dashed', c='black', label=metric_styles['student_train_acc']['label'])
plt.plot([],[], c='black',label=metric_styles['student_test_acc']['label'])
plt.plot([],[], c='grey', label=metric_styles['teacher_train_acc']['label'])
plt.legend(title='classes $c$')
plt.savefig(FIGURE_DIR / "1-hidden-layer-mem-more-classes.pdf", dpi=300, bbox_inches='tight')

In [28]:
for s_hidden, color in zip([400,500,600],['tab:orange','tab:blue','tab:purple']):
    constraints = {
        'student_temperature': 20.0,
        'data_num_labels': 100,
        'data_input_dim': 1000,
        'student_num_epochs': 10000,
        'student_hidden_dim': s_hidden,
        'teacher_hidden_dim': 500,
        'student_train_frac': 0.45,
        'student_shuffle_input_intra_class': False,
        'teacher_activation': 'relu',
        'student_activation': 'relu',
    }

    a = filter(df, constraints)
    a = a.replace('NaN', None)

    a = a.groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
    x = a['data_num_samples_per_class'] * a['data_num_labels'] / a['data_input_dim'] / a['data_num_labels']
    plt.plot(x, a['student_test_acc'], marker='x', c=color)
    plt.plot(x, a['student_train_acc'], linestyle='dashed', marker='x',c=color)
    plt.plot(x, a['teacher_train_acc'], c='grey',marker='x')
    plt.plot([],[], c=color, label=s_hidden)



plt.xlabel(r'$\alpha = n/(dc)$')
plt.ylabel('accuracy')
plt.plot([],[], c='black', linestyle='dashed', label=metric_styles['student_train_acc']['label'])
plt.plot([],[], c='black', label=metric_styles['student_test_acc']['label'])
plt.plot([],[], c='grey', label=metric_styles['teacher_train_acc']['label'])
plt.legend(title='hidden size $p^S$',ncol=2)

plt.savefig(FIGURE_DIR / "1-hidden-layer-student_hidden_layers.pdf", dpi=300, bbox_inches='tight')

In [36]:

constraints = {
    'student_temperature': 20.0,
    'data_num_labels': 100,
    'data_input_dim': 1000,
    'student_num_epochs': 10000,
    'student_train_frac': 0.45,
    'student_shuffle_input_intra_class': False,
    'teacher_activation': 'relu',
    'student_activation': 'relu',
}

k = filter(df, constraints)
k = k.replace('NaN', None)
k = k[k.student_hidden_dim == k.teacher_hidden_dim]

colors = { c: list(c_muted.values())[i] for i, c in enumerate(k['student_hidden_dim'].unique()) }

for student_hidden_dim, a in k.groupby('student_hidden_dim'):
    a = a[a.teacher_hidden_dim == student_hidden_dim]
    if len(a) == 0:
        continue
    a = a.groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
    print((a['data_num_samples_per_class'] * 100).values)
    x = a['data_num_samples_per_class'] * a['data_num_labels'] / a['data_input_dim']  / a['student_hidden_dim'] / a['data_num_labels']
    plt.plot(x, a['student_test_acc'],marker='x',  c=colors[student_hidden_dim], label=student_hidden_dim)
    plt.plot(x, a['student_train_acc'], linestyle='dashed', marker='x', c=colors[student_hidden_dim])
    plt.plot(x, a['teacher_train_acc'], c='grey',marker='x')

plt.xlabel(r'$\alpha / p^S = n/(dcp^S)$')
plt.ylabel('accuracy')
plt.plot([],[],linestyle='dashed', c='black', label=metric_styles['student_train_acc']['label'])
plt.plot([],[], c='black',label=metric_styles['student_test_acc']['label'])
plt.plot([],[], c='grey', label=metric_styles['teacher_train_acc']['label'])
plt.legend(title='hidden $p^T = p^S$')
plt.xlim(0,0.0075)
plt.savefig(FIGURE_DIR / "1-hidden-layer-hidden_s_is_hidden_t.pdf", dpi=300, bbox_inches='tight')