In [2]:
from config import *
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

In [3]:
df = load_result_table(project=f"{wandb_username}/fig_1a_linear_regression")
df = df[df.state == 'finished']
df_orig = df.copy()

In [None]:
constraints = {
    'data_num_labels': 2,
    'data_input_dim': 1000,
}

fig = plt.figure(figsize=(12, 3))
gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 0.05], wspace=0.3)

axes = [fig.add_subplot(gs[i]) for i in range(3)]
cbar_ax = fig.add_subplot(gs[3]) 

a = filter(df, constraints)
a = a.replace('NaN', None)

acc_cmap = plt.get_cmap('Spectral')

i = 0
for temp, g in a.groupby('student_temperature'):
    if temp not in [0.5, 1.0, 10.0]:
        continue
    
    g = g.groupby(['student_train_frac', 'data_num_samples_per_class']).mean(numeric_only=True).reset_index()
    
    x_vals = sorted(g['data_num_samples_per_class'].unique())
    y_vals = sorted(g['student_train_frac'].unique())

    X, Y = np.meshgrid(
        np.array(x_vals) * constraints['data_num_labels'] / constraints['data_input_dim'] / 2,
        y_vals
    )

    Z_train = g.pivot(index='student_train_frac', columns='data_num_samples_per_class', values='student_test_acc').loc[y_vals, x_vals].values

    pc = axes[i].pcolormesh(X, Y, Z_train, cmap=acc_cmap, shading='auto', vmin=0, vmax=1)
    axes[i].set_title(r"$\tau"+f"={temp}$")
    axes[i].set_xlabel(r"$\alpha$")
    axes[i].set_xlim(0.05, 1.5)
    i += 1

axes[0].set_ylabel(r"$\rho$")
fig.colorbar(pc, cax=cbar_ax, label="Test Accuracy")
fig.savefig(FIGURE_DIR / "multiclass_logistic_regression_different_temp.pdf", dpi=300, bbox_inches='tight')