In [None]:
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib.colors import LogNorm
from config import *

x_min, x_max = 0.2, 3.5
d = 200
results = [{**a, 'data_input_dim': d} for a in pickle.load(open(f"neurips_code/results_linear_with_mse_d={d}.pkl", "rb"))]
d = 1600
results += [{**a, 'data_input_dim': d} for a in pickle.load(open(f"neurips_code/results_linear_with_mse_d={d}.pkl", "rb"))]
d = 1600
results += [{**a, 'data_input_dim': d} for a in pickle.load(open(f"neurips_code/results_linear_with_mse_d={d}_v2.pkl", "rb"))]

df = pd.DataFrame(results)
df['student_train_acc'] = df['train_acc']
df['student_test_acc'] = df['test_acc']
df_orig = df.copy()

df = df[df.student_shuffle_input_intra_class == False]
df = df.groupby(['student_train_frac','data_num_samples_per_class','data_input_dim']).mean(numeric_only=True).reset_index()
b = df[df.data_input_dim == d]


def to_grid(z,a,d):
    table = a.pivot(index='student_train_frac', columns='data_num_samples_per_class', values=z)
    X = table.columns.values * 2 / d
    Y = table.index.values
    Z = table.values
    return X, Y, Z


fig, axes = plt.subplots(2, 3, figsize=(10, 6),sharex=True)
axes = axes.flatten()
metrics = ['teacher_train_acc','student_train_acc', 'student_test_acc', 'match_teacher_test_acc',  'train_mse','test_mse',]

for ax, metric in zip(axes, metrics):
    X, Y, Z = to_grid(metric,b,d)
    if metric == 'test_mse' or metric == 'train_mse':
        pcm = ax.pcolor(X, Y, Z, cmap='Greens_r', norm=LogNorm(vmin=Z.min(), vmax=Z.max()), shading='auto')
        fig.colorbar(pcm, ax=ax)
    elif metric == 'teacher_train_acc':
        a = b.groupby('data_num_samples_per_class')['train_acc'].mean().reset_index()
        ax.plot(a.data_num_samples_per_class  * 2 / d, a.train_acc, c='black')
    else:
        pcm = ax.pcolor(X, Y, Z, cmap='Spectral', vmin=0, vmax=1, shading='auto')
        fig.colorbar(pcm, ax=ax)
    ax.set_xlabel(r'$\alpha$')
    ax.set_ylabel(r'$\rho$')
    ax.set_title(metric_styles[metric]['label'])
for ax in axes:
    ax.set_xlim(x_min,x_max)
axes[0].set_ylabel('accuracy')
plt.tight_layout()
plt.savefig(FIGURE_DIR / f'logistic_regression_phase_diag_linear_with_mse_d={d}.pdf',bbox_inches='tight')
plt.show()


In [None]:
a = df[df.data_input_dim == 1600].groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
alpha_t_label_min = a[a.teacher_train_acc >= 0.99].data_num_samples_per_class.max() * 2 / d
alpha_t_label_max = a[a.teacher_train_acc < 0.99].data_num_samples_per_class.min() * 2 / d
alpha_t_label= (alpha_t_label_min + alpha_t_label_max) / 2

In [None]:
a = df[np.logical_and(df.data_input_dim == 1600,np.isclose(df.student_train_frac,0.80306122))].groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
alpha_s_label_min = a[a.student_train_acc >= 0.99].data_num_samples_per_class.max() * 2 / d
alpha_s_label_max = a[a.student_train_acc < 0.99].data_num_samples_per_class.min() * 2 / d
alpha_s_label= (alpha_s_label_min + alpha_s_label_max) / 2

In [None]:
a = df[np.logical_and(df.data_input_dim == 1600,np.isclose(df.student_train_frac,0.80306122))].groupby(['data_num_samples_per_class']).mean(numeric_only=True).reset_index()
alpha_s_id = a[a.student_test_acc >= 0.99].data_num_samples_per_class.min() * 2 / d

print(alpha_t_label, alpha_s_label, alpha_s_id)

In [None]:
from enum import IntEnum

class Phases(IntEnum):
    NADA = 4
    GENERALIZATION = 0
    GEN_LEAKAGE = 1
    MEM_FAIL_DONT_LEARN_TEACHER = 2
    MEM_FAIL_LEARN_TEACHER = 3
# Phase classification function
def get_phase(train_acc, test_acc, match_teacher_test_acc, train_mse, test_mse):
    if train_acc >= 0.99:
        if test_acc >= 0.99:
            return Phases.GENERALIZATION
        else:
            if test_acc >= 0.55:
                return Phases.GEN_LEAKAGE
            else:
                return Phases.NADA
    else:
        if match_teacher_test_acc >= 0.99:
            return Phases.MEM_FAIL_LEARN_TEACHER
        else:
            return Phases.MEM_FAIL_DONT_LEARN_TEACHER

a = df[df.data_input_dim==d].groupby(['student_train_frac', 'data_num_samples_per_class']).mean(numeric_only=True).reset_index()
# Apply the phase classification
a['phase'] = a.apply(lambda row: get_phase(
    row['train_acc'],
    row['test_acc'],
    row['match_teacher_test_acc'],
    row['train_mse'],
    row['test_mse']
), axis=1)

# Compute alpha and rho
a['alpha'] = a['data_num_samples_per_class'] * 2 / d  # assuming d = 1000, 2 classes
a['rho'] = a['student_train_frac']

#a = pd.concat([a, other_a], axis=0)

# Define colormap for phases
phase_colors = {
    Phases.GENERALIZATION: "#2ca02c",             # green
    Phases.GEN_LEAKAGE: "#1f77b4",                # blue
    Phases.MEM_FAIL_DONT_LEARN_TEACHER: "#d62728", # red
    Phases.MEM_FAIL_LEARN_TEACHER: "#ff7f0e",      # orange
}

phase_labels = {
    Phases.GENERALIZATION: "Generalization",
    Phases.GEN_LEAKAGE: "Gen Leakage",
    Phases.MEM_FAIL_DONT_LEARN_TEACHER: "Mem Fails (Learn Teacher)",
    Phases.MEM_FAIL_LEARN_TEACHER: "Mem Fails",
}

from matplotlib.colors import ListedColormap, BoundaryNorm

# Create 2D grid of phases
phase_grid = a.pivot(index='rho', columns='alpha', values='phase')
X = phase_grid.columns.values
Y = phase_grid.index.values
Z = phase_grid.values

# Define colormap
cmap = ListedColormap([
    '#2ca02c',   # GENERALIZATION (green)
    '#1f77b4',   # GEN_LEAKAGE (blue)
    '#d62728',   # MEM_FAIL_DONT_LEARN_TEACHER (red)
    '#ff7f0e',   # MEM_FAIL_LEARN_TEACHER (orange)
    '#7f7f7f',   # NADA (grey)
])
bounds = [0, 1, 2, 3, 4, 5]
norm = BoundaryNorm(bounds, cmap.N)

# Plot
fig, axes = plt.subplots(1,4,figsize=(22/2*0.8+22/2*0.8/3, 4*2*0.8/2))
ax = axes[-1]
pcm = ax.pcolor(X, Y, Z, cmap=cmap, norm=norm, shading='auto')
ax.axvline(alpha_t_label,color=alpha_t_label_color, linestyle='--')
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$\rho$')

mask = np.logical_or(Z == Phases.GENERALIZATION, Z == Phases.MEM_FAIL_LEARN_TEACHER).astype(int)

XX, YY = np.meshgrid(X, Y)

ax.set_title('phenomenology')


from matplotlib.colors import LogNorm


a = df_orig[df_orig.data_input_dim.isin([200,1600])]
a = a[np.logical_or(a.student_train_frac == 0.8,np.isclose(a.student_train_frac,0.80306122))]
a = a[a.student_shuffle_input_intra_class == False]

metrics = [ 'teacher_train_acc', 'student_train_acc', 'student_test_acc',  ]
labels = ['teacher memorization acc', 'student acc on train\n memorization set', 'student acc on test\n memorization set',  ]

for d, g in a.groupby('data_input_dim'):

    grouped = g.groupby('data_num_samples_per_class')

    mean_df = grouped.mean(numeric_only=True).reset_index()
    sem_df = grouped.sem(numeric_only=True).reset_index()

    x = mean_df['data_num_samples_per_class'] * 2 / d

    for ax, metric, label in zip(axes, metrics, labels):
        y = mean_df[metric]
        yerr = sem_df[metric]
        if d > 1000:
            ax.errorbar(x[::4], y[::4], yerr=yerr[::4], capsize=3, label=d, c=metric_styles[metric]['color'])
        else:
            ax.errorbar(x, y, yerr=yerr, capsize=3, label=d, c='grey')
        ax.set_xlabel(r'$\alpha$')
        ax.set_title(metric_styles[metric]['label'])

        if 'mse' in metric:
            ax.set_yscale('log')
        else:
            ax.set_ylim(0.5,1.05)

for ax in axes[:3]:
    ax.legend(title='$d$')

axes[0].axvline(alpha_t_label, color=alpha_t_label_color, linestyle='--')
axes[1].axvline(alpha_s_label, color=alpha_s_label_color, linestyle='--')
axes[2].axvline(alpha_s_id,color=alpha_s_id_color, linestyle='--',)

axes[0].set_ylabel('accuracy')
    
for ax, name in zip(axes,["A.1","A.2","A.3","B"]):
    ax.text(0.05, 1.12 if "C" not in name else 1.22, f'({name})', transform=ax.transAxes, fontsize=12, va='top', ha='right',fontweight='bold')


axes[-1].plot([],[],label=alpha_t_label_name, color=alpha_t_label_color, linestyle='--')

axes[3].axvline(alpha_s_label, color=alpha_s_label_color, linestyle='--', label=alpha_s_label_name)
axes[-1].plot([],[],label=alpha_s_id_name, color=alpha_s_id_color, linestyle='--')
axes[-1].legend()
y = np.linspace(1,2, 1000)
y = y[y != 1]  # Remove y=1 to avoid division by zero
f_y = 1 / y
axes[-1].plot(y, f_y, color=alpha_s_id_color, linestyle='--', label=r'$\frac{1}{\rho}$')

axes[-1].set_ylim(0.05,0.95)
for ax in axes:
    ax.set_xlim(x_min,x_max)
plt.tight_layout()
#plt.title('Knowledge Distillation for Logistic Regression')
plt.savefig(FIGURE_DIR / 'logistic_regression_phase_diagram_c=2.pdf', bbox_inches='tight')
plt.show()